summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--keystone/token/backends/memcache.py91
-rw-r--r--tests/test_backend_memcache.py84
2 files changed, 167 insertions, 8 deletions
diff --git a/keystone/token/backends/memcache.py b/keystone/token/backends/memcache.py
index a62f3421..e9d8482f 100644
--- a/keystone/token/backends/memcache.py
+++ b/keystone/token/backends/memcache.py
@@ -19,15 +19,20 @@ import copy
import memcache
+from keystone.common import logging
from keystone.common import utils
from keystone import config
from keystone import exception
from keystone.openstack.common import jsonutils
+from keystone.openstack.common import timeutils
from keystone import token
CONF = config.CONF
config.register_str('servers', group='memcache', default='localhost:11211')
+config.register_int('max_compare_and_set_retry', group='memcache', default=16)
+
+LOG = logging.getLogger(__name__)
class Token(token.Driver):
@@ -42,7 +47,13 @@ class Token(token.Driver):
def _get_memcache_client(self):
memcache_servers = CONF.memcache.servers.split(',')
- self._memcache_client = memcache.Client(memcache_servers, debug=0)
+ # NOTE(morganfainberg): The memcache client library for python is NOT
+ # thread safe and should not be passed between threads. This is highly
+ # specific to the cas() (compare and set) methods and the caching of
+ # the previous value(s). It appears greenthread should ensure there is
+ # a single data structure per spawned greenthread.
+ self._memcache_client = memcache.Client(memcache_servers, debug=0,
+ cache_cas=True)
return self._memcache_client
def _prefix_token_id(self, token_id):
@@ -77,13 +88,81 @@ class Token(token.Driver):
token_data = jsonutils.dumps(token_id)
user_id = data['user']['id']
user_key = self._prefix_user_id(user_id)
- if not self.client.append(user_key, ',%s' % token_data):
- if not self.client.add(user_key, token_data):
- if not self.client.append(user_key, ',%s' % token_data):
- msg = _('Unable to add token user list.')
- raise exception.UnexpectedError(msg)
+ # Append the new token_id to the token-index-list stored in the
+ # user-key within memcache.
+ self._update_user_list_with_cas(user_key, token_data)
return copy.deepcopy(data_copy)
+ def _update_user_list_with_cas(self, user_key, token_id):
+ cas_retry = 0
+ max_cas_retry = CONF.memcache.max_compare_and_set_retry
+ current_time = timeutils.normalize_time(
+ timeutils.parse_isotime(timeutils.isotime()))
+
+ self.client.reset_cas()
+
+ while cas_retry <= max_cas_retry:
+ # NOTE(morganfainberg): cas or "compare and set" is a function of
+ # memcache. It will return false if the value has changed since the
+ # last call to client.gets(). This is the memcache supported method
+ # of avoiding race conditions on set(). Memcache is already atomic
+ # on the back-end and serializes operations.
+ #
+ # cas_retry is for tracking our iterations before we give up (in
+ # case memcache is down or something horrible happens we don't
+ # iterate forever trying to compare and set the new value.
+ cas_retry += 1
+ record = self.client.gets(user_key)
+ filtered_list = []
+
+ if record is not None:
+ token_list = jsonutils.loads('[%s]' % record)
+ for token_i in token_list:
+ ptk = self._prefix_token_id(token.unique_id(token_i))
+ token_ref = self.client.get(ptk)
+ if not token_ref:
+ # skip tokens that do not exist in memcache
+ continue
+
+ if 'expires' in token_ref:
+ expires_at = timeutils.normalize_time(
+ token_ref['expires'])
+ if expires_at < current_time:
+ # skip tokens that are expired.
+ continue
+
+ # Add the still valid token_id to the list.
+ filtered_list.append(jsonutils.dumps(token_i))
+ # Add the new token_id to the list.
+ filtered_list.append(token_id)
+
+ # Use compare-and-set (cas) to set the new value for the
+ # token-index-list for the user-key. Cas is used to prevent race
+ # conditions from causing the loss of valid token ids from this
+ # list.
+ if self.client.cas(user_key, ','.join(filtered_list)):
+ msg = _('Successful set of token-index-list for user-key '
+ '"%(user_key)s", #%(count)d records')
+ LOG.debug(msg, {'user_key': user_key,
+ 'count': len(filtered_list)})
+ return filtered_list
+
+ # The cas function will return true if it succeeded or false if it
+ # failed for any reason, including memcache server being down, cas
+ # id changed since gets() called (the data changed between when
+ # this loop started and this point, etc.
+ error_msg = _('Failed to set token-index-list for user-key '
+ '"%(user_key)s". Attempt %(cas_retry)d of '
+ '%(cas_retry_max)d')
+ LOG.debug(error_msg,
+ {'user_key': user_key,
+ 'cas_retry': cas_retry,
+ 'cas_retry_max': max_cas_retry})
+
+ # Exceeded the maximum retry attempts.
+ error_msg = _('Unable to add token user list')
+ raise exception.UnexpectedError(error_msg)
+
def _add_to_revocation_list(self, data):
data_json = jsonutils.dumps(data)
if not self.client.append(self.revocation_key, ',%s' % data_json):
diff --git a/tests/test_backend_memcache.py b/tests/test_backend_memcache.py
index f5999002..5391d7f9 100644
--- a/tests/test_backend_memcache.py
+++ b/tests/test_backend_memcache.py
@@ -14,14 +14,18 @@
# License for the specific language governing permissions and limitations
# under the License.
+import copy
+import datetime
import uuid
import memcache
from keystone.common import utils
from keystone import exception
+from keystone.openstack.common import jsonutils
from keystone.openstack.common import timeutils
from keystone import test
+from keystone import token
from keystone.token.backends import memcache as token_memcache
import test_backend
@@ -33,6 +37,7 @@ class MemcacheClient(object):
def __init__(self, *args, **kwargs):
"""Ignores the passed in args."""
self.cache = {}
+ self.reject_cas = False
def add(self, key, value):
if self.get(key):
@@ -50,20 +55,50 @@ class MemcacheClient(object):
if not isinstance(key, str):
raise memcache.Client.MemcachedStringEncodingError()
+ def gets(self, key):
+ #Call self.get() since we don't really do 'cas' here.
+ return self.get(key)
+
def get(self, key):
"""Retrieves the value for a key or None."""
self.check_key(key)
obj = self.cache.get(key)
now = utils.unixtime(timeutils.utcnow())
if obj and (obj[1] == 0 or obj[1] > now):
- return obj[0]
+ # NOTE(morganfainberg): This behaves more like memcache
+ # actually does and prevents modification of the passed in
+ # reference from affecting the cached back-end data. This makes
+ # tests a little easier to write.
+ #
+ # The back-end store should only change with an explicit
+ # set/delete/append/etc
+ data_copy = copy.deepcopy(obj[0])
+ return data_copy
def set(self, key, value, time=0):
"""Sets the value for a key."""
self.check_key(key)
- self.cache[key] = (value, time)
+ # NOTE(morganfainberg): This behaves more like memcache
+ # actually does and prevents modification of the passed in
+ # reference from affecting the cached back-end data. This makes
+ # tests a little easier to write.
+ #
+ # The back-end store should only change with an explicit
+ # set/delete/append/etc
+ data_copy = copy.deepcopy(value)
+ self.cache[key] = (data_copy, time)
return True
+ def cas(self, key, value, time=0, min_compress_len=0):
+ # Call self.set() since we don't really do 'cas' here.
+ if self.reject_cas:
+ return False
+ return self.set(key, value, time=time)
+
+ def reset_cas(self):
+ #This is a stub for the memcache client reset_cas function.
+ pass
+
def delete(self, key):
self.check_key(key)
try:
@@ -101,3 +136,48 @@ class MemcacheToken(test.TestCase, test_backend.TokenTests):
def test_flush_expired_token(self):
with self.assertRaises(exception.NotImplemented):
self.token_api.flush_expired_tokens()
+
+ def test_cleanup_user_index_on_create(self):
+ valid_token_id = uuid.uuid4().hex
+ second_valid_token_id = uuid.uuid4().hex
+ expired_token_id = uuid.uuid4().hex
+ user_id = unicode(uuid.uuid4().hex)
+
+ expire_delta = datetime.timedelta(seconds=86400)
+
+ valid_data = {'id': valid_token_id, 'a': 'b',
+ 'user': {'id': user_id}}
+ second_valid_data = {'id': second_valid_token_id, 'a': 'b',
+ 'user': {'id': user_id}}
+ expired_data = {'id': expired_token_id, 'a': 'b',
+ 'user': {'id': user_id}}
+ self.token_api.create_token(valid_token_id, valid_data)
+ self.token_api.create_token(expired_token_id, expired_data)
+ # NOTE(morganfainberg): Directly access the data cache since we need to
+ # get expired tokens as well as valid tokens. token_api.list_tokens()
+ # will not return any expired tokens in the list.
+ user_key = self.token_api._prefix_user_id(user_id)
+ user_record = self.token_api.client.get(user_key)
+ user_token_list = jsonutils.loads('[%s]' % user_record)
+ self.assertEquals(len(user_token_list), 2)
+ expired_token_ptk = self.token_api._prefix_token_id(
+ token.unique_id(expired_token_id))
+ expired_token = self.token_api.client.get(expired_token_ptk)
+ expired_token['expires'] = (timeutils.utcnow() - expire_delta)
+ self.token_api.client.set(expired_token_ptk, expired_token)
+
+ self.token_api.create_token(second_valid_token_id, second_valid_data)
+ user_record = self.token_api.client.get(user_key)
+ user_token_list = jsonutils.loads('[%s]' % user_record)
+ self.assertEquals(len(user_token_list), 2)
+
+ def test_cas_failure(self):
+ self.token_api.client.reject_cas = True
+ token_id = uuid.uuid4().hex
+ user_id = unicode(uuid.uuid4().hex)
+ user_key = self.token_api._prefix_user_id(user_id)
+ token_data = jsonutils.dumps(token_id)
+ self.assertRaises(
+ exception.UnexpectedError,
+ self.token_api._update_user_list_with_cas,
+ user_key, token_data)