summaryrefslogtreecommitdiffstats
path: root/ipsilon/providers/saml2/provider.py
blob: d3ed5daad9e0ea43aad9643e332de387dd43c572 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/python
#
# Copyright (C) 2014  Simo Sorce <simo@redhat.com>
#
# see file 'COPYING' for use and warranty information
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from ipsilon.providers.common import ProviderException
import cherrypy
import lasso


NAMEID_MAP = {
    'email': lasso.SAML2_NAME_IDENTIFIER_FORMAT_EMAIL,
    'encrypted': lasso.SAML2_NAME_IDENTIFIER_FORMAT_ENCRYPTED,
    'entity': lasso.SAML2_NAME_IDENTIFIER_FORMAT_ENTITY,
    'kerberos': lasso.SAML2_NAME_IDENTIFIER_FORMAT_KERBEROS,
    'persistent': lasso.SAML2_NAME_IDENTIFIER_FORMAT_PERSISTENT,
    'transient': lasso.SAML2_NAME_IDENTIFIER_FORMAT_TRANSIENT,
    'unspecified': lasso.SAML2_NAME_IDENTIFIER_FORMAT_UNSPECIFIED,
    'windows': lasso.SAML2_NAME_IDENTIFIER_FORMAT_WINDOWS,
    'x509': lasso.SAML2_NAME_IDENTIFIER_FORMAT_X509,
}


class InvalidProviderId(ProviderException):

    def __init__(self, code):
        message = 'Invalid Provider ID: %s' % code
        super(InvalidProviderId, self).__init__(message)
        self._debug(message)


class NameIdNotAllowed(Exception):

    def __init__(self):
        message = 'The specified Name ID is not allowed'
        super(NameIdNotAllowed, self).__init__(message)
        self.message = message

    def __str__(self):
        return repr(self.message)


class ServiceProvider(object):

    def __init__(self, config, provider_id):
        self.cfg = config
        data = self.cfg.get_data(name='id', value=provider_id)
        if len(data) != 1:
            raise InvalidProviderId('multiple matches')
        idval = data.keys()[0]
        data = self.cfg.get_data(idval=idval)
        self._properties = data[idval]

    @property
    def provider_id(self):
        return self._properties['id']

    @property
    def name(self):
        return self._properties['name']

    @property
    def allowed_namedids(self):
        if 'allowed nameid' in self._properties:
            return self._properties['allowed nameid']
        else:
            return self.cfg.default_allowed_nameids

    @property
    def default_nameid(self):
        if 'default nameid' in self._properties:
            return self._properties['default nameid']
        else:
            return self.cfg.default_nameid

    def get_valid_nameid(self, nip):
        self._debug('Requested NameId [%s]' % (nip.format,))
        if nip.format is None:
            return NAMEID_MAP[self.default_nameid]
        elif nip.format == lasso.SAML2_NAME_IDENTIFIER_FORMAT_UNSPECIFIED:
            return NAMEID_MAP[self.default_nameid]
        else:
            allowed = self.allowed_namedids
            self._debug('Allowed NameIds %s' % (repr(allowed)))
            for nameid in allowed:
                if nip.format == NAMEID_MAP[nameid]:
                    return nip.format
        raise NameIdNotAllowed(nip.format)

    def _debug(self, fact):
        if cherrypy.config.get('debug', False):
            cherrypy.log(fact)

    def normalize_username(self, username):
        if 'strip domain' in self._properties:
            return username.split('@', 1)[0]
        return username