summaryrefslogtreecommitdiffstats
path: root/roles/ask/files/sanction-client.py
blob: 737f2114ef6c7de95ed609cb4107b7830bc01720 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# vim: set ts=4 sw=)
""" OAuth 2.0 client librar
"""

from json import loads
from datetime import datetime, timedelta
from time import mktime
try:
    from urllib import urlencode
    from urllib2 import Request, urlopen
    from urlparse import urlsplit, urlunsplit, parse_qsl

    # monkeypatch httpmessage
    from httplib import HTTPMessage
    def get_charset(self):
        try:
            data = filter(lambda s: 'Content-Type' in s, self.headers)[0]
            if 'charset' in data:
                cs = data[data.index(';') + 1:-2].split('=')[1].lower()
                return cs
        except IndexError:
            pass

        return 'utf-8'
    HTTPMessage.get_content_charset = get_charset 
except ImportError:
    from urllib.parse import urlencode, urlsplit, urlunsplit, parse_qsl
    from urllib.request import Request, urlopen


class Client(object):
    """ OAuth 2.0 client object
    """

    def __init__(self, auth_endpoint=None, token_endpoint=None,
        resource_endpoint=None, client_id=None, client_secret=None,
        redirect_uri=None, token_transport=None):
        assert(hasattr(token_transport, '__call__') or 
            token_transport in ('headers', 'query', None))

        self.auth_endpoint = auth_endpoint
        self.token_endpoint = token_endpoint
        self.resource_endpoint = resource_endpoint
        self.redirect_uri = redirect_uri
        self.client_id = client_id
        self.client_secret = client_secret
        self.access_token = None
        self.token_transport = token_transport or 'query'
        self.token_expires = -1
        self.refresh_token = None

    def auth_uri(self, scope=None, scope_delim=None, state=None, **kwargs):
        """  Builds the auth URI for the authorization endpoint
        """
        scope_delim = scope_delim and scope_delim or ' '
        kwargs.update({
            'client_id': self.client_id,
            'response_type': 'code',
        })

        if scope is not None:
            kwargs['scope'] = scope_delim.join(scope)

        if state is not None:
            kwargs['state'] = state

        if self.redirect_uri is not None:
            kwargs['redirect_uri'] = self.redirect_uri

        return '%s?%s' % (self.auth_endpoint, urlencode(kwargs))

    def request_token(self, parser=None, exclude=None, **kwargs):
        """ Request an access token from the token endpoint.
        This is largely a helper method and expects the client code to
        understand what the server expects. Anything that's passed into
        ``**kwargs`` will be sent (``urlencode``d) to the endpoint. Client
        secret and client ID are automatically included, so are not required
        as kwargs. For example::

            # if requesting access token from auth flow:
            {
                'code': rval_from_auth,
            }

            # if refreshing access token:
            {
                'refresh_token': stored_refresh_token,
                'grant_type': 'refresh_token',
            }

        :param exclude: An iterable of fields to exclude from the ``POST``
                        data. This is useful for fields such as ``redirect_uri``
                        that are required during initial code/token exchange,
                        but will cause errors with some providers when
                        exchanging refresh tokens for new access tokens.
        :param parser: Callback to deal with returned data. Not all providers
                       use JSON.
        """
        kwargs = kwargs and kwargs or {}
        exclude = exclude or {}

        parser = parser and parser or loads
        kwargs.update({
            'client_id': self.client_id,
            'client_secret': self.client_secret,
            'grant_type': 'grant_type' in kwargs and kwargs['grant_type'] or \
                'authorization_code'
        })
        if self.redirect_uri is not None and 'redirect_uri' not in exclude:
            kwargs.update({'redirect_uri': self.redirect_uri})

        msg = urlopen(self.token_endpoint, urlencode(kwargs).encode(
            'utf-8'))
        data = parser(msg.read().decode(msg.info().get_content_charset() or
            'utf-8'))

        for key in data:
            setattr(self, key, data[key])

        # expires_in is RFC-compliant. if anything else is used by the
        # provider, token_expires must be set manually
        if hasattr(self, 'expires_in'):
            self.token_expires = mktime((datetime.utcnow() + timedelta(
                seconds=self.expires_in)).timetuple())

        assert(self.access_token is not None)

    def refresh(self):
        assert self.refresh_token is not None
        self.request_token(refresh_token=self.refresh_token,
            grant_type='refresh_token', exclude=('redirect_uri',))

    def request(self, url, method=None, data=None, parser=None): 
        """ Request user data from the resource endpoint
        :param url: The path to the resource and querystring if required
        :param method: HTTP method. Defaults to ``GET`` unless data is not None
                       in which case it defaults to ``POST``
        :param data: Data to be POSTed to the resource endpoint
        :param parser: Parser callback to deal with the returned data. Defaults
                       to ``json.loads`.`
        """
        assert(self.access_token is not None)
        parser = parser or loads

        if not method:
            method = 'GET' if not data else 'POST'

        if not hasattr(self.token_transport, '__call__'):
            transport = globals()['_transport_{0}'.format(self.token_transport)]
        else:
            transport = self.token_transport

        req = transport('{0}{1}'.format(self.resource_endpoint,
            url), self.access_token, data=data, method=method)

        resp = urlopen(req)
        data = resp.read()
        try:
            # try to decode it first using either the content charset, falling
            # back to utf8
            return parser(data.decode(resp.info().get_content_charset() or
                'utf-8'))
        except UnicodeDecodeError:
            # if we've gotten a decoder error, the calling code better know how
            # to deal with it. some providers (i.e. stackexchange) like to gzip
            # their responses, so this allows the client code to handle it
            # directly.
            return parser(data)

def _transport_headers(url, access_token, data=None, method=None):
    try:
        req = Request(url, data=data, method=method)
    except TypeError:
        req = Request(url, data=data)
        req.get_method = lambda: method

    req.headers.update({
        'Authorization': 'Bearer {0}'.format(access_token)
    })
    return req

def _transport_query(url, access_token, data=None, method=None):
    parts = urlsplit(url)
    query = dict(parse_qsl(parts.query))
    query.update({
        'access_token': access_token
    })
    url = urlunsplit((parts.scheme, parts.netloc, parts.path,
        urlencode(query), parts.fragment))
    try:
        req = Request(url, data=data, method=method)
    except TypeError:
        req = Request(url, data=data)
        req.get_method = lambda: method
    return req