diff options
-rw-r--r-- | keystone/token/backends/memcache.py | 91 | ||||
-rw-r--r-- | tests/test_backend_memcache.py | 84 |
2 files changed, 167 insertions, 8 deletions
diff --git a/keystone/token/backends/memcache.py b/keystone/token/backends/memcache.py index a62f3421..e9d8482f 100644 --- a/keystone/token/backends/memcache.py +++ b/keystone/token/backends/memcache.py @@ -19,15 +19,20 @@ import copy import memcache +from keystone.common import logging from keystone.common import utils from keystone import config from keystone import exception from keystone.openstack.common import jsonutils +from keystone.openstack.common import timeutils from keystone import token CONF = config.CONF config.register_str('servers', group='memcache', default='localhost:11211') +config.register_int('max_compare_and_set_retry', group='memcache', default=16) + +LOG = logging.getLogger(__name__) class Token(token.Driver): @@ -42,7 +47,13 @@ class Token(token.Driver): def _get_memcache_client(self): memcache_servers = CONF.memcache.servers.split(',') - self._memcache_client = memcache.Client(memcache_servers, debug=0) + # NOTE(morganfainberg): The memcache client library for python is NOT + # thread safe and should not be passed between threads. This is highly + # specific to the cas() (compare and set) methods and the caching of + # the previous value(s). It appears greenthread should ensure there is + # a single data structure per spawned greenthread. + self._memcache_client = memcache.Client(memcache_servers, debug=0, + cache_cas=True) return self._memcache_client def _prefix_token_id(self, token_id): @@ -77,13 +88,81 @@ class Token(token.Driver): token_data = jsonutils.dumps(token_id) user_id = data['user']['id'] user_key = self._prefix_user_id(user_id) - if not self.client.append(user_key, ',%s' % token_data): - if not self.client.add(user_key, token_data): - if not self.client.append(user_key, ',%s' % token_data): - msg = _('Unable to add token user list.') - raise exception.UnexpectedError(msg) + # Append the new token_id to the token-index-list stored in the + # user-key within memcache. + self._update_user_list_with_cas(user_key, token_data) return copy.deepcopy(data_copy) + def _update_user_list_with_cas(self, user_key, token_id): + cas_retry = 0 + max_cas_retry = CONF.memcache.max_compare_and_set_retry + current_time = timeutils.normalize_time( + timeutils.parse_isotime(timeutils.isotime())) + + self.client.reset_cas() + + while cas_retry <= max_cas_retry: + # NOTE(morganfainberg): cas or "compare and set" is a function of + # memcache. It will return false if the value has changed since the + # last call to client.gets(). This is the memcache supported method + # of avoiding race conditions on set(). Memcache is already atomic + # on the back-end and serializes operations. + # + # cas_retry is for tracking our iterations before we give up (in + # case memcache is down or something horrible happens we don't + # iterate forever trying to compare and set the new value. + cas_retry += 1 + record = self.client.gets(user_key) + filtered_list = [] + + if record is not None: + token_list = jsonutils.loads('[%s]' % record) + for token_i in token_list: + ptk = self._prefix_token_id(token.unique_id(token_i)) + token_ref = self.client.get(ptk) + if not token_ref: + # skip tokens that do not exist in memcache + continue + + if 'expires' in token_ref: + expires_at = timeutils.normalize_time( + token_ref['expires']) + if expires_at < current_time: + # skip tokens that are expired. + continue + + # Add the still valid token_id to the list. + filtered_list.append(jsonutils.dumps(token_i)) + # Add the new token_id to the list. + filtered_list.append(token_id) + + # Use compare-and-set (cas) to set the new value for the + # token-index-list for the user-key. Cas is used to prevent race + # conditions from causing the loss of valid token ids from this + # list. + if self.client.cas(user_key, ','.join(filtered_list)): + msg = _('Successful set of token-index-list for user-key ' + '"%(user_key)s", #%(count)d records') + LOG.debug(msg, {'user_key': user_key, + 'count': len(filtered_list)}) + return filtered_list + + # The cas function will return true if it succeeded or false if it + # failed for any reason, including memcache server being down, cas + # id changed since gets() called (the data changed between when + # this loop started and this point, etc. + error_msg = _('Failed to set token-index-list for user-key ' + '"%(user_key)s". Attempt %(cas_retry)d of ' + '%(cas_retry_max)d') + LOG.debug(error_msg, + {'user_key': user_key, + 'cas_retry': cas_retry, + 'cas_retry_max': max_cas_retry}) + + # Exceeded the maximum retry attempts. + error_msg = _('Unable to add token user list') + raise exception.UnexpectedError(error_msg) + def _add_to_revocation_list(self, data): data_json = jsonutils.dumps(data) if not self.client.append(self.revocation_key, ',%s' % data_json): diff --git a/tests/test_backend_memcache.py b/tests/test_backend_memcache.py index f5999002..5391d7f9 100644 --- a/tests/test_backend_memcache.py +++ b/tests/test_backend_memcache.py @@ -14,14 +14,18 @@ # License for the specific language governing permissions and limitations # under the License. +import copy +import datetime import uuid import memcache from keystone.common import utils from keystone import exception +from keystone.openstack.common import jsonutils from keystone.openstack.common import timeutils from keystone import test +from keystone import token from keystone.token.backends import memcache as token_memcache import test_backend @@ -33,6 +37,7 @@ class MemcacheClient(object): def __init__(self, *args, **kwargs): """Ignores the passed in args.""" self.cache = {} + self.reject_cas = False def add(self, key, value): if self.get(key): @@ -50,20 +55,50 @@ class MemcacheClient(object): if not isinstance(key, str): raise memcache.Client.MemcachedStringEncodingError() + def gets(self, key): + #Call self.get() since we don't really do 'cas' here. + return self.get(key) + def get(self, key): """Retrieves the value for a key or None.""" self.check_key(key) obj = self.cache.get(key) now = utils.unixtime(timeutils.utcnow()) if obj and (obj[1] == 0 or obj[1] > now): - return obj[0] + # NOTE(morganfainberg): This behaves more like memcache + # actually does and prevents modification of the passed in + # reference from affecting the cached back-end data. This makes + # tests a little easier to write. + # + # The back-end store should only change with an explicit + # set/delete/append/etc + data_copy = copy.deepcopy(obj[0]) + return data_copy def set(self, key, value, time=0): """Sets the value for a key.""" self.check_key(key) - self.cache[key] = (value, time) + # NOTE(morganfainberg): This behaves more like memcache + # actually does and prevents modification of the passed in + # reference from affecting the cached back-end data. This makes + # tests a little easier to write. + # + # The back-end store should only change with an explicit + # set/delete/append/etc + data_copy = copy.deepcopy(value) + self.cache[key] = (data_copy, time) return True + def cas(self, key, value, time=0, min_compress_len=0): + # Call self.set() since we don't really do 'cas' here. + if self.reject_cas: + return False + return self.set(key, value, time=time) + + def reset_cas(self): + #This is a stub for the memcache client reset_cas function. + pass + def delete(self, key): self.check_key(key) try: @@ -101,3 +136,48 @@ class MemcacheToken(test.TestCase, test_backend.TokenTests): def test_flush_expired_token(self): with self.assertRaises(exception.NotImplemented): self.token_api.flush_expired_tokens() + + def test_cleanup_user_index_on_create(self): + valid_token_id = uuid.uuid4().hex + second_valid_token_id = uuid.uuid4().hex + expired_token_id = uuid.uuid4().hex + user_id = unicode(uuid.uuid4().hex) + + expire_delta = datetime.timedelta(seconds=86400) + + valid_data = {'id': valid_token_id, 'a': 'b', + 'user': {'id': user_id}} + second_valid_data = {'id': second_valid_token_id, 'a': 'b', + 'user': {'id': user_id}} + expired_data = {'id': expired_token_id, 'a': 'b', + 'user': {'id': user_id}} + self.token_api.create_token(valid_token_id, valid_data) + self.token_api.create_token(expired_token_id, expired_data) + # NOTE(morganfainberg): Directly access the data cache since we need to + # get expired tokens as well as valid tokens. token_api.list_tokens() + # will not return any expired tokens in the list. + user_key = self.token_api._prefix_user_id(user_id) + user_record = self.token_api.client.get(user_key) + user_token_list = jsonutils.loads('[%s]' % user_record) + self.assertEquals(len(user_token_list), 2) + expired_token_ptk = self.token_api._prefix_token_id( + token.unique_id(expired_token_id)) + expired_token = self.token_api.client.get(expired_token_ptk) + expired_token['expires'] = (timeutils.utcnow() - expire_delta) + self.token_api.client.set(expired_token_ptk, expired_token) + + self.token_api.create_token(second_valid_token_id, second_valid_data) + user_record = self.token_api.client.get(user_key) + user_token_list = jsonutils.loads('[%s]' % user_record) + self.assertEquals(len(user_token_list), 2) + + def test_cas_failure(self): + self.token_api.client.reject_cas = True + token_id = uuid.uuid4().hex + user_id = unicode(uuid.uuid4().hex) + user_key = self.token_api._prefix_user_id(user_id) + token_data = jsonutils.dumps(token_id) + self.assertRaises( + exception.UnexpectedError, + self.token_api._update_user_list_with_cas, + user_key, token_data) |