summaryrefslogtreecommitdiffstats
path: root/keystone/contrib/oauth1/backends/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'keystone/contrib/oauth1/backends/sql.py')
-rw-r--r--keystone/contrib/oauth1/backends/sql.py284
1 files changed, 284 insertions, 0 deletions
diff --git a/keystone/contrib/oauth1/backends/sql.py b/keystone/contrib/oauth1/backends/sql.py
new file mode 100644
index 00000000..9dc0665c
--- /dev/null
+++ b/keystone/contrib/oauth1/backends/sql.py
@@ -0,0 +1,284 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2013 OpenStack Foundation
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import random
+import uuid
+
+from keystone.common import sql
+from keystone.common.sql import migration
+from keystone.contrib.oauth1 import core
+from keystone import exception
+from keystone.openstack.common import timeutils
+
+
+class Consumer(sql.ModelBase, sql.DictBase):
+ __tablename__ = 'consumer'
+ attributes = ['id', 'description', 'secret']
+ id = sql.Column(sql.String(64), primary_key=True, nullable=False)
+ description = sql.Column(sql.String(64), nullable=False)
+ secret = sql.Column(sql.String(64), nullable=False)
+ extra = sql.Column(sql.JsonBlob(), nullable=False)
+
+
+class RequestToken(sql.ModelBase, sql.DictBase):
+ __tablename__ = 'request_token'
+ attributes = ['id', 'request_secret',
+ 'verifier', 'authorizing_user_id', 'requested_project_id',
+ 'requested_roles', 'consumer_id', 'expires_at']
+ id = sql.Column(sql.String(64), primary_key=True, nullable=False)
+ request_secret = sql.Column(sql.String(64), nullable=False)
+ verifier = sql.Column(sql.String(64), nullable=True)
+ authorizing_user_id = sql.Column(sql.String(64), nullable=True)
+ requested_project_id = sql.Column(sql.String(64), nullable=False)
+ requested_roles = sql.Column(sql.Text(), nullable=False)
+ consumer_id = sql.Column(sql.String(64), nullable=False, index=True)
+ expires_at = sql.Column(sql.String(64), nullable=True)
+
+ @classmethod
+ def from_dict(cls, user_dict):
+ return cls(**user_dict)
+
+ def to_dict(self):
+ return dict(self.iteritems())
+
+
+class AccessToken(sql.ModelBase, sql.DictBase):
+ __tablename__ = 'access_token'
+ attributes = ['id', 'access_secret', 'authorizing_user_id',
+ 'project_id', 'requested_roles', 'consumer_id',
+ 'expires_at']
+ id = sql.Column(sql.String(64), primary_key=True, nullable=False)
+ access_secret = sql.Column(sql.String(64), nullable=False)
+ authorizing_user_id = sql.Column(sql.String(64), nullable=False,
+ index=True)
+ project_id = sql.Column(sql.String(64), nullable=False)
+ requested_roles = sql.Column(sql.Text(), nullable=False)
+ consumer_id = sql.Column(sql.String(64), nullable=False)
+ expires_at = sql.Column(sql.String(64), nullable=True)
+
+ @classmethod
+ def from_dict(cls, user_dict):
+ return cls(**user_dict)
+
+ def to_dict(self):
+ return dict(self.iteritems())
+
+
+class OAuth1(sql.Base):
+ def db_sync(self):
+ migration.db_sync()
+
+ def _get_consumer(self, consumer_id):
+ session = self.get_session()
+ consumer_ref = session.query(Consumer).get(consumer_id)
+ if consumer_ref is None:
+ raise exception.NotFound(_('Consumer not found'))
+ return consumer_ref
+
+ def get_consumer(self, consumer_id):
+ session = self.get_session()
+ consumer_ref = session.query(Consumer).get(consumer_id)
+ if consumer_ref is None:
+ raise exception.NotFound(_('Consumer not found'))
+ return core.filter_consumer(consumer_ref.to_dict())
+
+ def create_consumer(self, consumer):
+ consumer['secret'] = uuid.uuid4().hex
+ if not consumer.get('description'):
+ consumer['description'] = None
+ session = self.get_session()
+ with session.begin():
+ consumer_ref = Consumer.from_dict(consumer)
+ session.add(consumer_ref)
+ session.flush()
+ return consumer_ref.to_dict()
+
+ def _delete_consumer(self, session, consumer_id):
+ consumer_ref = self._get_consumer(session, consumer_id)
+ q = session.query(Consumer)
+ q = q.filter_by(id=consumer_id)
+ q.delete(False)
+ session.delete(consumer_ref)
+
+ def _delete_request_tokens(self, session, consumer_id):
+ q = session.query(RequestToken)
+ req_tokens = q.filter_by(consumer_id=consumer_id)
+ req_tokens_list = set([x.id for x in req_tokens])
+ for token_id in req_tokens_list:
+ token_ref = self._get_request_token(session, token_id)
+ q = session.query(RequestToken)
+ q = q.filter_by(id=token_id)
+ q.delete(False)
+ session.delete(token_ref)
+
+ def _delete_access_tokens(self, session, consumer_id):
+ q = session.query(AccessToken)
+ acc_tokens = q.filter_by(consumer_id=consumer_id)
+ acc_tokens_list = set([x.id for x in acc_tokens])
+ for token_id in acc_tokens_list:
+ token_ref = self._get_access_token(session, token_id)
+ q = session.query(AccessToken)
+ q = q.filter_by(id=token_id)
+ q.delete(False)
+ session.delete(token_ref)
+
+ def delete_consumer(self, consumer_id):
+ session = self.get_session()
+ with session.begin():
+ self._delete_consumer(session, consumer_id)
+ self._delete_request_tokens(session, consumer_id)
+ self._delete_access_tokens(session, consumer_id)
+ session.flush()
+
+ def list_consumers(self):
+ session = self.get_session()
+ cons = session.query(Consumer)
+ return [core.filter_consumer(x.to_dict()) for x in cons]
+
+ def update_consumer(self, consumer_id, consumer):
+ session = self.get_session()
+ with session.begin():
+ consumer_ref = self._get_consumer(consumer_id)
+ old_consumer_dict = consumer_ref.to_dict()
+ old_consumer_dict.update(consumer)
+ new_consumer = Consumer.from_dict(old_consumer_dict)
+ for attr in Consumer.attributes:
+ if (attr != 'id' or attr != 'secret'):
+ setattr(consumer_ref,
+ attr,
+ getattr(new_consumer, attr))
+ consumer_ref.extra = new_consumer.extra
+ session.flush()
+ return core.filter_consumer(consumer_ref.to_dict())
+
+ def create_request_token(self, consumer_id, roles,
+ project_id, token_duration):
+ expiry_date = None
+ if token_duration:
+ now = timeutils.utcnow()
+ future = now + datetime.timedelta(seconds=token_duration)
+ expiry_date = timeutils.isotime(future, subsecond=True)
+
+ ref = {}
+ request_token_id = uuid.uuid4().hex
+ ref['id'] = request_token_id
+ ref['request_secret'] = uuid.uuid4().hex
+ ref['verifier'] = None
+ ref['authorizing_user_id'] = None
+ ref['requested_project_id'] = project_id
+ ref['requested_roles'] = roles
+ ref['consumer_id'] = consumer_id
+ ref['expires_at'] = expiry_date
+ session = self.get_session()
+ with session.begin():
+ token_ref = RequestToken.from_dict(ref)
+ session.add(token_ref)
+ session.flush()
+ return token_ref.to_dict()
+
+ def _get_request_token(self, session, request_token_id):
+ token_ref = session.query(RequestToken).get(request_token_id)
+ if token_ref is None:
+ raise exception.NotFound(_('Request token not found'))
+ return token_ref
+
+ def get_request_token(self, request_token_id):
+ session = self.get_session()
+ token_ref = self._get_request_token(session, request_token_id)
+ return token_ref.to_dict()
+
+ def authorize_request_token(self, request_token_id, user_id):
+ session = self.get_session()
+ with session.begin():
+ token_ref = self._get_request_token(session, request_token_id)
+ token_dict = token_ref.to_dict()
+ token_dict['authorizing_user_id'] = user_id
+ token_dict['verifier'] = str(random.randint(1000, 9999))
+
+ new_token = RequestToken.from_dict(token_dict)
+ for attr in RequestToken.attributes:
+ if (attr == 'authorizing_user_id' or attr == 'verifier'):
+ setattr(token_ref, attr, getattr(new_token, attr))
+
+ session.flush()
+ return token_ref.to_dict()
+
+ def create_access_token(self, request_token_id, token_duration):
+ session = self.get_session()
+ with session.begin():
+ req_token_ref = self._get_request_token(session, request_token_id)
+ token_dict = req_token_ref.to_dict()
+
+ expiry_date = None
+ if token_duration:
+ now = timeutils.utcnow()
+ future = now + datetime.timedelta(seconds=token_duration)
+ expiry_date = timeutils.isotime(future, subsecond=True)
+
+ # add Access Token
+ ref = {}
+ access_token_id = uuid.uuid4().hex
+ ref['id'] = access_token_id
+ ref['access_secret'] = uuid.uuid4().hex
+ ref['authorizing_user_id'] = token_dict['authorizing_user_id']
+ ref['project_id'] = token_dict['requested_project_id']
+ ref['requested_roles'] = token_dict['requested_roles']
+ ref['consumer_id'] = token_dict['consumer_id']
+ ref['expires_at'] = expiry_date
+ token_ref = AccessToken.from_dict(ref)
+ session.add(token_ref)
+
+ # remove request token, it's been used
+ q = session.query(RequestToken)
+ q = q.filter_by(id=request_token_id)
+ q.delete(False)
+ session.delete(req_token_ref)
+
+ session.flush()
+ return token_ref.to_dict()
+
+ def _get_access_token(self, session, access_token_id):
+ token_ref = session.query(AccessToken).get(access_token_id)
+ if token_ref is None:
+ raise exception.NotFound(_('Access token not found'))
+ return token_ref
+
+ def get_access_token(self, access_token_id):
+ session = self.get_session()
+ token_ref = self._get_access_token(session, access_token_id)
+ return token_ref.to_dict()
+
+ def list_access_tokens(self, user_id):
+ session = self.get_session()
+ q = session.query(AccessToken)
+ user_auths = q.filter_by(authorizing_user_id=user_id)
+ return [core.filter_token(x.to_dict()) for x in user_auths]
+
+ def delete_access_token(self, user_id, access_token_id):
+ session = self.get_session()
+ with session.begin():
+ token_ref = self._get_access_token(session, access_token_id)
+ token_dict = token_ref.to_dict()
+ if token_dict['authorizing_user_id'] != user_id:
+ raise exception.Unauthorized(_('User IDs do not match'))
+
+ q = session.query(AccessToken)
+ q = q.filter_by(id=access_token_id)
+ q.delete(False)
+
+ session.delete(token_ref)
+ session.flush()