summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--nova/api/ec2/__init__.py2
-rw-r--r--nova/api/ec2/cloud.py232
-rw-r--r--nova/auth/manager.py2
-rw-r--r--nova/compute/manager.py5
-rw-r--r--nova/db/api.py70
-rw-r--r--nova/db/sqlalchemy/api.py188
-rw-r--r--nova/db/sqlalchemy/models.py64
-rw-r--r--nova/db/sqlalchemy/session.py9
-rw-r--r--nova/exception.py3
-rw-r--r--nova/network/manager.py1
-rw-r--r--nova/process.py2
-rw-r--r--nova/test.py2
-rw-r--r--nova/tests/api_unittest.py188
-rw-r--r--nova/tests/objectstore_unittest.py2
-rw-r--r--nova/tests/virt_unittest.py184
-rw-r--r--nova/virt/interfaces.template1
-rw-r--r--nova/virt/libvirt.qemu.xml.template4
-rw-r--r--nova/virt/libvirt.uml.xml.template4
-rw-r--r--nova/virt/libvirt_conn.py209
-rw-r--r--run_tests.py1
20 files changed, 1132 insertions, 41 deletions
diff --git a/nova/api/ec2/__init__.py b/nova/api/ec2/__init__.py
index 6b538a7f1..6e771f064 100644
--- a/nova/api/ec2/__init__.py
+++ b/nova/api/ec2/__init__.py
@@ -142,6 +142,8 @@ class Authorizer(wsgi.Middleware):
'CreateKeyPair': ['all'],
'DeleteKeyPair': ['all'],
'DescribeSecurityGroups': ['all'],
+ 'AuthorizeSecurityGroupIngress': ['netadmin'],
+ 'RevokeSecurityGroupIngress': ['netadmin'],
'CreateSecurityGroup': ['netadmin'],
'DeleteSecurityGroup': ['netadmin'],
'GetConsoleOutput': ['projectmanager', 'sysadmin'],
diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py
index 619c1a4b0..7839dc92c 100644
--- a/nova/api/ec2/cloud.py
+++ b/nova/api/ec2/cloud.py
@@ -28,6 +28,8 @@ import logging
import os
import time
+import IPy
+
from nova import crypto
from nova import db
from nova import exception
@@ -43,6 +45,7 @@ from nova.api.ec2 import images
FLAGS = flags.FLAGS
flags.DECLARE('storage_availability_zone', 'nova.volume.manager')
+InvalidInputException = exception.InvalidInputException
class QuotaError(exception.ApiError):
"""Quota Exceeeded"""
@@ -127,6 +130,15 @@ class CloudController(object):
result[key] = [line]
return result
+ def _trigger_refresh_security_group(self, security_group):
+ nodes = set([instance['host'] for instance in security_group.instances
+ if instance['host'] is not None])
+ for node in nodes:
+ rpc.call('%s.%s' % (FLAGS.compute_topic, node),
+ { "method": "refresh_security_group",
+ "args": { "context": None,
+ "security_group_id": security_group.id}})
+
def get_metadata(self, address):
instance_ref = db.fixed_ip_get_instance(None, address)
if instance_ref is None:
@@ -246,18 +258,195 @@ class CloudController(object):
pass
return True
- def describe_security_groups(self, context, group_names, **kwargs):
- groups = {'securityGroupSet': []}
+ def describe_security_groups(self, context, group_name=None, **kwargs):
+ self._ensure_default_security_group(context)
+ if context.user.is_admin():
+ groups = db.security_group_get_all(context)
+ else:
+ groups = db.security_group_get_by_project(context,
+ context.project.id)
+ groups = [self._format_security_group(context, g) for g in groups]
+ if not group_name is None:
+ groups = [g for g in groups if g.name in group_name]
+
+ return {'securityGroupInfo': groups }
+
+ def _format_security_group(self, context, group):
+ g = {}
+ g['groupDescription'] = group.description
+ g['groupName'] = group.name
+ g['ownerId'] = group.project_id
+ g['ipPermissions'] = []
+ for rule in group.rules:
+ r = {}
+ r['ipProtocol'] = rule.protocol
+ r['fromPort'] = rule.from_port
+ r['toPort'] = rule.to_port
+ r['groups'] = []
+ r['ipRanges'] = []
+ if rule.group_id:
+ source_group = db.security_group_get(context, rule.group_id)
+ r['groups'] += [{'groupName': source_group.name,
+ 'userId': source_group.project_id}]
+ else:
+ r['ipRanges'] += [{'cidrIp': rule.cidr}]
+ g['ipPermissions'] += [r]
+ return g
+
+
+ def _authorize_revoke_rule_args_to_dict(self, context,
+ to_port=None, from_port=None,
+ ip_protocol=None, cidr_ip=None,
+ user_id=None,
+ source_security_group_name=None,
+ source_security_group_owner_id=None):
+
+ values = {}
+
+ if source_security_group_name:
+ source_project_id = self._get_source_project_id(context,
+ source_security_group_owner_id)
+
+ source_security_group = \
+ db.security_group_get_by_name(context,
+ source_project_id,
+ source_security_group_name)
+ values['group_id'] = source_security_group['id']
+ elif cidr_ip:
+ # If this fails, it throws an exception. This is what we want.
+ IPy.IP(cidr_ip)
+ values['cidr'] = cidr_ip
+ else:
+ values['cidr'] = '0.0.0.0/0'
+
+ if ip_protocol and from_port and to_port:
+ from_port = int(from_port)
+ to_port = int(to_port)
+ ip_protocol = str(ip_protocol)
+
+ if ip_protocol.upper() not in ['TCP','UDP','ICMP']:
+ raise InvalidInputException('%s is not a valid ipProtocol' %
+ (ip_protocol,))
+ if ((min(from_port, to_port) < -1) or
+ (max(from_port, to_port) > 65535)):
+ raise InvalidInputException('Invalid port range')
+
+ values['protocol'] = ip_protocol
+ values['from_port'] = from_port
+ values['to_port'] = to_port
+ else:
+ # If cidr based filtering, protocol and ports are mandatory
+ if 'cidr' in values:
+ return None
+
+ return values
- # Stubbed for now to unblock other things.
- return groups
- def create_security_group(self, context, group_name, **kwargs):
+ def _security_group_rule_exists(self, security_group, values):
+ """Indicates whether the specified rule values are already
+ defined in the given security group.
+ """
+ for rule in security_group.rules:
+ if 'group_id' in values:
+ if rule['group_id'] == values['group_id']:
+ return True
+ else:
+ is_duplicate = True
+ for key in ('cidr', 'from_port', 'to_port', 'protocol'):
+ if rule[key] != values[key]:
+ is_duplicate = False
+ break
+ if is_duplicate:
+ return True
+ return False
+
+
+ def revoke_security_group_ingress(self, context, group_name, **kwargs):
+ self._ensure_default_security_group(context)
+ security_group = db.security_group_get_by_name(context,
+ context.project.id,
+ group_name)
+
+ criteria = self._authorize_revoke_rule_args_to_dict(context, **kwargs)
+ if criteria == None:
+ raise exception.ApiError("No rule for the specified parameters.")
+
+ for rule in security_group.rules:
+ match = True
+ for (k,v) in criteria.iteritems():
+ if getattr(rule, k, False) != v:
+ match = False
+ if match:
+ db.security_group_rule_destroy(context, rule['id'])
+ self._trigger_refresh_security_group(security_group)
+ return True
+ raise exception.ApiError("No rule for the specified parameters.")
+
+ # TODO(soren): This has only been tested with Boto as the client.
+ # Unfortunately, it seems Boto is using an old API
+ # for these operations, so support for newer API versions
+ # is sketchy.
+ def authorize_security_group_ingress(self, context, group_name, **kwargs):
+ self._ensure_default_security_group(context)
+ security_group = db.security_group_get_by_name(context,
+ context.project.id,
+ group_name)
+
+ values = self._authorize_revoke_rule_args_to_dict(context, **kwargs)
+ values['parent_group_id'] = security_group.id
+
+ if self._security_group_rule_exists(security_group, values):
+ raise exception.ApiError('This rule already exists in group %s' %
+ group_name)
+
+ security_group_rule = db.security_group_rule_create(context, values)
+
+ self._trigger_refresh_security_group(security_group)
+
return True
+
+ def _get_source_project_id(self, context, source_security_group_owner_id):
+ if source_security_group_owner_id:
+ # Parse user:project for source group.
+ source_parts = source_security_group_owner_id.split(':')
+
+ # If no project name specified, assume it's same as user name.
+ # Since we're looking up by project name, the user name is not
+ # used here. It's only read for EC2 API compatibility.
+ if len(source_parts) == 2:
+ source_project_id = source_parts[1]
+ else:
+ source_project_id = source_parts[0]
+ else:
+ source_project_id = context.project.id
+
+ return source_project_id
+
+
+ def create_security_group(self, context, group_name, group_description):
+ self._ensure_default_security_group(context)
+ if db.security_group_exists(context, context.project.id, group_name):
+ raise exception.ApiError('group %s already exists' % group_name)
+
+ group = {'user_id' : context.user.id,
+ 'project_id': context.project.id,
+ 'name': group_name,
+ 'description': group_description}
+ group_ref = db.security_group_create(context, group)
+
+ return {'securityGroupSet': [self._format_security_group(context,
+ group_ref)]}
+
+
def delete_security_group(self, context, group_name, **kwargs):
+ security_group = db.security_group_get_by_name(context,
+ context.project.id,
+ group_name)
+ db.security_group_destroy(context, security_group.id)
return True
+
def get_console_output(self, context, instance_id, **kwargs):
# instance_id is passed in as a list of instances
ec2_id = instance_id[0]
@@ -554,6 +743,18 @@ class CloudController(object):
"project_id": context.project.id}})
return db.queue_get_for(context, FLAGS.network_topic, host)
+ def _ensure_default_security_group(self, context):
+ try:
+ db.security_group_get_by_name(context,
+ context.project.id,
+ 'default')
+ except exception.NotFound:
+ values = { 'name' : 'default',
+ 'description' : 'default',
+ 'user_id' : context.user.id,
+ 'project_id' : context.project.id }
+ group = db.security_group_create(context, values)
+
def run_instances(self, context, **kwargs):
instance_type = kwargs.get('instance_type', 'm1.small')
if instance_type not in INSTANCE_TYPES:
@@ -601,8 +802,17 @@ class CloudController(object):
kwargs['key_name'])
key_data = key_pair_ref['public_key']
- # TODO: Get the real security group of launch in here
- security_group = "default"
+ security_group_arg = kwargs.get('security_group', ["default"])
+ if not type(security_group_arg) is list:
+ security_group_arg = [security_group_arg]
+
+ security_groups = []
+ self._ensure_default_security_group(context)
+ for security_group_name in security_group_arg:
+ group = db.security_group_get_by_name(context,
+ context.project.id,
+ security_group_name)
+ security_groups.append(group['id'])
reservation_id = utils.generate_uid('r')
base_options = {}
@@ -616,12 +826,12 @@ class CloudController(object):
base_options['user_id'] = context.user.id
base_options['project_id'] = context.project.id
base_options['user_data'] = kwargs.get('user_data', '')
- base_options['security_group'] = security_group
- base_options['instance_type'] = instance_type
+
base_options['display_name'] = kwargs.get('display_name')
base_options['display_description'] = kwargs.get('display_description')
type_data = INSTANCE_TYPES[instance_type]
+ base_options['instance_type'] = instance_type
base_options['memory_mb'] = type_data['memory_mb']
base_options['vcpus'] = type_data['vcpus']
base_options['local_gb'] = type_data['local_gb']
@@ -630,6 +840,10 @@ class CloudController(object):
instance_ref = db.instance_create(context, base_options)
inst_id = instance_ref['id']
+ for security_group_id in security_groups:
+ db.instance_add_security_group(context, inst_id,
+ security_group_id)
+
inst = {}
inst['mac_address'] = utils.generate_mac()
inst['launch_index'] = num
diff --git a/nova/auth/manager.py b/nova/auth/manager.py
index 49235c910..58e33969b 100644
--- a/nova/auth/manager.py
+++ b/nova/auth/manager.py
@@ -490,6 +490,7 @@ class AuthManager(object):
except:
drv.delete_project(project.id)
raise
+
return project
def modify_project(self, project, manager_user=None, description=None):
@@ -565,6 +566,7 @@ class AuthManager(object):
except:
logging.exception('Could not destroy network for %s',
project)
+
with self.driver() as drv:
drv.delete_project(Project.safe_id(project))
diff --git a/nova/compute/manager.py b/nova/compute/manager.py
index 99705d3a9..ef7e9da6f 100644
--- a/nova/compute/manager.py
+++ b/nova/compute/manager.py
@@ -64,6 +64,11 @@ class ComputeManager(manager.Manager):
@defer.inlineCallbacks
@exception.wrap_exception
+ def refresh_security_group(self, context, security_group_id, **_kwargs):
+ yield self.driver.refresh_security_group(security_group_id)
+
+ @defer.inlineCallbacks
+ @exception.wrap_exception
def run_instance(self, context, instance_id, **_kwargs):
"""Launch a new instance with specified options."""
instance_ref = self.db.instance_get(context, instance_id)
diff --git a/nova/db/api.py b/nova/db/api.py
index 2f0879c5a..4be8df397 100644
--- a/nova/db/api.py
+++ b/nova/db/api.py
@@ -304,6 +304,11 @@ def instance_update(context, instance_id, values):
return IMPL.instance_update(context, instance_id, values)
+def instance_add_security_group(context, instance_id, security_group_id):
+ """Associate the given security group with the given instance"""
+ return IMPL.instance_add_security_group(context, instance_id, security_group_id)
+
+
###################
@@ -571,6 +576,71 @@ def volume_update(context, volume_id, values):
return IMPL.volume_update(context, volume_id, values)
+####################
+
+
+def security_group_get_all(context):
+ """Get all security groups"""
+ return IMPL.security_group_get_all(context)
+
+
+def security_group_get(context, security_group_id):
+ """Get security group by its internal id"""
+ return IMPL.security_group_get(context, security_group_id)
+
+
+def security_group_get_by_name(context, project_id, group_name):
+ """Returns a security group with the specified name from a project"""
+ return IMPL.security_group_get_by_name(context, project_id, group_name)
+
+
+def security_group_get_by_project(context, project_id):
+ """Get all security groups belonging to a project"""
+ return IMPL.security_group_get_by_project(context, project_id)
+
+
+def security_group_get_by_instance(context, instance_id):
+ """Get security groups to which the instance is assigned"""
+ return IMPL.security_group_get_by_instance(context, instance_id)
+
+
+def security_group_exists(context, project_id, group_name):
+ """Indicates if a group name exists in a project"""
+ return IMPL.security_group_exists(context, project_id, group_name)
+
+
+def security_group_create(context, values):
+ """Create a new security group"""
+ return IMPL.security_group_create(context, values)
+
+
+def security_group_destroy(context, security_group_id):
+ """Deletes a security group"""
+ return IMPL.security_group_destroy(context, security_group_id)
+
+
+def security_group_destroy_all(context):
+ """Deletes a security group"""
+ return IMPL.security_group_destroy_all(context)
+
+
+####################
+
+
+def security_group_rule_create(context, values):
+ """Create a new security group"""
+ return IMPL.security_group_rule_create(context, values)
+
+
+def security_group_rule_get_by_security_group(context, security_group_id):
+ """Get all rules for a a given security group"""
+ return IMPL.security_group_rule_get_by_security_group(context, security_group_id)
+
+def security_group_rule_destroy(context, security_group_rule_id):
+ """Deletes a security group rule"""
+ return IMPL.security_group_rule_destroy(context, security_group_rule_id)
+
+
###################
diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py
index 6f1ea7c23..50d802774 100644
--- a/nova/db/sqlalchemy/api.py
+++ b/nova/db/sqlalchemy/api.py
@@ -29,8 +29,11 @@ from nova.db.sqlalchemy import models
from nova.db.sqlalchemy.session import get_session
from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
-from sqlalchemy.orm import joinedload, joinedload_all
-from sqlalchemy.sql import exists, func
+from sqlalchemy.orm import joinedload
+from sqlalchemy.orm import joinedload_all
+from sqlalchemy.sql import exists
+from sqlalchemy.sql import func
+from sqlalchemy.orm.exc import NoResultFound
FLAGS = flags.FLAGS
@@ -571,11 +574,13 @@ def instance_get(context, instance_id, session=None):
if is_admin_context(context):
result = session.query(models.Instance
+ ).options(joinedload('security_groups')
).filter_by(id=instance_id
).filter_by(deleted=can_read_deleted(context)
).first()
elif is_user_context(context):
result = session.query(models.Instance
+ ).options(joinedload('security_groups')
).filter_by(project_id=context.project.id
).filter_by(id=instance_id
).filter_by(deleted=False
@@ -591,6 +596,7 @@ def instance_get_all(context):
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
+ ).options(joinedload('security_groups')
).filter_by(deleted=can_read_deleted(context)
).all()
@@ -600,6 +606,7 @@ def instance_get_all_by_user(context, user_id):
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
+ ).options(joinedload('security_groups')
).filter_by(deleted=can_read_deleted(context)
).filter_by(user_id=user_id
).all()
@@ -612,6 +619,7 @@ def instance_get_all_by_project(context, project_id):
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
+ ).options(joinedload('security_groups')
).filter_by(project_id=project_id
).filter_by(deleted=can_read_deleted(context)
).all()
@@ -624,12 +632,14 @@ def instance_get_all_by_reservation(context, reservation_id):
if is_admin_context(context):
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
+ ).options(joinedload('security_groups')
).filter_by(reservation_id=reservation_id
).filter_by(deleted=can_read_deleted(context)
).all()
elif is_user_context(context):
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
+ ).options(joinedload('security_groups')
).filter_by(project_id=context.project.id
).filter_by(reservation_id=reservation_id
).filter_by(deleted=False
@@ -642,11 +652,13 @@ def instance_get_by_internal_id(context, internal_id):
if is_admin_context(context):
result = session.query(models.Instance
+ ).options(joinedload('security_groups')
).filter_by(internal_id=internal_id
).filter_by(deleted=can_read_deleted(context)
).first()
elif is_user_context(context):
result = session.query(models.Instance
+ ).options(joinedload('security_groups')
).filter_by(project_id=context.project.id
).filter_by(internal_id=internal_id
).filter_by(deleted=False
@@ -718,6 +730,18 @@ def instance_update(context, instance_id, values):
instance_ref.save(session=session)
+def instance_add_security_group(context, instance_id, security_group_id):
+ """Associate the given security group with the given instance"""
+ session = get_session()
+ with session.begin():
+ instance_ref = instance_get(context, instance_id, session=session)
+ security_group_ref = security_group_get(context,
+ security_group_id,
+ session=session)
+ instance_ref.security_groups += [security_group_ref]
+ instance_ref.save(session=session)
+
+
###################
@@ -1192,6 +1216,7 @@ def volume_get(context, volume_id, session=None):
@require_admin_context
def volume_get_all(context):
+ session = get_session()
return session.query(models.Volume
).filter_by(deleted=can_read_deleted(context)
).all()
@@ -1282,6 +1307,163 @@ def volume_update(context, volume_id, values):
###################
+@require_context
+def security_group_get_all(context):
+ session = get_session()
+ return session.query(models.SecurityGroup
+ ).filter_by(deleted=can_read_deleted(context)
+ ).options(joinedload_all('rules')
+ ).all()
+
+
+@require_context
+def security_group_get(context, security_group_id, session=None):
+ if not session:
+ session = get_session()
+ if is_admin_context(context):
+ result = session.query(models.SecurityGroup
+ ).filter_by(deleted=can_read_deleted(context),
+ ).filter_by(id=security_group_id
+ ).options(joinedload_all('rules')
+ ).first()
+ else:
+ result = session.query(models.SecurityGroup
+ ).filter_by(deleted=False
+ ).filter_by(id=security_group_id
+ ).filter_by(project_id=context.project_id
+ ).options(joinedload_all('rules')
+ ).first()
+ if not result:
+ raise exception.NotFound("No secuity group with id %s" %
+ security_group_id)
+ return result
+
+
+@require_context
+def security_group_get_by_name(context, project_id, group_name):
+ session = get_session()
+ result = session.query(models.SecurityGroup
+ ).filter_by(project_id=project_id
+ ).filter_by(name=group_name
+ ).filter_by(deleted=False
+ ).options(joinedload_all('rules')
+ ).options(joinedload_all('instances')
+ ).first()
+ if not result:
+ raise exception.NotFound(
+ 'No security group named %s for project: %s' \
+ % (group_name, project_id))
+ return result
+
+
+@require_context
+def security_group_get_by_project(context, project_id):
+ session = get_session()
+ return session.query(models.SecurityGroup
+ ).filter_by(project_id=project_id
+ ).filter_by(deleted=False
+ ).options(joinedload_all('rules')
+ ).all()
+
+
+@require_context
+def security_group_get_by_instance(context, instance_id):
+ session = get_session()
+ return session.query(models.SecurityGroup
+ ).filter_by(deleted=False
+ ).options(joinedload_all('rules')
+ ).join(models.SecurityGroup.instances
+ ).filter_by(id=instance_id
+ ).filter_by(deleted=False
+ ).all()
+
+
+@require_context
+def security_group_exists(context, project_id, group_name):
+ try:
+ group = security_group_get_by_name(context, project_id, group_name)
+ return group != None
+ except exception.NotFound:
+ return False
+
+
+@require_context
+def security_group_create(context, values):
+ security_group_ref = models.SecurityGroup()
+ # FIXME(devcamcar): Unless I do this, rules fails with lazy load exception
+ # once save() is called. This will get cleaned up in next orm pass.
+ security_group_ref.rules
+ for (key, value) in values.iteritems():
+ security_group_ref[key] = value
+ security_group_ref.save()
+ return security_group_ref
+
+
+@require_context
+def security_group_destroy(context, security_group_id):
+ session = get_session()
+ with session.begin():
+ # TODO(vish): do we have to use sql here?
+ session.execute('update security_groups set deleted=1 where id=:id',
+ {'id': security_group_id})
+ session.execute('update security_group_rules set deleted=1 '
+ 'where group_id=:id',
+ {'id': security_group_id})
+
+@require_context
+def security_group_destroy_all(context, session=None):
+ if not session:
+ session = get_session()
+ with session.begin():
+ # TODO(vish): do we have to use sql here?
+ session.execute('update security_groups set deleted=1')
+ session.execute('update security_group_rules set deleted=1')
+
+
+###################
+
+
+@require_context
+def security_group_rule_get(context, security_group_rule_id, session=None):
+ if not session:
+ session = get_session()
+ if is_admin_context(context):
+ result = session.query(models.SecurityGroupIngressRule
+ ).filter_by(deleted=can_read_deleted(context)
+ ).filter_by(id=security_group_rule_id
+ ).first()
+ else:
+ # TODO(vish): Join to group and check for project_id
+ result = session.query(models.SecurityGroupIngressRule
+ ).filter_by(deleted=False
+ ).filter_by(id=security_group_rule_id
+ ).first()
+ if not result:
+ raise exception.NotFound("No secuity group rule with id %s" %
+ security_group_rule_id)
+ return result
+
+
+@require_context
+def security_group_rule_create(context, values):
+ security_group_rule_ref = models.SecurityGroupIngressRule()
+ for (key, value) in values.iteritems():
+ security_group_rule_ref[key] = value
+ security_group_rule_ref.save()
+ return security_group_rule_ref
+
+@require_context
+def security_group_rule_destroy(context, security_group_rule_id):
+ session = get_session()
+ with session.begin():
+ security_group_rule = security_group_rule_get(context,
+ security_group_rule_id,
+ session=session)
+ security_group_rule.delete(session=session)
+
+
+###################
+
@require_admin_context
def user_get(context, id, session=None):
if not session:
@@ -1491,6 +1673,8 @@ def user_add_project_role(context, user_id, project_id, role):
###################
+
+@require_admin_context
def host_get_networks(context, host):
session = get_session()
with session.begin():
diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py
index 9809eb7a7..7dfc39f6f 100644
--- a/nova/db/sqlalchemy/models.py
+++ b/nova/db/sqlalchemy/models.py
@@ -187,7 +187,6 @@ class Instance(BASE, NovaBase):
launch_index = Column(Integer)
key_name = Column(String(255))
key_data = Column(Text)
- security_group = Column(String(255))
state = Column(Integer)
state_description = Column(String(255))
@@ -289,10 +288,66 @@ class ExportDevice(BASE, NovaBase):
'ExportDevice.deleted==False)')
+class SecurityGroupInstanceAssociation(BASE, NovaBase):
+ __tablename__ = 'security_group_instance_association'
+ id = Column(Integer, primary_key=True)
+ security_group_id = Column(Integer, ForeignKey('security_groups.id'))
+ instance_id = Column(Integer, ForeignKey('instances.id'))
+
+
+class SecurityGroup(BASE, NovaBase):
+ """Represents a security group"""
+ __tablename__ = 'security_groups'
+ id = Column(Integer, primary_key=True)
+
+ name = Column(String(255))
+ description = Column(String(255))
+ user_id = Column(String(255))
+ project_id = Column(String(255))
+
+ instances = relationship(Instance,
+ secondary="security_group_instance_association",
+ primaryjoin="and_(SecurityGroup.id == SecurityGroupInstanceAssociation.security_group_id,"
+ "SecurityGroup.deleted == False)",
+ secondaryjoin="and_(SecurityGroupInstanceAssociation.instance_id == Instance.id,"
+ "Instance.deleted == False)",
+ backref='security_groups')
+
+ @property
+ def user(self):
+ return auth.manager.AuthManager().get_user(self.user_id)
+
+ @property
+ def project(self):
+ return auth.manager.AuthManager().get_project(self.project_id)
+
+
+class SecurityGroupIngressRule(BASE, NovaBase):
+ """Represents a rule in a security group"""
+ __tablename__ = 'security_group_rules'
+ id = Column(Integer, primary_key=True)
+
+ parent_group_id = Column(Integer, ForeignKey('security_groups.id'))
+ parent_group = relationship("SecurityGroup", backref="rules",
+ foreign_keys=parent_group_id,
+ primaryjoin="and_(SecurityGroupIngressRule.parent_group_id == SecurityGroup.id,"
+ "SecurityGroupIngressRule.deleted == False)")
+
+ protocol = Column(String(5)) # "tcp", "udp", or "icmp"
+ from_port = Column(Integer)
+ to_port = Column(Integer)
+ cidr = Column(String(255))
+
+ # Note: This is not the parent SecurityGroup. It's SecurityGroup we're
+ # granting access for.
+ group_id = Column(Integer, ForeignKey('security_groups.id'))
+
+
class KeyPair(BASE, NovaBase):
"""Represents a public key pair for ssh"""
__tablename__ = 'key_pairs'
id = Column(Integer, primary_key=True)
+
name = Column(String(255))
user_id = Column(String(255))
@@ -461,9 +516,10 @@ class FloatingIp(BASE, NovaBase):
def register_models():
"""Register Models and create metadata"""
from sqlalchemy import create_engine
- models = (Service, Instance, Volume, ExportDevice,
- FixedIp, FloatingIp, Network, NetworkIndex,
- AuthToken, UserProjectAssociation, User, Project) # , Image, Host)
+ models = (Service, Instance, Volume, ExportDevice, FixedIp,
+ FloatingIp, Network, NetworkIndex, SecurityGroup,
+ SecurityGroupIngressRule, SecurityGroupInstanceAssociation,
+ AuthToken, User, Project) # , Image, Host
engine = create_engine(FLAGS.sql_connection, echo=False)
for model in models:
model.metadata.create_all(engine)
diff --git a/nova/db/sqlalchemy/session.py b/nova/db/sqlalchemy/session.py
index 69a205378..826754f6a 100644
--- a/nova/db/sqlalchemy/session.py
+++ b/nova/db/sqlalchemy/session.py
@@ -36,7 +36,8 @@ def get_session(autocommit=True, expire_on_commit=False):
if not _MAKER:
if not _ENGINE:
_ENGINE = create_engine(FLAGS.sql_connection, echo=False)
- _MAKER = sessionmaker(bind=_ENGINE,
- autocommit=autocommit,
- expire_on_commit=expire_on_commit)
- return _MAKER()
+ _MAKER = (sessionmaker(bind=_ENGINE,
+ autocommit=autocommit,
+ expire_on_commit=expire_on_commit))
+ session = _MAKER()
+ return session
diff --git a/nova/exception.py b/nova/exception.py
index b8894758f..f157fab2d 100644
--- a/nova/exception.py
+++ b/nova/exception.py
@@ -69,6 +69,9 @@ class NotEmpty(Error):
class Invalid(Error):
pass
+class InvalidInputException(Error):
+ pass
+
def wrap_exception(f):
def _wrap(*args, **kw):
diff --git a/nova/network/manager.py b/nova/network/manager.py
index 9c1846dd9..093a6be9a 100644
--- a/nova/network/manager.py
+++ b/nova/network/manager.py
@@ -211,7 +211,6 @@ class FlatManager(NetworkManager):
# in the datastore?
net = {}
net['injected'] = True
- net['network_str'] = FLAGS.flat_network_network
net['netmask'] = FLAGS.flat_network_netmask
net['bridge'] = FLAGS.flat_network_bridge
net['gateway'] = FLAGS.flat_network_gateway
diff --git a/nova/process.py b/nova/process.py
index b3cad894b..13cb90e82 100644
--- a/nova/process.py
+++ b/nova/process.py
@@ -113,7 +113,7 @@ class BackRelayWithInput(protocol.ProcessProtocol):
if self.started_deferred:
self.started_deferred.callback(self)
if self.process_input:
- self.transport.write(self.process_input)
+ self.transport.write(str(self.process_input))
self.transport.closeStdin()
def get_process_output(executable, args=None, env=None, path=None,
diff --git a/nova/test.py b/nova/test.py
index 1f4b33272..08e1dea2d 100644
--- a/nova/test.py
+++ b/nova/test.py
@@ -31,6 +31,7 @@ from tornado import ioloop
from twisted.internet import defer
from twisted.trial import unittest
+from nova import db
from nova import fakerabbit
from nova import flags
from nova import rpc
@@ -83,6 +84,7 @@ class TrialTestCase(unittest.TestCase):
if FLAGS.fake_rabbit:
fakerabbit.reset_all()
+ db.security_group_destroy_all(None)
super(TrialTestCase, self).tearDown()
diff --git a/nova/tests/api_unittest.py b/nova/tests/api_unittest.py
index c040cdad3..7ab27e000 100644
--- a/nova/tests/api_unittest.py
+++ b/nova/tests/api_unittest.py
@@ -91,6 +91,9 @@ class ApiEc2TestCase(test.BaseTestCase):
self.host = '127.0.0.1'
self.app = api.API()
+
+ def expect_http(self, host=None, is_secure=False):
+ """Returns a new EC2 connection"""
self.ec2 = boto.connect_ec2(
aws_access_key_id='fake',
aws_secret_access_key='fake',
@@ -100,9 +103,6 @@ class ApiEc2TestCase(test.BaseTestCase):
path='/services/Cloud')
self.mox.StubOutWithMock(self.ec2, 'new_http_connection')
-
- def expect_http(self, host=None, is_secure=False):
- """Returns a new EC2 connection"""
http = FakeHttplibConnection(
self.app, '%s:8773' % (self.host), False)
# pylint: disable-msg=E1103
@@ -138,3 +138,185 @@ class ApiEc2TestCase(test.BaseTestCase):
self.assertEquals(len(results), 1)
self.manager.delete_project(project)
self.manager.delete_user(user)
+
+ def test_get_all_security_groups(self):
+ """Test that we can retrieve security groups"""
+ self.expect_http()
+ self.mox.ReplayAll()
+ user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
+ project = self.manager.create_project('fake', 'fake', 'fake')
+
+ rv = self.ec2.get_all_security_groups()
+
+ self.assertEquals(len(rv), 1)
+ self.assertEquals(rv[0].name, 'default')
+
+ self.manager.delete_project(project)
+ self.manager.delete_user(user)
+
+ def test_create_delete_security_group(self):
+ """Test that we can create a security group"""
+ self.expect_http()
+ self.mox.ReplayAll()
+ user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
+ project = self.manager.create_project('fake', 'fake', 'fake')
+
+ # At the moment, you need both of these to actually be netadmin
+ self.manager.add_role('fake', 'netadmin')
+ project.add_role('fake', 'netadmin')
+
+ security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
+ for x in range(random.randint(4, 8)))
+
+ self.ec2.create_security_group(security_group_name, 'test group')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ rv = self.ec2.get_all_security_groups()
+ self.assertEquals(len(rv), 2)
+ self.assertTrue(security_group_name in [group.name for group in rv])
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ self.ec2.delete_security_group(security_group_name)
+
+ self.manager.delete_project(project)
+ self.manager.delete_user(user)
+
+ def test_authorize_revoke_security_group_cidr(self):
+ """
+ Test that we can add and remove CIDR based rules
+ to a security group
+ """
+ self.expect_http()
+ self.mox.ReplayAll()
+ user = self.manager.create_user('fake', 'fake', 'fake')
+ project = self.manager.create_project('fake', 'fake', 'fake')
+
+ # At the moment, you need both of these to actually be netadmin
+ self.manager.add_role('fake', 'netadmin')
+ project.add_role('fake', 'netadmin')
+
+ security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
+ for x in range(random.randint(4, 8)))
+
+ group = self.ec2.create_security_group(security_group_name, 'test group')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+ group.connection = self.ec2
+
+ group.authorize('tcp', 80, 81, '0.0.0.0/0')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ rv = self.ec2.get_all_security_groups()
+ # I don't bother checkng that we actually find it here,
+ # because the create/delete unit test further up should
+ # be good enough for that.
+ for group in rv:
+ if group.name == security_group_name:
+ self.assertEquals(len(group.rules), 1)
+ self.assertEquals(int(group.rules[0].from_port), 80)
+ self.assertEquals(int(group.rules[0].to_port), 81)
+ self.assertEquals(len(group.rules[0].grants), 1)
+ self.assertEquals(str(group.rules[0].grants[0]), '0.0.0.0/0')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+ group.connection = self.ec2
+
+ group.revoke('tcp', 80, 81, '0.0.0.0/0')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ self.ec2.delete_security_group(security_group_name)
+
+ self.expect_http()
+ self.mox.ReplayAll()
+ group.connection = self.ec2
+
+ rv = self.ec2.get_all_security_groups()
+
+ self.assertEqual(len(rv), 1)
+ self.assertEqual(rv[0].name, 'default')
+
+ self.manager.delete_project(project)
+ self.manager.delete_user(user)
+
+ return
+
+ def test_authorize_revoke_security_group_foreign_group(self):
+ """
+ Test that we can grant and revoke another security group access
+ to a security group
+ """
+ self.expect_http()
+ self.mox.ReplayAll()
+ user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
+ project = self.manager.create_project('fake', 'fake', 'fake')
+
+ # At the moment, you need both of these to actually be netadmin
+ self.manager.add_role('fake', 'netadmin')
+ project.add_role('fake', 'netadmin')
+
+ security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
+ for x in range(random.randint(4, 8)))
+ other_security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
+ for x in range(random.randint(4, 8)))
+
+ group = self.ec2.create_security_group(security_group_name, 'test group')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ other_group = self.ec2.create_security_group(other_security_group_name,
+ 'some other group')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+ group.connection = self.ec2
+
+ group.authorize(src_group=other_group)
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ rv = self.ec2.get_all_security_groups()
+
+ # I don't bother checkng that we actually find it here,
+ # because the create/delete unit test further up should
+ # be good enough for that.
+ for group in rv:
+ if group.name == security_group_name:
+ self.assertEquals(len(group.rules), 1)
+ self.assertEquals(len(group.rules[0].grants), 1)
+ self.assertEquals(str(group.rules[0].grants[0]),
+ '%s-%s' % (other_security_group_name, 'fake'))
+
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ rv = self.ec2.get_all_security_groups()
+
+ for group in rv:
+ if group.name == security_group_name:
+ self.expect_http()
+ self.mox.ReplayAll()
+ group.connection = self.ec2
+ group.revoke(src_group=other_group)
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ self.ec2.delete_security_group(security_group_name)
+
+ self.manager.delete_project(project)
+ self.manager.delete_user(user)
+
+ return
diff --git a/nova/tests/objectstore_unittest.py b/nova/tests/objectstore_unittest.py
index eb2ee0406..872f1ab23 100644
--- a/nova/tests/objectstore_unittest.py
+++ b/nova/tests/objectstore_unittest.py
@@ -210,7 +210,7 @@ class S3APITestCase(test.TrialTestCase):
"""Setup users, projects, and start a test server."""
super(S3APITestCase, self).setUp()
- FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver',
+ FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver'
FLAGS.buckets_path = os.path.join(OSS_TEMPDIR, 'buckets')
self.auth_manager = manager.AuthManager()
diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py
index 730928f39..684347473 100644
--- a/nova/tests/virt_unittest.py
+++ b/nova/tests/virt_unittest.py
@@ -14,11 +14,16 @@
# License for the specific language governing permissions and limitations
# under the License.
-from xml.etree.ElementTree import fromstring as parseXml
+from xml.etree.ElementTree import fromstring as xml_to_tree
+from xml.dom.minidom import parseString as xml_to_dom
+from nova import db
from nova import flags
from nova import test
+from nova.api import context
+from nova.api.ec2 import cloud
from nova.auth import manager
+
# Needed to get FLAGS.instances_path defined:
from nova.compute import manager as compute_manager
from nova.virt import libvirt_conn
@@ -33,34 +38,49 @@ class LibvirtConnTestCase(test.TrialTestCase):
FLAGS.instances_path = ''
def test_get_uri_and_template(self):
- instance = { 'name' : 'i-cafebabe',
- 'id' : 'i-cafebabe',
+ ip = '10.11.12.13'
+
+ instance = { 'internal_id' : 1,
'memory_kb' : '1024000',
'basepath' : '/some/path',
'bridge_name' : 'br100',
'mac_address' : '02:12:34:46:56:67',
'vcpus' : 2,
'project_id' : 'fake',
- 'ip_address' : '10.11.12.13',
'bridge' : 'br101',
'instance_type' : 'm1.small'}
+ instance_ref = db.instance_create(None, instance)
+ network_ref = db.project_get_network(None, self.project.id)
+
+ fixed_ip = { 'address' : ip,
+ 'network_id' : network_ref['id'] }
+
+ fixed_ip_ref = db.fixed_ip_create(None, fixed_ip)
+ db.fixed_ip_update(None, ip, { 'allocated' : True,
+ 'instance_id' : instance_ref['id'] })
+
type_uri_map = { 'qemu' : ('qemu:///system',
- [(lambda t: t.find('.').tag, 'domain'),
- (lambda t: t.find('.').get('type'), 'qemu'),
+ [(lambda t: t.find('.').get('type'), 'qemu'),
(lambda t: t.find('./os/type').text, 'hvm'),
(lambda t: t.find('./devices/emulator'), None)]),
'kvm' : ('qemu:///system',
- [(lambda t: t.find('.').tag, 'domain'),
- (lambda t: t.find('.').get('type'), 'kvm'),
+ [(lambda t: t.find('.').get('type'), 'kvm'),
(lambda t: t.find('./os/type').text, 'hvm'),
(lambda t: t.find('./devices/emulator'), None)]),
'uml' : ('uml:///system',
- [(lambda t: t.find('.').tag, 'domain'),
- (lambda t: t.find('.').get('type'), 'uml'),
+ [(lambda t: t.find('.').get('type'), 'uml'),
(lambda t: t.find('./os/type').text, 'uml')]),
}
+ common_checks = [(lambda t: t.find('.').tag, 'domain'),
+ (lambda t: \
+ t.find('./devices/interface/filterref/parameter') \
+ .get('name'), 'IP'),
+ (lambda t: \
+ t.find('./devices/interface/filterref/parameter') \
+ .get('value'), '10.11.12.13')]
+
for (libvirt_type,(expected_uri, checks)) in type_uri_map.iteritems():
FLAGS.libvirt_type = libvirt_type
conn = libvirt_conn.LibvirtConnection(True)
@@ -68,13 +88,18 @@ class LibvirtConnTestCase(test.TrialTestCase):
uri, template = conn.get_uri_and_template()
self.assertEquals(uri, expected_uri)
- xml = conn.to_xml(instance)
- tree = parseXml(xml)
+ xml = conn.to_xml(instance_ref)
+ tree = xml_to_tree(xml)
for i, (check, expected_result) in enumerate(checks):
self.assertEqual(check(tree),
expected_result,
'%s failed check %d' % (xml, i))
+ for i, (check, expected_result) in enumerate(common_checks):
+ self.assertEqual(check(tree),
+ expected_result,
+ '%s failed common check %d' % (xml, i))
+
# Deliberately not just assigning this string to FLAGS.libvirt_uri and
# checking against that later on. This way we make sure the
# implementation doesn't fiddle around with the FLAGS.
@@ -90,3 +115,138 @@ class LibvirtConnTestCase(test.TrialTestCase):
def tearDown(self):
self.manager.delete_project(self.project)
self.manager.delete_user(self.user)
+
+class NWFilterTestCase(test.TrialTestCase):
+ def setUp(self):
+ super(NWFilterTestCase, self).setUp()
+
+ class Mock(object):
+ pass
+
+ self.manager = manager.AuthManager()
+ self.user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
+ self.project = self.manager.create_project('fake', 'fake', 'fake')
+ self.context = context.APIRequestContext(self.user, self.project)
+
+ self.fake_libvirt_connection = Mock()
+
+ self.fw = libvirt_conn.NWFilterFirewall(self.fake_libvirt_connection)
+
+ def tearDown(self):
+ self.manager.delete_project(self.project)
+ self.manager.delete_user(self.user)
+
+
+ def test_cidr_rule_nwfilter_xml(self):
+ cloud_controller = cloud.CloudController()
+ cloud_controller.create_security_group(self.context,
+ 'testgroup',
+ 'test group description')
+ cloud_controller.authorize_security_group_ingress(self.context,
+ 'testgroup',
+ from_port='80',
+ to_port='81',
+ ip_protocol='tcp',
+ cidr_ip='0.0.0.0/0')
+
+
+ security_group = db.security_group_get_by_name(self.context,
+ 'fake',
+ 'testgroup')
+
+ xml = self.fw.security_group_to_nwfilter_xml(security_group.id)
+
+ dom = xml_to_dom(xml)
+ self.assertEqual(dom.firstChild.tagName, 'filter')
+
+ rules = dom.getElementsByTagName('rule')
+ self.assertEqual(len(rules), 1)
+
+ # It's supposed to allow inbound traffic.
+ self.assertEqual(rules[0].getAttribute('action'), 'accept')
+ self.assertEqual(rules[0].getAttribute('direction'), 'in')
+
+ # Must be lower priority than the base filter (which blocks everything)
+ self.assertTrue(int(rules[0].getAttribute('priority')) < 1000)
+
+ ip_conditions = rules[0].getElementsByTagName('tcp')
+ self.assertEqual(len(ip_conditions), 1)
+ self.assertEqual(ip_conditions[0].getAttribute('srcipaddr'), '0.0.0.0')
+ self.assertEqual(ip_conditions[0].getAttribute('srcipmask'), '0.0.0.0')
+ self.assertEqual(ip_conditions[0].getAttribute('dstportstart'), '80')
+ self.assertEqual(ip_conditions[0].getAttribute('dstportend'), '81')
+
+
+ self.teardown_security_group()
+
+ def teardown_security_group(self):
+ cloud_controller = cloud.CloudController()
+ cloud_controller.delete_security_group(self.context, 'testgroup')
+
+
+ def setup_and_return_security_group(self):
+ cloud_controller = cloud.CloudController()
+ cloud_controller.create_security_group(self.context,
+ 'testgroup',
+ 'test group description')
+ cloud_controller.authorize_security_group_ingress(self.context,
+ 'testgroup',
+ from_port='80',
+ to_port='81',
+ ip_protocol='tcp',
+ cidr_ip='0.0.0.0/0')
+
+ return db.security_group_get_by_name(self.context, 'fake', 'testgroup')
+
+ def test_creates_base_rule_first(self):
+ # These come pre-defined by libvirt
+ self.defined_filters = ['no-mac-spoofing',
+ 'no-ip-spoofing',
+ 'no-arp-spoofing',
+ 'allow-dhcp-server']
+
+ self.recursive_depends = {}
+ for f in self.defined_filters:
+ self.recursive_depends[f] = []
+
+ def _filterDefineXMLMock(xml):
+ dom = xml_to_dom(xml)
+ name = dom.firstChild.getAttribute('name')
+ self.recursive_depends[name] = []
+ for f in dom.getElementsByTagName('filterref'):
+ ref = f.getAttribute('filter')
+ self.assertTrue(ref in self.defined_filters,
+ ('%s referenced filter that does ' +
+ 'not yet exist: %s') % (name, ref))
+ dependencies = [ref] + self.recursive_depends[ref]
+ self.recursive_depends[name] += dependencies
+
+ self.defined_filters.append(name)
+ return True
+
+ self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock
+
+ instance_ref = db.instance_create(self.context,
+ {'user_id': 'fake',
+ 'project_id': 'fake'})
+ inst_id = instance_ref['id']
+
+ def _ensure_all_called(_):
+ instance_filter = 'nova-instance-%s' % instance_ref['name']
+ secgroup_filter = 'nova-secgroup-%s' % self.security_group['id']
+ for required in [secgroup_filter, 'allow-dhcp-server',
+ 'no-arp-spoofing', 'no-ip-spoofing',
+ 'no-mac-spoofing']:
+ self.assertTrue(required in self.recursive_depends[instance_filter],
+ "Instance's filter does not include %s" % required)
+
+ self.security_group = self.setup_and_return_security_group()
+
+ db.instance_add_security_group(self.context, inst_id, self.security_group.id)
+ instance = db.instance_get(self.context, inst_id)
+
+ d = self.fw.setup_nwfilters_for_instance(instance)
+ d.addCallback(_ensure_all_called)
+ d.addCallback(lambda _:self.teardown_security_group())
+
+ return d
diff --git a/nova/virt/interfaces.template b/nova/virt/interfaces.template
index 11df301f6..87b92b84a 100644
--- a/nova/virt/interfaces.template
+++ b/nova/virt/interfaces.template
@@ -10,7 +10,6 @@ auto eth0
iface eth0 inet static
address %(address)s
netmask %(netmask)s
- network %(network)s
broadcast %(broadcast)s
gateway %(gateway)s
dns-nameservers %(dns)s
diff --git a/nova/virt/libvirt.qemu.xml.template b/nova/virt/libvirt.qemu.xml.template
index 17bd79b7c..2538b1ade 100644
--- a/nova/virt/libvirt.qemu.xml.template
+++ b/nova/virt/libvirt.qemu.xml.template
@@ -20,6 +20,10 @@
<source bridge='%(bridge_name)s'/>
<mac address='%(mac_address)s'/>
<!-- <model type='virtio'/> CANT RUN virtio network right now -->
+ <filterref filter="nova-instance-%(name)s">
+ <parameter name="IP" value="%(ip_address)s" />
+ <parameter name="DHCPSERVER" value="%(dhcp_server)s" />
+ </filterref>
</interface>
<serial type="file">
<source path='%(basepath)s/console.log'/>
diff --git a/nova/virt/libvirt.uml.xml.template b/nova/virt/libvirt.uml.xml.template
index c039d6d90..bb8b47911 100644
--- a/nova/virt/libvirt.uml.xml.template
+++ b/nova/virt/libvirt.uml.xml.template
@@ -14,6 +14,10 @@
<interface type='bridge'>
<source bridge='%(bridge_name)s'/>
<mac address='%(mac_address)s'/>
+ <filterref filter="nova-instance-%(name)s">
+ <parameter name="IP" value="%(ip_address)s" />
+ <parameter name="DHCPSERVER" value="%(dhcp_server)s" />
+ </filterref>
</interface>
<console type="file">
<source path='%(basepath)s/console.log'/>
diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py
index d868e083c..319f7d2af 100644
--- a/nova/virt/libvirt_conn.py
+++ b/nova/virt/libvirt_conn.py
@@ -25,14 +25,17 @@ import logging
import os
import shutil
+import IPy
from twisted.internet import defer
from twisted.internet import task
+from twisted.internet import threads
from nova import db
from nova import exception
from nova import flags
from nova import process
from nova import utils
+#from nova.api import context
from nova.auth import manager
from nova.compute import disk
from nova.compute import instance_types
@@ -60,6 +63,9 @@ flags.DEFINE_string('libvirt_uri',
'',
'Override the default libvirt URI (which is dependent'
' on libvirt_type)')
+flags.DEFINE_bool('allow_project_net_traffic',
+ True,
+ 'Whether to allow in project network traffic')
def get_connection(read_only):
@@ -134,7 +140,7 @@ class LibvirtConnection(object):
d.addCallback(lambda _: self._cleanup(instance))
# FIXME: What does this comment mean?
# TODO(termie): short-circuit me for tests
- # WE'LL save this for when we do shutdown,
+ # WE'LL save this for when we do shutdown,
# instead of destroy - but destroy returns immediately
timer = task.LoopingCall(f=None)
def _wait_for_shutdown():
@@ -214,6 +220,7 @@ class LibvirtConnection(object):
instance['id'],
power_state.NOSTATE,
'launching')
+ yield NWFilterFirewall(self._conn).setup_nwfilters_for_instance(instance)
yield self._create_image(instance, xml)
yield self._conn.createXML(xml, 0)
# TODO(termie): this should actually register
@@ -285,7 +292,6 @@ class LibvirtConnection(object):
address = db.instance_get_fixed_address(None, inst['id'])
with open(FLAGS.injected_network_template) as f:
net = f.read() % {'address': address,
- 'network': network_ref['network'],
'netmask': network_ref['netmask'],
'gateway': network_ref['gateway'],
'broadcast': network_ref['broadcast'],
@@ -317,6 +323,9 @@ class LibvirtConnection(object):
network = db.project_get_network(None, instance['project_id'])
# FIXME(vish): stick this in db
instance_type = instance_types.INSTANCE_TYPES[instance['instance_type']]
+ ip_address = db.instance_get_fixed_address({}, instance['id'])
+ # Assume that the gateway also acts as the dhcp server.
+ dhcp_server = network['gateway']
xml_info = {'type': FLAGS.libvirt_type,
'name': instance['name'],
'basepath': os.path.join(FLAGS.instances_path,
@@ -324,7 +333,9 @@ class LibvirtConnection(object):
'memory_kb': instance_type['memory_mb'] * 1024,
'vcpus': instance_type['vcpus'],
'bridge_name': network['bridge'],
- 'mac_address': instance['mac_address']}
+ 'mac_address': instance['mac_address'],
+ 'ip_address': ip_address,
+ 'dhcp_server': dhcp_server }
libvirt_xml = self.libvirt_xml % xml_info
logging.debug('instance %s: finished toXML method', instance['name'])
@@ -438,3 +449,195 @@ class LibvirtConnection(object):
"""
domain = self._conn.lookupByName(instance_name)
return domain.interfaceStats(interface)
+
+
+ def refresh_security_group(self, security_group_id):
+ fw = NWFilterFirewall(self._conn)
+ fw.ensure_security_group_filter(security_group_id)
+
+
+class NWFilterFirewall(object):
+ """
+ This class implements a network filtering mechanism versatile
+ enough for EC2 style Security Group filtering by leveraging
+ libvirt's nwfilter.
+
+ First, all instances get a filter ("nova-base-filter") applied.
+ This filter drops all incoming ipv4 and ipv6 connections.
+ Outgoing connections are never blocked.
+
+ Second, every security group maps to a nwfilter filter(*).
+ NWFilters can be updated at runtime and changes are applied
+ immediately, so changes to security groups can be applied at
+ runtime (as mandated by the spec).
+
+ Security group rules are named "nova-secgroup-<id>" where <id>
+ is the internal id of the security group. They're applied only on
+ hosts that have instances in the security group in question.
+
+ Updates to security groups are done by updating the data model
+ (in response to API calls) followed by a request sent to all
+ the nodes with instances in the security group to refresh the
+ security group.
+
+ Each instance has its own NWFilter, which references the above
+ mentioned security group NWFilters. This was done because
+ interfaces can only reference one filter while filters can
+ reference multiple other filters. This has the added benefit of
+ actually being able to add and remove security groups from an
+ instance at run time. This functionality is not exposed anywhere,
+ though.
+
+ Outstanding questions:
+
+ The name is unique, so would there be any good reason to sync
+ the uuid across the nodes (by assigning it from the datamodel)?
+
+
+ (*) This sentence brought to you by the redundancy department of
+ redundancy.
+ """
+
+ def __init__(self, get_connection):
+ self._conn = get_connection
+
+
+ nova_base_filter = '''<filter name='nova-base' chain='root'>
+ <uuid>26717364-50cf-42d1-8185-29bf893ab110</uuid>
+ <filterref filter='no-mac-spoofing'/>
+ <filterref filter='no-ip-spoofing'/>
+ <filterref filter='no-arp-spoofing'/>
+ <filterref filter='allow-dhcp-server'/>
+ <filterref filter='nova-allow-dhcp-server'/>
+ <filterref filter='nova-base-ipv4'/>
+ <filterref filter='nova-base-ipv6'/>
+ </filter>'''
+
+ nova_dhcp_filter = '''<filter name='nova-allow-dhcp-server' chain='ipv4'>
+ <uuid>891e4787-e5c0-d59b-cbd6-41bc3c6b36fc</uuid>
+ <rule action='accept' direction='out'
+ priority='100'>
+ <udp srcipaddr='0.0.0.0'
+ dstipaddr='255.255.255.255'
+ srcportstart='68'
+ dstportstart='67'/>
+ </rule>
+ <rule action='accept' direction='in' priority='100'>
+ <udp srcipaddr='$DHCPSERVER'
+ srcportstart='67'
+ dstportstart='68'/>
+ </rule>
+ </filter>'''
+
+ def nova_base_ipv4_filter(self):
+ retval = "<filter name='nova-base-ipv4' chain='ipv4'>"
+ for protocol in ['tcp', 'udp', 'icmp']:
+ for direction,action,priority in [('out','accept', 399),
+ ('inout','drop', 400)]:
+ retval += """<rule action='%s' direction='%s' priority='%d'>
+ <%s />
+ </rule>""" % (action, direction,
+ priority, protocol)
+ retval += '</filter>'
+ return retval
+
+
+ def nova_base_ipv6_filter(self):
+ retval = "<filter name='nova-base-ipv6' chain='ipv6'>"
+ for protocol in ['tcp', 'udp', 'icmp']:
+ for direction,action,priority in [('out','accept',399),
+ ('inout','drop',400)]:
+ retval += """<rule action='%s' direction='%s' priority='%d'>
+ <%s-ipv6 />
+ </rule>""" % (action, direction,
+ priority, protocol)
+ retval += '</filter>'
+ return retval
+
+
+ def nova_project_filter(self, project, net, mask):
+ retval = "<filter name='nova-project-%s' chain='ipv4'>" % project
+ for protocol in ['tcp', 'udp', 'icmp']:
+ retval += """<rule action='accept' direction='in' priority='200'>
+ <%s srcipaddr='%s' srcipmask='%s' />
+ </rule>""" % (protocol, net, mask)
+ retval += '</filter>'
+ return retval
+
+
+ def _define_filter(self, xml):
+ if callable(xml):
+ xml = xml()
+ d = threads.deferToThread(self._conn.nwfilterDefineXML, xml)
+ return d
+
+
+ @staticmethod
+ def _get_net_and_mask(cidr):
+ net = IPy.IP(cidr)
+ return str(net.net()), str(net.netmask())
+
+ @defer.inlineCallbacks
+ def setup_nwfilters_for_instance(self, instance):
+ """
+ Creates an NWFilter for the given instance. In the process,
+ it makes sure the filters for the security groups as well as
+ the base filter are all in place.
+ """
+
+ yield self._define_filter(self.nova_base_ipv4_filter)
+ yield self._define_filter(self.nova_base_ipv6_filter)
+ yield self._define_filter(self.nova_dhcp_filter)
+ yield self._define_filter(self.nova_base_filter)
+
+ nwfilter_xml = ("<filter name='nova-instance-%s' chain='root'>\n" +
+ " <filterref filter='nova-base' />\n"
+ ) % instance['name']
+
+ if FLAGS.allow_project_net_traffic:
+ network_ref = db.project_get_network({}, instance['project_id'])
+ net, mask = self._get_net_and_mask(network_ref['cidr'])
+ project_filter = self.nova_project_filter(instance['project_id'],
+ net, mask)
+ yield self._define_filter(project_filter)
+
+ nwfilter_xml += (" <filterref filter='nova-project-%s' />\n"
+ ) % instance['project_id']
+
+ for security_group in instance.security_groups:
+ yield self.ensure_security_group_filter(security_group['id'])
+
+ nwfilter_xml += (" <filterref filter='nova-secgroup-%d' />\n"
+ ) % security_group['id']
+ nwfilter_xml += "</filter>"
+
+ yield self._define_filter(nwfilter_xml)
+ return
+
+ def ensure_security_group_filter(self, security_group_id):
+ return self._define_filter(
+ self.security_group_to_nwfilter_xml(security_group_id))
+
+
+ def security_group_to_nwfilter_xml(self, security_group_id):
+ security_group = db.security_group_get({}, security_group_id)
+ rule_xml = ""
+ for rule in security_group.rules:
+ rule_xml += "<rule action='accept' direction='in' priority='300'>"
+ if rule.cidr:
+ net, mask = self._get_net_and_mask(rule.cidr)
+ rule_xml += "<%s srcipaddr='%s' srcipmask='%s' " % (rule.protocol, net, mask)
+ if rule.protocol in ['tcp', 'udp']:
+ rule_xml += "dstportstart='%s' dstportend='%s' " % \
+ (rule.from_port, rule.to_port)
+ elif rule.protocol == 'icmp':
+ logging.info('rule.protocol: %r, rule.from_port: %r, rule.to_port: %r' % (rule.protocol, rule.from_port, rule.to_port))
+ if rule.from_port != -1:
+ rule_xml += "type='%s' " % rule.from_port
+ if rule.to_port != -1:
+ rule_xml += "code='%s' " % rule.to_port
+
+ rule_xml += '/>\n'
+ rule_xml += "</rule>\n"
+ xml = '''<filter name='nova-secgroup-%s' chain='ipv4'>%s</filter>''' % (security_group_id, rule_xml,)
+ return xml
diff --git a/run_tests.py b/run_tests.py
index fa1e6f15b..0b27ec6cf 100644
--- a/run_tests.py
+++ b/run_tests.py
@@ -65,6 +65,7 @@ from nova.tests.service_unittest import *
from nova.tests.validator_unittest import *
from nova.tests.virt_unittest import *
from nova.tests.volume_unittest import *
+from nova.tests.virt_unittest import *
FLAGS = flags.FLAGS