summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSoren Hansen <soren@linux2go.dk>2011-07-22 22:41:29 +0200
committerSoren Hansen <soren@linux2go.dk>2011-07-22 22:41:29 +0200
commitc3cdcc1eb0c9fd37f49701d976c7ceae8df44caf (patch)
tree09fb706f6c3294e553ac2d81b02a30ce7b0a0b21
parentfa2cdbc5d4201ace6c1a6459bbd653b0b63b7667 (diff)
This is me being all cocky, thinking I'll make it use ipsets...
-rw-r--r--nova/compute/api.py18
-rw-r--r--nova/db/sqlalchemy/models.py6
-rw-r--r--nova/network/linux_net.py30
-rw-r--r--nova/network/manager.py24
-rw-r--r--nova/tests/test_iptables_network.py39
-rw-r--r--nova/virt/libvirt/firewall.py44
6 files changed, 139 insertions, 22 deletions
diff --git a/nova/compute/api.py b/nova/compute/api.py
index 432658bbb..65a594d2c 100644
--- a/nova/compute/api.py
+++ b/nova/compute/api.py
@@ -305,10 +305,6 @@ class API(base.Base):
updates['hostname'] = self.hostname_factory(instance)
instance = self.update(context, instance_id, **updates)
-
- for group_id in security_groups:
- self.trigger_security_group_members_refresh(elevated, group_id)
-
return instance
def _ask_scheduler_to_create_instance(self, context, base_options,
@@ -464,19 +460,22 @@ class API(base.Base):
{"method": "refresh_security_group_rules",
"args": {"security_group_id": security_group.id}})
- def trigger_security_group_members_refresh(self, context, group_id):
+ def trigger_security_group_members_refresh(self, context, group_ids):
"""Called when a security group gains a new or loses a member.
Sends an update request to each compute node for whom this is
relevant.
"""
- # First, we get the security group rules that reference this group as
+ # First, we get the security group rules that reference these groups as
# the grantee..
- security_group_rules = \
+ security_group_rules = set()
+ for group_id in group_ids:
+ security_group_rules.update(
self.db.security_group_rule_get_by_security_group_grantee(
context,
- group_id)
+ group_id))
+ LOG.info('rules: %r', security_group_rules)
# ..then we distill the security groups to which they belong..
security_groups = set()
for rule in security_group_rules:
@@ -485,12 +484,14 @@ class API(base.Base):
rule['parent_group_id'])
security_groups.add(security_group)
+ LOG.info('security_groups: %r', security_groups)
# ..then we find the instances that are members of these groups..
instances = set()
for security_group in security_groups:
for instance in security_group['instances']:
instances.add(instance)
+ LOG.info('instances: %r', instances)
# ...then we find the hosts where they live...
hosts = set()
for instance in instances:
@@ -500,6 +501,7 @@ class API(base.Base):
# ...and finally we tell these nodes to refresh their view of this
# particular security group.
for host in hosts:
+ LOG.info('host: %r', host)
rpc.cast(context,
self.db.queue_get_for(context, FLAGS.compute_topic, host),
{"method": "refresh_security_group_members",
diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py
index d29d3d6f1..023821dfb 100644
--- a/nova/db/sqlalchemy/models.py
+++ b/nova/db/sqlalchemy/models.py
@@ -491,6 +491,12 @@ class SecurityGroupIngressRule(BASE, NovaBase):
# Note: This is not the parent SecurityGroup. It's SecurityGroup we're
# granting access for.
group_id = Column(Integer, ForeignKey('security_groups.id'))
+ grantee_group = relationship("SecurityGroup",
+ foreign_keys=group_id,
+ primaryjoin='and_('
+ 'SecurityGroupIngressRule.group_id == SecurityGroup.id,'
+ 'SecurityGroupIngressRule.deleted == False)')
+
class ProviderFirewallRule(BASE, NovaBase):
diff --git a/nova/network/linux_net.py b/nova/network/linux_net.py
index 283a5aca1..0e021a40f 100644
--- a/nova/network/linux_net.py
+++ b/nova/network/linux_net.py
@@ -96,6 +96,33 @@ class IptablesRule(object):
chain = self.chain
return '-A %s %s' % (chain, self.rule)
+class IpSet(object):
+ """A class for handling large collections of IPs efficiently"""
+
+ def __init__(self, name, execute=None):
+ self.name = name
+ self._ips = set()
+ if not execute:
+ self.execute = _execute
+ else:
+ self.execute = execute
+
+ def __contains__(self, addr):
+ return addr in self._ips
+
+ def _set_name(self):
+ return '%s-%s' % (binary_name, self.name)
+
+ def add_ip(self, addr):
+ self._ips.add(addr)
+ self.execute('ipset', '-A', self._set_name(), addr)
+
+ def remove_ip(self, addr):
+ self._ips.remove(addr)
+ self.execute('ipset', '-D', self._set_name(), addr)
+
+ def iptables_source_match(self):
+ return ['-m set --match-set %s src' % (self._set_name(),)]
class IptablesTable(object):
"""An iptables table."""
@@ -281,6 +308,9 @@ class IptablesManager(object):
self.ipv4['nat'].add_chain('floating-snat')
self.ipv4['nat'].add_rule('snat', '-j $floating-snat')
+ def ipset_supported(self):
+ return False
+
@utils.synchronized('iptables', external=True)
def apply(self):
"""Apply the current in-memory set of iptables rules.
diff --git a/nova/network/manager.py b/nova/network/manager.py
index 824e8d24d..928cb09f6 100644
--- a/nova/network/manager.py
+++ b/nova/network/manager.py
@@ -63,6 +63,7 @@ from nova import quota
from nova import utils
from nova import rpc
from nova.network import api as network_api
+from nova.compute import api as compute_api
import random
@@ -297,6 +298,7 @@ class NetworkManager(manager.SchedulerDependentManager):
network_driver = FLAGS.network_driver
self.driver = utils.import_object(network_driver)
self.network_api = network_api.API()
+ self.compute_api = compute_api.API()
super(NetworkManager, self).__init__(service_name='network',
*args, **kwargs)
@@ -350,6 +352,15 @@ class NetworkManager(manager.SchedulerDependentManager):
# return so worker will only grab 1 (to help scale flatter)
return self.set_network_host(context, network['id'])
+ def _do_trigger_security_group_members_refresh_for_instance(self,
+ context,
+ instance_id):
+ instance_ref = db.instance_get(context, instance_id)
+ groups = instance_ref.security_groups
+ group_ids = [group.id for group in groups]
+ self.compute_api.trigger_security_group_members_refresh(context,
+ group_ids)
+
def _get_networks_for_instance(self, context, instance_id, project_id):
"""Determine & return which networks an instance should connect to."""
# TODO(tr3buchet) maybe this needs to be updated in the future if
@@ -511,6 +522,9 @@ class NetworkManager(manager.SchedulerDependentManager):
address = self.db.fixed_ip_associate_pool(context.elevated(),
network['id'],
instance_id)
+ self._do_trigger_security_group_members_refresh_for_instance(
+ context,
+ instance_id)
vif = self.db.virtual_interface_get_by_instance_and_network(context,
instance_id,
network['id'])
@@ -524,6 +538,12 @@ class NetworkManager(manager.SchedulerDependentManager):
self.db.fixed_ip_update(context, address,
{'allocated': False,
'virtual_interface_id': None})
+ fixed_ip_ref = self.db.fixed_ip_get_by_address(context, address)
+ instance_ref = fixed_ip_ref['instance']
+ instance_id = instance_ref['id']
+ self._do_trigger_security_group_members_refresh_for_instance(
+ context,
+ instance_id)
def lease_fixed_ip(self, context, address):
"""Called by dhcp-bridge when ip is leased."""
@@ -825,7 +845,9 @@ class VlanManager(RPCAllocateFixedIP, FloatingIP, NetworkManager):
address = self.db.fixed_ip_associate_pool(context,
network['id'],
instance_id)
-
+ self._do_trigger_security_group_members_refresh_for_instance(
+ context,
+ instance_id)
vif = self.db.virtual_interface_get_by_instance_and_network(context,
instance_id,
network['id'])
diff --git a/nova/tests/test_iptables_network.py b/nova/tests/test_iptables_network.py
index 918034269..d0a8c052c 100644
--- a/nova/tests/test_iptables_network.py
+++ b/nova/tests/test_iptables_network.py
@@ -17,11 +17,46 @@
# under the License.
"""Unit Tests for network code."""
-import os
-
from nova import test
from nova.network import linux_net
+class IpSetTestCase(test.TestCase):
+ def test_add(self):
+ """Adding an address"""
+ ipset = linux_net.IpSet('somename')
+
+ ipset.add_ip('1.2.3.4')
+ self.assertTrue('1.2.3.4' in ipset)
+
+
+ def test_add_remove(self):
+ """Adding and then removing an address"""
+
+ self.verify_cmd_call_count = 0
+ def verify_cmd(*args):
+ self.assertEquals(args, self.expected_cmd)
+ self.verify_cmd_call_count += 1
+
+ self.expected_cmd = ('ipset', '-A', 'run_tests.py-somename', '1.2.3.4')
+ ipset = linux_net.IpSet('somename',execute=verify_cmd)
+ ipset.add_ip('1.2.3.4')
+ self.assertTrue('1.2.3.4' in ipset)
+
+ self.expected_cmd = ('ipset', '-D', 'run_tests.py-somename', '1.2.3.4')
+ ipset.remove_ip('1.2.3.4')
+ self.assertTrue('1.2.3.4' not in ipset)
+ self.assertEquals(self.verify_cmd_call_count, 2)
+
+
+ def test_two_adds_one_remove(self):
+ """Adding the same address twice works. Removing it once removes it entirely."""
+ ipset = linux_net.IpSet('somename')
+
+ ipset.add_ip('1.2.3.4')
+ ipset.add_ip('1.2.3.4')
+ ipset.remove_ip('1.2.3.4')
+ self.assertTrue('1.2.3.4' not in ipset)
+
class IptablesManagerTestCase(test.TestCase):
sample_filter = ['#Generated by iptables-save on Fri Feb 18 15:17:05 2011',
diff --git a/nova/virt/libvirt/firewall.py b/nova/virt/libvirt/firewall.py
index 379197398..aa36e4184 100644
--- a/nova/virt/libvirt/firewall.py
+++ b/nova/virt/libvirt/firewall.py
@@ -663,11 +663,10 @@ class IptablesFirewallDriver(FirewallDriver):
LOG.debug(_('Adding security group rule: %r'), rule)
if not rule.cidr:
- # Eventually, a mechanism to grant access for security
- # groups will turn up here. It'll use ipsets.
- continue
+ version = 4
+ else:
+ version = netutils.get_ip_version(rule.cidr)
- version = netutils.get_ip_version(rule.cidr)
if version == 4:
fw_rules = ipv4_rules
else:
@@ -677,16 +676,16 @@ class IptablesFirewallDriver(FirewallDriver):
if version == 6 and rule.protocol == 'icmp':
protocol = 'icmpv6'
- args = ['-p', protocol, '-s', rule.cidr]
+ args = ['-j ACCEPT', '-p', protocol]
- if rule.protocol in ['udp', 'tcp']:
+ if protocol in ['udp', 'tcp']:
if rule.from_port == rule.to_port:
args += ['--dport', '%s' % (rule.from_port,)]
else:
args += ['-m', 'multiport',
'--dports', '%s:%s' % (rule.from_port,
rule.to_port)]
- elif rule.protocol == 'icmp':
+ elif protocol == 'icmp':
icmp_type = rule.from_port
icmp_code = rule.to_port
@@ -705,9 +704,30 @@ class IptablesFirewallDriver(FirewallDriver):
args += ['-m', 'icmp6', '--icmpv6-type',
icmp_type_arg]
- args += ['-j ACCEPT']
- fw_rules += [' '.join(args)]
-
+ if rule.cidr:
+ LOG.info('Using cidr %r', rule.cidr)
+ args += ['-s', rule.cidr]
+ fw_rules += [' '.join(args)]
+ else:
+ LOG.info('Not using cidr %r', rule.cidr)
+ if self.iptables.ipset_supported():
+ LOG.info('ipset supported %r', rule.cidr)
+ ipset = linux_net.IpSet('%s' % rule.group_id)
+ args += ipset.iptables_source_match()
+ fw_rules += [' '.join(args)]
+ else:
+ LOG.info('ipset unsupported %r', rule.cidr)
+ LOG.info('rule.grantee_group.instances: %r', rule.grantee_group.instances)
+ for instance in rule.grantee_group.instances:
+ LOG.info('instance: %r', instance)
+ ips = db.instance_get_fixed_addresses(ctxt,
+ instance['id'])
+ LOG.info('ips: %r', ips)
+ for ip in ips:
+ subrule = args + ['-s %s' % ip]
+ fw_rules += [' '.join(subrule)]
+
+ LOG.info('Using fw_rules: %r', fw_rules)
ipv4_rules += ['-j $sg-fallback']
ipv6_rules += ['-j $sg-fallback']
@@ -718,7 +738,9 @@ class IptablesFirewallDriver(FirewallDriver):
return self.nwfilter.instance_filter_exists(instance)
def refresh_security_group_members(self, security_group):
- pass
+ if not self.iptables.ipset_supported():
+ self.do_refresh_security_group_rules(security_group)
+ self.iptables.apply()
def refresh_security_group_rules(self, security_group, network_info=None):
self.do_refresh_security_group_rules(security_group, network_info)