summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSoren Hansen <soren.hansen@rackspace.com>2010-10-04 21:58:22 +0200
committerSoren Hansen <soren.hansen@rackspace.com>2010-10-04 21:58:22 +0200
commita4720c03a8260fb920035d072799d3ecc478db99 (patch)
tree03fd5726225a9ea6756bfbcc35996e6c19617346
parent3543d8430e02c1b22f1932cb9d0af028d9ef648b (diff)
Merge security group related changes from lp:~anso/nova/deploy
-rw-r--r--nova/api/ec2/cloud.py31
-rw-r--r--nova/db/sqlalchemy/api.py105
-rw-r--r--nova/tests/virt_unittest.py33
-rw-r--r--nova/virt/libvirt_conn.py39
4 files changed, 162 insertions, 46 deletions
diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py
index 839b84b4e..4cd4c78ae 100644
--- a/nova/api/ec2/cloud.py
+++ b/nova/api/ec2/cloud.py
@@ -327,6 +327,26 @@ class CloudController(object):
return values
+
+ def _security_group_rule_exists(self, security_group, values):
+ """Indicates whether the specified rule values are already
+ defined in the given security group.
+ """
+ for rule in security_group.rules:
+ if 'group_id' in values:
+ if rule['group_id'] == values['group_id']:
+ return True
+ else:
+ is_duplicate = True
+ for key in ('cidr', 'from_port', 'to_port', 'protocol'):
+ if rule[key] != values[key]:
+ is_duplicate = False
+ break
+ if is_duplicate:
+ return True
+ return False
+
+
def revoke_security_group_ingress(self, context, group_name, **kwargs):
self._ensure_default_security_group(context)
security_group = db.security_group_get_by_name(context,
@@ -348,9 +368,6 @@ class CloudController(object):
return True
raise exception.ApiError("No rule for the specified parameters.")
- # TODO(soren): Dupe detection. Adding the same rule twice actually
- # adds the same rule twice to the rule set, which is
- # pointless.
# TODO(soren): This has only been tested with Boto as the client.
# Unfortunately, it seems Boto is using an old API
# for these operations, so support for newer API versions
@@ -364,6 +381,10 @@ class CloudController(object):
values = self._authorize_revoke_rule_args_to_dict(context, **kwargs)
values['parent_group_id'] = security_group.id
+ if self._security_group_rule_exists(security_group, values):
+ raise exception.ApiError('This rule already exists in group %s' %
+ group_name)
+
security_group_rule = db.security_group_rule_create(context, values)
self._trigger_refresh_security_group(security_group)
@@ -709,7 +730,7 @@ class CloudController(object):
'description' : 'default',
'user_id' : context.user.id,
'project_id' : context.project.id }
- group = db.security_group_create({}, values)
+ group = db.security_group_create(context, values)
def run_instances(self, context, **kwargs):
instance_type = kwargs.get('instance_type', 'm1.small')
@@ -797,7 +818,7 @@ class CloudController(object):
inst_id = instance_ref['id']
for security_group_id in security_groups:
- db.instance_add_security_group(context, inst_id,
+ db.instance_add_security_group(context.admin(), inst_id,
security_group_id)
inst = {}
diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py
index d395b7e2c..bc5ef5a9b 100644
--- a/nova/db/sqlalchemy/api.py
+++ b/nova/db/sqlalchemy/api.py
@@ -572,11 +572,13 @@ def instance_get(context, instance_id, session=None):
if is_admin_context(context):
result = session.query(models.Instance
+ ).options(joinedload('security_groups')
).filter_by(id=instance_id
).filter_by(deleted=can_read_deleted(context)
).first()
elif is_user_context(context):
result = session.query(models.Instance
+ ).options(joinedload('security_groups')
).filter_by(project_id=context.project.id
).filter_by(id=instance_id
).filter_by(deleted=False
@@ -592,6 +594,7 @@ def instance_get_all(context):
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
+ ).options(joinedload('security_groups')
).filter_by(deleted=can_read_deleted(context)
).all()
@@ -601,6 +604,7 @@ def instance_get_all_by_user(context, user_id):
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
+ ).options(joinedload('security_groups')
).filter_by(deleted=can_read_deleted(context)
).filter_by(user_id=user_id
).all()
@@ -613,6 +617,7 @@ def instance_get_all_by_project(context, project_id):
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
+ ).options(joinedload('security_groups')
).filter_by(project_id=project_id
).filter_by(deleted=can_read_deleted(context)
).all()
@@ -625,12 +630,14 @@ def instance_get_all_by_reservation(context, reservation_id):
if is_admin_context(context):
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
+ ).options(joinedload('security_groups')
).filter_by(reservation_id=reservation_id
).filter_by(deleted=can_read_deleted(context)
).all()
elif is_user_context(context):
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
+ ).options(joinedload('security_groups')
).filter_by(project_id=context.project.id
).filter_by(reservation_id=reservation_id
).filter_by(deleted=False
@@ -643,11 +650,13 @@ def instance_get_by_ec2_id(context, ec2_id):
if is_admin_context(context):
result = session.query(models.Instance
+ ).options(joinedload('security_groups')
).filter_by(ec2_id=ec2_id
).filter_by(deleted=can_read_deleted(context)
).first()
elif is_user_context(context):
result = session.query(models.Instance
+ ).options(joinedload('security_groups')
).filter_by(project_id=context.project.id
).filter_by(ec2_id=ec2_id
).filter_by(deleted=False
@@ -721,9 +730,10 @@ def instance_add_security_group(context, instance_id, security_group_id):
"""Associate the given security group with the given instance"""
session = get_session()
with session.begin():
- instance_ref = models.Instance.find(instance_id, session=session)
- security_group_ref = models.SecurityGroup.find(security_group_id,
- session=session)
+ instance_ref = instance_get(context, instance_id, session=session)
+ security_group_ref = security_group_get(context,
+ security_group_id,
+ session=session)
instance_ref.security_groups += [security_group_ref]
instance_ref.save(session=session)
@@ -1202,6 +1212,7 @@ def volume_get(context, volume_id, session=None):
@require_admin_context
def volume_get_all(context):
+ session = get_session()
return session.query(models.Volume
).filter_by(deleted=can_read_deleted(context)
).all()
@@ -1292,27 +1303,39 @@ def volume_update(context, volume_id, values):
###################
-def security_group_get_all(_context):
+@require_context
+def security_group_get_all(context):
session = get_session()
return session.query(models.SecurityGroup
- ).filter_by(deleted=False
+ ).filter_by(deleted=can_read_deleted(context)
).options(joinedload_all('rules')
).all()
-def security_group_get(_context, security_group_id):
- session = get_session()
- result = session.query(models.SecurityGroup
- ).filter_by(deleted=False
- ).filter_by(id=security_group_id
- ).options(joinedload_all('rules')
- ).first()
+@require_context
+def security_group_get(context, security_group_id, session=None):
+ if not session:
+ session = get_session()
+ if is_admin_context(context):
+ result = session.query(models.SecurityGroup
+ ).filter_by(deleted=can_read_deleted(context),
+ ).filter_by(id=security_group_id
+ ).options(joinedload_all('rules')
+ ).first()
+ else:
+ result = session.query(models.SecurityGroup
+ ).filter_by(deleted=False
+ ).filter_by(id=security_group_id
+ ).filter_by(project_id=context.project_id
+ ).options(joinedload_all('rules')
+ ).first()
if not result:
raise exception.NotFound("No secuity group with id %s" %
security_group_id)
return result
+@require_context
def security_group_get_by_name(context, project_id, group_name):
session = get_session()
result = session.query(models.SecurityGroup
@@ -1329,7 +1352,8 @@ def security_group_get_by_name(context, project_id, group_name):
return result
-def security_group_get_by_project(_context, project_id):
+@require_context
+def security_group_get_by_project(context, project_id):
session = get_session()
return session.query(models.SecurityGroup
).filter_by(project_id=project_id
@@ -1338,7 +1362,8 @@ def security_group_get_by_project(_context, project_id):
).all()
-def security_group_get_by_instance(_context, instance_id):
+@require_context
+def security_group_get_by_instance(context, instance_id):
session = get_session()
return session.query(models.SecurityGroup
).filter_by(deleted=False
@@ -1349,15 +1374,17 @@ def security_group_get_by_instance(_context, instance_id):
).all()
-def security_group_exists(_context, project_id, group_name):
+@require_context
+def security_group_exists(context, project_id, group_name):
try:
- group = security_group_get_by_name(_context, project_id, group_name)
+ group = security_group_get_by_name(context, project_id, group_name)
return group != None
except exception.NotFound:
return False
-def security_group_create(_context, values):
+@require_context
+def security_group_create(context, values):
security_group_ref = models.SecurityGroup()
# FIXME(devcamcar): Unless I do this, rules fails with lazy load exception
# once save() is called. This will get cleaned up in next orm pass.
@@ -1368,7 +1395,8 @@ def security_group_create(_context, values):
return security_group_ref
-def security_group_destroy(_context, security_group_id):
+@require_context
+def security_group_destroy(context, security_group_id):
session = get_session()
with session.begin():
# TODO(vish): do we have to use sql here?
@@ -1378,35 +1406,62 @@ def security_group_destroy(_context, security_group_id):
'where group_id=:id',
{'id': security_group_id})
-def security_group_destroy_all(_context):
- session = get_session()
+@require_context
+def security_group_destroy_all(context, session=None):
+ if not session:
+ session = get_session()
with session.begin():
# TODO(vish): do we have to use sql here?
session.execute('update security_groups set deleted=1')
session.execute('update security_group_rules set deleted=1')
+
###################
-def security_group_rule_create(_context, values):
+@require_context
+def security_group_rule_get(context, security_group_rule_id, session=None):
+ if not session:
+ session = get_session()
+ if is_admin_context(context):
+ result = session.query(models.SecurityGroupIngressRule
+ ).filter_by(deleted=can_read_deleted(context)
+ ).filter_by(id=security_group_rule_id
+ ).first()
+ else:
+ # TODO(vish): Join to group and check for project_id
+ result = session.query(models.SecurityGroupIngressRule
+ ).filter_by(deleted=False
+ ).filter_by(id=security_group_rule_id
+ ).first()
+ if not result:
+ raise exception.NotFound("No secuity group rule with id %s" %
+ security_group_rule_id)
+ return result
+
+
+@require_context
+def security_group_rule_create(context, values):
security_group_rule_ref = models.SecurityGroupIngressRule()
for (key, value) in values.iteritems():
security_group_rule_ref[key] = value
security_group_rule_ref.save()
return security_group_rule_ref
-def security_group_rule_destroy(_context, security_group_rule_id):
+@require_context
+def security_group_rule_destroy(context, security_group_rule_id):
session = get_session()
with session.begin():
- model = models.SecurityGroupIngressRule
- security_group_rule = model.find(security_group_rule_id,
- session=session)
+ security_group_rule = security_group_rule_get(context,
+ security_group_rule_id,
+ session=session)
security_group_rule.delete(session=session)
###################
+@require_admin_context
def host_get_networks(context, host):
session = get_session()
with session.begin():
diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py
index 7fa8e52ac..8b0de6c29 100644
--- a/nova/tests/virt_unittest.py
+++ b/nova/tests/virt_unittest.py
@@ -19,7 +19,9 @@ from xml.dom.minidom import parseString
from nova import db
from nova import flags
from nova import test
+from nova.api import context
from nova.api.ec2 import cloud
+from nova.auth import manager
from nova.virt import libvirt_conn
FLAGS = flags.FLAGS
@@ -83,17 +85,20 @@ class NWFilterTestCase(test.TrialTestCase):
class Mock(object):
pass
- self.context = Mock()
- self.context.user = Mock()
- self.context.user.id = 'fake'
- self.context.user.is_superuser = lambda:True
- self.context.project = Mock()
- self.context.project.id = 'fake'
+ self.manager = manager.AuthManager()
+ self.user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
+ self.project = self.manager.create_project('fake', 'fake', 'fake')
+ self.context = context.APIRequestContext(self.user, self.project)
self.fake_libvirt_connection = Mock()
self.fw = libvirt_conn.NWFilterFirewall(self.fake_libvirt_connection)
+ def tearDown(self):
+ self.manager.delete_project(self.project)
+ self.manager.delete_user(self.user)
+
+
def test_cidr_rule_nwfilter_xml(self):
cloud_controller = cloud.CloudController()
cloud_controller.create_security_group(self.context,
@@ -107,7 +112,9 @@ class NWFilterTestCase(test.TrialTestCase):
cidr_ip='0.0.0.0/0')
- security_group = db.security_group_get_by_name({}, 'fake', 'testgroup')
+ security_group = db.security_group_get_by_name(self.context,
+ 'fake',
+ 'testgroup')
xml = self.fw.security_group_to_nwfilter_xml(security_group.id)
@@ -126,7 +133,8 @@ class NWFilterTestCase(test.TrialTestCase):
ip_conditions = rules[0].getElementsByTagName('tcp')
self.assertEqual(len(ip_conditions), 1)
- self.assertEqual(ip_conditions[0].getAttribute('srcipaddr'), '0.0.0.0/0')
+ self.assertEqual(ip_conditions[0].getAttribute('srcipaddr'), '0.0.0.0')
+ self.assertEqual(ip_conditions[0].getAttribute('srcipmask'), '0.0.0.0')
self.assertEqual(ip_conditions[0].getAttribute('dstportstart'), '80')
self.assertEqual(ip_conditions[0].getAttribute('dstportend'), '81')
@@ -150,7 +158,7 @@ class NWFilterTestCase(test.TrialTestCase):
ip_protocol='tcp',
cidr_ip='0.0.0.0/0')
- return db.security_group_get_by_name({}, 'fake', 'testgroup')
+ return db.security_group_get_by_name(self.context, 'fake', 'testgroup')
def test_creates_base_rule_first(self):
# These come pre-defined by libvirt
@@ -180,7 +188,8 @@ class NWFilterTestCase(test.TrialTestCase):
self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock
- instance_ref = db.instance_create({}, {'user_id': 'fake',
+ instance_ref = db.instance_create(self.context,
+ {'user_id': 'fake',
'project_id': 'fake'})
inst_id = instance_ref['id']
@@ -195,8 +204,8 @@ class NWFilterTestCase(test.TrialTestCase):
self.security_group = self.setup_and_return_security_group()
- db.instance_add_security_group({}, inst_id, self.security_group.id)
- instance = db.instance_get({}, inst_id)
+ db.instance_add_security_group(self.context, inst_id, self.security_group.id)
+ instance = db.instance_get(self.context, inst_id)
d = self.fw.setup_nwfilters_for_instance(instance)
d.addCallback(_ensure_all_called)
diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py
index 9d889cf29..319f7d2af 100644
--- a/nova/virt/libvirt_conn.py
+++ b/nova/virt/libvirt_conn.py
@@ -25,6 +25,7 @@ import logging
import os
import shutil
+import IPy
from twisted.internet import defer
from twisted.internet import task
from twisted.internet import threads
@@ -34,6 +35,7 @@ from nova import exception
from nova import flags
from nova import process
from nova import utils
+#from nova.api import context
from nova.auth import manager
from nova.compute import disk
from nova.compute import instance_types
@@ -61,6 +63,9 @@ flags.DEFINE_string('libvirt_uri',
'',
'Override the default libvirt URI (which is dependent'
' on libvirt_type)')
+flags.DEFINE_bool('allow_project_net_traffic',
+ True,
+ 'Whether to allow in project network traffic')
def get_connection(read_only):
@@ -135,7 +140,7 @@ class LibvirtConnection(object):
d.addCallback(lambda _: self._cleanup(instance))
# FIXME: What does this comment mean?
# TODO(termie): short-circuit me for tests
- # WE'LL save this for when we do shutdown,
+ # WE'LL save this for when we do shutdown,
# instead of destroy - but destroy returns immediately
timer = task.LoopingCall(f=None)
def _wait_for_shutdown():
@@ -550,6 +555,16 @@ class NWFilterFirewall(object):
return retval
+ def nova_project_filter(self, project, net, mask):
+ retval = "<filter name='nova-project-%s' chain='ipv4'>" % project
+ for protocol in ['tcp', 'udp', 'icmp']:
+ retval += """<rule action='accept' direction='in' priority='200'>
+ <%s srcipaddr='%s' srcipmask='%s' />
+ </rule>""" % (protocol, net, mask)
+ retval += '</filter>'
+ return retval
+
+
def _define_filter(self, xml):
if callable(xml):
xml = xml()
@@ -557,6 +572,11 @@ class NWFilterFirewall(object):
return d
+ @staticmethod
+ def _get_net_and_mask(cidr):
+ net = IPy.IP(cidr)
+ return str(net.net()), str(net.netmask())
+
@defer.inlineCallbacks
def setup_nwfilters_for_instance(self, instance):
"""
@@ -570,9 +590,19 @@ class NWFilterFirewall(object):
yield self._define_filter(self.nova_dhcp_filter)
yield self._define_filter(self.nova_base_filter)
- nwfilter_xml = ("<filter name='nova-instance-%s' chain='root'>\n" +
+ nwfilter_xml = ("<filter name='nova-instance-%s' chain='root'>\n" +
" <filterref filter='nova-base' />\n"
- ) % instance['name']
+ ) % instance['name']
+
+ if FLAGS.allow_project_net_traffic:
+ network_ref = db.project_get_network({}, instance['project_id'])
+ net, mask = self._get_net_and_mask(network_ref['cidr'])
+ project_filter = self.nova_project_filter(instance['project_id'],
+ net, mask)
+ yield self._define_filter(project_filter)
+
+ nwfilter_xml += (" <filterref filter='nova-project-%s' />\n"
+ ) % instance['project_id']
for security_group in instance.security_groups:
yield self.ensure_security_group_filter(security_group['id'])
@@ -595,7 +625,8 @@ class NWFilterFirewall(object):
for rule in security_group.rules:
rule_xml += "<rule action='accept' direction='in' priority='300'>"
if rule.cidr:
- rule_xml += "<%s srcipaddr='%s' " % (rule.protocol, rule.cidr)
+ net, mask = self._get_net_and_mask(rule.cidr)
+ rule_xml += "<%s srcipaddr='%s' srcipmask='%s' " % (rule.protocol, net, mask)
if rule.protocol in ['tcp', 'udp']:
rule_xml += "dstportstart='%s' dstportend='%s' " % \
(rule.from_port, rule.to_port)