diff options
| author | Jenkins <jenkins@review.openstack.org> | 2013-07-30 14:21:27 +0000 |
|---|---|---|
| committer | Gerrit Code Review <review@openstack.org> | 2013-07-30 14:21:27 +0000 |
| commit | f58c936582a17cffcdcdb5bf624f24d574abaf0a (patch) | |
| tree | 79d1b7c24678ce19c4da4240a840c29fbf3576eb | |
| parent | df53b596491563aa1681d4ad21fd1ade1108d2e3 (diff) | |
| parent | 3626b6db91ce1e897d1993fb4e13ac4237d04b7f (diff) | |
| download | oslo-f58c936582a17cffcdcdb5bf624f24d574abaf0a.tar.gz oslo-f58c936582a17cffcdcdb5bf624f24d574abaf0a.tar.xz oslo-f58c936582a17cffcdcdb5bf624f24d574abaf0a.zip | |
Merge "Fix policy default_rule issue"
| -rw-r--r-- | openstack/common/policy.py | 18 | ||||
| -rw-r--r-- | tests/unit/test_policy.py | 47 |
2 files changed, 50 insertions, 15 deletions
diff --git a/openstack/common/policy.py b/openstack/common/policy.py index 00531e5..02335ca 100644 --- a/openstack/common/policy.py +++ b/openstack/common/policy.py @@ -115,12 +115,18 @@ class Rules(dict): def __missing__(self, key): """Implements the default rule handling.""" + if isinstance(self.default_rule, dict): + raise KeyError(key) + # If the default rule isn't actually defined, do something # reasonably intelligent if not self.default_rule or self.default_rule not in self: raise KeyError(key) - return self[self.default_rule] + if isinstance(self.default_rule, BaseCheck): + return self.default_rule + elif isinstance(self.default_rule, six.string_types): + return self[self.default_rule] def __str__(self): """Dumps a string representation of the rules.""" @@ -153,7 +159,7 @@ class Enforcer(object): """ def __init__(self, policy_file=None, rules=None, default_rule=None): - self.rules = Rules(rules) + self.rules = Rules(rules, default_rule) self.default_rule = default_rule or CONF.policy_default_rule self.policy_path = None @@ -172,13 +178,14 @@ class Enforcer(object): "got %s instead") % type(rules)) if overwrite: - self.rules = Rules(rules) + self.rules = Rules(rules, self.default_rule) else: - self.update(rules) + self.rules.update(rules) def clear(self): """Clears Enforcer rules, policy's cache and policy's path.""" self.set_rules({}) + self.default_rule = None self.policy_path = None def load_rules(self, force_reload=False): @@ -194,8 +201,7 @@ class Enforcer(object): reloaded, data = fileutils.read_cached_file(self.policy_path, force_reload=force_reload) - - if reloaded: + if reloaded or not self.rules: rules = Rules.load_json(data, self.default_rule) self.set_rules(rules) LOG.debug(_("Rules successfully reloaded")) diff --git a/tests/unit/test_policy.py b/tests/unit/test_policy.py index b7d38a3..2ccf71e 100644 --- a/tests/unit/test_policy.py +++ b/tests/unit/test_policy.py @@ -170,6 +170,44 @@ class EnforcerTest(PolicyBaseTestCase): creds = {'roles': ''} self.assertEqual(self.enforcer.enforce(action, {}, creds), True) + def test_enforcer_with_default_rule(self): + rules_json = """{ + "deny_stack_user": "not role:stack_user", + "cloudwatch:PutMetricData": "" + }""" + rules = policy.Rules.load_json(rules_json) + default_rule = policy.TrueCheck() + enforcer = policy.Enforcer(default_rule=default_rule) + enforcer.set_rules(rules) + action = "cloudwatch:PutMetricData" + creds = {'roles': ''} + self.assertEqual(enforcer.enforce(action, {}, creds), True) + + def test_enforcer_force_reload_true(self): + self.enforcer.set_rules({'test': 'test'}) + self.enforcer.load_rules(force_reload=True) + self.assertNotIn({'test': 'test'}, self.enforcer.rules) + self.assertIn('default', self.enforcer.rules) + self.assertIn('admin', self.enforcer.rules) + + def test_enforcer_force_reload_false(self): + self.enforcer.set_rules({'test': 'test'}) + self.enforcer.load_rules(force_reload=False) + self.assertIn('test', self.enforcer.rules) + self.assertNotIn('default', self.enforcer.rules) + self.assertNotIn('admin', self.enforcer.rules) + + def test_enforcer_overwrite_rules(self): + self.enforcer.set_rules({'test': 'test'}) + self.enforcer.set_rules({'test': 'test1'}, overwrite=True) + self.assertEquals(self.enforcer.rules, {'test': 'test1'}) + + def test_enforcer_update_rules(self): + self.enforcer.set_rules({'test': 'test'}) + self.enforcer.set_rules({'test1': 'test1'}, overwrite=False) + self.assertEquals(self.enforcer.rules, {'test': 'test', + 'test1': 'test1'}) + class FakeCheck(policy.BaseCheck): def __init__(self, result=None): @@ -187,24 +225,15 @@ class FakeCheck(policy.BaseCheck): class CheckFunctionTestCase(PolicyBaseTestCase): def test_check_explicit(self): - self.enforcer.load_rules() - self.enforcer.rules = None rule = FakeCheck() result = self.enforcer.enforce(rule, "target", "creds") - self.assertEqual(result, ("target", "creds", self.enforcer)) - self.assertEqual(self.enforcer.rules, None) def test_check_no_rules(self): - self.enforcer.load_rules() - self.enforcer.rules = None result = self.enforcer.enforce('rule', "target", "creds") - self.assertEqual(result, False) - self.assertEqual(self.enforcer.rules, None) def test_check_missing_rule(self): - self.enforcer.rules = {} result = self.enforcer.enforce('rule', 'target', 'creds') self.assertEqual(result, False) |
