summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTodd Willey <todd@ansolabs.com>2011-04-07 01:42:49 -0400
committerTodd Willey <todd@ansolabs.com>2011-04-07 01:42:49 -0400
commit2b79fa82872c55368167fc7433cb28a2369f5191 (patch)
tree89220f5f198ad73f2cad47a1878a70e381a7cee5
parent26d2a6ca8939156e8957e31dd17906070283ff24 (diff)
test provider fw rules at the virt/ipteables layer.
lowercase protocol names in admin api to match what the firewall driver expects. add provider fw rule chain in iptables6 as well. fix a couple of small typos and copy-paste errors.
-rw-r--r--nova/api/ec2/admin.py6
-rw-r--r--nova/tests/test_virt.py65
-rw-r--r--nova/virt/libvirt_conn.py23
3 files changed, 80 insertions, 14 deletions
diff --git a/nova/api/ec2/admin.py b/nova/api/ec2/admin.py
index 0b27854ef..c0c2bcd0d 100644
--- a/nova/api/ec2/admin.py
+++ b/nova/api/ec2/admin.py
@@ -356,11 +356,11 @@ class AdminController(object):
IPy.IP(cidr)
rule = {'cidr': cidr}
tcp_rule = rule.copy()
- tcp_rule.update({"protocol": "TCP", "from_port": 1, "to_port": 65535})
+ tcp_rule.update({"protocol": "tcp", "from_port": 1, "to_port": 65535})
udp_rule = rule.copy()
- udp_rule.update({"protocol": "UDP", "from_port": 1, "to_port": 65535})
+ udp_rule.update({"protocol": "udp", "from_port": 1, "to_port": 65535})
icmp_rule = rule.copy()
- icmp_rule.update({"protocol": "ICMP", "from_port": -1,
+ icmp_rule.update({"protocol": "icmp", "from_port": -1,
"to_port": None})
rules_added = 0
if not self._provider_fw_rule_exists(context, tcp_rule):
diff --git a/nova/tests/test_virt.py b/nova/tests/test_virt.py
index 958c8e3e2..34ecd09c6 100644
--- a/nova/tests/test_virt.py
+++ b/nova/tests/test_virt.py
@@ -566,7 +566,9 @@ class IptablesFirewallTestCase(test.TestCase):
self.network = utils.import_object(FLAGS.network_manager)
class FakeLibvirtConnection(object):
- pass
+ def nwfilterDefineXML(*args, **kwargs):
+ """setup_basic_rules in nwfilter calls this."""
+ pass
self.fake_libvirt_connection = FakeLibvirtConnection()
self.fw = libvirt_conn.IptablesFirewallDriver(
get_connection=lambda: self.fake_libvirt_connection)
@@ -728,6 +730,67 @@ class IptablesFirewallTestCase(test.TestCase):
"TCP port 80/81 acceptance rule wasn't added")
db.instance_destroy(admin_ctxt, instance_ref['id'])
+ def test_provider_firewall_rules(self):
+ # keep from changing state of actual firewall
+ #def fake_function(*args, **kwargs):
+ # pass
+ #self.fw.iptables.apply = fake_function
+ #self.fw.nwfilter.setup_basic_filtering = fake_function
+
+ # setup basic instance data
+ instance_ref = db.instance_create(self.context,
+ {'user_id': 'fake',
+ 'project_id': 'fake',
+ 'mac_address': '56:12:12:12:12:12'})
+ ip = '10.11.12.13'
+ network_ref = db.project_get_network(self.context, 'fake')
+ admin_ctxt = context.get_admin_context()
+ fixed_ip = {'address': ip, 'network_id': network_ref['id']}
+ db.fixed_ip_create(admin_ctxt, fixed_ip)
+ db.fixed_ip_update(admin_ctxt, ip, {'allocated': True,
+ 'instance_id': instance_ref['id']})
+ # FRAGILE: peeks at how the firewall names chains
+ chain_name = 'inst-%s' % instance_ref['id']
+
+ # create a firewall via setup_basic_filtering like libvirt_conn.spawn
+ # should have a chain with 0 rules
+ self.fw.setup_basic_filtering(instance_ref, network_info=None)
+ self.assertTrue('provider' in self.fw.iptables.ipv4['filter'].chains)
+ rules = [rule for rule in self.fw.iptables.ipv4['filter'].rules
+ if rule.chain == 'provider']
+ self.assertEqual(0, len(rules))
+
+ # add a rule and send the update message, check for 1 rule
+ provider_fw0 = db.provider_fw_rule_create(admin_ctxt,
+ {'protocol': 'tcp',
+ 'cidr': '10.99.99.99/32',
+ 'from_port': 1,
+ 'to_port': 65535})
+ self.fw.refresh_provider_fw_rules()
+ rules = [rule for rule in self.fw.iptables.ipv4['filter'].rules
+ if rule.chain == 'provider']
+ self.assertEqual(1, len(rules))
+
+ # Add another, refresh, and make sure number of rules goes to two
+ provider_fw1 = db.provider_fw_rule_create(admin_ctxt,
+ {'protocol': 'udp',
+ 'cidr': '10.99.99.99/32',
+ 'from_port': 1,
+ 'to_port': 65535})
+ self.fw.refresh_provider_fw_rules()
+ rules = [rule for rule in self.fw.iptables.ipv4['filter'].rules
+ if rule.chain == 'provider']
+ self.assertEqual(2, len(rules))
+
+ # create the instance filter and make sure it has a jump rule
+ self.fw.prepare_instance_filter(instance_ref, network_info=None)
+ self.fw.apply_instance_filter(instance_ref)
+ inst_rules = [rule for rule in self.fw.iptables.ipv4['filter'].rules
+ if rule.chain == chain_name]
+ jump_rules = [rule for rule in inst_rules if '-j' in rule.rule]
+ prov_rules = [rule for rule in jump_rules if 'provider' in rule.rule]
+ self.assertEqual(1, len(prov_rules))
+
class NWFilterTestCase(test.TestCase):
def setUp(self):
diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py
index 0d92e2e70..38ba21521 100644
--- a/nova/virt/libvirt_conn.py
+++ b/nova/virt/libvirt_conn.py
@@ -1939,6 +1939,7 @@ class IptablesFirewallDriver(FirewallDriver):
network_info = _get_network_info(instance)
self.nwfilter.setup_basic_filtering(instance, network_info)
if not self.basicly_filtered:
+ LOG.debug("Setup Basic Filtering")
self.refresh_provider_fw_rules()
self.basicly_filtered = True
@@ -1967,6 +1968,7 @@ class IptablesFirewallDriver(FirewallDriver):
chain_name = self._instance_chain_name(instance)
self.iptables.ipv4['filter'].add_chain(chain_name)
+ self.iptables.ipv4['filter'].empty_chain(chain_name)
ips_v4 = [ip['ip'] for (_, mapping) in network_info
for ip in mapping['ips']]
@@ -1978,6 +1980,7 @@ class IptablesFirewallDriver(FirewallDriver):
if FLAGS.use_ipv6:
self.iptables.ipv6['filter'].add_chain(chain_name)
+ self.iptables.ipv6['filter'].empty_chain(chain_name)
ips_v6 = [ip['ip'] for (_, mapping) in network_info
for ip in mapping['ip6s']]
@@ -1991,9 +1994,6 @@ class IptablesFirewallDriver(FirewallDriver):
for rule in ipv4_rules:
self.iptables.ipv4['filter'].add_rule(chain_name, rule)
- for rule in ipv4_rules:
- self.iptables.ipv4['filter'].add_rule(chain_name, rule)
-
if FLAGS.use_ipv6:
for rule in ipv6_rules:
self.iptables.ipv6['filter'].add_rule(chain_name, rule)
@@ -2042,7 +2042,7 @@ class IptablesFirewallDriver(FirewallDriver):
# they're not worth the clutter.
if FLAGS.use_ipv6:
# Allow RA responses
- gateways_v6 = [network['gateway_v6'] for (network, _) in
+ gateways_v6 = [network['gateway_v6'] for (network, _m) in
network_info]
for gateway_v6 in gateways_v6:
ipv6_rules.append(
@@ -2065,7 +2065,7 @@ class IptablesFirewallDriver(FirewallDriver):
security_group['id'])
for rule in rules:
- LOG.debug(_('Adding security group rule: %r'), rule)
+ LOG.debug(_("Adding security group rule: %r"), rule)
if not rule.cidr:
# Eventually, a mechanism to grant access for security
@@ -2139,8 +2139,8 @@ class IptablesFirewallDriver(FirewallDriver):
@utils.synchronized('iptables', external=True)
def _do_refresh_provider_fw_rules(self):
"""Internal, synchronized version of refresh_provider_fw_rules."""
- self.purge_provider_fw_rules(self)
- self.build_provider_fw_rules(self)
+ self._purge_provider_fw_rules()
+ self._build_provider_fw_rules()
def _purge_provider_fw_rules(self):
"""Remove all rules from the provider chains."""
@@ -2150,6 +2150,9 @@ class IptablesFirewallDriver(FirewallDriver):
def _build_provider_fw_rules(self):
"""Create all rules for the provider IP DROPs."""
+ self.iptables.ipv4['filter'].add_chain('provider')
+ if FLAGS.use_ipv6:
+ self.iptables.ipv6['filter'].add_chain('provider')
ipv4_rules, ipv6_rules = self._provider_rules()
for rule in ipv4_rules:
self.iptables.ipv4['filter'].add_rule('provider', rule)
@@ -2179,19 +2182,19 @@ class IptablesFirewallDriver(FirewallDriver):
fw_rules = ipv6_rules
protocol = rule.protocol
- if version == 6 and rule.protocol == 'icmp':
+ if version == 6 and protocol == 'icmp':
protocol = 'icmpv6'
args = ['-p', protocol, '-s', rule.cidr]
- if rule.protocol in ['udp', 'tcp']:
+ if protocol in ['udp', 'tcp']:
if rule.from_port == rule.to_port:
args += ['--dport', '%s' % (rule.from_port,)]
else:
args += ['-m', 'multiport',
'--dports', '%s:%s' % (rule.from_port,
rule.to_port)]
- elif rule.protocol == 'icmp':
+ elif protocol == 'icmp':
icmp_type = rule.from_port
icmp_code = rule.to_port