summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJenkins <jenkins@review.openstack.org>2013-07-30 14:21:27 +0000
committerGerrit Code Review <review@openstack.org>2013-07-30 14:21:27 +0000
commitf58c936582a17cffcdcdb5bf624f24d574abaf0a (patch)
tree79d1b7c24678ce19c4da4240a840c29fbf3576eb
parentdf53b596491563aa1681d4ad21fd1ade1108d2e3 (diff)
parent3626b6db91ce1e897d1993fb4e13ac4237d04b7f (diff)
downloadoslo-f58c936582a17cffcdcdb5bf624f24d574abaf0a.tar.gz
oslo-f58c936582a17cffcdcdb5bf624f24d574abaf0a.tar.xz
oslo-f58c936582a17cffcdcdb5bf624f24d574abaf0a.zip
Merge "Fix policy default_rule issue"
-rw-r--r--openstack/common/policy.py18
-rw-r--r--tests/unit/test_policy.py47
2 files changed, 50 insertions, 15 deletions
diff --git a/openstack/common/policy.py b/openstack/common/policy.py
index 00531e5..02335ca 100644
--- a/openstack/common/policy.py
+++ b/openstack/common/policy.py
@@ -115,12 +115,18 @@ class Rules(dict):
def __missing__(self, key):
"""Implements the default rule handling."""
+ if isinstance(self.default_rule, dict):
+ raise KeyError(key)
+
# If the default rule isn't actually defined, do something
# reasonably intelligent
if not self.default_rule or self.default_rule not in self:
raise KeyError(key)
- return self[self.default_rule]
+ if isinstance(self.default_rule, BaseCheck):
+ return self.default_rule
+ elif isinstance(self.default_rule, six.string_types):
+ return self[self.default_rule]
def __str__(self):
"""Dumps a string representation of the rules."""
@@ -153,7 +159,7 @@ class Enforcer(object):
"""
def __init__(self, policy_file=None, rules=None, default_rule=None):
- self.rules = Rules(rules)
+ self.rules = Rules(rules, default_rule)
self.default_rule = default_rule or CONF.policy_default_rule
self.policy_path = None
@@ -172,13 +178,14 @@ class Enforcer(object):
"got %s instead") % type(rules))
if overwrite:
- self.rules = Rules(rules)
+ self.rules = Rules(rules, self.default_rule)
else:
- self.update(rules)
+ self.rules.update(rules)
def clear(self):
"""Clears Enforcer rules, policy's cache and policy's path."""
self.set_rules({})
+ self.default_rule = None
self.policy_path = None
def load_rules(self, force_reload=False):
@@ -194,8 +201,7 @@ class Enforcer(object):
reloaded, data = fileutils.read_cached_file(self.policy_path,
force_reload=force_reload)
-
- if reloaded:
+ if reloaded or not self.rules:
rules = Rules.load_json(data, self.default_rule)
self.set_rules(rules)
LOG.debug(_("Rules successfully reloaded"))
diff --git a/tests/unit/test_policy.py b/tests/unit/test_policy.py
index b7d38a3..2ccf71e 100644
--- a/tests/unit/test_policy.py
+++ b/tests/unit/test_policy.py
@@ -170,6 +170,44 @@ class EnforcerTest(PolicyBaseTestCase):
creds = {'roles': ''}
self.assertEqual(self.enforcer.enforce(action, {}, creds), True)
+ def test_enforcer_with_default_rule(self):
+ rules_json = """{
+ "deny_stack_user": "not role:stack_user",
+ "cloudwatch:PutMetricData": ""
+ }"""
+ rules = policy.Rules.load_json(rules_json)
+ default_rule = policy.TrueCheck()
+ enforcer = policy.Enforcer(default_rule=default_rule)
+ enforcer.set_rules(rules)
+ action = "cloudwatch:PutMetricData"
+ creds = {'roles': ''}
+ self.assertEqual(enforcer.enforce(action, {}, creds), True)
+
+ def test_enforcer_force_reload_true(self):
+ self.enforcer.set_rules({'test': 'test'})
+ self.enforcer.load_rules(force_reload=True)
+ self.assertNotIn({'test': 'test'}, self.enforcer.rules)
+ self.assertIn('default', self.enforcer.rules)
+ self.assertIn('admin', self.enforcer.rules)
+
+ def test_enforcer_force_reload_false(self):
+ self.enforcer.set_rules({'test': 'test'})
+ self.enforcer.load_rules(force_reload=False)
+ self.assertIn('test', self.enforcer.rules)
+ self.assertNotIn('default', self.enforcer.rules)
+ self.assertNotIn('admin', self.enforcer.rules)
+
+ def test_enforcer_overwrite_rules(self):
+ self.enforcer.set_rules({'test': 'test'})
+ self.enforcer.set_rules({'test': 'test1'}, overwrite=True)
+ self.assertEquals(self.enforcer.rules, {'test': 'test1'})
+
+ def test_enforcer_update_rules(self):
+ self.enforcer.set_rules({'test': 'test'})
+ self.enforcer.set_rules({'test1': 'test1'}, overwrite=False)
+ self.assertEquals(self.enforcer.rules, {'test': 'test',
+ 'test1': 'test1'})
+
class FakeCheck(policy.BaseCheck):
def __init__(self, result=None):
@@ -187,24 +225,15 @@ class FakeCheck(policy.BaseCheck):
class CheckFunctionTestCase(PolicyBaseTestCase):
def test_check_explicit(self):
- self.enforcer.load_rules()
- self.enforcer.rules = None
rule = FakeCheck()
result = self.enforcer.enforce(rule, "target", "creds")
-
self.assertEqual(result, ("target", "creds", self.enforcer))
- self.assertEqual(self.enforcer.rules, None)
def test_check_no_rules(self):
- self.enforcer.load_rules()
- self.enforcer.rules = None
result = self.enforcer.enforce('rule', "target", "creds")
-
self.assertEqual(result, False)
- self.assertEqual(self.enforcer.rules, None)
def test_check_missing_rule(self):
- self.enforcer.rules = {}
result = self.enforcer.enforce('rule', 'target', 'creds')
self.assertEqual(result, False)