summaryrefslogtreecommitdiffstats
path: root/keystone/common/controller.py
blob: affc34de3b8fb8221071a07643008626fd09466f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
import collections
import functools
import uuid

from keystone.common import dependency
from keystone.common import logging
from keystone.common import wsgi
from keystone import config
from keystone import exception


LOG = logging.getLogger(__name__)
CONF = config.CONF
DEFAULT_DOMAIN_ID = CONF.identity.default_domain_id


def _build_policy_check_credentials(self, action, context, kwargs):
    LOG.debug(_('RBAC: Authorizing %(action)s(%(kwargs)s)') % {
        'action': action,
        'kwargs': ', '.join(['%s=%s' % (k, kwargs[k]) for k in kwargs])})

    try:
        token_ref = self.token_api.get_token(context['token_id'])
    except exception.TokenNotFound:
        LOG.warning(_('RBAC: Invalid token'))
        raise exception.Unauthorized()

    # NOTE(jamielennox): whilst this maybe shouldn't be within this function
    # it would otherwise need to reload the token_ref from backing store.
    wsgi.validate_token_bind(context, token_ref)

    creds = {}
    if 'token_data' in token_ref and 'token' in token_ref['token_data']:
        #V3 Tokens
        token_data = token_ref['token_data']['token']
        try:
            creds['user_id'] = token_data['user']['id']
        except AttributeError:
            LOG.warning(_('RBAC: Invalid user'))
            raise exception.Unauthorized()

        if 'project' in token_data:
            creds['project_id'] = token_data['project']['id']
        else:
            LOG.debug(_('RBAC: Proceeding without project'))

        if 'domain' in token_data:
            creds['domain_id'] = token_data['domain']['id']

        if 'roles' in token_data:
            creds['roles'] = []
            for role in token_data['roles']:
                creds['roles'].append(role['name'])
    else:
        #v2 Tokens
        creds = token_ref.get('metadata', {}).copy()
        try:
            creds['user_id'] = token_ref['user'].get('id')
        except AttributeError:
            LOG.warning(_('RBAC: Invalid user'))
            raise exception.Unauthorized()
        try:
            creds['project_id'] = token_ref['tenant'].get('id')
        except AttributeError:
            LOG.debug(_('RBAC: Proceeding without tenant'))
        # NOTE(vish): this is pretty inefficient
        creds['roles'] = [self.identity_api.get_role(role)['name']
                          for role in creds.get('roles', [])]

    return creds


def flatten(d, parent_key=''):
    """Flatten a nested dictionary

    Converts a dictionary with nested values to a single level flat
    dictionary, with dotted notation for each key.

    """
    items = []
    for k, v in d.items():
        new_key = parent_key + '.' + k if parent_key else k
        if isinstance(v, collections.MutableMapping):
            items.extend(flatten(v, new_key).items())
        else:
            items.append((new_key, v))
    return dict(items)


def protected(f):
    """Wraps API calls with role based access controls (RBAC)."""
    @functools.wraps(f)
    def wrapper(self, context, *args, **kwargs):
        if 'is_admin' in context and context['is_admin']:
            LOG.warning(_('RBAC: Bypassing authorization'))
        else:
            action = 'identity:%s' % f.__name__
            creds = _build_policy_check_credentials(self, action,
                                                    context, kwargs)
            # Simply use the passed kwargs as the target dict, which
            # would typically include the prime key of a get/update/delete
            # call.
            self.policy_api.enforce(creds, action, flatten(kwargs))
            LOG.debug(_('RBAC: Authorization granted'))

        return f(self, context, *args, **kwargs)
    return wrapper


def filterprotected(*filters):
    """Wraps filtered API calls with role based access controls (RBAC)."""

    def _filterprotected(f):
        @functools.wraps(f)
        def wrapper(self, context, **kwargs):
            if not context['is_admin']:
                action = 'identity:%s' % f.__name__
                creds = _build_policy_check_credentials(self, action,
                                                        context, kwargs)
                # Now, build the target dict for policy check.  We include:
                #
                # - Any query filter parameters
                # - Data from the main url (which will be in the kwargs
                #   parameter) and would typically include the prime key
                #   of a get/update/delete call
                #
                # First  any query filter parameters
                target = dict()
                if len(filters) > 0:
                    for filter in filters:
                        if filter in context['query_string']:
                            target[filter] = context['query_string'][filter]

                    LOG.debug(_('RBAC: Adding query filter params (%s)') % (
                        ', '.join(['%s=%s' % (filter, target[filter])
                                  for filter in target])))

                # Now any formal url parameters
                for key in kwargs:
                    target[key] = kwargs[key]

                self.policy_api.enforce(creds, action, flatten(target))

                LOG.debug(_('RBAC: Authorization granted'))
            else:
                LOG.warning(_('RBAC: Bypassing authorization'))
            return f(self, context, filters, **kwargs)
        return wrapper
    return _filterprotected


@dependency.requires('identity_api', 'policy_api', 'token_api',
                     'trust_api', 'catalog_api', 'credential_api',
                     'assignment_api')
class V2Controller(wsgi.Application):
    """Base controller class for Identity API v2."""

    def _delete_tokens_for_trust(self, user_id, trust_id):
        self.token_api.delete_tokens(user_id, trust_id=trust_id)

    def _delete_tokens_for_user(self, user_id, project_id=None):
        #First delete tokens that could get other tokens.
        self.token_api.delete_tokens(user_id, tenant_id=project_id)

        #delete tokens generated from trusts
        for trust in self.trust_api.list_trusts_for_trustee(user_id):
            self._delete_tokens_for_trust(user_id, trust['id'])
        for trust in self.trust_api.list_trusts_for_trustor(user_id):
            self._delete_tokens_for_trust(trust['trustee_user_id'],
                                          trust['id'])

    def _require_attribute(self, ref, attr):
        """Ensures the reference contains the specified attribute."""
        if ref.get(attr) is None or ref.get(attr) == '':
            msg = '%s field is required and cannot be empty' % attr
            raise exception.ValidationError(message=msg)

    def _normalize_domain_id(self, context, ref):
        """Fill in domain_id since v2 calls are not domain-aware.

        This will overwrite any domain_id that was inadvertently
        specified in the v2 call.

        """
        ref['domain_id'] = DEFAULT_DOMAIN_ID
        return ref

    def _filter_domain_id(self, ref):
        """Remove domain_id since v2 calls are not domain-aware."""
        ref.pop('domain_id', None)
        return ref


class V3Controller(V2Controller):
    """Base controller class for Identity API v3.

    Child classes should set the ``collection_name`` and ``member_name`` class
    attributes, representing the collection of entities they are exposing to
    the API. This is required for supporting self-referential links,
    pagination, etc.

    """

    collection_name = 'entities'
    member_name = 'entity'

    def _delete_tokens_for_group(self, group_id):
        user_refs = self.identity_api.list_users_in_group(group_id)
        for user in user_refs:
            self._delete_tokens_for_user(user['id'])

    @classmethod
    def base_url(cls, path=None):
        endpoint = CONF.public_endpoint % CONF

        # allow a missing trailing slash in the config
        if endpoint[-1] != '/':
            endpoint += '/'

        url = endpoint + 'v3'

        if path:
            return url + path
        else:
            return url + '/' + cls.collection_name

    @classmethod
    def _add_self_referential_link(cls, ref):
        ref.setdefault('links', {})
        ref['links']['self'] = cls.base_url() + '/' + ref['id']

    @classmethod
    def wrap_member(cls, context, ref):
        cls._add_self_referential_link(ref)
        return {cls.member_name: ref}

    @classmethod
    def wrap_collection(cls, context, refs, filters=[]):
        for f in filters:
            refs = cls.filter_by_attribute(context, refs, f)

        refs = cls.paginate(context, refs)

        for ref in refs:
            cls.wrap_member(context, ref)

        container = {cls.collection_name: refs}
        container['links'] = {
            'next': None,
            'self': cls.base_url(path=context['path']),
            'previous': None}
        return container

    @classmethod
    def paginate(cls, context, refs):
        """Paginates a list of references by page & per_page query strings."""
        # FIXME(dolph): client needs to support pagination first
        return refs

        page = context['query_string'].get('page', 1)
        per_page = context['query_string'].get('per_page', 30)
        return refs[per_page * (page - 1):per_page * page]

    @classmethod
    def filter_by_attribute(cls, context, refs, attr):
        """Filters a list of references by query string value."""

        def _attr_match(ref_attr, val_attr):
            """Matches attributes allowing for booleans as strings.

            We test explicitly for a value that defines it as 'False',
            which also means that the existence of the attribute with
            no value implies 'True'

            """
            if type(ref_attr) is bool:
                if (isinstance(val_attr, basestring) and
                        val_attr == '0'):
                    val = False
                else:
                    val = True
                return (ref_attr == val)
            else:
                return (ref_attr == val_attr)

        if attr in context['query_string']:
            value = context['query_string'][attr]
            return [r for r in refs if _attr_match(
                flatten(r).get(attr), value)]
        return refs

    def _require_matching_id(self, value, ref):
        """Ensures the value matches the reference's ID, if any."""
        if 'id' in ref and ref['id'] != value:
            raise exception.ValidationError('Cannot change ID')

    def _assign_unique_id(self, ref):
        """Generates and assigns a unique identifer to a reference."""
        ref = ref.copy()
        ref['id'] = uuid.uuid4().hex
        return ref

    def _normalize_domain_id(self, context, ref):
        """Fill in domain_id if not specified in a v3 call."""

        if 'domain_id' not in ref:
            if context['is_admin']:
                ref['domain_id'] = DEFAULT_DOMAIN_ID
            else:
                # Fish the domain_id out of the token
                #
                # We could make this more efficient by loading the domain_id
                # into the context in the wrapper function above (since
                # this version of normalize_domain will only be called inside
                # a v3 protected call).  However, given that we only use this
                # for creating entities, this optimization is probably not
                # worth the duplication of state
                try:
                    token_ref = self.token_api.get_token(
                        token_id=context['token_id'])
                except exception.TokenNotFound:
                    LOG.warning(_('Invalid token in normalize_domain_id'))
                    raise exception.Unauthorized()

                if 'domain' in token_ref:
                    ref['domain_id'] = token_ref['domain']['id']
                else:
                    # FIXME(henry-nash) Revisit this once v3 token scoping
                    # across domains has been hashed out
                    ref['domain_id'] = DEFAULT_DOMAIN_ID
        return ref

    def _filter_domain_id(self, ref):
        """Override v2 filter to let domain_id out for v3 calls."""
        return ref