diff options
48 files changed, 1673 insertions, 674 deletions
diff --git a/.testr.conf b/.testr.conf new file mode 100644 index 0000000..d54ffb8 --- /dev/null +++ b/.testr.conf @@ -0,0 +1,9 @@ +[DEFAULT] +TESTS_PATH=./test +test_command=OS_STDOUT_CAPTURE=${OS_STDOUT_CAPTURE:-1} \ + OS_STDERR_CAPTURE=${OS_STDERR_CAPTURE:-1} \ + OS_TEST_TIMEOUT=${OS_TEST_TIMEOUT:-60} \ + ${PYTHON:-python} -m subunit.run discover -t ./ $TESTS_PATH $LISTOPT $IDOPTION + +test_id_option=--load-list $IDFILE +test_list_option=--list diff --git a/HACKING.rst b/HACKING.rst index 3cea316..846c1b1 100644 --- a/HACKING.rst +++ b/HACKING.rst @@ -46,6 +46,16 @@ General pass +- Use 'raise' instead of 'raise e' to preserve original traceback or exception being reraised:: + + except Exception as e: + ... + raise e # BAD + + except Exception: + ... + raise # OKAY + TODO vs FIXME ------------- diff --git a/MAINTAINERS b/MAINTAINERS index 128cb22..0500c53 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -199,12 +199,6 @@ M: Michael Still <mikal@stillhq.com> S: Maintained F: periodic_task.py -== plugins == - -M: -S: Orphan -F: plugin/ - == policy == M: diff --git a/TESTING.rst b/TESTING.rst new file mode 100644 index 0000000..4191b1b --- /dev/null +++ b/TESTING.rst @@ -0,0 +1,88 @@ +=========================== +Testing Your OpenStack Code +=========================== +------------ +A Quickstart +------------ + +This is designed to be enough information for you to run your first tests. +Detailed information on testing can be found here: https://wiki.openstack.org/wiki/Testing + +*Install pip*:: + + [apt-get | yum] install python-pip +More information on pip here: http://www.pip-installer.org/en/latest/ + +*Use pip to install tox*:: + + pip install tox + +Run The Tests +------------- + +*Navigate to the project's root directory and execute*:: + + tox +Note: completing this command may take a long time (depends on system resources) +also, you might not see any output until tox is complete. + +Information about tox can be found here: http://testrun.org/tox/latest/ + + +Run The Tests in One Environment +-------------------------------- + +Tox will run your entire test suite in the environments specified in the project tox.ini:: + + [tox] + + envlist = <list of available environments> + +To run the test suite in just one of the environments in envlist execute:: + + tox -e <env> +so for example, *run the test suite in py26*:: + + tox -e py26 + +Run One Test +------------ + +To run individual tests with tox: + +if testr is in tox.ini, for example:: + + [testenv] + + includes "python setup.py testr --slowest --testr-args='{posargs}'" + +run individual tests with the following syntax:: + + tox -e <env> -- path.to.module:Class.test +so for example, *run the cpu_limited test in Nova*:: + + tox -e py27 -- nova.tests.test_claims:ClaimTestCase.test_cpu_unlimited + +if nose is in tox.ini, for example:: + + [testenv] + + includes "nosetests {posargs}" + +run individual tests with the following syntax:: + + tox -e <env> -- --tests path.to.module:Class.test +so for example, *run the list test in Glance*:: + + tox -e py27 -- --tests glance.tests.unit.test_auth.py:TestImageRepoProxy.test_list + +Need More Info? +--------------- + +More information about testr: https://wiki.openstack.org/wiki/Testr + +More information about nose: https://nose.readthedocs.org/en/latest/ + + +More information about testing OpenStack code can be found here: +https://wiki.openstack.org/wiki/Testing diff --git a/openstack/common/config/generator.py b/openstack/common/config/generator.py index 09649e7..0dd7c97 100755 --- a/openstack/common/config/generator.py +++ b/openstack/common/config/generator.py @@ -188,7 +188,12 @@ def _get_my_ip(): def _sanitize_default(s): """Set up a reasonably sensible default for pybasedir, my_ip and host.""" - if s.startswith(BASEDIR): + if s.startswith(sys.prefix): + # NOTE(jd) Don't use os.path.join, because it is likely to think the + # second part is an absolute pathname and therefore drop the first + # part. + s = os.path.normpath("/usr/" + s[len(sys.prefix):]) + elif s.startswith(BASEDIR): return s.replace(BASEDIR, '/usr/lib/python/site-packages') elif BASEDIR in s: return s.replace(BASEDIR, '') @@ -205,6 +210,7 @@ def _print_opt(opt): opt_name, opt_default, opt_help = opt.dest, opt.default, opt.help if not opt_help: sys.stderr.write('WARNING: "%s" is missing help string.\n' % opt_name) + opt_help = "" opt_type = None try: opt_type = OPTION_REGEX.search(str(type(opt))).group(0) diff --git a/openstack/common/context.py b/openstack/common/context.py index 3899c2c..81772bc 100644 --- a/openstack/common/context.py +++ b/openstack/common/context.py @@ -61,7 +61,7 @@ class RequestContext(object): 'request_id': self.request_id} -def get_admin_context(show_deleted="no"): +def get_admin_context(show_deleted=False): context = RequestContext(None, tenant=None, is_admin=True, diff --git a/openstack/common/crypto/__init__.py b/openstack/common/crypto/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/openstack/common/crypto/__init__.py diff --git a/openstack/common/crypto/utils.py b/openstack/common/crypto/utils.py new file mode 100644 index 0000000..61c1a50 --- /dev/null +++ b/openstack/common/crypto/utils.py @@ -0,0 +1,179 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import base64 + +from Crypto.Hash import HMAC +from Crypto import Random + +from openstack.common.gettextutils import _ +from openstack.common.importutils import import_module + + +class CryptoutilsException(Exception): + """Generic Exception for Crypto utilities.""" + + message = _("An unknown error occurred in crypto utils.") + + +class CipherBlockLengthTooBig(CryptoutilsException): + """The block size is too big.""" + + def __init__(self, requested, permitted): + msg = _("Block size of %(given)d is too big, max = %(maximum)d") + message = msg % {'given': requested, 'maximum': permitted} + super(CryptoutilsException, self).__init__(message) + + +class HKDFOutputLengthTooLong(CryptoutilsException): + """The amount of Key Material asked is too much.""" + + def __init__(self, requested, permitted): + msg = _("Length of %(given)d is too long, max = %(maximum)d") + message = msg % {'given': requested, 'maximum': permitted} + super(CryptoutilsException, self).__init__(message) + + +class HKDF(object): + """An HMAC-based Key Derivation Function implementation (RFC5869) + + This class creates an object that allows to use HKDF to derive keys. + """ + + def __init__(self, hashtype='SHA256'): + self.hashfn = import_module('Crypto.Hash.' + hashtype) + self.max_okm_length = 255 * self.hashfn.digest_size + + def extract(self, ikm, salt=None): + """An extract function that can be used to derive a robust key given + weak Input Key Material (IKM) which could be a password. + Returns a pseudorandom key (of HashLen octets) + + :param ikm: input keying material (ex a password) + :param salt: optional salt value (a non-secret random value) + """ + if salt is None: + salt = '\x00' * self.hashfn.digest_size + + return HMAC.new(salt, ikm, self.hashfn).digest() + + def expand(self, prk, info, length): + """An expand function that will return arbitrary length output that can + be used as keys. + Returns a buffer usable as key material. + + :param prk: a pseudorandom key of at least HashLen octets + :param info: optional string (can be a zero-length string) + :param length: length of output keying material (<= 255 * HashLen) + """ + if length > self.max_okm_length: + raise HKDFOutputLengthTooLong(length, self.max_okm_length) + + N = (length + self.hashfn.digest_size - 1) / self.hashfn.digest_size + + okm = "" + tmp = "" + for block in range(1, N + 1): + tmp = HMAC.new(prk, tmp + info + chr(block), self.hashfn).digest() + okm += tmp + + return okm[:length] + + +MAX_CB_SIZE = 256 + + +class SymmetricCrypto(object): + """Symmetric Key Crypto object. + + This class creates a Symmetric Key Crypto object that can be used + to encrypt, decrypt, or sign arbitrary data. + + :param enctype: Encryption Cipher name (default: AES) + :param hashtype: Hash/HMAC type name (default: SHA256) + """ + + def __init__(self, enctype='AES', hashtype='SHA256'): + self.cipher = import_module('Crypto.Cipher.' + enctype) + self.hashfn = import_module('Crypto.Hash.' + hashtype) + + def new_key(self, size): + return Random.new().read(size) + + def encrypt(self, key, msg, b64encode=True): + """Encrypt the provided msg and returns the cyphertext optionally + base64 encoded. + + Uses AES-128-CBC with a Random IV by default. + + The plaintext is padded to reach blocksize length. + The last byte of the block is the length of the padding. + The length of the padding does not include the length byte itself. + + :param key: The Encryption key. + :param msg: the plain text. + + :returns encblock: a block of encrypted data. + """ + iv = Random.new().read(self.cipher.block_size) + cipher = self.cipher.new(key, self.cipher.MODE_CBC, iv) + + # CBC mode requires a fixed block size. Append padding and length of + # padding. + if self.cipher.block_size > MAX_CB_SIZE: + raise CipherBlockLengthTooBig(self.cipher.block_size, MAX_CB_SIZE) + r = len(msg) % self.cipher.block_size + padlen = self.cipher.block_size - r - 1 + msg += '\x00' * padlen + msg += chr(padlen) + + enc = iv + cipher.encrypt(msg) + if b64encode: + enc = base64.b64encode(enc) + return enc + + def decrypt(self, key, msg, b64decode=True): + """Decrypts the provided ciphertext, optionally base 64 encoded, and + returns the plaintext message, after padding is removed. + + Uses AES-128-CBC with an IV by default. + + :param key: The Encryption key. + :param msg: the ciphetext, the first block is the IV + """ + if b64decode: + msg = base64.b64decode(msg) + iv = msg[:self.cipher.block_size] + cipher = self.cipher.new(key, self.cipher.MODE_CBC, iv) + + padded = cipher.decrypt(msg[self.cipher.block_size:]) + l = ord(padded[-1]) + 1 + plain = padded[:-l] + return plain + + def sign(self, key, msg, b64encode=True): + """Signs a message string and returns a base64 encoded signature. + + Uses HMAC-SHA-256 by default. + + :param key: The Signing key. + :param msg: the message to sign. + """ + h = HMAC.new(key, msg, self.hashfn) + out = h.digest() + if b64encode: + out = base64.b64encode(out) + return out diff --git a/openstack/common/db/sqlalchemy/session.py b/openstack/common/db/sqlalchemy/session.py index b5e10f1..7400b17 100644 --- a/openstack/common/db/sqlalchemy/session.py +++ b/openstack/common/db/sqlalchemy/session.py @@ -260,8 +260,6 @@ from openstack.common.gettextutils import _ from openstack.common import log as logging from openstack.common import timeutils -DEFAULT = 'DEFAULT' - sqlite_db_opts = [ cfg.StrOpt('sqlite_db', default='oslo.sqlite', @@ -278,8 +276,10 @@ database_opts = [ '../', '$sqlite_db')), help='The SQLAlchemy connection string used to connect to the ' 'database', - deprecated_name='sql_connection', - deprecated_group=DEFAULT, + deprecated_opts=[cfg.DeprecatedOpt('sql_connection', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_connection', + group='DATABASE')], secret=True), cfg.StrOpt('slave_connection', default='', @@ -288,56 +288,71 @@ database_opts = [ secret=True), cfg.IntOpt('idle_timeout', default=3600, - deprecated_name='sql_idle_timeout', - deprecated_group=DEFAULT, + deprecated_opts=[cfg.DeprecatedOpt('sql_idle_timeout', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_idle_timeout', + group='DATABASE')], help='timeout before idle sql connections are reaped'), cfg.IntOpt('min_pool_size', default=1, - deprecated_name='sql_min_pool_size', - deprecated_group=DEFAULT, + deprecated_opts=[cfg.DeprecatedOpt('sql_min_pool_size', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_min_pool_size', + group='DATABASE')], help='Minimum number of SQL connections to keep open in a ' 'pool'), cfg.IntOpt('max_pool_size', default=None, - deprecated_name='sql_max_pool_size', - deprecated_group=DEFAULT, + deprecated_opts=[cfg.DeprecatedOpt('sql_max_pool_size', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_max_pool_size', + group='DATABASE')], help='Maximum number of SQL connections to keep open in a ' 'pool'), cfg.IntOpt('max_retries', default=10, - deprecated_name='sql_max_retries', - deprecated_group=DEFAULT, + deprecated_opts=[cfg.DeprecatedOpt('sql_max_retries', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_max_retries', + group='DATABASE')], help='maximum db connection retries during startup. ' '(setting -1 implies an infinite retry count)'), cfg.IntOpt('retry_interval', default=10, - deprecated_name='sql_retry_interval', - deprecated_group=DEFAULT, + deprecated_opts=[cfg.DeprecatedOpt('sql_retry_interval', + group='DEFAULT'), + cfg.DeprecatedOpt('reconnect_interval', + group='DATABASE')], help='interval between retries of opening a sql connection'), cfg.IntOpt('max_overflow', default=None, - deprecated_name='sql_max_overflow', - deprecated_group=DEFAULT, + deprecated_opts=[cfg.DeprecatedOpt('sql_max_overflow', + group='DEFAULT'), + cfg.DeprecatedOpt('sqlalchemy_max_overflow', + group='DATABASE')], help='If set, use this value for max_overflow with sqlalchemy'), cfg.IntOpt('connection_debug', default=0, - deprecated_name='sql_connection_debug', - deprecated_group=DEFAULT, + deprecated_opts=[cfg.DeprecatedOpt('sql_connection_debug', + group='DEFAULT')], help='Verbosity of SQL debugging information. 0=None, ' '100=Everything'), cfg.BoolOpt('connection_trace', default=False, - deprecated_name='sql_connection_trace', - deprecated_group=DEFAULT, + deprecated_opts=[cfg.DeprecatedOpt('sql_connection_trace', + group='DEFAULT')], help='Add python stack traces to SQL as comment strings'), cfg.IntOpt('pool_timeout', default=None, + deprecated_opts=[cfg.DeprecatedOpt('sqlalchemy_pool_timeout', + group='DATABASE')], help='If set, use this value for pool_timeout with sqlalchemy'), ] CONF = cfg.CONF CONF.register_opts(sqlite_db_opts) CONF.register_opts(database_opts, 'database') + LOG = logging.getLogger(__name__) _ENGINE = None diff --git a/openstack/common/eventlet_backdoor.py b/openstack/common/eventlet_backdoor.py index 57b89ae..f2102d6 100644 --- a/openstack/common/eventlet_backdoor.py +++ b/openstack/common/eventlet_backdoor.py @@ -18,8 +18,11 @@ from __future__ import print_function +import errno import gc +import os import pprint +import socket import sys import traceback @@ -28,14 +31,34 @@ import eventlet.backdoor import greenlet from oslo.config import cfg +from openstack.common.gettextutils import _ +from openstack.common import log as logging + +help_for_backdoor_port = 'Acceptable ' + \ + 'values are 0, <port> and <start>:<end>, where 0 results in ' + \ + 'listening on a random tcp port number, <port> results in ' + \ + 'listening on the specified port number and not enabling backdoor' + \ + 'if it is in use and <start>:<end> results in listening on the ' + \ + 'smallest unused port number within the specified range of port ' + \ + 'numbers. The chosen port is displayed in the service\'s log file.' eventlet_backdoor_opts = [ - cfg.IntOpt('backdoor_port', + cfg.StrOpt('backdoor_port', default=None, - help='port for eventlet backdoor to listen') + help='Enable eventlet backdoor. %s' % help_for_backdoor_port) ] CONF = cfg.CONF CONF.register_opts(eventlet_backdoor_opts) +LOG = logging.getLogger(__name__) + + +class EventletBackdoorConfigValueError(Exception): + def __init__(self, port_range, help_msg, ex): + msg = ('Invalid backdoor_port configuration %(range)s: %(ex)s. ' + '%(help)s' % + {'range': port_range, 'ex': ex, 'help': help_msg}) + super(EventletBackdoorConfigValueError, self).__init__(msg) + self.port_range = port_range def _dont_use_this(): @@ -60,6 +83,33 @@ def _print_nativethreads(): print() +def _parse_port_range(port_range): + if ':' not in port_range: + start, end = port_range, port_range + else: + start, end = port_range.split(':', 1) + try: + start, end = int(start), int(end) + if end < start: + raise ValueError + return start, end + except ValueError as ex: + raise EventletBackdoorConfigValueError(port_range, ex, + help_for_backdoor_port) + + +def _listen(host, start_port, end_port, listen_func): + try_port = start_port + while True: + try: + return listen_func((host, try_port)) + except socket.error as exc: + if (exc.errno != errno.EADDRINUSE or + try_port >= end_port): + raise + try_port += 1 + + def initialize_if_enabled(): backdoor_locals = { 'exit': _dont_use_this, # So we don't exit the entire process @@ -72,6 +122,8 @@ def initialize_if_enabled(): if CONF.backdoor_port is None: return None + start_port, end_port = _parse_port_range(str(CONF.backdoor_port)) + # NOTE(johannes): The standard sys.displayhook will print the value of # the last expression and set it to __builtin__._, which overwrites # the __builtin__._ that gettext sets. Let's switch to using pprint @@ -82,8 +134,13 @@ def initialize_if_enabled(): pprint.pprint(val) sys.displayhook = displayhook - sock = eventlet.listen(('localhost', CONF.backdoor_port)) + sock = _listen('localhost', start_port, end_port, eventlet.listen) + + # In the case of backdoor port being zero, a port number is assigned by + # listen(). In any case, pull the port number out here. port = sock.getsockname()[1] + LOG.info(_('Eventlet backdoor listening on %(port)s for process %(pid)d') % + {'port': port, 'pid': os.getpid()}) eventlet.spawn_n(eventlet.backdoor.backdoor_server, sock, locals=backdoor_locals) return port diff --git a/openstack/common/exception.py b/openstack/common/exception.py index cdf40f3..f6c8463 100644 --- a/openstack/common/exception.py +++ b/openstack/common/exception.py @@ -122,9 +122,9 @@ class OpenstackException(Exception): try: self._error_string = self.message % kwargs - except Exception as e: + except Exception: if _FATAL_EXCEPTION_FORMAT_ERRORS: - raise e + raise else: # at least get the core message out if something happened self._error_string = self.message diff --git a/openstack/common/excutils.py b/openstack/common/excutils.py index 06d6e29..336e147 100644 --- a/openstack/common/excutils.py +++ b/openstack/common/excutils.py @@ -19,16 +19,15 @@ Exception related utilities. """ -import contextlib import logging import sys +import time import traceback from openstack.common.gettextutils import _ -@contextlib.contextmanager -def save_and_reraise_exception(): +class save_and_reraise_exception(object): """Save current exception, run some code and then re-raise. In some cases the exception context can be cleared, resulting in None @@ -40,12 +39,60 @@ def save_and_reraise_exception(): To work around this, we save the exception state, run handler code, and then re-raise the original exception. If another exception occurs, the saved exception is logged and the new exception is re-raised. - """ - type_, value, tb = sys.exc_info() - try: - yield + + In some cases the caller may not want to re-raise the exception, and + for those circumstances this context provides a reraise flag that + can be used to suppress the exception. For example: + except Exception: - logging.error(_('Original exception being dropped: %s'), - traceback.format_exception(type_, value, tb)) - raise - raise type_, value, tb + with save_and_reraise_exception() as ctxt: + decide_if_need_reraise() + if not should_be_reraised: + ctxt.reraise = False + """ + def __init__(self): + self.reraise = True + + def __enter__(self): + self.type_, self.value, self.tb, = sys.exc_info() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None: + logging.error(_('Original exception being dropped: %s'), + traceback.format_exception(self.type_, + self.value, + self.tb)) + return False + if self.reraise: + raise self.type_, self.value, self.tb + + +def forever_retry_uncaught_exceptions(infunc): + def inner_func(*args, **kwargs): + last_log_time = 0 + last_exc_message = None + exc_count = 0 + while True: + try: + return infunc(*args, **kwargs) + except Exception as exc: + if exc.message == last_exc_message: + exc_count += 1 + else: + exc_count = 1 + # Do not log any more frequently than once a minute unless + # the exception message changes + cur_time = int(time.time()) + if (cur_time - last_log_time > 60 or + exc.message != last_exc_message): + logging.exception( + _('Unexpected exception occurred %d time(s)... ' + 'retrying.') % exc_count) + last_log_time = cur_time + last_exc_message = exc.message + exc_count = 0 + # This should be a very rare event. In case it isn't, do + # a sleep. + time.sleep(1) + return inner_func diff --git a/openstack/common/jsonutils.py b/openstack/common/jsonutils.py index bf23403..9c72376 100644 --- a/openstack/common/jsonutils.py +++ b/openstack/common/jsonutils.py @@ -41,6 +41,7 @@ import json import types import xmlrpclib +import netaddr import six from openstack.common import timeutils @@ -137,6 +138,8 @@ def to_primitive(value, convert_instances=False, convert_datetime=True, # Likely an instance of something. Watch for cycles. # Ignore class member vars. return recursive(value.__dict__, level=level + 1) + elif isinstance(value, netaddr.IPAddress): + return six.text_type(value) else: if any(test(value) for test in _nasty_type_tests): return six.text_type(value) diff --git a/openstack/common/network_utils.py b/openstack/common/network_utils.py index 0fbf171..dbed1ce 100644 --- a/openstack/common/network_utils.py +++ b/openstack/common/network_utils.py @@ -19,6 +19,8 @@ Network-related utilities and helper functions. """ +import urlparse + def parse_host_port(address, default_port=None): """Interpret a string as a host:port pair. @@ -62,3 +64,18 @@ def parse_host_port(address, default_port=None): port = default_port return (host, None if port is None else int(port)) + + +def urlsplit(url, scheme='', allow_fragments=True): + """Parse a URL using urlparse.urlsplit(), splitting query and fragments. + This function papers over Python issue9374 when needed. + + The parameters are the same as urlparse.urlsplit. + """ + scheme, netloc, path, query, fragment = urlparse.urlsplit( + url, scheme, allow_fragments) + if allow_fragments and '#' in path: + path, fragment = path.split('#', 1) + if '?' in path: + path, query = path.split('?', 1) + return urlparse.SplitResult(scheme, netloc, path, query, fragment) diff --git a/openstack/common/plugin/__init__.py b/openstack/common/plugin/__init__.py deleted file mode 100644 index b706747..0000000 --- a/openstack/common/plugin/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2012 OpenStack Foundation. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. diff --git a/openstack/common/plugin/callbackplugin.py b/openstack/common/plugin/callbackplugin.py deleted file mode 100644 index 2de7fb0..0000000 --- a/openstack/common/plugin/callbackplugin.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2012 OpenStack Foundation. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -from openstack.common import log as logging -from openstack.common.plugin import plugin - - -LOG = logging.getLogger(__name__) - - -class _CallbackNotifier(object): - """Manages plugin-defined notification callbacks. - - For each Plugin, a CallbackNotifier will be added to the - notification driver list. Calls to notify() with appropriate - messages will be hooked and prompt callbacks. - - A callback should look like this: - def callback(context, message, user_data) - """ - - def __init__(self): - self._callback_dict = {} - - def _add_callback(self, event_type, callback, user_data): - callback_list = self._callback_dict.get(event_type, []) - callback_list.append({'function': callback, - 'user_data': user_data}) - self._callback_dict[event_type] = callback_list - - def _remove_callback(self, callback): - for callback_list in self._callback_dict.values(): - for entry in callback_list: - if entry['function'] == callback: - callback_list.remove(entry) - - def notify(self, context, message): - if message.get('event_type') not in self._callback_dict: - return - - for entry in self._callback_dict[message.get('event_type')]: - entry['function'](context, message, entry.get('user_data')) - - def callbacks(self): - return self._callback_dict - - -class CallbackPlugin(plugin.Plugin): - """Plugin with a simple callback interface. - - This class is provided as a convenience for producing a simple - plugin that only watches a couple of events. For example, here's - a subclass which prints a line the first time an instance is created. - - class HookInstanceCreation(CallbackPlugin): - - def __init__(self, _service_name): - super(HookInstanceCreation, self).__init__() - self._add_callback(self.magic, 'compute.instance.create.start') - - def magic(self): - print "An instance was created!" - self._remove_callback(self, self.magic) - """ - - def __init__(self, service_name): - super(CallbackPlugin, self).__init__(service_name) - self._callback_notifier = _CallbackNotifier() - self._add_notifier(self._callback_notifier) - - def _add_callback(self, callback, event_type, user_data=None): - """Add callback for a given event notification. - - Subclasses can call this as an alternative to implementing - a fullblown notify notifier. - """ - self._callback_notifier._add_callback(event_type, callback, user_data) - - def _remove_callback(self, callback): - """Remove all notification callbacks to specified function.""" - self._callback_notifier._remove_callback(callback) diff --git a/openstack/common/plugin/plugin.py b/openstack/common/plugin/plugin.py deleted file mode 100644 index d2be0b3..0000000 --- a/openstack/common/plugin/plugin.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2012 OpenStack Foundation. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -from openstack.common import log as logging - - -LOG = logging.getLogger(__name__) - - -class Plugin(object): - """Defines an interface for adding functionality to an OpenStack service. - - A plugin interacts with a service via the following pathways: - - - An optional set of notifiers, managed by calling add_notifier() - or by overriding _notifiers() - - - A set of api extensions, managed via add_api_extension_descriptor() - - - Direct calls to service functions. - - - Whatever else the plugin wants to do on its own. - - This is the reference implementation. - """ - - # The following functions are provided as convenience methods - # for subclasses. Subclasses should call them but probably not - # override them. - def _add_api_extension_descriptor(self, descriptor): - """Subclass convenience method which adds an extension descriptor. - - Subclass constructors should call this method when - extending a project's REST interface. - - Note that once the api service has loaded, the - API extension set is more-or-less fixed, so - this should mainly be called by subclass constructors. - """ - self._api_extension_descriptors.append(descriptor) - - def _add_notifier(self, notifier): - """Subclass convenience method which adds a notifier. - - Notifier objects should implement the function notify(message). - Each notifier receives a notify() call whenever an openstack - service broadcasts a notification. - - Best to call this during construction. Notifiers are enumerated - and registered by the pluginmanager at plugin load time. - """ - self._notifiers.append(notifier) - - # The following methods are called by OpenStack services to query - # plugin features. Subclasses should probably not override these. - def _notifiers(self): - """Returns list of notifiers for this plugin.""" - return self._notifiers - - notifiers = property(_notifiers) - - def _api_extension_descriptors(self): - """Return a list of API extension descriptors. - - Called by a project API during its load sequence. - """ - return self._api_extension_descriptors - - api_extension_descriptors = property(_api_extension_descriptors) - - # Most plugins will override this: - def __init__(self, service_name): - self._notifiers = [] - self._api_extension_descriptors = [] diff --git a/openstack/common/plugin/pluginmanager.py b/openstack/common/plugin/pluginmanager.py deleted file mode 100644 index 3962447..0000000 --- a/openstack/common/plugin/pluginmanager.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2012 OpenStack Foundation. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import pkg_resources - -from oslo.config import cfg - -from openstack.common.gettextutils import _ -from openstack.common import log as logging -from openstack.common.notifier import api as notifier_api - - -CONF = cfg.CONF -LOG = logging.getLogger(__name__) - - -class PluginManager(object): - """Manages plugin entrypoints and loading. - - For a service to implement this plugin interface for callback purposes: - - - Make use of the openstack-common notifier system - - Instantiate this manager in each process (passing in - project and service name) - - For an API service to extend itself using this plugin interface, - it needs to query the plugin_extension_factory provided by - the already-instantiated PluginManager. - """ - - def __init__(self, project_name, service_name): - """Construct Plugin Manager; load and initialize plugins. - - project_name (e.g. 'nova' or 'glance') is used - to construct the entry point that identifies plugins. - - The service_name (e.g. 'compute') is passed on to - each plugin as a raw string for it to do what it will. - """ - self._project_name = project_name - self._service_name = service_name - self.plugins = [] - - def load_plugins(self): - self.plugins = [] - - for entrypoint in pkg_resources.iter_entry_points('%s.plugin' % - self._project_name): - try: - pluginclass = entrypoint.load() - plugin = pluginclass(self._service_name) - self.plugins.append(plugin) - except Exception as exc: - LOG.error(_("Failed to load plugin %(plug)s: %(exc)s") % - {'plug': entrypoint, 'exc': exc}) - - # Register individual notifiers. - for plugin in self.plugins: - for notifier in plugin.notifiers: - notifier_api.add_driver(notifier) - - def plugin_extension_factory(self, ext_mgr): - for plugin in self.plugins: - descriptors = plugin.api_extension_descriptors - for descriptor in descriptors: - ext_mgr.load_extension(descriptor) diff --git a/openstack/common/rootwrap/filters.py b/openstack/common/rootwrap/filters.py index 0cc55ce..b40fdfd 100644 --- a/openstack/common/rootwrap/filters.py +++ b/openstack/common/rootwrap/filters.py @@ -217,7 +217,8 @@ class KillFilter(CommandFilter): return (os.path.isabs(command) and kill_command == os.path.basename(command) and - os.path.dirname(command) in os.environ['PATH'].split(':')) + os.path.dirname(command) in os.environ.get('PATH', '' + ).split(':')) class ReadFileFilter(CommandFilter): @@ -235,3 +236,116 @@ class ReadFileFilter(CommandFilter): if len(userargs) != 2: return False return True + + +class IpFilter(CommandFilter): + """Specific filter for the ip utility to that does not match exec.""" + + def match(self, userargs): + if userargs[0] == 'ip': + if userargs[1] == 'netns': + return (userargs[2] in ('list', 'add', 'delete')) + else: + return True + + +class EnvFilter(CommandFilter): + """Specific filter for the env utility. + + Behaves like CommandFilter, except that it handles + leading env A=B.. strings appropriately. + """ + + def _extract_env(self, arglist): + """Extract all leading NAME=VALUE arguments from arglist.""" + + envs = set() + for arg in arglist: + if '=' not in arg: + break + envs.add(arg.partition('=')[0]) + return envs + + def __init__(self, exec_path, run_as, *args): + super(EnvFilter, self).__init__(exec_path, run_as, *args) + + env_list = self._extract_env(self.args) + # Set exec_path to X when args are in the form of + # env A=a B=b C=c X Y Z + if "env" in exec_path and len(env_list) < len(self.args): + self.exec_path = self.args[len(env_list)] + + def match(self, userargs): + # ignore leading 'env' + if userargs[0] == 'env': + userargs.pop(0) + + # require one additional argument after configured ones + if len(userargs) < len(self.args): + return False + + # extract all env args + user_envs = self._extract_env(userargs) + filter_envs = self._extract_env(self.args) + user_command = userargs[len(user_envs):len(user_envs) + 1] + + # match first non-env argument with CommandFilter + return (super(EnvFilter, self).match(user_command) + and len(filter_envs) and user_envs == filter_envs) + + def exec_args(self, userargs): + args = userargs[:] + + # ignore leading 'env' + if args[0] == 'env': + args.pop(0) + + # Throw away leading NAME=VALUE arguments + while args and '=' in args[0]: + args.pop(0) + + return args + + def get_command(self, userargs, exec_dirs=[]): + to_exec = self.get_exec(exec_dirs=exec_dirs) or self.exec_path + return [to_exec] + self.exec_args(userargs)[1:] + + def get_environment(self, userargs): + env = os.environ.copy() + + # ignore leading 'env' + if userargs[0] == 'env': + userargs.pop(0) + + # Handle leading NAME=VALUE pairs + for a in userargs: + env_name, equals, env_value = a.partition('=') + if not equals: + break + if env_name and env_value: + env[env_name] = env_value + + return env + + +class ChainingFilter(CommandFilter): + def exec_args(self, userargs): + return [] + + +class IpNetnsExecFilter(ChainingFilter): + """Specific filter for the ip utility to that does match exec.""" + + def match(self, userargs): + # Network namespaces currently require root + # require <ns> argument + if self.run_as != "root" or len(userargs) < 4: + return False + + return (userargs[:3] == ['ip', 'netns', 'exec']) + + def exec_args(self, userargs): + args = userargs[4:] + if args: + args[0] = os.path.basename(args[0]) + return args diff --git a/openstack/common/rootwrap/wrapper.py b/openstack/common/rootwrap/wrapper.py index 5390c1b..6bd829e 100644 --- a/openstack/common/rootwrap/wrapper.py +++ b/openstack/common/rootwrap/wrapper.py @@ -46,8 +46,10 @@ class RootwrapConfig(object): if config.has_option("DEFAULT", "exec_dirs"): self.exec_dirs = config.get("DEFAULT", "exec_dirs").split(",") else: + self.exec_dirs = [] # Use system PATH if exec_dirs is not specified - self.exec_dirs = os.environ["PATH"].split(':') + if "PATH" in os.environ: + self.exec_dirs = os.environ['PATH'].split(':') # syslog_log_facility if config.has_option("DEFAULT", "syslog_log_facility"): @@ -131,6 +133,20 @@ def match_filter(filter_list, userargs, exec_dirs=[]): for f in filter_list: if f.match(userargs): + if isinstance(f, filters.ChainingFilter): + # This command calls exec verify that remaining args + # matches another filter. + def non_chain_filter(fltr): + return (fltr.run_as == f.run_as + and not isinstance(fltr, filters.ChainingFilter)) + + leaf_filters = [fltr for fltr in filter_list + if non_chain_filter(fltr)] + args = f.exec_args(userargs) + if (not args or not match_filter(leaf_filters, + args, exec_dirs=exec_dirs)): + continue + # Try other filters if executable is absent if not f.get_exec(exec_dirs=exec_dirs): if not first_not_executable_filter: diff --git a/openstack/common/rpc/amqp.py b/openstack/common/rpc/amqp.py index 22e01d7..c3e4e26 100644 --- a/openstack/common/rpc/amqp.py +++ b/openstack/common/rpc/amqp.py @@ -151,11 +151,13 @@ class ConnectionContext(rpc_common.Connection): def create_worker(self, topic, proxy, pool_name): self.connection.create_worker(topic, proxy, pool_name) - def join_consumer_pool(self, callback, pool_name, topic, exchange_name): + def join_consumer_pool(self, callback, pool_name, topic, exchange_name, + ack_on_error=True): self.connection.join_consumer_pool(callback, pool_name, topic, - exchange_name) + exchange_name, + ack_on_error) def consume_in_thread(self): self.connection.consume_in_thread() @@ -219,12 +221,7 @@ def msg_reply(conf, msg_id, reply_q, connection_pool, reply=None, failure = rpc_common.serialize_remote_exception(failure, log_failure) - try: - msg = {'result': reply, 'failure': failure} - except TypeError: - msg = {'result': dict((k, repr(v)) - for k, v in reply.__dict__.iteritems()), - 'failure': failure} + msg = {'result': reply, 'failure': failure} if ending: msg['ending'] = True _add_unique_id(msg) diff --git a/openstack/common/rpc/impl_kombu.py b/openstack/common/rpc/impl_kombu.py index c062d9a..36d2fc5 100644 --- a/openstack/common/rpc/impl_kombu.py +++ b/openstack/common/rpc/impl_kombu.py @@ -18,7 +18,6 @@ import functools import itertools import socket import ssl -import sys import time import uuid @@ -30,6 +29,7 @@ import kombu.entity import kombu.messaging from oslo.config import cfg +from openstack.common import excutils from openstack.common.gettextutils import _ from openstack.common import network_utils from openstack.common.rpc import amqp as rpc_amqp @@ -129,6 +129,7 @@ class ConsumerBase(object): self.tag = str(tag) self.kwargs = kwargs self.queue = None + self.ack_on_error = kwargs.get('ack_on_error', True) self.reconnect(channel) def reconnect(self, channel): @@ -138,6 +139,36 @@ class ConsumerBase(object): self.queue = kombu.entity.Queue(**self.kwargs) self.queue.declare() + def _callback_handler(self, message, callback): + """Call callback with deserialized message. + + Messages that are processed without exception are ack'ed. + + If the message processing generates an exception, it will be + ack'ed if ack_on_error=True. Otherwise it will be .reject()'ed. + Rejection is better than waiting for the message to timeout. + Rejected messages are immediately requeued. + """ + + ack_msg = False + try: + msg = rpc_common.deserialize_msg(message.payload) + callback(msg) + ack_msg = True + except Exception: + if self.ack_on_error: + ack_msg = True + LOG.exception(_("Failed to process message" + " ... skipping it.")) + else: + LOG.exception(_("Failed to process message" + " ... will requeue.")) + finally: + if ack_msg: + message.ack() + else: + message.reject() + def consume(self, *args, **kwargs): """Actually declare the consumer on the amqp channel. This will start the flow of messages from the queue. Using the @@ -150,8 +181,6 @@ class ConsumerBase(object): If kwargs['nowait'] is True, then this call will block until a message is read. - Messages will automatically be acked if the callback doesn't - raise an exception """ options = {'consumer_tag': self.tag} @@ -162,13 +191,7 @@ class ConsumerBase(object): def _callback(raw_message): message = self.channel.message_to_python(raw_message) - try: - msg = rpc_common.deserialize_msg(message.payload) - callback(msg) - except Exception: - LOG.exception(_("Failed to process message... skipping it.")) - finally: - message.ack() + self._callback_handler(message, callback) self.queue.consume(*args, callback=_callback, **options) @@ -537,13 +560,11 @@ class Connection(object): log_info.update(params) if self.max_retries and attempt == self.max_retries: - LOG.error(_('Unable to connect to AMQP server on ' - '%(hostname)s:%(port)d after %(max_retries)d ' - 'tries: %(err_str)s') % log_info) - # NOTE(comstud): Copied from original code. There's - # really no better recourse because if this was a queue we - # need to consume on, we have no way to consume anymore. - sys.exit(1) + msg = _('Unable to connect to AMQP server on ' + '%(hostname)s:%(port)d after %(max_retries)d ' + 'tries: %(err_str)s') % log_info + LOG.error(msg) + raise rpc_common.RPCException(msg) if attempt == 1: sleep_time = self.interval_start or 1 @@ -635,8 +656,8 @@ class Connection(object): def _consume(): if info['do_consume']: - queues_head = self.consumers[:-1] - queues_tail = self.consumers[-1] + queues_head = self.consumers[:-1] # not fanout. + queues_tail = self.consumers[-1] # fanout for queue in queues_head: queue.consume(nowait=True) queues_tail.consume(nowait=False) @@ -685,11 +706,12 @@ class Connection(object): self.declare_consumer(DirectConsumer, topic, callback) def declare_topic_consumer(self, topic, callback=None, queue_name=None, - exchange_name=None): + exchange_name=None, ack_on_error=True): """Create a 'topic' consumer.""" self.declare_consumer(functools.partial(TopicConsumer, name=queue_name, exchange_name=exchange_name, + ack_on_error=ack_on_error, ), topic, callback) @@ -724,6 +746,7 @@ class Connection(object): def consume_in_thread(self): """Consumer from all queues/consumers in a greenthread.""" + @excutils.forever_retry_uncaught_exceptions def _consumer_thread(): try: self.consume() @@ -754,7 +777,7 @@ class Connection(object): self.declare_topic_consumer(topic, proxy_cb, pool_name) def join_consumer_pool(self, callback, pool_name, topic, - exchange_name=None): + exchange_name=None, ack_on_error=True): """Register as a member of a group of consumers for a given topic from the specified exchange. @@ -775,6 +798,7 @@ class Connection(object): topic=topic, exchange_name=exchange_name, callback=callback_wrapper, + ack_on_error=ack_on_error, ) diff --git a/openstack/common/rpc/impl_qpid.py b/openstack/common/rpc/impl_qpid.py index 7352517..c988ae8 100644 --- a/openstack/common/rpc/impl_qpid.py +++ b/openstack/common/rpc/impl_qpid.py @@ -24,6 +24,7 @@ import eventlet import greenlet from oslo.config import cfg +from openstack.common import excutils from openstack.common.gettextutils import _ from openstack.common import importutils from openstack.common import jsonutils @@ -118,10 +119,17 @@ class ConsumerBase(object): self.address = "%s ; %s" % (node_name, jsonutils.dumps(addr_opts)) - self.reconnect(session) + self.connect(session) + + def connect(self, session): + """Declare the reciever on connect.""" + self._declare_receiver(session) def reconnect(self, session): """Re-declare the receiver after a qpid reconnect.""" + self._declare_receiver(session) + + def _declare_receiver(self, session): self.session = session self.receiver = session.receiver(self.address) self.receiver.capacity = 1 @@ -152,11 +160,15 @@ class ConsumerBase(object): except Exception: LOG.exception(_("Failed to process message... skipping it.")) finally: + # TODO(sandy): Need support for optional ack_on_error. self.session.acknowledge(message) def get_receiver(self): return self.receiver + def get_node_name(self): + return self.address.split(';')[0] + class DirectConsumer(ConsumerBase): """Queue/consumer class for 'direct'.""" @@ -206,6 +218,7 @@ class FanoutConsumer(ConsumerBase): 'topic' is the topic to listen on 'callback' is the callback to call when messages are received """ + self.conf = conf super(FanoutConsumer, self).__init__( session, callback, @@ -214,6 +227,18 @@ class FanoutConsumer(ConsumerBase): "%s_fanout_%s" % (topic, uuid.uuid4().hex), {"exclusive": True}) + def reconnect(self, session): + topic = self.get_node_name() + params = { + 'session': session, + 'topic': topic, + 'callback': self.callback, + } + + self.__init__(conf=self.conf, **params) + + super(FanoutConsumer, self).reconnect(session) + class Publisher(object): """Base Publisher class.""" @@ -575,6 +600,7 @@ class Connection(object): def consume_in_thread(self): """Consumer from all queues/consumers in a greenthread.""" + @excutils.forever_retry_uncaught_exceptions def _consumer_thread(): try: self.consume() @@ -615,7 +641,7 @@ class Connection(object): return consumer def join_consumer_pool(self, callback, pool_name, topic, - exchange_name=None): + exchange_name=None, ack_on_error=True): """Register as a member of a group of consumers for a given topic from the specified exchange. diff --git a/openstack/common/service.py b/openstack/common/service.py index 55e23ed..36cf300 100644 --- a/openstack/common/service.py +++ b/openstack/common/service.py @@ -27,6 +27,7 @@ import sys import time import eventlet +from eventlet import event import logging as std_logging from oslo.config import cfg @@ -51,20 +52,9 @@ class Launcher(object): :returns: None """ - self._services = threadgroup.ThreadGroup() + self.services = Services() self.backdoor_port = eventlet_backdoor.initialize_if_enabled() - @staticmethod - def run_service(service): - """Start and wait for a service to finish. - - :param service: service to run and wait for. - :returns: None - - """ - service.start() - service.wait() - def launch_service(self, service): """Load and start the given service. @@ -73,7 +63,7 @@ class Launcher(object): """ service.backdoor_port = self.backdoor_port - self._services.add_thread(self.run_service, service) + self.services.add(service) def stop(self): """Stop all services which are currently running. @@ -81,7 +71,7 @@ class Launcher(object): :returns: None """ - self._services.stop() + self.services.stop() def wait(self): """Waits until all services have been stopped, and then returns. @@ -89,7 +79,7 @@ class Launcher(object): :returns: None """ - self._services.wait() + self.services.wait() class SignalExit(SystemExit): @@ -124,9 +114,9 @@ class ServiceLauncher(Launcher): except SystemExit as exc: status = exc.code finally: + self.stop() if rpc: rpc.cleanup() - self.stop() return status @@ -189,7 +179,8 @@ class ProcessLauncher(object): random.seed() launcher = Launcher() - launcher.run_service(service) + launcher.launch_service(service) + launcher.wait() def _start_child(self, wrap): if len(wrap.forktimes) > wrap.workers: @@ -313,15 +304,60 @@ class Service(object): def __init__(self, threads=1000): self.tg = threadgroup.ThreadGroup(threads) + # signal that the service is done shutting itself down: + self._done = event.Event() + def start(self): pass def stop(self): self.tg.stop() + self.tg.wait() + self._done.send() + + def wait(self): + self._done.wait() + + +class Services(object): + + def __init__(self): + self.services = [] + self.tg = threadgroup.ThreadGroup() + self.done = event.Event() + + def add(self, service): + self.services.append(service) + self.tg.add_thread(self.run_service, service, self.done) + + def stop(self): + # wait for graceful shutdown of services: + for service in self.services: + service.stop() + service.wait() + + # each service has performed cleanup, now signal that the run_service + # wrapper threads can now die: + self.done.send() + + # reap threads: + self.tg.stop() def wait(self): self.tg.wait() + @staticmethod + def run_service(service, done): + """Service start wrapper. + + :param service: service to run + :param done: event to wait on until a shutdown is triggered + :returns: None + + """ + service.start() + done.wait() + def launch(service, workers=None): if workers: diff --git a/openstack/common/timeutils.py b/openstack/common/timeutils.py index ac2441b..bd60489 100644 --- a/openstack/common/timeutils.py +++ b/openstack/common/timeutils.py @@ -23,6 +23,7 @@ import calendar import datetime import iso8601 +import six # ISO 8601 extended time format with microseconds @@ -75,14 +76,14 @@ def normalize_time(timestamp): def is_older_than(before, seconds): """Return True if before is older than seconds.""" - if isinstance(before, basestring): + if isinstance(before, six.string_types): before = parse_strtime(before).replace(tzinfo=None) return utcnow() - before > datetime.timedelta(seconds=seconds) def is_newer_than(after, seconds): """Return True if after is newer than seconds.""" - if isinstance(after, basestring): + if isinstance(after, six.string_types): after = parse_strtime(after).replace(tzinfo=None) return after - utcnow() > datetime.timedelta(seconds=seconds) diff --git a/requirements.txt b/requirements.txt index 067af58..0a2d1c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,11 +7,13 @@ greenlet>=0.3.2 lxml routes==1.12.3 iso8601>=0.1.4 -anyjson==0.2.4 -kombu==1.0.4 +anyjson>=0.3.3 +kombu>2.4.7 argparse stevedore SQLAlchemy>=0.7.8,<=0.7.9 -oslo.config>=1.1.0 +http://tarballs.openstack.org/oslo.config/oslo.config-1.2.0a2.tar.gz#egg=oslo.config-1.2.0a2 qpid-python six +netaddr +pycrypto>=2.6 diff --git a/test-requirements.txt b/test-requirements.txt index a19b4af..8fbc5ab 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,15 +1,11 @@ coverage +discover fixtures>=0.3.12 flake8==2.0 -hacking>=0.5.3,<0.6 +hacking>=0.5.6,<0.6 mock mox==0.5.3 mysql-python -nose -nose-exclude -nosexcover -openstack.nose_plugin -nosehtmloutput pep8==1.4.5 pyflakes==0.7.2 pylint @@ -17,5 +13,5 @@ pyzmq==2.2.0.1 redis setuptools-git>=0.4 sphinx +testrepository>=0.0.13 testtools>=0.9.22 -webtest diff --git a/tests/unit/crypto/__init__.py b/tests/unit/crypto/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/unit/crypto/__init__.py diff --git a/tests/unit/crypto/test_utils.py b/tests/unit/crypto/test_utils.py new file mode 100644 index 0000000..3a39100 --- /dev/null +++ b/tests/unit/crypto/test_utils.py @@ -0,0 +1,186 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +Unit Tests for crypto utils. +""" + +from openstack.common.crypto import utils as cryptoutils +from tests import utils as test_utils + + +class CryptoUtilsTestCase(test_utils.BaseTestCase): + + # Uses Tests from RFC5869 + def _test_HKDF(self, ikm, prk, okm, length, + salt=None, info='', hashtype='SHA256'): + hkdf = cryptoutils.HKDF(hashtype=hashtype) + + tprk = hkdf.extract(ikm, salt=salt) + self.assertEqual(prk, tprk) + + tokm = hkdf.expand(prk, info, length) + self.assertEqual(okm, tokm) + + def test_HKDF_1(self): + ikm = '\x0b' * 22 + salt = ''.join(map(lambda x: chr(x), range(0x00, 0x0d))) + info = ''.join(map(lambda x: chr(x), range(0xf0, 0xfa))) + length = 42 + + prk = ('\x07\x77\x09\x36\x2c\x2e\x32\xdf\x0d\xdc\x3f\x0d\xc4\x7b' + '\xba\x63\x90\xb6\xc7\x3b\xb5\x0f\x9c\x31\x22\xec\x84\x4a' + '\xd7\xc2\xb3\xe5') + + okm = ('\x3c\xb2\x5f\x25\xfa\xac\xd5\x7a\x90\x43\x4f\x64\xd0\x36' + '\x2f\x2a\x2d\x2d\x0a\x90\xcf\x1a\x5a\x4c\x5d\xb0\x2d\x56' + '\xec\xc4\xc5\xbf\x34\x00\x72\x08\xd5\xb8\x87\x18\x58\x65') + + self._test_HKDF(ikm, prk, okm, length, salt, info) + + def test_HKDF_2(self): + ikm = ''.join(map(lambda x: chr(x), range(0x00, 0x50))) + salt = ''.join(map(lambda x: chr(x), range(0x60, 0xb0))) + info = ''.join(map(lambda x: chr(x), range(0xb0, 0x100))) + length = 82 + + prk = ('\x06\xa6\xb8\x8c\x58\x53\x36\x1a\x06\x10\x4c\x9c\xeb\x35' + '\xb4\x5c\xef\x76\x00\x14\x90\x46\x71\x01\x4a\x19\x3f\x40' + '\xc1\x5f\xc2\x44') + + okm = ('\xb1\x1e\x39\x8d\xc8\x03\x27\xa1\xc8\xe7\xf7\x8c\x59\x6a' + '\x49\x34\x4f\x01\x2e\xda\x2d\x4e\xfa\xd8\xa0\x50\xcc\x4c' + '\x19\xaf\xa9\x7c\x59\x04\x5a\x99\xca\xc7\x82\x72\x71\xcb' + '\x41\xc6\x5e\x59\x0e\x09\xda\x32\x75\x60\x0c\x2f\x09\xb8' + '\x36\x77\x93\xa9\xac\xa3\xdb\x71\xcc\x30\xc5\x81\x79\xec' + '\x3e\x87\xc1\x4c\x01\xd5\xc1\xf3\x43\x4f\x1d\x87') + + self._test_HKDF(ikm, prk, okm, length, salt, info) + + def test_HKDF_3(self): + ikm = '\x0b' * 22 + length = 42 + + prk = ('\x19\xef\x24\xa3\x2c\x71\x7b\x16\x7f\x33\xa9\x1d\x6f\x64' + '\x8b\xdf\x96\x59\x67\x76\xaf\xdb\x63\x77\xac\x43\x4c\x1c' + '\x29\x3c\xcb\x04') + + okm = ('\x8d\xa4\xe7\x75\xa5\x63\xc1\x8f\x71\x5f\x80\x2a\x06\x3c' + '\x5a\x31\xb8\xa1\x1f\x5c\x5e\xe1\x87\x9e\xc3\x45\x4e\x5f' + '\x3c\x73\x8d\x2d\x9d\x20\x13\x95\xfa\xa4\xb6\x1a\x96\xc8') + + self._test_HKDF(ikm, prk, okm, length) + + def test_HKDF_4(self): + ikm = '\x0b' * 11 + salt = ''.join(map(lambda x: chr(x), range(0x00, 0x0d))) + info = ''.join(map(lambda x: chr(x), range(0xf0, 0xfa))) + length = 42 + + prk = ('\x9b\x6c\x18\xc4\x32\xa7\xbf\x8f\x0e\x71\xc8\xeb\x88\xf4' + '\xb3\x0b\xaa\x2b\xa2\x43') + + okm = ('\x08\x5a\x01\xea\x1b\x10\xf3\x69\x33\x06\x8b\x56\xef\xa5' + '\xad\x81\xa4\xf1\x4b\x82\x2f\x5b\x09\x15\x68\xa9\xcd\xd4' + '\xf1\x55\xfd\xa2\xc2\x2e\x42\x24\x78\xd3\x05\xf3\xf8\x96') + + self._test_HKDF(ikm, prk, okm, length, salt, info, hashtype='SHA') + + def test_HKDF_5(self): + ikm = ''.join(map(lambda x: chr(x), range(0x00, 0x50))) + salt = ''.join(map(lambda x: chr(x), range(0x60, 0xb0))) + info = ''.join(map(lambda x: chr(x), range(0xb0, 0x100))) + length = 82 + + prk = ('\x8a\xda\xe0\x9a\x2a\x30\x70\x59\x47\x8d\x30\x9b\x26\xc4' + '\x11\x5a\x22\x4c\xfa\xf6') + + okm = ('\x0b\xd7\x70\xa7\x4d\x11\x60\xf7\xc9\xf1\x2c\xd5\x91\x2a' + '\x06\xeb\xff\x6a\xdc\xae\x89\x9d\x92\x19\x1f\xe4\x30\x56' + '\x73\xba\x2f\xfe\x8f\xa3\xf1\xa4\xe5\xad\x79\xf3\xf3\x34' + '\xb3\xb2\x02\xb2\x17\x3c\x48\x6e\xa3\x7c\xe3\xd3\x97\xed' + '\x03\x4c\x7f\x9d\xfe\xb1\x5c\x5e\x92\x73\x36\xd0\x44\x1f' + '\x4c\x43\x00\xe2\xcf\xf0\xd0\x90\x0b\x52\xd3\xb4') + + self._test_HKDF(ikm, prk, okm, length, salt, info, hashtype='SHA') + + def test_HKDF_6(self): + ikm = '\x0b' * 22 + length = 42 + + prk = ('\xda\x8c\x8a\x73\xc7\xfa\x77\x28\x8e\xc6\xf5\xe7\xc2\x97' + '\x78\x6a\xa0\xd3\x2d\x01') + + okm = ('\x0a\xc1\xaf\x70\x02\xb3\xd7\x61\xd1\xe5\x52\x98\xda\x9d' + '\x05\x06\xb9\xae\x52\x05\x72\x20\xa3\x06\xe0\x7b\x6b\x87' + '\xe8\xdf\x21\xd0\xea\x00\x03\x3d\xe0\x39\x84\xd3\x49\x18') + + self._test_HKDF(ikm, prk, okm, length, hashtype='SHA') + + def test_HKDF_7(self): + ikm = '\x0c' * 22 + length = 42 + + prk = ('\x2a\xdc\xca\xda\x18\x77\x9e\x7c\x20\x77\xad\x2e\xb1\x9d' + '\x3f\x3e\x73\x13\x85\xdd') + + okm = ('\x2c\x91\x11\x72\x04\xd7\x45\xf3\x50\x0d\x63\x6a\x62\xf6' + '\x4f\x0a\xb3\xba\xe5\x48\xaa\x53\xd4\x23\xb0\xd1\xf2\x7e' + '\xbb\xa6\xf5\xe5\x67\x3a\x08\x1d\x70\xcc\xe7\xac\xfc\x48') + + self._test_HKDF(ikm, prk, okm, length, hashtype='SHA') + + def test_HKDF_8(self): + ikm = '\x0b' * 22 + prk = ('\x19\xef\x24\xa3\x2c\x71\x7b\x16\x7f\x33\xa9\x1d\x6f\x64' + '\x8b\xdf\x96\x59\x67\x76\xaf\xdb\x63\x77\xac\x43\x4c\x1c' + '\x29\x3c\xcb\x04') + + # Just testing HKDFOutputLengthTooLong is returned + try: + self._test_HKDF(ikm, prk, None, 1000000) + except cryptoutils.HKDFOutputLengthTooLong: + pass + + def test_SymmetricCrypto_encrypt_string(self): + msg = 'Plain Text' + + skc = cryptoutils.SymmetricCrypto() + key = skc.new_key(16) + cipher = skc.encrypt(key, msg) + plain = skc.decrypt(key, cipher) + self.assertEqual(msg, plain) + + def test_SymmetricCrypto_encrypt_blocks(self): + cb = 16 + et = 'AES' + + skc = cryptoutils.SymmetricCrypto(enctype=et) + key = skc.new_key(16) + msg = skc.new_key(cb * 2) + + for i in range(0, cb * 2): + cipher = skc.encrypt(key, msg[0:i], b64encode=False) + plain = skc.decrypt(key, cipher, b64decode=False) + self.assertEqual(msg[0:i], plain) + + def test_SymmetricCrypto_signing(self): + msg = 'Authenticated Message' + signature = 'KWjl6i30RMjc5PjnaccRwTPKTRCWM6sPpmGS2bxm5fQ=' + skey = 'L\xdd0\xf3\xb4\xc6\xe2p\xef\xc7\xbd\xaa\xc9eNC' + + skc = cryptoutils.SymmetricCrypto() + validate = skc.sign(skey, msg) + self.assertEqual(signature, validate) diff --git a/tests/unit/db/sqlalchemy/test_sqlalchemy.py b/tests/unit/db/sqlalchemy/test_sqlalchemy.py index ac178b8..48d6cf7 100644 --- a/tests/unit/db/sqlalchemy/test_sqlalchemy.py +++ b/tests/unit/db/sqlalchemy/test_sqlalchemy.py @@ -52,15 +52,15 @@ sql_max_overflow=50 sql_connection_debug=60 sql_connection_trace=True """)]) - test_utils.CONF(['--config-file', paths[0]]) - self.assertEquals(test_utils.CONF.database.connection, 'x://y.z') - self.assertEquals(test_utils.CONF.database.min_pool_size, 10) - self.assertEquals(test_utils.CONF.database.max_pool_size, 20) - self.assertEquals(test_utils.CONF.database.max_retries, 30) - self.assertEquals(test_utils.CONF.database.retry_interval, 40) - self.assertEquals(test_utils.CONF.database.max_overflow, 50) - self.assertEquals(test_utils.CONF.database.connection_debug, 60) - self.assertEquals(test_utils.CONF.database.connection_trace, True) + self.conf(['--config-file', paths[0]]) + self.assertEquals(self.conf.database.connection, 'x://y.z') + self.assertEquals(self.conf.database.min_pool_size, 10) + self.assertEquals(self.conf.database.max_pool_size, 20) + self.assertEquals(self.conf.database.max_retries, 30) + self.assertEquals(self.conf.database.retry_interval, 40) + self.assertEquals(self.conf.database.max_overflow, 50) + self.assertEquals(self.conf.database.connection_debug, 60) + self.assertEquals(self.conf.database.connection_trace, True) def test_session_parameters(self): paths = self.create_tempfiles([('test', """[database] @@ -74,16 +74,39 @@ connection_debug=60 connection_trace=True pool_timeout=7 """)]) - test_utils.CONF(['--config-file', paths[0]]) - self.assertEquals(test_utils.CONF.database.connection, 'x://y.z') - self.assertEquals(test_utils.CONF.database.min_pool_size, 10) - self.assertEquals(test_utils.CONF.database.max_pool_size, 20) - self.assertEquals(test_utils.CONF.database.max_retries, 30) - self.assertEquals(test_utils.CONF.database.retry_interval, 40) - self.assertEquals(test_utils.CONF.database.max_overflow, 50) - self.assertEquals(test_utils.CONF.database.connection_debug, 60) - self.assertEquals(test_utils.CONF.database.connection_trace, True) - self.assertEquals(test_utils.CONF.database.pool_timeout, 7) + self.conf(['--config-file', paths[0]]) + self.assertEquals(self.conf.database.connection, 'x://y.z') + self.assertEquals(self.conf.database.min_pool_size, 10) + self.assertEquals(self.conf.database.max_pool_size, 20) + self.assertEquals(self.conf.database.max_retries, 30) + self.assertEquals(self.conf.database.retry_interval, 40) + self.assertEquals(self.conf.database.max_overflow, 50) + self.assertEquals(self.conf.database.connection_debug, 60) + self.assertEquals(self.conf.database.connection_trace, True) + self.assertEquals(self.conf.database.pool_timeout, 7) + + def test_dbapi_database_deprecated_parameters(self): + paths = self.create_tempfiles([('test', + '[DATABASE]\n' + 'sql_connection=fake_connection\n' + 'sql_idle_timeout=100\n' + 'sql_min_pool_size=99\n' + 'sql_max_pool_size=199\n' + 'sql_max_retries=22\n' + 'reconnect_interval=17\n' + 'sqlalchemy_max_overflow=101\n' + 'sqlalchemy_pool_timeout=5\n' + )]) + self.conf(['--config-file', paths[0]]) + self.assertEquals(self.conf.database.connection, + 'fake_connection') + self.assertEquals(self.conf.database.idle_timeout, 100) + self.assertEquals(self.conf.database.min_pool_size, 99) + self.assertEquals(self.conf.database.max_pool_size, 199) + self.assertEquals(self.conf.database.max_retries, 22) + self.assertEquals(self.conf.database.retry_interval, 17) + self.assertEquals(self.conf.database.max_overflow, 101) + self.assertEquals(self.conf.database.pool_timeout, 5) class SessionErrorWrapperTestCase(test_base.DbTestCase): diff --git a/tests/unit/db/test_api.py b/tests/unit/db/test_api.py index f6e0d4c..2a8db3b 100644 --- a/tests/unit/db/test_api.py +++ b/tests/unit/db/test_api.py @@ -40,9 +40,9 @@ class DBAPITestCase(test_utils.BaseTestCase): 'dbapi_use_tpool=True\n' )]) - test_utils.CONF(['--config-file', paths[0]]) - self.assertEquals(test_utils.CONF.database.backend, 'test_123') - self.assertEquals(test_utils.CONF.database.use_tpool, True) + self.conf(['--config-file', paths[0]]) + self.assertEquals(self.conf.database.backend, 'test_123') + self.assertEquals(self.conf.database.use_tpool, True) def test_dbapi_parameters(self): paths = self.create_tempfiles([('test', @@ -51,9 +51,9 @@ class DBAPITestCase(test_utils.BaseTestCase): 'use_tpool=True\n' )]) - test_utils.CONF(['--config-file', paths[0]]) - self.assertEquals(test_utils.CONF.database.backend, 'test_123') - self.assertEquals(test_utils.CONF.database.use_tpool, True) + self.conf(['--config-file', paths[0]]) + self.assertEquals(self.conf.database.backend, 'test_123') + self.assertEquals(self.conf.database.use_tpool, True) def test_dbapi_api_class_method_and_tpool_false(self): backend_mapping = {'test_known': 'tests.unit.db.test_api'} diff --git a/tests/unit/plugin/__init__.py b/tests/unit/plugin/__init__.py deleted file mode 100644 index b706747..0000000 --- a/tests/unit/plugin/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2012 OpenStack Foundation. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. diff --git a/tests/unit/plugin/test_callback_plugin.py b/tests/unit/plugin/test_callback_plugin.py deleted file mode 100644 index 3f3fd63..0000000 --- a/tests/unit/plugin/test_callback_plugin.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2012 OpenStack Foundation. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import pkg_resources - -from openstack.common.notifier import api as notifier_api -from openstack.common.plugin import callbackplugin -from openstack.common.plugin import pluginmanager -from tests import utils as test_utils - -userdatastring = "magic user data string" - - -class TestCBP(callbackplugin.CallbackPlugin): - - def callback1(self, context, message, userdata): - self.callback1count += 1 - - def callback2(self, context, message, userdata): - self.callback2count += 1 - - def callback3(self, context, message, userdata): - self.callback3count += 1 - self.userdata = userdata - - def __init__(self, service_name): - super(TestCBP, self).__init__(service_name) - self.callback1count = 0 - self.callback2count = 0 - self.callback3count = 0 - - self._add_callback(self.callback1, 'type1', None) - self._add_callback(self.callback2, 'type1', None) - self._add_callback(self.callback3, 'type2', 'magic user data string') - - -class CallbackTestCase(test_utils.BaseTestCase): - """Tests for the callback plugin convenience class.""" - - def test_callback_plugin_subclass(self): - - class MockEntrypoint(pkg_resources.EntryPoint): - def load(self): - return TestCBP - - def mock_iter_entry_points(_t): - return [MockEntrypoint("fake", "fake", ["fake"])] - - self.stubs.Set(pkg_resources, 'iter_entry_points', - mock_iter_entry_points) - - plugmgr = pluginmanager.PluginManager("testproject", "testservice") - plugmgr.load_plugins() - self.assertEqual(len(plugmgr.plugins), 1) - plugin = plugmgr.plugins[0] - self.assertEqual(len(plugin.notifiers), 1) - - notifier_api.notify('contextarg', 'publisher_id', 'type1', - notifier_api.WARN, dict(a=3)) - - self.assertEqual(plugin.callback1count, 1) - self.assertEqual(plugin.callback2count, 1) - self.assertEqual(plugin.callback3count, 0) - - notifier_api.notify('contextarg', 'publisher_id', 'type2', - notifier_api.WARN, dict(a=3)) - - self.assertEqual(plugin.callback1count, 1) - self.assertEqual(plugin.callback2count, 1) - self.assertEqual(plugin.callback3count, 1) - self.assertEqual(plugin.userdata, userdatastring) - - plugin._remove_callback(plugin.callback1) - - notifier_api.notify('contextarg', 'publisher_id', 'type1', - notifier_api.WARN, dict(a=3)) - - self.assertEqual(plugin.callback1count, 1) - self.assertEqual(plugin.callback2count, 2) - self.assertEqual(plugin.callback3count, 1) diff --git a/tests/unit/rpc/test_common.py b/tests/unit/rpc/test_common.py index c2432f4..6f32005 100644 --- a/tests/unit/rpc/test_common.py +++ b/tests/unit/rpc/test_common.py @@ -108,7 +108,7 @@ class RpcCommonTestCase(test_utils.BaseTestCase): '__unicode__': str_override}) new_ex_type.__module__ = '%s_Remote' % e.__class__.__module__ e.__class__ = new_ex_type - raise e + raise try: raise_remote_exception() diff --git a/tests/unit/rpc/test_kombu.py b/tests/unit/rpc/test_kombu.py index 159fefb..cbe948d 100644 --- a/tests/unit/rpc/test_kombu.py +++ b/tests/unit/rpc/test_kombu.py @@ -23,11 +23,15 @@ import eventlet eventlet.monkey_patch() import contextlib +import functools import logging +import weakref +import fixtures import mock from oslo.config import cfg import six +import time from openstack.common import exception from openstack.common.rpc import amqp as rpc_amqp @@ -37,6 +41,8 @@ from tests import utils try: import kombu + import kombu.connection + import kombu.entity from openstack.common.rpc import impl_kombu except ImportError: kombu = None @@ -65,25 +71,50 @@ def _raise_exc_stub(stubs, times, obj, method, exc_msg, return info -class KombuStubs: - @staticmethod +class KombuStubs(fixtures.Fixture): + def __init__(self, test): + super(KombuStubs, self).__init__() + + # NOTE(rpodolyaka): use a weak ref here to prevent ref cycles + self.test = weakref.ref(test) + def setUp(self): + super(KombuStubs, self).setUp() + + test = self.test() if kombu: - self.config(fake_rabbit=True) - self.config(rpc_response_timeout=5) - self.rpc = impl_kombu + test.conf = FLAGS + test.config(fake_rabbit=True) + test.config(rpc_response_timeout=5) + test.rpc = impl_kombu self.addCleanup(impl_kombu.cleanup) else: - self.rpc = None + test.rpc = None + + +class FakeMessage(object): + acked = False + rejected = False + + def __init__(self, payload): + self.payload = payload + + def ack(self): + self.acked = True + + def reject(self): + self.rejected = True class RpcKombuTestCase(amqp.BaseRpcAMQPTestCase): def setUp(self): - KombuStubs.setUp(self) - super(RpcKombuTestCase, self).setUp() if kombu is None: self.skipTest("Test requires kombu") + self.useFixture(KombuStubs(self)) + + super(RpcKombuTestCase, self).setUp() + def test_reusing_connection(self): """Test that reusing a connection returns same one.""" conn_context = self.rpc.create_connection(FLAGS, new=False) @@ -112,6 +143,74 @@ class RpcKombuTestCase(amqp.BaseRpcAMQPTestCase): self.assertEqual(self.received_message, message) + def test_callback_handler_ack_on_error(self): + """The default case will ack on error. Same as before. + """ + def _callback(msg): + pass + + conn = self.rpc.create_connection(FLAGS) + consumer = conn.declare_consumer(functools.partial( + impl_kombu.TopicConsumer, + name=None, + exchange_name=None), + "a_topic", _callback) + message = FakeMessage("some message") + consumer._callback_handler(message, _callback) + self.assertTrue(message.acked) + self.assertFalse(message.rejected) + + def test_callback_handler_ack_on_error_exception(self): + + def _callback(msg): + raise MyException() + + conn = self.rpc.create_connection(FLAGS) + consumer = conn.declare_consumer(functools.partial( + impl_kombu.TopicConsumer, + name=None, + exchange_name=None, + ack_on_error=True), + "a_topic", _callback) + message = FakeMessage("some message") + consumer._callback_handler(message, _callback) + self.assertTrue(message.acked) + self.assertFalse(message.rejected) + + def test_callback_handler_no_ack_on_error_exception(self): + + def _callback(msg): + raise MyException() + + conn = self.rpc.create_connection(FLAGS) + consumer = conn.declare_consumer(functools.partial( + impl_kombu.TopicConsumer, + name=None, + exchange_name=None, + ack_on_error=False), + "a_topic", _callback) + message = FakeMessage("some message") + consumer._callback_handler(message, _callback) + self.assertFalse(message.acked) + self.assertTrue(message.rejected) + + def test_callback_handler_no_ack_on_error(self): + + def _callback(msg): + pass + + conn = self.rpc.create_connection(FLAGS) + consumer = conn.declare_consumer(functools.partial( + impl_kombu.TopicConsumer, + name=None, + exchange_name=None, + ack_on_error=False), + "a_topic", _callback) + message = FakeMessage("some message") + consumer._callback_handler(message, _callback) + self.assertTrue(message.acked) + self.assertFalse(message.rejected) + def test_message_ttl_on_timeout(self): """Test message ttl being set by request timeout. The message should die on the vine and never arrive. @@ -308,6 +407,22 @@ class RpcKombuTestCase(amqp.BaseRpcAMQPTestCase): impl_kombu.cast_to_server(FLAGS, ctxt, server_params, 'fake_topic', {'msg': 'fake'}) + def test_fanout_success(self): + # Overriding the method in the base class because the test + # seems to fail for the same reason as + # test_fanout_send_receive(). + self.skipTest("kombu memory transport seems buggy with " + "fanout queues as this test passes when " + "you use rabbit (fake_rabbit=False)") + + def test_cast_success_despite_missing_args(self): + # Overriding the method in the base class because the test + # seems to fail for the same reason as + # test_fanout_send_receive(). + self.skipTest("kombu memory transport seems buggy with " + "fanout queues as this test passes when " + "you use rabbit (fake_rabbit=False)") + def test_fanout_send_receive(self): """Test sending to a fanout exchange and consuming from 2 queues.""" @@ -514,7 +629,7 @@ class RpcKombuTestCase(amqp.BaseRpcAMQPTestCase): 'pool.name', ) - def test_join_consumer_pool(self): + def test_join_consumer_pool_default(self): meth = 'declare_topic_consumer' with mock.patch.object(self.rpc.Connection, meth) as p: conn = self.rpc.create_connection(FLAGS) @@ -529,13 +644,108 @@ class RpcKombuTestCase(amqp.BaseRpcAMQPTestCase): queue_name='pool.name', exchange_name='exchange.name', topic='topic.name', + ack_on_error=True, ) + def test_join_consumer_pool_no_ack(self): + meth = 'declare_topic_consumer' + with mock.patch.object(self.rpc.Connection, meth) as p: + conn = self.rpc.create_connection(FLAGS) + conn.join_consumer_pool( + callback=lambda *a, **k: (a, k), + pool_name='pool.name', + topic='topic.name', + exchange_name='exchange.name', + ack_on_error=False, + ) + p.assert_called_with( + callback=mock.ANY, # the callback wrapper + queue_name='pool.name', + exchange_name='exchange.name', + topic='topic.name', + ack_on_error=False, + ) + + # used to make unexpected exception tests run faster + def my_time_sleep(self, sleep_time): + return + + def test_service_consume_thread_unexpected_exceptions(self): + + def my_TopicConsumer_consume(myself, *args, **kwargs): + self.consume_calls += 1 + # see if it can sustain three failures + if self.consume_calls < 3: + raise Exception('unexpected exception') + else: + self.orig_TopicConsumer_consume(myself, *args, **kwargs) + + self.consume_calls = 0 + self.orig_TopicConsumer_consume = impl_kombu.TopicConsumer.consume + self.stubs.Set(impl_kombu.TopicConsumer, 'consume', + my_TopicConsumer_consume) + self.stubs.Set(time, 'sleep', self.my_time_sleep) + + value = 42 + result = self.rpc.call(FLAGS, self.context, self.topic, + {"method": "echo", + "args": {"value": value}}) + self.assertEqual(value, result) + + def test_replyproxy_consume_thread_unexpected_exceptions(self): + + def my_DirectConsumer_consume(myself, *args, **kwargs): + self.consume_calls += 1 + # see if it can sustain three failures + if self.consume_calls < 3: + raise Exception('unexpected exception') + else: + self.orig_DirectConsumer_consume(myself, *args, **kwargs) + + self.consume_calls = 1 + self.orig_DirectConsumer_consume = impl_kombu.DirectConsumer.consume + self.stubs.Set(impl_kombu.DirectConsumer, 'consume', + my_DirectConsumer_consume) + self.stubs.Set(time, 'sleep', self.my_time_sleep) + + value = 42 + result = self.rpc.call(FLAGS, self.context, self.topic, + {"method": "echo", + "args": {"value": value}}) + self.assertEqual(value, result) + + def test_reconnect_max_retries(self): + self.config(rabbit_hosts=[ + 'host1:1234', 'host2:5678', '[::1]:2345', + '[2001:0db8:85a3:0042:0000:8a2e:0370:7334]'], + rabbit_max_retries=2, + rabbit_retry_interval=0.1, + rabbit_retry_backoff=0.1) + + info = {'attempt': 0} + + class MyConnection(kombu.connection.BrokerConnection): + def __init__(self, *args, **params): + super(MyConnection, self).__init__(*args, **params) + info['attempt'] += 1 + + def connect(self): + if info['attempt'] < 3: + # the word timeout is important (see impl_kombu.py:486) + raise Exception('connection timeout') + super(kombu.connection.BrokerConnection, self).connect() + + self.stubs.Set(kombu.connection, 'BrokerConnection', MyConnection) + + self.assertRaises(rpc_common.RPCException, self.rpc.Connection, FLAGS) + self.assertEqual(info['attempt'], 2) + class RpcKombuHATestCase(utils.BaseTestCase): def setUp(self): super(RpcKombuHATestCase, self).setUp() - KombuStubs.setUp(self) + + self.useFixture(KombuStubs(self)) self.addCleanup(FLAGS.reset) def test_roundrobin_reconnect(self): @@ -576,15 +786,13 @@ class RpcKombuHATestCase(utils.BaseTestCase): ] } - import kombu.connection - class MyConnection(kombu.connection.BrokerConnection): def __init__(myself, *args, **params): super(MyConnection, myself).__init__(*args, **params) self.assertEqual(params, info['params_list'][info['attempt'] % len(info['params_list'])]) - info['attempt'] = info['attempt'] + 1 + info['attempt'] += 1 def connect(myself): if info['attempt'] < 5: @@ -601,8 +809,6 @@ class RpcKombuHATestCase(utils.BaseTestCase): def test_queue_not_declared_ha_if_ha_off(self): self.config(rabbit_ha_queues=False) - import kombu.entity - def my_declare(myself): self.assertEqual(None, (myself.queue_arguments or {}).get('x-ha-policy')) @@ -615,8 +821,6 @@ class RpcKombuHATestCase(utils.BaseTestCase): def test_queue_declared_ha_if_ha_on(self): self.config(rabbit_ha_queues=True) - import kombu.entity - def my_declare(myself): self.assertEqual('all', (myself.queue_arguments or {}).get('x-ha-policy')) diff --git a/tests/unit/rpc/test_qpid.py b/tests/unit/rpc/test_qpid.py index 0bad387..5d51a4b 100644 --- a/tests/unit/rpc/test_qpid.py +++ b/tests/unit/rpc/test_qpid.py @@ -26,6 +26,7 @@ eventlet.monkey_patch() import fixtures import mox from oslo.config import cfg +import time import uuid from openstack.common import context @@ -218,7 +219,7 @@ class RpcQpidTestCase(utils.BaseTestCase): ) connection.close() - def test_topic_consumer(self): + def test_topic_consumer(self, consume_thread_exc=False): self.mock_connection = self.mox.CreateMock(self.orig_connection) self.mock_session = self.mox.CreateMock(self.orig_session) self.mock_receiver = self.mox.CreateMock(self.orig_receiver) @@ -235,6 +236,9 @@ class RpcQpidTestCase(utils.BaseTestCase): self.mock_session.receiver(expected_address).AndReturn( self.mock_receiver) self.mock_receiver.capacity = 1 + if consume_thread_exc: + self.mock_session.next_receiver(timeout=None).AndRaise( + Exception('unexpected exception')) self.mock_connection.close() self.mox.ReplayAll() @@ -244,8 +248,14 @@ class RpcQpidTestCase(utils.BaseTestCase): lambda *_x, **_y: None, queue_name='impl.qpid.test.workers', exchange_name='foobar') + if consume_thread_exc: + connection.consume_in_thread() + time.sleep(0) connection.close() + def test_consume_thread_exception(self): + self.test_topic_consumer(consume_thread_exc=True) + def _test_cast(self, fanout, server_params=None): self.mock_connection = self.mox.CreateMock(self.orig_connection) self.mock_session = self.mox.CreateMock(self.orig_session) @@ -338,7 +348,11 @@ class RpcQpidTestCase(utils.BaseTestCase): self._setup_to_server_tests(server_params) self._test_cast(fanout=True, server_params=server_params) + def my_time_sleep(self, arg): + pass + def _test_call_mock_common(self): + self.stubs.Set(time, 'sleep', self.my_time_sleep) self.mock_connection = self.mox.CreateMock(self.orig_connection) self.mock_session = self.mox.CreateMock(self.orig_session) self.mock_sender = self.mox.CreateMock(self.orig_sender) @@ -367,9 +381,12 @@ class RpcQpidTestCase(utils.BaseTestCase): self.mock_session.close() self.mock_connection.session().AndReturn(self.mock_session) - def _test_call(self, multi): + def _test_call(self, multi, reply_proxy_exc): self._test_call_mock_common() + if reply_proxy_exc: + self.mock_session.next_receiver(timeout=None).AndRaise( + Exception('unexpected exception')) self.mock_session.next_receiver(timeout=None).AndReturn( self.mock_receiver) self.mock_receiver.fetch().AndReturn(qpid.messaging.Message( @@ -393,6 +410,9 @@ class RpcQpidTestCase(utils.BaseTestCase): "failure": False, "ending": False})) self.mock_session.acknowledge(mox.IgnoreArg()) + if reply_proxy_exc: + self.mock_session.next_receiver(timeout=None).AndRaise( + Exception('unexpected exception')) self.mock_session.next_receiver(timeout=None).AndReturn( self.mock_receiver) self.mock_receiver.fetch().AndReturn(qpid.messaging.Message( @@ -425,7 +445,10 @@ class RpcQpidTestCase(utils.BaseTestCase): self.uuid4 = uuid.uuid4() def test_call(self): - self._test_call(multi=False) + self._test_call(multi=False, reply_proxy_exc=False) + + def test_replyproxy_consume_thread_unexpected_exceptions(self): + self._test_call(multi=False, reply_proxy_exc=True) def _test_call_with_timeout(self, timeout, expect_failure): self._test_call_mock_common() @@ -483,7 +506,7 @@ class RpcQpidTestCase(utils.BaseTestCase): self._test_call_with_timeout(timeout=0.1, expect_failure=True) def test_multicall(self): - self._test_call(multi=True) + self._test_call(multi=True, reply_proxy_exc=False) def _test_publisher(self, message=True): """Test that messages containing long strings are correctly serialized diff --git a/tests/unit/rpc/test_zmq.py b/tests/unit/rpc/test_zmq.py index b0f0262..c87a040 100644 --- a/tests/unit/rpc/test_zmq.py +++ b/tests/unit/rpc/test_zmq.py @@ -60,6 +60,7 @@ class _RpcZmqBaseTestCase(common.BaseRpcTestCase): self.reactor = None self.rpc = impl_zmq + self.conf = FLAGS self.config(rpc_zmq_bind_address='127.0.0.1') self.config(rpc_zmq_host='127.0.0.1') self.config(rpc_response_timeout=5) diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index 0db7aaa..2f9a3de 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -24,3 +24,7 @@ class ContextTest(utils.BaseTestCase): def test_context(self): ctx = context.RequestContext() self.assertTrue(ctx) + + def test_admin_context_show_deleted_flag_default(self): + ctx = context.get_admin_context() + self.assertFalse(ctx.show_deleted) diff --git a/tests/unit/test_eventlet_backdoor.py b/tests/unit/test_eventlet_backdoor.py new file mode 100644 index 0000000..986678e --- /dev/null +++ b/tests/unit/test_eventlet_backdoor.py @@ -0,0 +1,71 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Unit Tests for eventlet backdoor +""" +import errno +import eventlet +import mox +import socket + +from openstack.common import eventlet_backdoor +from tests import utils + + +class BackdoorPortTest(utils.BaseTestCase): + + def common_backdoor_port_setup(self): + self.sock = self.mox.CreateMockAnything() + self.mox.StubOutWithMock(eventlet, 'listen') + self.mox.StubOutWithMock(eventlet, 'spawn_n') + + def test_backdoor_port_inuse(self): + self.config(backdoor_port=2345) + self.common_backdoor_port_setup() + eventlet.listen(('localhost', 2345)).AndRaise( + socket.error(errno.EADDRINUSE, '')) + self.mox.ReplayAll() + self.assertRaises(socket.error, + eventlet_backdoor.initialize_if_enabled) + + def test_backdoor_port_range(self): + self.config(backdoor_port='8800:8899') + self.common_backdoor_port_setup() + eventlet.listen(('localhost', 8800)).AndReturn(self.sock) + self.sock.getsockname().AndReturn(('127.0.0.1', 8800)) + eventlet.spawn_n(eventlet.backdoor.backdoor_server, self.sock, + locals=mox.IsA(dict)) + self.mox.ReplayAll() + port = eventlet_backdoor.initialize_if_enabled() + self.assertEqual(port, 8800) + + def test_backdoor_port_range_all_inuse(self): + self.config(backdoor_port='8800:8899') + self.common_backdoor_port_setup() + for i in range(8800, 8900): + eventlet.listen(('localhost', i)).AndRaise( + socket.error(errno.EADDRINUSE, '')) + self.mox.ReplayAll() + self.assertRaises(socket.error, + eventlet_backdoor.initialize_if_enabled) + + def test_backdoor_port_bad(self): + self.config(backdoor_port='abc') + self.assertRaises(eventlet_backdoor.EventletBackdoorConfigValueError, + eventlet_backdoor.initialize_if_enabled) diff --git a/tests/unit/test_excutils.py b/tests/unit/test_excutils.py index 8c8137a..1386eaa 100644 --- a/tests/unit/test_excutils.py +++ b/tests/unit/test_excutils.py @@ -14,6 +14,10 @@ # License for the specific language governing permissions and limitations # under the License. +import logging +import mox +import time + from openstack.common import excutils from tests import utils @@ -47,3 +51,118 @@ class SaveAndReraiseTest(utils.BaseTestCase): e = _e self.assertEqual(str(e), msg) + + def test_save_and_reraise_exception_no_reraise(self): + """Test that suppressing the reraise works.""" + try: + raise Exception('foo') + except Exception: + with excutils.save_and_reraise_exception() as ctxt: + ctxt.reraise = False + + +class ForeverRetryUncaughtExceptionsTest(utils.BaseTestCase): + + @excutils.forever_retry_uncaught_exceptions + def exception_generator(self): + exc = self.exception_to_raise() + while exc is not None: + raise exc + exc = self.exception_to_raise() + + def exception_to_raise(self): + return None + + def my_time_sleep(self, arg): + pass + + def exc_retrier_common_start(self): + self.stubs.Set(time, 'sleep', self.my_time_sleep) + self.mox.StubOutWithMock(logging, 'exception') + self.mox.StubOutWithMock(time, 'time') + self.mox.StubOutWithMock(self, 'exception_to_raise') + + def exc_retrier_sequence(self, exc_id=None, timestamp=None, + exc_count=None): + self.exception_to_raise().AndReturn( + Exception('unexpected %d' % exc_id)) + time.time().AndReturn(timestamp) + if exc_count != 0: + logging.exception(mox.In( + 'Unexpected exception occurred %d time(s)' % exc_count)) + + def exc_retrier_common_end(self): + self.exception_to_raise().AndReturn(None) + self.mox.ReplayAll() + self.exception_generator() + self.addCleanup(self.stubs.UnsetAll) + + def test_exc_retrier_1exc_gives_1log(self): + self.exc_retrier_common_start() + self.exc_retrier_sequence(exc_id=1, timestamp=1, exc_count=1) + self.exc_retrier_common_end() + + def test_exc_retrier_same_10exc_1min_gives_1log(self): + self.exc_retrier_common_start() + self.exc_retrier_sequence(exc_id=1, timestamp=1, exc_count=1) + # By design, the following exception don't get logged because they + # are within the same minute. + for i in range(2, 11): + self.exc_retrier_sequence(exc_id=1, timestamp=i, exc_count=0) + self.exc_retrier_common_end() + + def test_exc_retrier_same_2exc_2min_gives_2logs(self): + self.exc_retrier_common_start() + self.exc_retrier_sequence(exc_id=1, timestamp=1, exc_count=1) + self.exc_retrier_sequence(exc_id=1, timestamp=65, exc_count=1) + self.exc_retrier_common_end() + + def test_exc_retrier_same_10exc_2min_gives_2logs(self): + self.exc_retrier_common_start() + self.exc_retrier_sequence(exc_id=1, timestamp=1, exc_count=1) + self.exc_retrier_sequence(exc_id=1, timestamp=12, exc_count=0) + self.exc_retrier_sequence(exc_id=1, timestamp=23, exc_count=0) + self.exc_retrier_sequence(exc_id=1, timestamp=34, exc_count=0) + self.exc_retrier_sequence(exc_id=1, timestamp=45, exc_count=0) + # The previous 4 exceptions are counted here + self.exc_retrier_sequence(exc_id=1, timestamp=106, exc_count=5) + # Again, the following are not logged due to being within + # the same minute + self.exc_retrier_sequence(exc_id=1, timestamp=117, exc_count=0) + self.exc_retrier_sequence(exc_id=1, timestamp=128, exc_count=0) + self.exc_retrier_sequence(exc_id=1, timestamp=139, exc_count=0) + self.exc_retrier_sequence(exc_id=1, timestamp=150, exc_count=0) + self.exc_retrier_common_end() + + def test_exc_retrier_mixed_4exc_1min_gives_2logs(self): + self.exc_retrier_common_start() + self.exc_retrier_sequence(exc_id=1, timestamp=1, exc_count=1) + # By design, this second 'unexpected 1' exception is not counted. This + # is likely a rare thing and is a sacrifice for code simplicity. + self.exc_retrier_sequence(exc_id=1, timestamp=10, exc_count=0) + self.exc_retrier_sequence(exc_id=2, timestamp=20, exc_count=1) + # Again, trailing exceptions within a minute are not counted. + self.exc_retrier_sequence(exc_id=2, timestamp=30, exc_count=0) + self.exc_retrier_common_end() + + def test_exc_retrier_mixed_4exc_2min_gives_2logs(self): + self.exc_retrier_common_start() + self.exc_retrier_sequence(exc_id=1, timestamp=1, exc_count=1) + # Again, this second exception of the same type is not counted + # for the sake of code simplicity. + self.exc_retrier_sequence(exc_id=1, timestamp=10, exc_count=0) + # The difference between this and the previous case is the log + # is also triggered by more than a minute expiring. + self.exc_retrier_sequence(exc_id=2, timestamp=100, exc_count=1) + self.exc_retrier_sequence(exc_id=2, timestamp=110, exc_count=0) + self.exc_retrier_common_end() + + def test_exc_retrier_mixed_4exc_2min_gives_3logs(self): + self.exc_retrier_common_start() + self.exc_retrier_sequence(exc_id=1, timestamp=1, exc_count=1) + # This time the second 'unexpected 1' exception is counted due + # to the same exception occurring same when the minute expires. + self.exc_retrier_sequence(exc_id=1, timestamp=10, exc_count=0) + self.exc_retrier_sequence(exc_id=1, timestamp=100, exc_count=2) + self.exc_retrier_sequence(exc_id=2, timestamp=110, exc_count=1) + self.exc_retrier_common_end() diff --git a/tests/unit/test_jsonutils.py b/tests/unit/test_jsonutils.py index 758455b..28d588e 100644 --- a/tests/unit/test_jsonutils.py +++ b/tests/unit/test_jsonutils.py @@ -18,6 +18,7 @@ import datetime import xmlrpclib +import netaddr from six import StringIO from openstack.common import jsonutils @@ -170,3 +171,8 @@ class ToPrimitiveTestCase(utils.BaseTestCase): ret = jsonutils.to_primitive(l4_obj, max_depth=4) self.assertEquals(ret, json_l4) + + def test_ipaddr(self): + thing = {'ip_addr': netaddr.IPAddress('1.2.3.4')} + ret = jsonutils.to_primitive(thing) + self.assertEquals({'ip_addr': '1.2.3.4'}, ret) diff --git a/tests/unit/test_network_utils.py b/tests/unit/test_network_utils.py index 2783e70..4ac0222 100644 --- a/tests/unit/test_network_utils.py +++ b/tests/unit/test_network_utils.py @@ -40,3 +40,20 @@ class NetworkUtilsTest(utils.BaseTestCase): network_utils.parse_host_port( '2001:db8:85a3::8a2e:370:7334', default_port=1234)) + + def test_urlsplit(self): + result = network_utils.urlsplit('rpc://myhost?someparam#somefragment') + self.assertEqual(result.scheme, 'rpc') + self.assertEqual(result.netloc, 'myhost') + self.assertEqual(result.path, '') + self.assertEqual(result.query, 'someparam') + self.assertEqual(result.fragment, 'somefragment') + + result = network_utils.urlsplit( + 'rpc://myhost/mypath?someparam#somefragment', + allow_fragments=False) + self.assertEqual(result.scheme, 'rpc') + self.assertEqual(result.netloc, 'myhost') + self.assertEqual(result.path, '/mypath') + self.assertEqual(result.query, 'someparam#somefragment') + self.assertEqual(result.fragment, '') diff --git a/tests/unit/test_plugin.py b/tests/unit/test_plugin.py deleted file mode 100644 index fd653d7..0000000 --- a/tests/unit/test_plugin.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2012 OpenStack Foundation. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import pkg_resources - -from openstack.common.notifier import api as notifier_api -from openstack.common.plugin import plugin -from openstack.common.plugin import pluginmanager -from tests import utils - - -class SimpleNotifier(object): - def __init__(self): - self.message_list = [] - - def notify(self, context, message): - self.context = context - self.message_list.append(message) - - -class ManagerTestCase(utils.BaseTestCase): - def test_constructs(self): - manager1 = pluginmanager.PluginManager("testproject", "testservice") - self.assertNotEqual(manager1, False) - - -class NotifyTestCase(utils.BaseTestCase): - """Test case for the plugin notification interface.""" - - def test_add_notifier(self): - notifier1 = SimpleNotifier() - notifier2 = SimpleNotifier() - notifier3 = SimpleNotifier() - - testplugin = plugin.Plugin('service') - testplugin._add_notifier(notifier1) - testplugin._add_notifier(notifier2) - self.assertEqual(len(testplugin.notifiers), 2) - - testplugin._add_notifier(notifier3) - self.assertEqual(len(testplugin.notifiers), 3) - - def test_notifier_action(self): - def mock_iter_entry_points(_t): - return [MockEntrypoint("fake", "fake", ["fake"])] - - self.stubs.Set(pkg_resources, 'iter_entry_points', - mock_iter_entry_points) - - plugmgr = pluginmanager.PluginManager("testproject", "testservice") - plugmgr.load_plugins() - self.assertEqual(len(plugmgr.plugins), 1) - self.assertEqual(len(plugmgr.plugins[0].notifiers), 1) - notifier = plugmgr.plugins[0].notifiers[0] - - notifier_api.notify('contextarg', 'publisher_id', 'event_type', - notifier_api.WARN, dict(a=3)) - - self.assertEqual(len(notifier.message_list), 1) - - -class StubControllerExtension(object): - name = 'stubextension' - alias = 'stubby' - - -class TestPluginClass(plugin.Plugin): - - def __init__(self, service_name): - super(TestPluginClass, self).__init__(service_name) - self._add_api_extension_descriptor(StubControllerExtension) - notifier1 = SimpleNotifier() - self._add_notifier(notifier1) - - -class MockEntrypoint(pkg_resources.EntryPoint): - def load(self): - return TestPluginClass - - -class MockExtManager(): - def __init__(self): - self.descriptors = [] - - def load_extension(self, descriptor): - self.descriptors.append(descriptor) - - -class APITestCase(utils.BaseTestCase): - """Test case for the plugin api extension interface.""" - def test_add_extension(self): - def mock_load(_s): - return TestPluginClass() - - def mock_iter_entry_points(_t): - return [MockEntrypoint("fake", "fake", ["fake"])] - - self.stubs.Set(pkg_resources, 'iter_entry_points', - mock_iter_entry_points) - - mgr = MockExtManager() - plugmgr = pluginmanager.PluginManager("testproject", "testservice") - plugmgr.load_plugins() - plugmgr.plugin_extension_factory(mgr) - - self.assertTrue(StubControllerExtension in mgr.descriptors) diff --git a/tests/unit/test_rootwrap.py b/tests/unit/test_rootwrap.py index 0e08b5e..6e1e6e6 100644 --- a/tests/unit/test_rootwrap.py +++ b/tests/unit/test_rootwrap.py @@ -61,10 +61,11 @@ class RootwrapTestCase(utils.BaseTestCase): self.assertRaises(wrapper.NoFilterMatched, wrapper.match_filter, self.filters, invalid) - def _test_DnsmasqFilter(self, filter_class, config_file_arg): + def _test_EnvFilter_as_DnsMasq(self, config_file_arg): usercmd = ['env', config_file_arg + '=A', 'NETWORK_ID=foobar', 'dnsmasq', 'foo'] - f = filter_class("/usr/bin/dnsmasq", "root") + f = filters.EnvFilter("env", "root", config_file_arg + '=A', + 'NETWORK_ID=', "/usr/bin/dnsmasq") self.assertTrue(f.match(usercmd)) self.assertEqual(f.get_command(usercmd), ['/usr/bin/dnsmasq', 'foo']) env = f.get_environment(usercmd) @@ -72,10 +73,68 @@ class RootwrapTestCase(utils.BaseTestCase): self.assertEqual(env.get('NETWORK_ID'), 'foobar') def test_DnsmasqFilter(self): - self._test_DnsmasqFilter(filters.DnsmasqFilter, 'CONFIG_FILE') + self._test_EnvFilter_as_DnsMasq('CONFIG_FILE') def test_DeprecatedDnsmasqFilter(self): - self._test_DnsmasqFilter(filters.DeprecatedDnsmasqFilter, 'FLAGFILE') + self._test_EnvFilter_as_DnsMasq('FLAGFILE') + + def test_EnvFilter(self): + envset = ['A=/some/thing', 'B=somethingelse'] + envcmd = ['env'] + envset + realcmd = ['sleep', '10'] + usercmd = envcmd + realcmd + + f = filters.EnvFilter("env", "root", "A=", "B=ignored", "sleep") + # accept with leading env + self.assertTrue(f.match(envcmd + ["sleep"])) + # accept without leading env + self.assertTrue(f.match(envset + ["sleep"])) + + # any other command does not match + self.assertFalse(f.match(envcmd + ["sleep2"])) + self.assertFalse(f.match(envset + ["sleep2"])) + + # accept any trailing arguments + self.assertTrue(f.match(usercmd)) + + # require given environment variables to match + self.assertFalse(f.match([envcmd, 'C=ELSE'])) + self.assertFalse(f.match(['env', 'C=xx'])) + self.assertFalse(f.match(['env', 'A=xx'])) + + # require env command to be given + # (otherwise CommandFilters should match + self.assertFalse(f.match(realcmd)) + # require command to match + self.assertFalse(f.match(envcmd)) + self.assertFalse(f.match(envcmd[1:])) + + # ensure that the env command is stripped when executing + self.assertEqual(f.exec_args(usercmd), realcmd) + env = f.get_environment(usercmd) + # check that environment variables are set + self.assertEqual(env.get('A'), '/some/thing') + self.assertEqual(env.get('B'), 'somethingelse') + self.assertFalse('sleep' in env.keys()) + + def test_EnvFilter_without_leading_env(self): + envset = ['A=/some/thing', 'B=somethingelse'] + envcmd = ['env'] + envset + realcmd = ['sleep', '10'] + + f = filters.EnvFilter("sleep", "root", "A=", "B=ignored") + + # accept without leading env + self.assertTrue(f.match(envset + ["sleep"])) + + self.assertEqual(f.get_command(envcmd + realcmd), realcmd) + self.assertEqual(f.get_command(envset + realcmd), realcmd) + + env = f.get_environment(envset + realcmd) + # check that environment variables are set + self.assertEqual(env.get('A'), '/some/thing') + self.assertEqual(env.get('B'), 'somethingelse') + self.assertFalse('sleep' in env.keys()) def test_KillFilter(self): if not os.path.exists("/proc/%d" % os.getpid()): @@ -119,8 +178,9 @@ class RootwrapTestCase(utils.BaseTestCase): # Filter shouldn't be able to find binary in $PATH, so fail with fixtures.EnvironmentVariable("PATH", "/foo:/bar"): self.assertFalse(f.match(usercmd)) - pass - + # ensure that unset $PATH is not causing an exception + with fixtures.EnvironmentVariable("PATH"): + self.assertFalse(f.match(usercmd)) finally: # Terminate the "cat" process and wait for it to finish p.terminate() @@ -169,6 +229,66 @@ class RootwrapTestCase(utils.BaseTestCase): self.assertEqual(f.get_command(usercmd), ['/bin/cat', goodfn]) self.assertTrue(f.match(usercmd)) + def test_IpFilter_non_netns(self): + f = filters.IpFilter('/sbin/ip', 'root') + self.assertTrue(f.match(['ip', 'link', 'list'])) + + def _test_IpFilter_netns_helper(self, action): + f = filters.IpFilter('/sbin/ip', 'root') + self.assertTrue(f.match(['ip', 'link', action])) + + def test_IpFilter_netns_add(self): + self._test_IpFilter_netns_helper('add') + + def test_IpFilter_netns_delete(self): + self._test_IpFilter_netns_helper('delete') + + def test_IpFilter_netns_list(self): + self._test_IpFilter_netns_helper('list') + + def test_IpNetnsExecFilter_match(self): + f = filters.IpNetnsExecFilter('/sbin/ip', 'root') + self.assertTrue( + f.match(['ip', 'netns', 'exec', 'foo', 'ip', 'link', 'list'])) + + def test_IpNetnsExecFilter_nomatch(self): + f = filters.IpNetnsExecFilter('/sbin/ip', 'root') + self.assertFalse(f.match(['ip', 'link', 'list'])) + + # verify that at least a NS is given + self.assertFalse(f.match(['ip', 'netns', 'exec'])) + + def test_IpNetnsExecFilter_nomatch_nonroot(self): + f = filters.IpNetnsExecFilter('/sbin/ip', 'user') + self.assertFalse( + f.match(['ip', 'netns', 'exec', 'foo', 'ip', 'link', 'list'])) + + def test_match_filter_recurses_exec_command_filter_matches(self): + filter_list = [filters.IpNetnsExecFilter('/sbin/ip', 'root'), + filters.IpFilter('/sbin/ip', 'root')] + args = ['ip', 'netns', 'exec', 'foo', 'ip', 'link', 'list'] + + self.assertIsNotNone(wrapper.match_filter(filter_list, args)) + + def test_match_filter_recurses_exec_command_matches_user(self): + filter_list = [filters.IpNetnsExecFilter('/sbin/ip', 'root'), + filters.IpFilter('/sbin/ip', 'user')] + args = ['ip', 'netns', 'exec', 'foo', 'ip', 'link', 'list'] + + # Currently ip netns exec requires root, so verify that + # no non-root filter is matched, as that would escalate privileges + self.assertRaises(wrapper.NoFilterMatched, + wrapper.match_filter, filter_list, args) + + def test_match_filter_recurses_exec_command_filter_does_not_match(self): + filter_list = [filters.IpNetnsExecFilter('/sbin/ip', 'root'), + filters.IpFilter('/sbin/ip', 'root')] + args = ['ip', 'netns', 'exec', 'foo', 'ip', 'netns', 'exec', 'bar', + 'ip', 'link', 'list'] + + self.assertRaises(wrapper.NoFilterMatched, + wrapper.match_filter, filter_list, args) + def test_exec_dirs_search(self): # This test supposes you have /bin/cat or /usr/bin/cat locally f = filters.CommandFilter("cat", "root") @@ -195,6 +315,11 @@ class RootwrapTestCase(utils.BaseTestCase): config = wrapper.RootwrapConfig(raw) self.assertEqual(config.filters_path, ['/a', '/b']) self.assertEqual(config.exec_dirs, os.environ["PATH"].split(':')) + + with fixtures.EnvironmentVariable("PATH"): + c = wrapper.RootwrapConfig(raw) + self.assertEqual(c.exec_dirs, []) + self.assertFalse(config.use_syslog) self.assertEqual(config.syslog_log_facility, logging.handlers.SysLogHandler.LOG_SYSLOG) diff --git a/tests/unit/test_service.py b/tests/unit/test_service.py index 7e07f28..20007de 100644 --- a/tests/unit/test_service.py +++ b/tests/unit/test_service.py @@ -17,18 +17,24 @@ # under the License. """ -Unit Tests for remote procedure calls using queue +Unit Tests for service class """ from __future__ import print_function +import errno +import eventlet +import mox import os import signal +import socket import time import traceback +from eventlet import event from oslo.config import cfg +from openstack.common import eventlet_backdoor from openstack.common import log as logging from openstack.common.notifier import api as notifier_api from openstack.common import service @@ -190,11 +196,86 @@ class ServiceLauncherTest(utils.BaseTestCase): self.assertEqual(os.WEXITSTATUS(status), 0) +class _Service(service.Service): + def __init__(self): + super(_Service, self).__init__() + self.init = event.Event() + self.cleaned_up = False + + def start(self): + self.init.send() + + def stop(self): + self.cleaned_up = True + super(_Service, self).stop() + + class LauncherTest(utils.BaseTestCase): + def test_backdoor_port(self): + self.config(backdoor_port='1234') + + sock = self.mox.CreateMockAnything() + self.mox.StubOutWithMock(eventlet, 'listen') + self.mox.StubOutWithMock(eventlet, 'spawn_n') + + eventlet.listen(('localhost', 1234)).AndReturn(sock) + sock.getsockname().AndReturn(('127.0.0.1', 1234)) + eventlet.spawn_n(eventlet.backdoor.backdoor_server, sock, + locals=mox.IsA(dict)) + + self.mox.ReplayAll() + + svc = service.Service() + launcher = service.launch(svc) + self.assertEqual(svc.backdoor_port, 1234) + launcher.stop() + + def test_backdoor_inuse(self): + sock = eventlet.listen(('localhost', 0)) + port = sock.getsockname()[1] + self.config(backdoor_port=port) + svc = service.Service() + self.assertRaises(socket.error, + service.launch, svc) + sock.close() + + def test_backdoor_port_range_one_inuse(self): + self.config(backdoor_port='8800:8900') + + sock = self.mox.CreateMockAnything() + self.mox.StubOutWithMock(eventlet, 'listen') + self.mox.StubOutWithMock(eventlet, 'spawn_n') + + eventlet.listen(('localhost', 8800)).AndRaise( + socket.error(errno.EADDRINUSE, '')) + eventlet.listen(('localhost', 8801)).AndReturn(sock) + sock.getsockname().AndReturn(('127.0.0.1', 8801)) + eventlet.spawn_n(eventlet.backdoor.backdoor_server, sock, + locals=mox.IsA(dict)) + + self.mox.ReplayAll() + + svc = service.Service() + launcher = service.launch(svc) + self.assertEqual(svc.backdoor_port, 8801) + launcher.stop() + + def test_backdoor_port_reverse_range(self): # backdoor port should get passed to the service being launched - self.config(backdoor_port=1234) + self.config(backdoor_port='8888:7777') svc = service.Service() + self.assertRaises(eventlet_backdoor.EventletBackdoorConfigValueError, + service.launch, svc) + + def test_graceful_shutdown(self): + # test that services are given a chance to clean up: + svc = _Service() + launcher = service.launch(svc) - self.assertEqual(1234, svc.backdoor_port) + # wait on 'init' so we know the service had time to start: + svc.init.wait() + launcher.stop() + self.assertTrue(svc.cleaned_up) + self.assertTrue(svc._done.ready()) diff --git a/tests/utils.py b/tests/utils.py index 794a3d2..e93c278 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -27,17 +27,16 @@ import testtools from openstack.common import exception from openstack.common.fixture import moxstubout -CONF = cfg.CONF - class BaseTestCase(testtools.TestCase): - def setUp(self): + def setUp(self, conf=cfg.CONF): super(BaseTestCase, self).setUp() moxfixture = self.useFixture(moxstubout.MoxStubout()) self.mox = moxfixture.mox self.stubs = moxfixture.stubs - self.addCleanup(CONF.reset) + self.conf = conf + self.addCleanup(self.conf.reset) self.useFixture(fixtures.FakeLogger('openstack.common')) self.useFixture(fixtures.Timeout(30, True)) self.stubs.Set(exception, '_FATAL_EXCEPTION_FORMAT_ERRORS', True) @@ -46,7 +45,7 @@ class BaseTestCase(testtools.TestCase): def tearDown(self): super(BaseTestCase, self).tearDown() - CONF.reset() + self.conf.reset() self.stubs.UnsetAll() self.stubs.SmartUnsetAll() @@ -79,4 +78,4 @@ class BaseTestCase(testtools.TestCase): """ group = kw.pop('group', None) for k, v in kw.iteritems(): - CONF.set_override(k, v, group) + self.conf.set_override(k, v, group) diff --git a/tools/patch_tox_venv.py b/tools/patch_tox_venv.py index a3340f2..dc9ce83 100644 --- a/tools/patch_tox_venv.py +++ b/tools/patch_tox_venv.py @@ -17,7 +17,7 @@ import os import sys -import install_venv_common as install_venv +import install_venv_common as install_venv # noqa def first_file(file_list): @@ -2,26 +2,22 @@ envlist = py26,py27,py33,pep8,pylint [testenv] +sitepackages = False setenv = VIRTUAL_ENV={envdir} - NOSE_WITH_OPENSTACK=1 - NOSE_OPENSTACK_COLOR=1 - NOSE_OPENSTACK_RED=0.05 - NOSE_OPENSTACK_YELLOW=0.025 - NOSE_OPENSTACK_SHOW_ELAPSED=1 - NOSE_OPENSTACK_STDOUT=1 deps = -r{toxinidir}/requirements.txt -r{toxinidir}/test-requirements.txt commands = python tools/patch_tox_venv.py - nosetests --with-doctest --exclude-dir=tests/testmods {posargs} + # due to dependencies between tests (bug 1192207) we use `--concurrency=1` option + python setup.py testr --slowest --testr-args='--concurrency=1 {posargs}' [flake8] show-source = True -ignore = H202,H302,H304 +ignore = H302 exclude = .venv,.tox,dist,doc,*.egg,.update-venv [testenv:pep8] -commands = flake8 +commands = flake8 {posargs} [testenv:pylint] deps = pylint>=0.26.0 @@ -30,7 +26,9 @@ commands = python ./tools/lint.py ./openstack [testenv:cover] setenv = VIRTUAL_ENV={envdir} - NOSE_WITH_COVERAGE=1 +commands = + python tools/patch_tox_venv.py + python setup.py testr --coverage --testr-args='{posargs}' [testenv:venv] commands = {posargs} |