summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--nova/api/ec2/__init__.py2
-rw-r--r--nova/api/ec2/cloud.py207
-rw-r--r--nova/auth/manager.py2
-rw-r--r--nova/compute/manager.py5
-rw-r--r--nova/db/api.py70
-rw-r--r--nova/db/sqlalchemy/api.py129
-rw-r--r--nova/db/sqlalchemy/models.py73
-rw-r--r--nova/db/sqlalchemy/session.py9
-rw-r--r--nova/exception.py3
-rw-r--r--nova/network/manager.py1
-rw-r--r--nova/process.py2
-rw-r--r--nova/test.py2
-rw-r--r--nova/tests/api_unittest.py188
-rw-r--r--nova/tests/objectstore_unittest.py2
-rw-r--r--nova/tests/virt_unittest.py142
-rw-r--r--nova/virt/interfaces.template1
-rw-r--r--nova/virt/libvirt.qemu.xml.template4
-rw-r--r--nova/virt/libvirt.uml.xml.template4
-rw-r--r--nova/virt/libvirt_conn.py176
-rw-r--r--run_tests.py1
20 files changed, 988 insertions, 35 deletions
diff --git a/nova/api/ec2/__init__.py b/nova/api/ec2/__init__.py
index 7a958f841..8111ef023 100644
--- a/nova/api/ec2/__init__.py
+++ b/nova/api/ec2/__init__.py
@@ -142,6 +142,8 @@ class Authorizer(wsgi.Middleware):
'CreateKeyPair': ['all'],
'DeleteKeyPair': ['all'],
'DescribeSecurityGroups': ['all'],
+ 'AuthorizeSecurityGroupIngress': ['netadmin'],
+ 'RevokeSecurityGroupIngress': ['netadmin'],
'CreateSecurityGroup': ['netadmin'],
'DeleteSecurityGroup': ['netadmin'],
'GetConsoleOutput': ['projectmanager', 'sysadmin'],
diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py
index d3f54367b..ca3f71036 100644
--- a/nova/api/ec2/cloud.py
+++ b/nova/api/ec2/cloud.py
@@ -28,6 +28,8 @@ import logging
import os
import time
+import IPy
+
from nova import crypto
from nova import db
from nova import exception
@@ -42,6 +44,7 @@ from nova.api.ec2 import images
FLAGS = flags.FLAGS
flags.DECLARE('storage_availability_zone', 'nova.volume.manager')
+InvalidInputException = exception.InvalidInputException
class QuotaError(exception.ApiError):
"""Quota Exceeeded"""
@@ -112,6 +115,15 @@ class CloudController(object):
result[key] = [line]
return result
+ def _trigger_refresh_security_group(self, security_group):
+ nodes = set([instance['host'] for instance in security_group.instances
+ if instance['host'] is not None])
+ for node in nodes:
+ rpc.call('%s.%s' % (FLAGS.compute_topic, node),
+ { "method": "refresh_security_group",
+ "args": { "context": None,
+ "security_group_id": security_group.id}})
+
def get_metadata(self, address):
instance_ref = db.fixed_ip_get_instance(None, address)
if instance_ref is None:
@@ -231,18 +243,171 @@ class CloudController(object):
pass
return True
- def describe_security_groups(self, context, group_names, **kwargs):
- groups = {'securityGroupSet': []}
+ def describe_security_groups(self, context, group_name=None, **kwargs):
+ self._ensure_default_security_group(context)
+ if context.user.is_admin():
+ groups = db.security_group_get_all(context)
+ else:
+ groups = db.security_group_get_by_project(context,
+ context.project.id)
+ groups = [self._format_security_group(context, g) for g in groups]
+ if not group_name is None:
+ groups = [g for g in groups if g.name in group_name]
+
+ return {'securityGroupInfo': groups }
+
+ def _format_security_group(self, context, group):
+ g = {}
+ g['groupDescription'] = group.description
+ g['groupName'] = group.name
+ g['ownerId'] = context.user.id
+ g['ipPermissions'] = []
+ for rule in group.rules:
+ r = {}
+ r['ipProtocol'] = rule.protocol
+ r['fromPort'] = rule.from_port
+ r['toPort'] = rule.to_port
+ r['groups'] = []
+ r['ipRanges'] = []
+ if rule.group_id:
+ source_group = db.security_group_get(context, rule.group_id)
+ r['groups'] += [{'groupName': source_group.name,
+ 'userId': source_group.user_id}]
+ else:
+ r['ipRanges'] += [{'cidrIp': rule.cidr}]
+ g['ipPermissions'] += [r]
+ return g
+
+
+ def _authorize_revoke_rule_args_to_dict(self, context,
+ to_port=None, from_port=None,
+ ip_protocol=None, cidr_ip=None,
+ user_id=None,
+ source_security_group_name=None,
+ source_security_group_owner_id=None):
+
+ values = {}
+
+ if source_security_group_name:
+ source_project_id = self._get_source_project_id(context,
+ source_security_group_owner_id)
+
+ source_security_group = \
+ db.security_group_get_by_name(context,
+ source_project_id,
+ source_security_group_name)
+ values['group_id'] = source_security_group.id
+ elif cidr_ip:
+ # If this fails, it throws an exception. This is what we want.
+ IPy.IP(cidr_ip)
+ values['cidr'] = cidr_ip
+ else:
+ values['cidr'] = '0.0.0.0/0'
+
+ if ip_protocol and from_port and to_port:
+ from_port = int(from_port)
+ to_port = int(to_port)
+ ip_protocol = str(ip_protocol)
+
+ if ip_protocol.upper() not in ['TCP','UDP','ICMP']:
+ raise InvalidInputException('%s is not a valid ipProtocol' %
+ (ip_protocol,))
+ if ((min(from_port, to_port) < -1) or
+ (max(from_port, to_port) > 65535)):
+ raise InvalidInputException('Invalid port range')
+
+ values['protocol'] = ip_protocol
+ values['from_port'] = from_port
+ values['to_port'] = to_port
+ else:
+ # If cidr based filtering, protocol and ports are mandatory
+ if 'cidr' in values:
+ return None
+
+ return values
+
+ def revoke_security_group_ingress(self, context, group_name, **kwargs):
+ self._ensure_default_security_group(context)
+ security_group = db.security_group_get_by_name(context,
+ context.project.id,
+ group_name)
+
+ criteria = self._authorize_revoke_rule_args_to_dict(context, **kwargs)
+
+ for rule in security_group.rules:
+ for (k,v) in criteria.iteritems():
+ if getattr(rule, k, False) != v:
+ break
+ # If we make it here, we have a match
+ db.security_group_rule_destroy(context, rule.id)
+
+ self._trigger_refresh_security_group(security_group)
+
+ return True
+
+ # TODO(soren): Dupe detection. Adding the same rule twice actually
+ # adds the same rule twice to the rule set, which is
+ # pointless.
+ # TODO(soren): This has only been tested with Boto as the client.
+ # Unfortunately, it seems Boto is using an old API
+ # for these operations, so support for newer API versions
+ # is sketchy.
+ def authorize_security_group_ingress(self, context, group_name, **kwargs):
+ self._ensure_default_security_group(context)
+ security_group = db.security_group_get_by_name(context,
+ context.project.id,
+ group_name)
+
+ values = self._authorize_revoke_rule_args_to_dict(context, **kwargs)
+ values['parent_group_id'] = security_group.id
- # Stubbed for now to unblock other things.
- return groups
+ security_group_rule = db.security_group_rule_create(context, values)
+
+ self._trigger_refresh_security_group(security_group)
- def create_security_group(self, context, group_name, **kwargs):
return True
+ def _get_source_project_id(self, context, source_security_group_owner_id):
+ if source_security_group_owner_id:
+ # Parse user:project for source group.
+ source_parts = source_security_group_owner_id.split(':')
+
+ # If no project name specified, assume it's same as user name.
+ # Since we're looking up by project name, the user name is not
+ # used here. It's only read for EC2 API compatibility.
+ if len(source_parts) == 2:
+ source_project_id = source_parts[1]
+ else:
+ source_project_id = source_parts[0]
+ else:
+ source_project_id = context.project.id
+
+ return source_project_id
+
+
+ def create_security_group(self, context, group_name, group_description):
+ self._ensure_default_security_group(context)
+ if db.securitygroup_exists(context, context.project.id, group_name):
+ raise exception.ApiError('group %s already exists' % group_name)
+
+ group = {'user_id' : context.user.id,
+ 'project_id': context.project.id,
+ 'name': group_name,
+ 'description': group_description}
+ group_ref = db.security_group_create(context, group)
+
+ return {'securityGroupSet': [self._format_security_group(context,
+ group_ref)]}
+
+
def delete_security_group(self, context, group_name, **kwargs):
+ security_group = db.security_group_get_by_name(context,
+ context.project.id,
+ group_name)
+ db.security_group_destroy(context, security_group.id)
return True
+
def get_console_output(self, context, instance_id, **kwargs):
# instance_id is passed in as a list of instances
instance_ref = db.instance_get_by_ec2_id(context, instance_id[0])
@@ -530,6 +695,18 @@ class CloudController(object):
"project_id": context.project.id}})
return db.queue_get_for(context, FLAGS.network_topic, host)
+ def _ensure_default_security_group(self, context):
+ try:
+ db.security_group_get_by_name(context,
+ context.project.id,
+ 'default')
+ except exception.NotFound:
+ values = { 'name' : 'default',
+ 'description' : 'default',
+ 'user_id' : context.user.id,
+ 'project_id' : context.project.id }
+ group = db.security_group_create({}, values)
+
def run_instances(self, context, **kwargs):
instance_type = kwargs.get('instance_type', 'm1.small')
if instance_type not in INSTANCE_TYPES:
@@ -577,8 +754,17 @@ class CloudController(object):
kwargs['key_name'])
key_data = key_pair_ref['public_key']
- # TODO: Get the real security group of launch in here
- security_group = "default"
+ security_group_arg = kwargs.get('security_group', ["default"])
+ if not type(security_group_arg) is list:
+ security_group_arg = [security_group_arg]
+
+ security_groups = []
+ self._ensure_default_security_group(context)
+ for security_group_name in security_group_arg:
+ group = db.security_group_get_by_name(context,
+ context.project.id,
+ security_group_name)
+ security_groups.append(group['id'])
reservation_id = utils.generate_uid('r')
base_options = {}
@@ -592,7 +778,8 @@ class CloudController(object):
base_options['user_id'] = context.user.id
base_options['project_id'] = context.project.id
base_options['user_data'] = kwargs.get('user_data', '')
- base_options['security_group'] = security_group
+
+ type_data = INSTANCE_TYPES[instance_type]
base_options['instance_type'] = instance_type
base_options['display_name'] = kwargs.get('display_name')
base_options['display_description'] = kwargs.get('display_description')
@@ -606,6 +793,10 @@ class CloudController(object):
instance_ref = db.instance_create(context, base_options)
inst_id = instance_ref['id']
+ for security_group_id in security_groups:
+ db.instance_add_security_group(context, inst_id,
+ security_group_id)
+
inst = {}
inst['mac_address'] = utils.generate_mac()
inst['launch_index'] = num
diff --git a/nova/auth/manager.py b/nova/auth/manager.py
index 0bc12c80f..e2e035d37 100644
--- a/nova/auth/manager.py
+++ b/nova/auth/manager.py
@@ -490,6 +490,7 @@ class AuthManager(object):
except:
drv.delete_project(project.id)
raise
+
return project
def modify_project(self, project, manager_user=None, description=None):
@@ -565,6 +566,7 @@ class AuthManager(object):
except:
logging.exception('Could not destroy network for %s',
project)
+
with self.driver() as drv:
drv.delete_project(Project.safe_id(project))
diff --git a/nova/compute/manager.py b/nova/compute/manager.py
index f370ede8b..02ac3cb4c 100644
--- a/nova/compute/manager.py
+++ b/nova/compute/manager.py
@@ -64,6 +64,11 @@ class ComputeManager(manager.Manager):
@defer.inlineCallbacks
@exception.wrap_exception
+ def refresh_security_group(self, context, security_group_id, **_kwargs):
+ yield self.driver.refresh_security_group(security_group_id)
+
+ @defer.inlineCallbacks
+ @exception.wrap_exception
def run_instance(self, context, instance_id, **_kwargs):
"""Launch a new instance with specified options."""
instance_ref = self.db.instance_get(context, instance_id)
diff --git a/nova/db/api.py b/nova/db/api.py
index b68a0fe8f..9ce3dfb2a 100644
--- a/nova/db/api.py
+++ b/nova/db/api.py
@@ -304,6 +304,11 @@ def instance_update(context, instance_id, values):
return IMPL.instance_update(context, instance_id, values)
+def instance_add_security_group(context, instance_id, security_group_id):
+ """Associate the given security group with the given instance"""
+ return IMPL.instance_add_security_group(context, instance_id, security_group_id)
+
+
###################
@@ -565,3 +570,68 @@ def volume_update(context, volume_id, values):
"""
return IMPL.volume_update(context, volume_id, values)
+
+
+####################
+
+
+def security_group_get_all(context):
+ """Get all security groups"""
+ return IMPL.security_group_get_all(context)
+
+
+def security_group_get(context, security_group_id):
+ """Get security group by its internal id"""
+ return IMPL.security_group_get(context, security_group_id)
+
+
+def security_group_get_by_name(context, project_id, group_name):
+ """Returns a security group with the specified name from a project"""
+ return IMPL.security_group_get_by_name(context, project_id, group_name)
+
+
+def security_group_get_by_project(context, project_id):
+ """Get all security groups belonging to a project"""
+ return IMPL.security_group_get_by_project(context, project_id)
+
+
+def security_group_get_by_instance(context, instance_id):
+ """Get security groups to which the instance is assigned"""
+ return IMPL.security_group_get_by_instance(context, instance_id)
+
+
+def securitygroup_exists(context, project_id, group_name):
+ """Indicates if a group name exists in a project"""
+ return IMPL.security_group_exists(context, project_id, group_name)
+
+
+def security_group_create(context, values):
+ """Create a new security group"""
+ return IMPL.security_group_create(context, values)
+
+
+def security_group_destroy(context, security_group_id):
+ """Deletes a security group"""
+ return IMPL.security_group_destroy(context, security_group_id)
+
+
+def security_group_destroy_all(context):
+ """Deletes a security group"""
+ return IMPL.security_group_destroy_all(context)
+
+
+####################
+
+
+def security_group_rule_create(context, values):
+ """Create a new security group"""
+ return IMPL.security_group_rule_create(context, values)
+
+
+def security_group_rule_get_by_security_group(context, security_group_id):
+ """Get all rules for a a given security group"""
+ return IMPL.security_group_rule_get_by_security_group(context, security_group_id)
+
+def security_group_rule_destroy(context, security_group_rule_id):
+ """Deletes a security group rule"""
+ return IMPL.security_group_rule_destroy(context, security_group_rule_id)
diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py
index 9c3caf9af..013e8ab16 100644
--- a/nova/db/sqlalchemy/api.py
+++ b/nova/db/sqlalchemy/api.py
@@ -29,8 +29,10 @@ from nova.db.sqlalchemy import models
from nova.db.sqlalchemy.session import get_session
from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
+from sqlalchemy.orm import eagerload
from sqlalchemy.orm import joinedload_all
-from sqlalchemy.sql import exists, func
+from sqlalchemy.sql import exists
+from sqlalchemy.sql import func
FLAGS = flags.FLAGS
@@ -410,7 +412,8 @@ def instance_destroy(_context, instance_id):
def instance_get(context, instance_id):
- return models.Instance.find(instance_id, deleted=_deleted(context))
+ return models.Instance().find(instance_id, deleted=_deleted(context),
+ options=eagerload('security_groups'))
def instance_get_all(context):
@@ -511,6 +514,17 @@ def instance_update(_context, instance_id, values):
instance_ref.save(session=session)
+def instance_add_security_group(context, instance_id, security_group_id):
+ """Associate the given security group with the given instance"""
+ session = get_session()
+ with session.begin():
+ instance_ref = models.Instance.find(instance_id, session=session)
+ security_group_ref = models.SecurityGroup.find(security_group_id,
+ session=session)
+ instance_ref.security_groups += [security_group_ref]
+ instance_ref.save(session=session)
+
+
###################
@@ -925,3 +939,114 @@ def volume_update(_context, volume_id, values):
for (key, value) in values.iteritems():
volume_ref[key] = value
volume_ref.save(session=session)
+
+
+###################
+
+
+def security_group_get_all(_context):
+ session = get_session()
+ return session.query(models.SecurityGroup
+ ).options(eagerload('rules')
+ ).filter_by(deleted=False
+ ).all()
+
+
+def security_group_get(_context, security_group_id):
+ session = get_session()
+ result = session.query(models.SecurityGroup
+ ).options(eagerload('rules')
+ ).get(security_group_id)
+ if not result:
+ raise exception.NotFound("No secuity group with id %s" %
+ security_group_id)
+ return result
+
+
+def security_group_get_by_name(context, project_id, group_name):
+ session = get_session()
+ group_ref = session.query(models.SecurityGroup
+ ).options(eagerload('rules')
+ ).options(eagerload('instances')
+ ).filter_by(project_id=project_id
+ ).filter_by(name=group_name
+ ).filter_by(deleted=False
+ ).first()
+ if not group_ref:
+ raise exception.NotFound(
+ 'No security group named %s for project: %s' \
+ % (group_name, project_id))
+ return group_ref
+
+
+def security_group_get_by_project(_context, project_id):
+ session = get_session()
+ return session.query(models.SecurityGroup
+ ).options(eagerload('rules')
+ ).filter_by(project_id=project_id
+ ).filter_by(deleted=False
+ ).all()
+
+
+def security_group_get_by_instance(_context, instance_id):
+ session = get_session()
+ with session.begin():
+ return session.query(models.Instance
+ ).join(models.Instance.security_groups
+ ).filter_by(deleted=False
+ ).all()
+
+
+def security_group_exists(_context, project_id, group_name):
+ try:
+ group = security_group_get_by_name(_context, project_id, group_name)
+ return group != None
+ except exception.NotFound:
+ return False
+
+
+def security_group_create(_context, values):
+ security_group_ref = models.SecurityGroup()
+ # FIXME(devcamcar): Unless I do this, rules fails with lazy load exception
+ # once save() is called. This will get cleaned up in next orm pass.
+ security_group_ref.rules
+ for (key, value) in values.iteritems():
+ security_group_ref[key] = value
+ security_group_ref.save()
+ return security_group_ref
+
+
+def security_group_destroy(_context, security_group_id):
+ session = get_session()
+ with session.begin():
+ # TODO(vish): do we have to use sql here?
+ session.execute('update security_group set deleted=1 where id=:id',
+ {'id': security_group_id})
+ session.execute('update security_group_rules set deleted=1 '
+ 'where group_id=:id',
+ {'id': security_group_id})
+
+def security_group_destroy_all(_context):
+ session = get_session()
+ with session.begin():
+ # TODO(vish): do we have to use sql here?
+ session.execute('update security_group set deleted=1')
+ session.execute('update security_group_rules set deleted=1')
+
+###################
+
+
+def security_group_rule_create(_context, values):
+ security_group_rule_ref = models.SecurityGroupIngressRule()
+ for (key, value) in values.iteritems():
+ security_group_rule_ref[key] = value
+ security_group_rule_ref.save()
+ return security_group_rule_ref
+
+def security_group_rule_destroy(_context, security_group_rule_id):
+ session = get_session()
+ with session.begin():
+ model = models.SecurityGroupIngressRule
+ security_group_rule = model.find(security_group_rule_id,
+ session=session)
+ security_group_rule.delete(session=session)
diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py
index 01e58b05e..d4caf0b52 100644
--- a/nova/db/sqlalchemy/models.py
+++ b/nova/db/sqlalchemy/models.py
@@ -25,7 +25,7 @@ import datetime
# TODO(vish): clean up these imports
from sqlalchemy.orm import relationship, backref, exc, object_mapper
-from sqlalchemy import Column, Integer, String
+from sqlalchemy import Column, Integer, String, Table
from sqlalchemy import ForeignKey, DateTime, Boolean, Text
from sqlalchemy.ext.declarative import declarative_base
@@ -69,15 +69,19 @@ class NovaBase(object):
).count()
@classmethod
- def find(cls, obj_id, session=None, deleted=False):
+ def find(cls, obj_id, session=None, deleted=False, options=None):
"""Find object by id"""
if not session:
session = get_session()
try:
- return session.query(cls
+ query = session.query(cls
).filter_by(id=obj_id
- ).filter_by(deleted=deleted
- ).one()
+ ).filter_by(deleted=deleted)
+
+ if options:
+ query = query.options(options)
+
+ return query.one()
except exc.NoResultFound:
new_exc = exception.NotFound("No model for id %s" % obj_id)
raise new_exc.__class__, new_exc, sys.exc_info()[2]
@@ -230,7 +234,6 @@ class Instance(BASE, NovaBase):
launch_index = Column(Integer)
key_name = Column(String(255))
key_data = Column(Text)
- security_group = Column(String(255))
state = Column(Integer)
state_description = Column(String(255))
@@ -337,13 +340,65 @@ class ExportDevice(BASE, NovaBase):
uselist=False))
+security_group_instance_association = Table('security_group_instance_association',
+ BASE.metadata,
+ Column('security_group_id', Integer,
+ ForeignKey('security_group.id')),
+ Column('instance_id', Integer,
+ ForeignKey('instances.id')))
+
+class SecurityGroup(BASE, NovaBase):
+ """Represents a security group"""
+ __tablename__ = 'security_group'
+ id = Column(Integer, primary_key=True)
+
+ name = Column(String(255))
+ description = Column(String(255))
+ user_id = Column(String(255))
+ project_id = Column(String(255))
+
+ instances = relationship(Instance,
+ secondary=security_group_instance_association,
+ backref='security_groups')
+
+ @property
+ def user(self):
+ return auth.manager.AuthManager().get_user(self.user_id)
+
+ @property
+ def project(self):
+ return auth.manager.AuthManager().get_project(self.project_id)
+
+
+class SecurityGroupIngressRule(BASE, NovaBase):
+ """Represents a rule in a security group"""
+ __tablename__ = 'security_group_rules'
+ id = Column(Integer, primary_key=True)
+
+ parent_group_id = Column(Integer, ForeignKey('security_group.id'))
+ parent_group = relationship("SecurityGroup", backref="rules",
+ foreign_keys=parent_group_id,
+ primaryjoin=parent_group_id==SecurityGroup.id)
+
+ protocol = Column(String(5)) # "tcp", "udp", or "icmp"
+ from_port = Column(Integer)
+ to_port = Column(Integer)
+ cidr = Column(String(255))
+
+ # Note: This is not the parent SecurityGroup. It's SecurityGroup we're
+ # granting access for.
+ group_id = Column(Integer, ForeignKey('security_group.id'))
+
+
class KeyPair(BASE, NovaBase):
"""Represents a public key pair for ssh"""
__tablename__ = 'key_pairs'
id = Column(Integer, primary_key=True)
+
name = Column(String(255))
user_id = Column(String(255))
+ project_id = Column(String(255))
fingerprint = Column(String(255))
public_key = Column(Text)
@@ -484,9 +539,9 @@ class FloatingIp(BASE, NovaBase):
def register_models():
"""Register Models and create metadata"""
from sqlalchemy import create_engine
- models = (Service, Instance, Volume, ExportDevice,
- FixedIp, FloatingIp, Network, NetworkIndex,
- AuthToken) # , Image, Host)
+ models = (Service, Instance, Volume, ExportDevice, FixedIp, FloatingIp,
+ Network, NetworkIndex, SecurityGroup, SecurityGroupIngressRule,
+ AuthToken) # , Image, Host
engine = create_engine(FLAGS.sql_connection, echo=False)
for model in models:
model.metadata.create_all(engine)
diff --git a/nova/db/sqlalchemy/session.py b/nova/db/sqlalchemy/session.py
index 69a205378..826754f6a 100644
--- a/nova/db/sqlalchemy/session.py
+++ b/nova/db/sqlalchemy/session.py
@@ -36,7 +36,8 @@ def get_session(autocommit=True, expire_on_commit=False):
if not _MAKER:
if not _ENGINE:
_ENGINE = create_engine(FLAGS.sql_connection, echo=False)
- _MAKER = sessionmaker(bind=_ENGINE,
- autocommit=autocommit,
- expire_on_commit=expire_on_commit)
- return _MAKER()
+ _MAKER = (sessionmaker(bind=_ENGINE,
+ autocommit=autocommit,
+ expire_on_commit=expire_on_commit))
+ session = _MAKER()
+ return session
diff --git a/nova/exception.py b/nova/exception.py
index b8894758f..f157fab2d 100644
--- a/nova/exception.py
+++ b/nova/exception.py
@@ -69,6 +69,9 @@ class NotEmpty(Error):
class Invalid(Error):
pass
+class InvalidInputException(Error):
+ pass
+
def wrap_exception(f):
def _wrap(*args, **kw):
diff --git a/nova/network/manager.py b/nova/network/manager.py
index a7126ea4f..8f1924ac9 100644
--- a/nova/network/manager.py
+++ b/nova/network/manager.py
@@ -201,7 +201,6 @@ class FlatManager(NetworkManager):
# in the datastore?
net = {}
net['injected'] = True
- net['network_str'] = FLAGS.flat_network_network
net['netmask'] = FLAGS.flat_network_netmask
net['bridge'] = FLAGS.flat_network_bridge
net['gateway'] = FLAGS.flat_network_gateway
diff --git a/nova/process.py b/nova/process.py
index b3cad894b..13cb90e82 100644
--- a/nova/process.py
+++ b/nova/process.py
@@ -113,7 +113,7 @@ class BackRelayWithInput(protocol.ProcessProtocol):
if self.started_deferred:
self.started_deferred.callback(self)
if self.process_input:
- self.transport.write(self.process_input)
+ self.transport.write(str(self.process_input))
self.transport.closeStdin()
def get_process_output(executable, args=None, env=None, path=None,
diff --git a/nova/test.py b/nova/test.py
index 1f4b33272..08e1dea2d 100644
--- a/nova/test.py
+++ b/nova/test.py
@@ -31,6 +31,7 @@ from tornado import ioloop
from twisted.internet import defer
from twisted.trial import unittest
+from nova import db
from nova import fakerabbit
from nova import flags
from nova import rpc
@@ -83,6 +84,7 @@ class TrialTestCase(unittest.TestCase):
if FLAGS.fake_rabbit:
fakerabbit.reset_all()
+ db.security_group_destroy_all(None)
super(TrialTestCase, self).tearDown()
diff --git a/nova/tests/api_unittest.py b/nova/tests/api_unittest.py
index c040cdad3..7ab27e000 100644
--- a/nova/tests/api_unittest.py
+++ b/nova/tests/api_unittest.py
@@ -91,6 +91,9 @@ class ApiEc2TestCase(test.BaseTestCase):
self.host = '127.0.0.1'
self.app = api.API()
+
+ def expect_http(self, host=None, is_secure=False):
+ """Returns a new EC2 connection"""
self.ec2 = boto.connect_ec2(
aws_access_key_id='fake',
aws_secret_access_key='fake',
@@ -100,9 +103,6 @@ class ApiEc2TestCase(test.BaseTestCase):
path='/services/Cloud')
self.mox.StubOutWithMock(self.ec2, 'new_http_connection')
-
- def expect_http(self, host=None, is_secure=False):
- """Returns a new EC2 connection"""
http = FakeHttplibConnection(
self.app, '%s:8773' % (self.host), False)
# pylint: disable-msg=E1103
@@ -138,3 +138,185 @@ class ApiEc2TestCase(test.BaseTestCase):
self.assertEquals(len(results), 1)
self.manager.delete_project(project)
self.manager.delete_user(user)
+
+ def test_get_all_security_groups(self):
+ """Test that we can retrieve security groups"""
+ self.expect_http()
+ self.mox.ReplayAll()
+ user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
+ project = self.manager.create_project('fake', 'fake', 'fake')
+
+ rv = self.ec2.get_all_security_groups()
+
+ self.assertEquals(len(rv), 1)
+ self.assertEquals(rv[0].name, 'default')
+
+ self.manager.delete_project(project)
+ self.manager.delete_user(user)
+
+ def test_create_delete_security_group(self):
+ """Test that we can create a security group"""
+ self.expect_http()
+ self.mox.ReplayAll()
+ user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
+ project = self.manager.create_project('fake', 'fake', 'fake')
+
+ # At the moment, you need both of these to actually be netadmin
+ self.manager.add_role('fake', 'netadmin')
+ project.add_role('fake', 'netadmin')
+
+ security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
+ for x in range(random.randint(4, 8)))
+
+ self.ec2.create_security_group(security_group_name, 'test group')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ rv = self.ec2.get_all_security_groups()
+ self.assertEquals(len(rv), 2)
+ self.assertTrue(security_group_name in [group.name for group in rv])
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ self.ec2.delete_security_group(security_group_name)
+
+ self.manager.delete_project(project)
+ self.manager.delete_user(user)
+
+ def test_authorize_revoke_security_group_cidr(self):
+ """
+ Test that we can add and remove CIDR based rules
+ to a security group
+ """
+ self.expect_http()
+ self.mox.ReplayAll()
+ user = self.manager.create_user('fake', 'fake', 'fake')
+ project = self.manager.create_project('fake', 'fake', 'fake')
+
+ # At the moment, you need both of these to actually be netadmin
+ self.manager.add_role('fake', 'netadmin')
+ project.add_role('fake', 'netadmin')
+
+ security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
+ for x in range(random.randint(4, 8)))
+
+ group = self.ec2.create_security_group(security_group_name, 'test group')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+ group.connection = self.ec2
+
+ group.authorize('tcp', 80, 81, '0.0.0.0/0')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ rv = self.ec2.get_all_security_groups()
+ # I don't bother checkng that we actually find it here,
+ # because the create/delete unit test further up should
+ # be good enough for that.
+ for group in rv:
+ if group.name == security_group_name:
+ self.assertEquals(len(group.rules), 1)
+ self.assertEquals(int(group.rules[0].from_port), 80)
+ self.assertEquals(int(group.rules[0].to_port), 81)
+ self.assertEquals(len(group.rules[0].grants), 1)
+ self.assertEquals(str(group.rules[0].grants[0]), '0.0.0.0/0')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+ group.connection = self.ec2
+
+ group.revoke('tcp', 80, 81, '0.0.0.0/0')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ self.ec2.delete_security_group(security_group_name)
+
+ self.expect_http()
+ self.mox.ReplayAll()
+ group.connection = self.ec2
+
+ rv = self.ec2.get_all_security_groups()
+
+ self.assertEqual(len(rv), 1)
+ self.assertEqual(rv[0].name, 'default')
+
+ self.manager.delete_project(project)
+ self.manager.delete_user(user)
+
+ return
+
+ def test_authorize_revoke_security_group_foreign_group(self):
+ """
+ Test that we can grant and revoke another security group access
+ to a security group
+ """
+ self.expect_http()
+ self.mox.ReplayAll()
+ user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
+ project = self.manager.create_project('fake', 'fake', 'fake')
+
+ # At the moment, you need both of these to actually be netadmin
+ self.manager.add_role('fake', 'netadmin')
+ project.add_role('fake', 'netadmin')
+
+ security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
+ for x in range(random.randint(4, 8)))
+ other_security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
+ for x in range(random.randint(4, 8)))
+
+ group = self.ec2.create_security_group(security_group_name, 'test group')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ other_group = self.ec2.create_security_group(other_security_group_name,
+ 'some other group')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+ group.connection = self.ec2
+
+ group.authorize(src_group=other_group)
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ rv = self.ec2.get_all_security_groups()
+
+ # I don't bother checkng that we actually find it here,
+ # because the create/delete unit test further up should
+ # be good enough for that.
+ for group in rv:
+ if group.name == security_group_name:
+ self.assertEquals(len(group.rules), 1)
+ self.assertEquals(len(group.rules[0].grants), 1)
+ self.assertEquals(str(group.rules[0].grants[0]),
+ '%s-%s' % (other_security_group_name, 'fake'))
+
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ rv = self.ec2.get_all_security_groups()
+
+ for group in rv:
+ if group.name == security_group_name:
+ self.expect_http()
+ self.mox.ReplayAll()
+ group.connection = self.ec2
+ group.revoke(src_group=other_group)
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ self.ec2.delete_security_group(security_group_name)
+
+ self.manager.delete_project(project)
+ self.manager.delete_user(user)
+
+ return
diff --git a/nova/tests/objectstore_unittest.py b/nova/tests/objectstore_unittest.py
index 5a599ff3a..1d6b9e826 100644
--- a/nova/tests/objectstore_unittest.py
+++ b/nova/tests/objectstore_unittest.py
@@ -191,7 +191,7 @@ class S3APITestCase(test.TrialTestCase):
"""Setup users, projects, and start a test server."""
super(S3APITestCase, self).setUp()
- FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver',
+ FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver'
FLAGS.buckets_path = os.path.join(OSS_TEMPDIR, 'buckets')
self.auth_manager = manager.AuthManager()
diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py
index 2aab16809..5e9505374 100644
--- a/nova/tests/virt_unittest.py
+++ b/nova/tests/virt_unittest.py
@@ -14,23 +14,31 @@
# License for the specific language governing permissions and limitations
# under the License.
+from xml.dom.minidom import parseString
+
+from nova import db
from nova import flags
from nova import test
+from nova.api.ec2 import cloud
from nova.virt import libvirt_conn
FLAGS = flags.FLAGS
class LibvirtConnTestCase(test.TrialTestCase):
- def test_get_uri_and_template(self):
+ def bitrot_test_get_uri_and_template(self):
class MockDataModel(object):
+ def __getitem__(self, name):
+ return self.datamodel[name]
+
def __init__(self):
self.datamodel = { 'name' : 'i-cafebabe',
'memory_kb' : '1024000',
'basepath' : '/some/path',
'bridge_name' : 'br100',
'mac_address' : '02:12:34:46:56:67',
- 'vcpus' : 2 }
+ 'vcpus' : 2,
+ 'project_id' : None }
type_uri_map = { 'qemu' : ('qemu:///system',
[lambda s: '<domain type=\'qemu\'>' in s,
@@ -53,7 +61,7 @@ class LibvirtConnTestCase(test.TrialTestCase):
self.assertEquals(uri, expected_uri)
for i, check in enumerate(checks):
- xml = conn.toXml(MockDataModel())
+ xml = conn.to_xml(MockDataModel())
self.assertTrue(check(xml), '%s failed check %d' % (xml, i))
# Deliberately not just assigning this string to FLAGS.libvirt_uri and
@@ -67,3 +75,131 @@ class LibvirtConnTestCase(test.TrialTestCase):
uri, template = conn.get_uri_and_template()
self.assertEquals(uri, testuri)
+
+class NWFilterTestCase(test.TrialTestCase):
+ def setUp(self):
+ super(NWFilterTestCase, self).setUp()
+
+ class Mock(object):
+ pass
+
+ self.context = Mock()
+ self.context.user = Mock()
+ self.context.user.id = 'fake'
+ self.context.user.is_superuser = lambda:True
+ self.context.project = Mock()
+ self.context.project.id = 'fake'
+
+ self.fake_libvirt_connection = Mock()
+
+ self.fw = libvirt_conn.NWFilterFirewall(self.fake_libvirt_connection)
+
+ def test_cidr_rule_nwfilter_xml(self):
+ cloud_controller = cloud.CloudController()
+ cloud_controller.create_security_group(self.context,
+ 'testgroup',
+ 'test group description')
+ cloud_controller.authorize_security_group_ingress(self.context,
+ 'testgroup',
+ from_port='80',
+ to_port='81',
+ ip_protocol='tcp',
+ cidr_ip='0.0.0.0/0')
+
+
+ security_group = db.security_group_get_by_name({}, 'fake', 'testgroup')
+
+ xml = self.fw.security_group_to_nwfilter_xml(security_group.id)
+
+ dom = parseString(xml)
+ self.assertEqual(dom.firstChild.tagName, 'filter')
+
+ rules = dom.getElementsByTagName('rule')
+ self.assertEqual(len(rules), 1)
+
+ # It's supposed to allow inbound traffic.
+ self.assertEqual(rules[0].getAttribute('action'), 'accept')
+ self.assertEqual(rules[0].getAttribute('direction'), 'in')
+
+ # Must be lower priority than the base filter (which blocks everything)
+ self.assertTrue(int(rules[0].getAttribute('priority')) < 1000)
+
+ ip_conditions = rules[0].getElementsByTagName('tcp')
+ self.assertEqual(len(ip_conditions), 1)
+ self.assertEqual(ip_conditions[0].getAttribute('srcipaddr'), '0.0.0.0/0')
+ self.assertEqual(ip_conditions[0].getAttribute('dstportstart'), '80')
+ self.assertEqual(ip_conditions[0].getAttribute('dstportend'), '81')
+
+
+ self.teardown_security_group()
+
+ def teardown_security_group(self):
+ cloud_controller = cloud.CloudController()
+ cloud_controller.delete_security_group(self.context, 'testgroup')
+
+
+ def setup_and_return_security_group(self):
+ cloud_controller = cloud.CloudController()
+ cloud_controller.create_security_group(self.context,
+ 'testgroup',
+ 'test group description')
+ cloud_controller.authorize_security_group_ingress(self.context,
+ 'testgroup',
+ from_port='80',
+ to_port='81',
+ ip_protocol='tcp',
+ cidr_ip='0.0.0.0/0')
+
+ return db.security_group_get_by_name({}, 'fake', 'testgroup')
+
+ def test_creates_base_rule_first(self):
+ # These come pre-defined by libvirt
+ self.defined_filters = ['no-mac-spoofing',
+ 'no-ip-spoofing',
+ 'no-arp-spoofing',
+ 'allow-dhcp-server']
+
+ self.recursive_depends = {}
+ for f in self.defined_filters:
+ self.recursive_depends[f] = []
+
+ def _filterDefineXMLMock(xml):
+ dom = parseString(xml)
+ name = dom.firstChild.getAttribute('name')
+ self.recursive_depends[name] = []
+ for f in dom.getElementsByTagName('filterref'):
+ ref = f.getAttribute('filter')
+ self.assertTrue(ref in self.defined_filters,
+ ('%s referenced filter that does ' +
+ 'not yet exist: %s') % (name, ref))
+ dependencies = [ref] + self.recursive_depends[ref]
+ self.recursive_depends[name] += dependencies
+
+ self.defined_filters.append(name)
+ return True
+
+ self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock
+
+ instance_ref = db.instance_create({}, {'user_id': 'fake',
+ 'project_id': 'fake'})
+ inst_id = instance_ref['id']
+
+ def _ensure_all_called(_):
+ instance_filter = 'nova-instance-%s' % instance_ref['str_id']
+ secgroup_filter = 'nova-secgroup-%s' % self.security_group['id']
+ for required in [secgroup_filter, 'allow-dhcp-server',
+ 'no-arp-spoofing', 'no-ip-spoofing',
+ 'no-mac-spoofing']:
+ self.assertTrue(required in self.recursive_depends[instance_filter],
+ "Instance's filter does not include %s" % required)
+
+ self.security_group = self.setup_and_return_security_group()
+
+ db.instance_add_security_group({}, inst_id, self.security_group.id)
+ instance = db.instance_get({}, inst_id)
+
+ d = self.fw.setup_nwfilters_for_instance(instance)
+ d.addCallback(_ensure_all_called)
+ d.addCallback(lambda _:self.teardown_security_group())
+
+ return d
diff --git a/nova/virt/interfaces.template b/nova/virt/interfaces.template
index 11df301f6..87b92b84a 100644
--- a/nova/virt/interfaces.template
+++ b/nova/virt/interfaces.template
@@ -10,7 +10,6 @@ auto eth0
iface eth0 inet static
address %(address)s
netmask %(netmask)s
- network %(network)s
broadcast %(broadcast)s
gateway %(gateway)s
dns-nameservers %(dns)s
diff --git a/nova/virt/libvirt.qemu.xml.template b/nova/virt/libvirt.qemu.xml.template
index 17bd79b7c..2538b1ade 100644
--- a/nova/virt/libvirt.qemu.xml.template
+++ b/nova/virt/libvirt.qemu.xml.template
@@ -20,6 +20,10 @@
<source bridge='%(bridge_name)s'/>
<mac address='%(mac_address)s'/>
<!-- <model type='virtio'/> CANT RUN virtio network right now -->
+ <filterref filter="nova-instance-%(name)s">
+ <parameter name="IP" value="%(ip_address)s" />
+ <parameter name="DHCPSERVER" value="%(dhcp_server)s" />
+ </filterref>
</interface>
<serial type="file">
<source path='%(basepath)s/console.log'/>
diff --git a/nova/virt/libvirt.uml.xml.template b/nova/virt/libvirt.uml.xml.template
index c039d6d90..bb8b47911 100644
--- a/nova/virt/libvirt.uml.xml.template
+++ b/nova/virt/libvirt.uml.xml.template
@@ -14,6 +14,10 @@
<interface type='bridge'>
<source bridge='%(bridge_name)s'/>
<mac address='%(mac_address)s'/>
+ <filterref filter="nova-instance-%(name)s">
+ <parameter name="IP" value="%(ip_address)s" />
+ <parameter name="DHCPSERVER" value="%(dhcp_server)s" />
+ </filterref>
</interface>
<console type="file">
<source path='%(basepath)s/console.log'/>
diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py
index d868e083c..c86f3ffb7 100644
--- a/nova/virt/libvirt_conn.py
+++ b/nova/virt/libvirt_conn.py
@@ -27,6 +27,7 @@ import shutil
from twisted.internet import defer
from twisted.internet import task
+from twisted.internet import threads
from nova import db
from nova import exception
@@ -214,6 +215,7 @@ class LibvirtConnection(object):
instance['id'],
power_state.NOSTATE,
'launching')
+ yield NWFilterFirewall(self._conn).setup_nwfilters_for_instance(instance)
yield self._create_image(instance, xml)
yield self._conn.createXML(xml, 0)
# TODO(termie): this should actually register
@@ -285,7 +287,6 @@ class LibvirtConnection(object):
address = db.instance_get_fixed_address(None, inst['id'])
with open(FLAGS.injected_network_template) as f:
net = f.read() % {'address': address,
- 'network': network_ref['network'],
'netmask': network_ref['netmask'],
'gateway': network_ref['gateway'],
'broadcast': network_ref['broadcast'],
@@ -317,6 +318,9 @@ class LibvirtConnection(object):
network = db.project_get_network(None, instance['project_id'])
# FIXME(vish): stick this in db
instance_type = instance_types.INSTANCE_TYPES[instance['instance_type']]
+ ip_address = db.instance_get_fixed_address({}, instance['id'])
+ # Assume that the gateway also acts as the dhcp server.
+ dhcp_server = network['gateway']
xml_info = {'type': FLAGS.libvirt_type,
'name': instance['name'],
'basepath': os.path.join(FLAGS.instances_path,
@@ -324,7 +328,9 @@ class LibvirtConnection(object):
'memory_kb': instance_type['memory_mb'] * 1024,
'vcpus': instance_type['vcpus'],
'bridge_name': network['bridge'],
- 'mac_address': instance['mac_address']}
+ 'mac_address': instance['mac_address'],
+ 'ip_address': ip_address,
+ 'dhcp_server': dhcp_server }
libvirt_xml = self.libvirt_xml % xml_info
logging.debug('instance %s: finished toXML method', instance['name'])
@@ -438,3 +444,169 @@ class LibvirtConnection(object):
"""
domain = self._conn.lookupByName(instance_name)
return domain.interfaceStats(interface)
+
+
+ def refresh_security_group(self, security_group_id):
+ fw = NWFilterFirewall(self._conn)
+ fw.ensure_security_group_filter(security_group_id)
+
+
+class NWFilterFirewall(object):
+ """
+ This class implements a network filtering mechanism versatile
+ enough for EC2 style Security Group filtering by leveraging
+ libvirt's nwfilter.
+
+ First, all instances get a filter ("nova-base-filter") applied.
+ This filter drops all incoming ipv4 and ipv6 connections.
+ Outgoing connections are never blocked.
+
+ Second, every security group maps to a nwfilter filter(*).
+ NWFilters can be updated at runtime and changes are applied
+ immediately, so changes to security groups can be applied at
+ runtime (as mandated by the spec).
+
+ Security group rules are named "nova-secgroup-<id>" where <id>
+ is the internal id of the security group. They're applied only on
+ hosts that have instances in the security group in question.
+
+ Updates to security groups are done by updating the data model
+ (in response to API calls) followed by a request sent to all
+ the nodes with instances in the security group to refresh the
+ security group.
+
+ Each instance has its own NWFilter, which references the above
+ mentioned security group NWFilters. This was done because
+ interfaces can only reference one filter while filters can
+ reference multiple other filters. This has the added benefit of
+ actually being able to add and remove security groups from an
+ instance at run time. This functionality is not exposed anywhere,
+ though.
+
+ Outstanding questions:
+
+ The name is unique, so would there be any good reason to sync
+ the uuid across the nodes (by assigning it from the datamodel)?
+
+
+ (*) This sentence brought to you by the redundancy department of
+ redundancy.
+ """
+
+ def __init__(self, get_connection):
+ self._conn = get_connection
+
+
+ nova_base_filter = '''<filter name='nova-base' chain='root'>
+ <uuid>26717364-50cf-42d1-8185-29bf893ab110</uuid>
+ <filterref filter='no-mac-spoofing'/>
+ <filterref filter='no-ip-spoofing'/>
+ <filterref filter='no-arp-spoofing'/>
+ <filterref filter='allow-dhcp-server'/>
+ <filterref filter='nova-allow-dhcp-server'/>
+ <filterref filter='nova-base-ipv4'/>
+ <filterref filter='nova-base-ipv6'/>
+ </filter>'''
+
+ nova_dhcp_filter = '''<filter name='nova-allow-dhcp-server' chain='ipv4'>
+ <uuid>891e4787-e5c0-d59b-cbd6-41bc3c6b36fc</uuid>
+ <rule action='accept' direction='out'
+ priority='100'>
+ <udp srcipaddr='0.0.0.0'
+ dstipaddr='255.255.255.255'
+ srcportstart='68'
+ dstportstart='67'/>
+ </rule>
+ <rule action='accept' direction='in' priority='100'>
+ <udp srcipaddr='$DHCPSERVER'
+ srcportstart='67'
+ dstportstart='68'/>
+ </rule>
+ </filter>'''
+
+ def nova_base_ipv4_filter(self):
+ retval = "<filter name='nova-base-ipv4' chain='ipv4'>"
+ for protocol in ['tcp', 'udp', 'icmp']:
+ for direction,action,priority in [('out','accept', 400),
+ ('in','drop', 399)]:
+ retval += """<rule action='%s' direction='%s' priority='%d'>
+ <%s />
+ </rule>""" % (action, direction,
+ priority, protocol)
+ retval += '</filter>'
+ return retval
+
+
+ def nova_base_ipv6_filter(self):
+ retval = "<filter name='nova-base-ipv6' chain='ipv6'>"
+ for protocol in ['tcp', 'udp', 'icmp']:
+ for direction,action,priority in [('out','accept',400),
+ ('in','drop',399)]:
+ retval += """<rule action='%s' direction='%s' priority='%d'>
+ <%s-ipv6 />
+ </rule>""" % (action, direction,
+ priority, protocol)
+ retval += '</filter>'
+ return retval
+
+
+ def _define_filter(self, xml):
+ if callable(xml):
+ xml = xml()
+ d = threads.deferToThread(self._conn.nwfilterDefineXML, xml)
+ return d
+
+
+ @defer.inlineCallbacks
+ def setup_nwfilters_for_instance(self, instance):
+ """
+ Creates an NWFilter for the given instance. In the process,
+ it makes sure the filters for the security groups as well as
+ the base filter are all in place.
+ """
+
+ yield self._define_filter(self.nova_base_ipv4_filter)
+ yield self._define_filter(self.nova_base_ipv6_filter)
+ yield self._define_filter(self.nova_dhcp_filter)
+ yield self._define_filter(self.nova_base_filter)
+
+ nwfilter_xml = ("<filter name='nova-instance-%s' chain='root'>\n" +
+ " <filterref filter='nova-base' />\n"
+ ) % instance['name']
+
+ for security_group in instance.security_groups:
+ yield self.ensure_security_group_filter(security_group['id'])
+
+ nwfilter_xml += (" <filterref filter='nova-secgroup-%d' />\n"
+ ) % security_group['id']
+ nwfilter_xml += "</filter>"
+
+ yield self._define_filter(nwfilter_xml)
+ return
+
+ def ensure_security_group_filter(self, security_group_id):
+ return self._define_filter(
+ self.security_group_to_nwfilter_xml(security_group_id))
+
+
+ def security_group_to_nwfilter_xml(self, security_group_id):
+ security_group = db.security_group_get({}, security_group_id)
+ rule_xml = ""
+ for rule in security_group.rules:
+ rule_xml += "<rule action='accept' direction='in' priority='300'>"
+ if rule.cidr:
+ rule_xml += "<%s srcipaddr='%s' " % (rule.protocol, rule.cidr)
+ if rule.protocol in ['tcp', 'udp']:
+ rule_xml += "dstportstart='%s' dstportend='%s' " % \
+ (rule.from_port, rule.to_port)
+ elif rule.protocol == 'icmp':
+ logging.info('rule.protocol: %r, rule.from_port: %r, rule.to_port: %r' % (rule.protocol, rule.from_port, rule.to_port))
+ if rule.from_port != -1:
+ rule_xml += "type='%s' " % rule.from_port
+ if rule.to_port != -1:
+ rule_xml += "code='%s' " % rule.to_port
+
+ rule_xml += '/>\n'
+ rule_xml += "</rule>\n"
+ xml = '''<filter name='nova-secgroup-%s' chain='ipv4'>%s</filter>''' % (security_group_id, rule_xml,)
+ return xml
diff --git a/run_tests.py b/run_tests.py
index 4121f4c06..8bb068ed1 100644
--- a/run_tests.py
+++ b/run_tests.py
@@ -64,6 +64,7 @@ from nova.tests.scheduler_unittest import *
from nova.tests.service_unittest import *
from nova.tests.validator_unittest import *
from nova.tests.volume_unittest import *
+from nova.tests.virt_unittest import *
FLAGS = flags.FLAGS