summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.mailmap1
-rw-r--r--MAINTAINERS18
-rw-r--r--openstack/common/apiclient/auth.py227
-rw-r--r--openstack/common/apiclient/base.py492
-rw-r--r--openstack/common/apiclient/client.py360
-rw-r--r--openstack/common/apiclient/exceptions.py42
-rw-r--r--openstack/common/apiclient/fake_client.py172
-rw-r--r--openstack/common/config/generator.py32
-rw-r--r--openstack/common/db/exception.py6
-rw-r--r--openstack/common/db/sqlalchemy/migration.py119
-rw-r--r--openstack/common/db/sqlalchemy/session.py16
-rw-r--r--[-rwxr-xr-x]openstack/common/db/sqlalchemy/utils.py61
-rw-r--r--openstack/common/deprecated/wsgi.py31
-rw-r--r--openstack/common/exception.py139
-rw-r--r--openstack/common/excutils.py7
-rw-r--r--openstack/common/fixture/config.py45
-rw-r--r--openstack/common/gettextutils.py136
-rw-r--r--openstack/common/local.py13
-rw-r--r--openstack/common/log.py7
-rw-r--r--openstack/common/middleware/base.py6
-rw-r--r--openstack/common/middleware/sizelimit.py3
-rw-r--r--openstack/common/notifier/log_notifier.py2
-rw-r--r--openstack/common/notifier/rpc_notifier.py2
-rw-r--r--openstack/common/notifier/rpc_notifier2.py2
-rw-r--r--openstack/common/policy.py24
-rw-r--r--openstack/common/processutils.py11
-rw-r--r--openstack/common/py3kcompat/__init__.py17
-rw-r--r--openstack/common/py3kcompat/urlutils.py49
-rw-r--r--openstack/common/quota.py1175
-rw-r--r--[-rwxr-xr-x]openstack/common/rootwrap/cmd.py1
-rw-r--r--openstack/common/rpc/__init__.py3
-rw-r--r--openstack/common/rpc/amqp.py9
-rw-r--r--openstack/common/rpc/impl_kombu.py8
-rw-r--r--openstack/common/rpc/impl_qpid.py2
-rw-r--r--openstack/common/rpc/matchmaker.py12
-rw-r--r--openstack/common/rpc/matchmaker_ring.py4
-rw-r--r--openstack/common/rpc/securemessage.py521
-rw-r--r--[-rwxr-xr-x]openstack/common/rpc/zmq_receiver.py1
-rw-r--r--openstack/common/service.py138
-rw-r--r--openstack/common/test.py52
-rw-r--r--openstack/common/timeutils.py4
-rw-r--r--pypi/setup.py2
-rw-r--r--requirements.txt2
-rwxr-xr-xrun_tests.sh52
-rw-r--r--test-requirements.txt2
-rw-r--r--tests/unit/apiclient/test_auth.py181
-rw-r--r--tests/unit/apiclient/test_base.py239
-rw-r--r--tests/unit/apiclient/test_client.py137
-rw-r--r--tests/unit/apiclient/test_exceptions.py7
-rw-r--r--tests/unit/crypto/test_utils.py4
-rw-r--r--tests/unit/db/sqlalchemy/test_migrate.py8
-rw-r--r--tests/unit/db/sqlalchemy/test_migration_common.py154
-rw-r--r--tests/unit/db/sqlalchemy/test_migrations.py14
-rw-r--r--tests/unit/db/sqlalchemy/test_models.py6
-rw-r--r--tests/unit/db/sqlalchemy/test_sqlalchemy.py51
-rw-r--r--tests/unit/db/sqlalchemy/test_utils.py153
-rw-r--r--tests/unit/db/test_api.py8
-rw-r--r--tests/unit/deprecated/test_wsgi.py17
-rw-r--r--tests/unit/fixture/__init__.py0
-rw-r--r--tests/unit/fixture/test_config.py45
-rw-r--r--tests/unit/middleware/test_context.py6
-rw-r--r--tests/unit/middleware/test_correlation_id.py5
-rw-r--r--tests/unit/middleware/test_sizelimit.py4
-rw-r--r--tests/unit/rpc/amqp.py11
-rw-r--r--tests/unit/rpc/common.py10
-rw-r--r--tests/unit/rpc/test_common.py62
-rw-r--r--tests/unit/rpc/test_kombu.py9
-rw-r--r--tests/unit/rpc/test_qpid.py6
-rw-r--r--tests/unit/rpc/test_securemessage.py134
-rw-r--r--tests/unit/rpc/test_zmq.py4
-rw-r--r--tests/unit/scheduler/test_weights.py4
-rw-r--r--tests/unit/test_authutils.py4
-rw-r--r--tests/unit/test_cfgfilter.py46
-rw-r--r--tests/unit/test_cliutils.py4
-rw-r--r--tests/unit/test_compat.py53
-rw-r--r--tests/unit/test_context.py4
-rw-r--r--tests/unit/test_exception.py99
-rw-r--r--tests/unit/test_fileutils.py14
-rw-r--r--tests/unit/test_funcutils.py5
-rw-r--r--tests/unit/test_gettext.py138
-rw-r--r--tests/unit/test_importutils.py4
-rw-r--r--tests/unit/test_jsonutils.py52
-rw-r--r--tests/unit/test_local.py24
-rw-r--r--tests/unit/test_lockutils.py14
-rw-r--r--tests/unit/test_log.py32
-rw-r--r--tests/unit/test_loopingcall.py4
-rw-r--r--tests/unit/test_memorycache.py4
-rw-r--r--tests/unit/test_network_utils.py4
-rw-r--r--tests/unit/test_pastedeploy.py14
-rw-r--r--tests/unit/test_periodic.py8
-rw-r--r--tests/unit/test_policy.py104
-rw-r--r--tests/unit/test_processutils.py29
-rw-r--r--tests/unit/test_quota.py441
-rw-r--r--tests/unit/test_service.py151
-rw-r--r--tests/unit/test_sslutils.py35
-rw-r--r--tests/unit/test_strutils.py8
-rw-r--r--tests/unit/test_threadgroup.py4
-rw-r--r--tests/unit/test_timeutils.py36
-rw-r--r--tests/unit/test_uuidutils.py4
-rw-r--r--tests/unit/test_xmlutils.py6
-rw-r--r--tests/utils.py21
-rwxr-xr-xtools/colorizer.py333
-rwxr-xr-xtools/config/generate_sample.sh3
-rw-r--r--tools/install_venv.py74
-rw-r--r--tools/install_venv_common.py22
-rwxr-xr-xtools/run_tests_common.sh253
-rwxr-xr-xtools/with_venv.sh7
107 files changed, 6564 insertions, 801 deletions
diff --git a/.mailmap b/.mailmap
index 18221d4..5559cc7 100644
--- a/.mailmap
+++ b/.mailmap
@@ -2,3 +2,4 @@
# <preferred e-mail> <other e-mail 1>
# <preferred e-mail> <other e-mail 2>
Zhongyue Luo <zhongyue.nah@intel.com> <lzyeval@gmail.com>
+<yufang521247@gmail.com> <zhangyufang@360.cn>
diff --git a/MAINTAINERS b/MAINTAINERS
index e069373..a611a5c 100644
--- a/MAINTAINERS
+++ b/MAINTAINERS
@@ -78,6 +78,12 @@ M:
S: Orphan
F: cliutils.py
+== compat ==
+
+M: Chuck Short <chuck.short@canonical.com>
+S: Maintained
+F: py3kcompat
+
== context ==
M:
@@ -109,12 +115,6 @@ M:
S: Orphan
F: eventlet_backdoor.py
-== exception ==
-
-M:
-S: Obsolete
-F: exception.py
-
== excutils ==
M:
@@ -223,6 +223,12 @@ M: Michael Still <mikal@stillhq.com>
S: Maintained
F: processutils.py
+== quota ==
+
+M: Sergey Skripnick <sskripnick@mirantis.com>
+S: Maintained
+F: quota.py
+
== redhat-eventlet.patch ==
M: Mark McLoughlin <markmc@redhat.com>
diff --git a/openstack/common/apiclient/auth.py b/openstack/common/apiclient/auth.py
new file mode 100644
index 0000000..1744228
--- /dev/null
+++ b/openstack/common/apiclient/auth.py
@@ -0,0 +1,227 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2013 OpenStack Foundation
+# Copyright 2013 Spanish National Research Council.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+# E0202: An attribute inherited from %s hide this method
+# pylint: disable=E0202
+
+import abc
+import argparse
+import logging
+import os
+
+from stevedore import extension
+
+from openstack.common.apiclient import exceptions
+
+
+logger = logging.getLogger(__name__)
+
+
+_discovered_plugins = {}
+
+
+def discover_auth_systems():
+ """Discover the available auth-systems.
+
+ This won't take into account the old style auth-systems.
+ """
+ global _discovered_plugins
+ _discovered_plugins = {}
+
+ def add_plugin(ext):
+ _discovered_plugins[ext.name] = ext.plugin
+
+ ep_namespace = "openstack.common.apiclient.auth"
+ mgr = extension.ExtensionManager(ep_namespace)
+ mgr.map(add_plugin)
+
+
+def load_auth_system_opts(parser):
+ """Load options needed by the available auth-systems into a parser.
+
+ This function will try to populate the parser with options from the
+ available plugins.
+ """
+ group = parser.add_argument_group("Common auth options")
+ BaseAuthPlugin.add_common_opts(group)
+ for name, auth_plugin in _discovered_plugins.iteritems():
+ group = parser.add_argument_group(
+ "Auth-system '%s' options" % name,
+ conflict_handler="resolve")
+ auth_plugin.add_opts(group)
+
+
+def load_plugin(auth_system):
+ try:
+ plugin_class = _discovered_plugins[auth_system]
+ except KeyError:
+ raise exceptions.AuthSystemNotFound(auth_system)
+ return plugin_class(auth_system=auth_system)
+
+
+def load_plugin_from_args(args):
+ """Load requred plugin and populate it with options.
+
+ Try to guess auth system if it is not specified. Systems are tried in
+ alphabetical order.
+
+ :type args: argparse.Namespace
+ :raises: AuthorizationFailure
+ """
+ auth_system = args.os_auth_system
+ if auth_system:
+ plugin = load_plugin(auth_system)
+ plugin.parse_opts(args)
+ plugin.sufficient_options()
+ return plugin
+
+ for plugin_auth_system in sorted(_discovered_plugins.iterkeys()):
+ plugin_class = _discovered_plugins[plugin_auth_system]
+ plugin = plugin_class()
+ plugin.parse_opts(args)
+ try:
+ plugin.sufficient_options()
+ except exceptions.AuthPluginOptionsMissing:
+ continue
+ return plugin
+ raise exceptions.AuthPluginOptionsMissing(["auth_system"])
+
+
+class BaseAuthPlugin(object):
+ """Base class for authentication plugins.
+
+ An authentication plugin needs to override at least the authenticate
+ method to be a valid plugin.
+ """
+
+ __metaclass__ = abc.ABCMeta
+
+ auth_system = None
+ opt_names = []
+ common_opt_names = [
+ "auth_system",
+ "username",
+ "password",
+ "tenant_name",
+ "token",
+ "auth_url",
+ ]
+
+ def __init__(self, auth_system=None, **kwargs):
+ self.auth_system = auth_system or self.auth_system
+ self.opts = dict((name, kwargs.get(name))
+ for name in self.opt_names)
+
+ @staticmethod
+ def _parser_add_opt(parser, opt):
+ """Add an option to parser in two variants.
+
+ :param opt: option name (with underscores)
+ """
+ dashed_opt = opt.replace("_", "-")
+ env_var = "OS_%s" % opt.upper()
+ arg_default = os.environ.get(env_var, "")
+ arg_help = "Defaults to env[%s]." % env_var
+ parser.add_argument(
+ "--os-%s" % dashed_opt,
+ metavar="<%s>" % dashed_opt,
+ default=arg_default,
+ help=arg_help)
+ parser.add_argument(
+ "--os_%s" % opt,
+ metavar="<%s>" % dashed_opt,
+ help=argparse.SUPPRESS)
+
+ @classmethod
+ def add_opts(cls, parser):
+ """Populate the parser with the options for this plugin.
+ """
+ for opt in cls.opt_names:
+ # use `BaseAuthPlugin.common_opt_names` since it is never
+ # changed in child classes
+ if opt not in BaseAuthPlugin.common_opt_names:
+ cls._parser_add_opt(parser, opt)
+
+ @classmethod
+ def add_common_opts(cls, parser):
+ """Add options that are common for several plugins.
+ """
+ for opt in cls.common_opt_names:
+ cls._parser_add_opt(parser, opt)
+
+ @staticmethod
+ def get_opt(opt_name, args):
+ """Return option name and value.
+
+ :param opt_name: name of the option, e.g., "username"
+ :param args: parsed arguments
+ """
+ return (opt_name, getattr(args, "os_%s" % opt_name, None))
+
+ def parse_opts(self, args):
+ """Parse the actual auth-system options if any.
+
+ This method is expected to populate the attribute `self.opts` with a
+ dict containing the options and values needed to make authentication.
+ """
+ self.opts.update(dict(self.get_opt(opt_name, args)
+ for opt_name in self.opt_names))
+
+ def authenticate(self, http_client):
+ """Authenticate using plugin defined method.
+
+ The method usually analyses `self.opts` and performs
+ a request to authentication server.
+
+ :param http_client: client object that needs authentication
+ :type http_client: HTTPClient
+ :raises: AuthorizationFailure
+ """
+ self.sufficient_options()
+ self._do_authenticate(http_client)
+
+ @abc.abstractmethod
+ def _do_authenticate(self, http_client):
+ """Protected method for authentication.
+ """
+
+ def sufficient_options(self):
+ """Check if all required options are present.
+
+ :raises: AuthPluginOptionsMissing
+ """
+ missing = [opt
+ for opt in self.opt_names
+ if not self.opts.get(opt)]
+ if missing:
+ raise exceptions.AuthPluginOptionsMissing(missing)
+
+ @abc.abstractmethod
+ def token_and_endpoint(self, endpoint_type, service_type):
+ """Return token and endpoint.
+
+ :param service_type: Service type of the endpoint
+ :type service_type: string
+ :param endpoint_type: Type of endpoint.
+ Possible values: public or publicURL,
+ internal or internalURL,
+ admin or adminURL
+ :type endpoint_type: string
+ :returns: tuple of token and endpoint strings
+ :raises: EndpointException
+ """
diff --git a/openstack/common/apiclient/base.py b/openstack/common/apiclient/base.py
new file mode 100644
index 0000000..0ecf1f5
--- /dev/null
+++ b/openstack/common/apiclient/base.py
@@ -0,0 +1,492 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 Jacob Kaplan-Moss
+# Copyright 2011 OpenStack LLC
+# Copyright 2012 Grid Dynamics
+# Copyright 2013 OpenStack Foundation
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Base utilities to build API operation managers and objects on top of.
+"""
+
+# E1102: %s is not callable
+# pylint: disable=E1102
+
+import abc
+import urllib
+
+from openstack.common.apiclient import exceptions
+from openstack.common import strutils
+
+
+def getid(obj):
+ """Return id if argument is a Resource.
+
+ Abstracts the common pattern of allowing both an object or an object's ID
+ (UUID) as a parameter when dealing with relationships.
+ """
+ try:
+ if obj.uuid:
+ return obj.uuid
+ except AttributeError:
+ pass
+ try:
+ return obj.id
+ except AttributeError:
+ return obj
+
+
+# TODO(aababilov): call run_hooks() in HookableMixin's child classes
+class HookableMixin(object):
+ """Mixin so classes can register and run hooks."""
+ _hooks_map = {}
+
+ @classmethod
+ def add_hook(cls, hook_type, hook_func):
+ """Add a new hook of specified type.
+
+ :param cls: class that registers hooks
+ :param hook_type: hook type, e.g., '__pre_parse_args__'
+ :param hook_func: hook function
+ """
+ if hook_type not in cls._hooks_map:
+ cls._hooks_map[hook_type] = []
+
+ cls._hooks_map[hook_type].append(hook_func)
+
+ @classmethod
+ def run_hooks(cls, hook_type, *args, **kwargs):
+ """Run all hooks of specified type.
+
+ :param cls: class that registers hooks
+ :param hook_type: hook type, e.g., '__pre_parse_args__'
+ :param **args: args to be passed to every hook function
+ :param **kwargs: kwargs to be passed to every hook function
+ """
+ hook_funcs = cls._hooks_map.get(hook_type) or []
+ for hook_func in hook_funcs:
+ hook_func(*args, **kwargs)
+
+
+class BaseManager(HookableMixin):
+ """Basic manager type providing common operations.
+
+ Managers interact with a particular type of API (servers, flavors, images,
+ etc.) and provide CRUD operations for them.
+ """
+ resource_class = None
+
+ def __init__(self, client):
+ """Initializes BaseManager with `client`.
+
+ :param client: instance of BaseClient descendant for HTTP requests
+ """
+ super(BaseManager, self).__init__()
+ self.client = client
+
+ def _list(self, url, response_key, obj_class=None, json=None):
+ """List the collection.
+
+ :param url: a partial URL, e.g., '/servers'
+ :param response_key: the key to be looked up in response dictionary,
+ e.g., 'servers'
+ :param obj_class: class for constructing the returned objects
+ (self.resource_class will be used by default)
+ :param json: data that will be encoded as JSON and passed in POST
+ request (GET will be sent by default)
+ """
+ if json:
+ body = self.client.post(url, json=json).json()
+ else:
+ body = self.client.get(url).json()
+
+ if obj_class is None:
+ obj_class = self.resource_class
+
+ data = body[response_key]
+ # NOTE(ja): keystone returns values as list as {'values': [ ... ]}
+ # unlike other services which just return the list...
+ try:
+ data = data['values']
+ except (KeyError, TypeError):
+ pass
+
+ return [obj_class(self, res, loaded=True) for res in data if res]
+
+ def _get(self, url, response_key):
+ """Get an object from collection.
+
+ :param url: a partial URL, e.g., '/servers'
+ :param response_key: the key to be looked up in response dictionary,
+ e.g., 'server'
+ """
+ body = self.client.get(url).json()
+ return self.resource_class(self, body[response_key], loaded=True)
+
+ def _head(self, url):
+ """Retrieve request headers for an object.
+
+ :param url: a partial URL, e.g., '/servers'
+ """
+ resp = self.client.head(url)
+ return resp.status_code == 204
+
+ def _post(self, url, json, response_key, return_raw=False):
+ """Create an object.
+
+ :param url: a partial URL, e.g., '/servers'
+ :param json: data that will be encoded as JSON and passed in POST
+ request (GET will be sent by default)
+ :param response_key: the key to be looked up in response dictionary,
+ e.g., 'servers'
+ :param return_raw: flag to force returning raw JSON instead of
+ Python object of self.resource_class
+ """
+ body = self.client.post(url, json=json).json()
+ if return_raw:
+ return body[response_key]
+ return self.resource_class(self, body[response_key])
+
+ def _put(self, url, json=None, response_key=None):
+ """Update an object with PUT method.
+
+ :param url: a partial URL, e.g., '/servers'
+ :param json: data that will be encoded as JSON and passed in POST
+ request (GET will be sent by default)
+ :param response_key: the key to be looked up in response dictionary,
+ e.g., 'servers'
+ """
+ resp = self.client.put(url, json=json)
+ # PUT requests may not return a body
+ if resp.content:
+ body = resp.json()
+ if response_key is not None:
+ return self.resource_class(self, body[response_key])
+ else:
+ return self.resource_class(self, body)
+
+ def _patch(self, url, json=None, response_key=None):
+ """Update an object with PATCH method.
+
+ :param url: a partial URL, e.g., '/servers'
+ :param json: data that will be encoded as JSON and passed in POST
+ request (GET will be sent by default)
+ :param response_key: the key to be looked up in response dictionary,
+ e.g., 'servers'
+ """
+ body = self.client.patch(url, json=json).json()
+ if response_key is not None:
+ return self.resource_class(self, body[response_key])
+ else:
+ return self.resource_class(self, body)
+
+ def _delete(self, url):
+ """Delete an object.
+
+ :param url: a partial URL, e.g., '/servers/my-server'
+ """
+ return self.client.delete(url)
+
+
+class ManagerWithFind(BaseManager):
+ """Manager with additional `find()`/`findall()` methods."""
+
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractmethod
+ def list(self):
+ pass
+
+ def find(self, **kwargs):
+ """Find a single item with attributes matching ``**kwargs``.
+
+ This isn't very efficient: it loads the entire list then filters on
+ the Python side.
+ """
+ matches = self.findall(**kwargs)
+ num_matches = len(matches)
+ if num_matches == 0:
+ msg = "No %s matching %s." % (self.resource_class.__name__, kwargs)
+ raise exceptions.NotFound(msg)
+ elif num_matches > 1:
+ raise exceptions.NoUniqueMatch()
+ else:
+ return matches[0]
+
+ def findall(self, **kwargs):
+ """Find all items with attributes matching ``**kwargs``.
+
+ This isn't very efficient: it loads the entire list then filters on
+ the Python side.
+ """
+ found = []
+ searches = kwargs.items()
+
+ for obj in self.list():
+ try:
+ if all(getattr(obj, attr) == value
+ for (attr, value) in searches):
+ found.append(obj)
+ except AttributeError:
+ continue
+
+ return found
+
+
+class CrudManager(BaseManager):
+ """Base manager class for manipulating entities.
+
+ Children of this class are expected to define a `collection_key` and `key`.
+
+ - `collection_key`: Usually a plural noun by convention (e.g. `entities`);
+ used to refer collections in both URL's (e.g. `/v3/entities`) and JSON
+ objects containing a list of member resources (e.g. `{'entities': [{},
+ {}, {}]}`).
+ - `key`: Usually a singular noun by convention (e.g. `entity`); used to
+ refer to an individual member of the collection.
+
+ """
+ collection_key = None
+ key = None
+
+ def build_url(self, base_url=None, **kwargs):
+ """Builds a resource URL for the given kwargs.
+
+ Given an example collection where `collection_key = 'entities'` and
+ `key = 'entity'`, the following URL's could be generated.
+
+ By default, the URL will represent a collection of entities, e.g.::
+
+ /entities
+
+ If kwargs contains an `entity_id`, then the URL will represent a
+ specific member, e.g.::
+
+ /entities/{entity_id}
+
+ :param base_url: if provided, the generated URL will be appended to it
+ """
+ url = base_url if base_url is not None else ''
+
+ url += '/%s' % self.collection_key
+
+ # do we have a specific entity?
+ entity_id = kwargs.get('%s_id' % self.key)
+ if entity_id is not None:
+ url += '/%s' % entity_id
+
+ return url
+
+ def _filter_kwargs(self, kwargs):
+ """Drop null values and handle ids."""
+ for key, ref in kwargs.copy().iteritems():
+ if ref is None:
+ kwargs.pop(key)
+ else:
+ if isinstance(ref, Resource):
+ kwargs.pop(key)
+ kwargs['%s_id' % key] = getid(ref)
+ return kwargs
+
+ def create(self, **kwargs):
+ kwargs = self._filter_kwargs(kwargs)
+ return self._post(
+ self.build_url(**kwargs),
+ {self.key: kwargs},
+ self.key)
+
+ def get(self, **kwargs):
+ kwargs = self._filter_kwargs(kwargs)
+ return self._get(
+ self.build_url(**kwargs),
+ self.key)
+
+ def head(self, **kwargs):
+ kwargs = self._filter_kwargs(kwargs)
+ return self._head(self.build_url(**kwargs))
+
+ def list(self, base_url=None, **kwargs):
+ """List the collection.
+
+ :param base_url: if provided, the generated URL will be appended to it
+ """
+ kwargs = self._filter_kwargs(kwargs)
+
+ return self._list(
+ '%(base_url)s%(query)s' % {
+ 'base_url': self.build_url(base_url=base_url, **kwargs),
+ 'query': '?%s' % urllib.urlencode(kwargs) if kwargs else '',
+ },
+ self.collection_key)
+
+ def put(self, base_url=None, **kwargs):
+ """Update an element.
+
+ :param base_url: if provided, the generated URL will be appended to it
+ """
+ kwargs = self._filter_kwargs(kwargs)
+
+ return self._put(self.build_url(base_url=base_url, **kwargs))
+
+ def update(self, **kwargs):
+ kwargs = self._filter_kwargs(kwargs)
+ params = kwargs.copy()
+ params.pop('%s_id' % self.key)
+
+ return self._patch(
+ self.build_url(**kwargs),
+ {self.key: params},
+ self.key)
+
+ def delete(self, **kwargs):
+ kwargs = self._filter_kwargs(kwargs)
+
+ return self._delete(
+ self.build_url(**kwargs))
+
+ def find(self, base_url=None, **kwargs):
+ """Find a single item with attributes matching ``**kwargs``.
+
+ :param base_url: if provided, the generated URL will be appended to it
+ """
+ kwargs = self._filter_kwargs(kwargs)
+
+ rl = self._list(
+ '%(base_url)s%(query)s' % {
+ 'base_url': self.build_url(base_url=base_url, **kwargs),
+ 'query': '?%s' % urllib.urlencode(kwargs) if kwargs else '',
+ },
+ self.collection_key)
+ num = len(rl)
+
+ if num == 0:
+ msg = "No %s matching %s." % (self.resource_class.__name__, kwargs)
+ raise exceptions.NotFound(404, msg)
+ elif num > 1:
+ raise exceptions.NoUniqueMatch
+ else:
+ return rl[0]
+
+
+class Extension(HookableMixin):
+ """Extension descriptor."""
+
+ SUPPORTED_HOOKS = ('__pre_parse_args__', '__post_parse_args__')
+ manager_class = None
+
+ def __init__(self, name, module):
+ super(Extension, self).__init__()
+ self.name = name
+ self.module = module
+ self._parse_extension_module()
+
+ def _parse_extension_module(self):
+ self.manager_class = None
+ for attr_name, attr_value in self.module.__dict__.items():
+ if attr_name in self.SUPPORTED_HOOKS:
+ self.add_hook(attr_name, attr_value)
+ else:
+ try:
+ if issubclass(attr_value, BaseManager):
+ self.manager_class = attr_value
+ except TypeError:
+ pass
+
+ def __repr__(self):
+ return "<Extension '%s'>" % self.name
+
+
+class Resource(object):
+ """Base class for OpenStack resources (tenant, user, etc.).
+
+ This is pretty much just a bag for attributes.
+ """
+
+ HUMAN_ID = False
+ NAME_ATTR = 'name'
+
+ def __init__(self, manager, info, loaded=False):
+ """Populate and bind to a manager.
+
+ :param manager: BaseManager object
+ :param info: dictionary representing resource attributes
+ :param loaded: prevent lazy-loading if set to True
+ """
+ self.manager = manager
+ self._info = info
+ self._add_details(info)
+ self._loaded = loaded
+
+ def __repr__(self):
+ reprkeys = sorted(k
+ for k in self.__dict__.keys()
+ if k[0] != '_' and k != 'manager')
+ info = ", ".join("%s=%s" % (k, getattr(self, k)) for k in reprkeys)
+ return "<%s %s>" % (self.__class__.__name__, info)
+
+ @property
+ def human_id(self):
+ """Human-readable ID which can be used for bash completion.
+ """
+ if self.NAME_ATTR in self.__dict__ and self.HUMAN_ID:
+ return strutils.to_slug(getattr(self, self.NAME_ATTR))
+ return None
+
+ def _add_details(self, info):
+ for (k, v) in info.iteritems():
+ try:
+ setattr(self, k, v)
+ self._info[k] = v
+ except AttributeError:
+ # In this case we already defined the attribute on the class
+ pass
+
+ def __getattr__(self, k):
+ if k not in self.__dict__:
+ #NOTE(bcwaldon): disallow lazy-loading if already loaded once
+ if not self.is_loaded():
+ self.get()
+ return self.__getattr__(k)
+
+ raise AttributeError(k)
+ else:
+ return self.__dict__[k]
+
+ def get(self):
+ # set_loaded() first ... so if we have to bail, we know we tried.
+ self.set_loaded(True)
+ if not hasattr(self.manager, 'get'):
+ return
+
+ new = self.manager.get(self.id)
+ if new:
+ self._add_details(new._info)
+
+ def __eq__(self, other):
+ if not isinstance(other, Resource):
+ return NotImplemented
+ # two resources of different types are not equal
+ if not isinstance(other, self.__class__):
+ return False
+ if hasattr(self, 'id') and hasattr(other, 'id'):
+ return self.id == other.id
+ return self._info == other._info
+
+ def is_loaded(self):
+ return self._loaded
+
+ def set_loaded(self, val):
+ self._loaded = val
diff --git a/openstack/common/apiclient/client.py b/openstack/common/apiclient/client.py
new file mode 100644
index 0000000..cb3e1d7
--- /dev/null
+++ b/openstack/common/apiclient/client.py
@@ -0,0 +1,360 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 Jacob Kaplan-Moss
+# Copyright 2011 OpenStack LLC
+# Copyright 2011 Piston Cloud Computing, Inc.
+# Copyright 2013 Alessio Ababilov
+# Copyright 2013 Grid Dynamics
+# Copyright 2013 OpenStack Foundation
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+OpenStack Client interface. Handles the REST calls and responses.
+"""
+
+# E0202: An attribute inherited from %s hide this method
+# pylint: disable=E0202
+
+import logging
+import time
+
+try:
+ import simplejson as json
+except ImportError:
+ import json
+
+import requests
+
+from openstack.common.apiclient import exceptions
+from openstack.common import importutils
+
+
+_logger = logging.getLogger(__name__)
+
+
+class HTTPClient(object):
+ """This client handles sending HTTP requests to OpenStack servers.
+
+ Features:
+ - share authentication information between several clients to different
+ services (e.g., for compute and image clients);
+ - reissue authentication request for expired tokens;
+ - encode/decode JSON bodies;
+ - raise exeptions on HTTP errors;
+ - pluggable authentication;
+ - store authentication information in a keyring;
+ - store time spent for requests;
+ - register clients for particular services, so one can use
+ `http_client.identity` or `http_client.compute`;
+ - log requests and responses in a format that is easy to copy-and-paste
+ into terminal and send the same request with curl.
+ """
+
+ user_agent = "openstack.common.apiclient"
+
+ def __init__(self,
+ auth_plugin,
+ region_name=None,
+ endpoint_type="publicURL",
+ original_ip=None,
+ verify=True,
+ cert=None,
+ timeout=None,
+ timings=False,
+ keyring_saver=None,
+ debug=False,
+ user_agent=None,
+ http=None):
+ self.auth_plugin = auth_plugin
+
+ self.endpoint_type = endpoint_type
+ self.region_name = region_name
+
+ self.original_ip = original_ip
+ self.timeout = timeout
+ self.verify = verify
+ self.cert = cert
+
+ self.keyring_saver = keyring_saver
+ self.debug = debug
+ self.user_agent = user_agent or self.user_agent
+
+ self.times = [] # [("item", starttime, endtime), ...]
+ self.timings = timings
+
+ # requests within the same session can reuse TCP connections from pool
+ self.http = http or requests.Session()
+
+ self.cached_token = None
+
+ def _http_log_req(self, method, url, kwargs):
+ if not self.debug:
+ return
+
+ string_parts = [
+ "curl -i",
+ "-X '%s'" % method,
+ "'%s'" % url,
+ ]
+
+ for element in kwargs['headers']:
+ header = "-H '%s: %s'" % (element, kwargs['headers'][element])
+ string_parts.append(header)
+
+ _logger.debug("REQ: %s" % " ".join(string_parts))
+ if 'data' in kwargs:
+ _logger.debug("REQ BODY: %s\n" % (kwargs['data']))
+
+ def _http_log_resp(self, resp):
+ if not self.debug:
+ return
+ _logger.debug(
+ "RESP: [%s] %s\n",
+ resp.status_code,
+ resp.headers)
+ if resp._content_consumed:
+ _logger.debug(
+ "RESP BODY: %s\n",
+ resp.text)
+
+ def serialize(self, kwargs):
+ if kwargs.get('json') is not None:
+ kwargs['headers']['Content-Type'] = 'application/json'
+ kwargs['data'] = json.dumps(kwargs['json'])
+ try:
+ del kwargs['json']
+ except KeyError:
+ pass
+
+ def get_timings(self):
+ return self.times
+
+ def reset_timings(self):
+ self.times = []
+
+ def request(self, method, url, **kwargs):
+ """Send an http request with the specified characteristics.
+
+ Wrapper around `requests.Session.request` to handle tasks such as
+ setting headers, JSON encoding/decoding, and error handling.
+
+ :param method: method of HTTP request
+ :param url: URL of HTTP request
+ :param kwargs: any other parameter that can be passed to
+' requests.Session.request (such as `headers`) or `json`
+ that will be encoded as JSON and used as `data` argument
+ """
+ kwargs.setdefault("headers", kwargs.get("headers", {}))
+ kwargs["headers"]["User-Agent"] = self.user_agent
+ if self.original_ip:
+ kwargs["headers"]["Forwarded"] = "for=%s;by=%s" % (
+ self.original_ip, self.user_agent)
+ if self.timeout is not None:
+ kwargs.setdefault("timeout", self.timeout)
+ kwargs.setdefault("verify", self.verify)
+ if self.cert is not None:
+ kwargs.setdefault("cert", self.cert)
+ self.serialize(kwargs)
+
+ self._http_log_req(method, url, kwargs)
+ if self.timings:
+ start_time = time.time()
+ resp = self.http.request(method, url, **kwargs)
+ if self.timings:
+ self.times.append(("%s %s" % (method, url),
+ start_time, time.time()))
+ self._http_log_resp(resp)
+
+ if resp.status_code >= 400:
+ _logger.debug(
+ "Request returned failure status: %s",
+ resp.status_code)
+ raise exceptions.from_response(resp, method, url)
+
+ return resp
+
+ @staticmethod
+ def concat_url(endpoint, url):
+ """Concatenate endpoint and final URL.
+
+ E.g., "http://keystone/v2.0/" and "/tokens" are concatenated to
+ "http://keystone/v2.0/tokens".
+
+ :param endpoint: the base URL
+ :param url: the final URL
+ """
+ return "%s/%s" % (endpoint.rstrip("/"), url.strip("/"))
+
+ def client_request(self, client, method, url, **kwargs):
+ """Send an http request using `client`'s endpoint and specified `url`.
+
+ If request was rejected as unauthorized (possibly because the token is
+ expired), issue one authorization attempt and send the request once
+ again.
+
+ :param client: instance of BaseClient descendant
+ :param method: method of HTTP request
+ :param url: URL of HTTP request
+ :param kwargs: any other parameter that can be passed to
+' `HTTPClient.request`
+ """
+
+ filter_args = {
+ "endpoint_type": client.endpoint_type or self.endpoint_type,
+ "service_type": client.service_type,
+ }
+ token, endpoint = (self.cached_token, client.cached_endpoint)
+ just_authenticated = False
+ if not (token and endpoint):
+ try:
+ token, endpoint = self.auth_plugin.token_and_endpoint(
+ **filter_args)
+ except exceptions.EndpointException:
+ pass
+ if not (token and endpoint):
+ self.authenticate()
+ just_authenticated = True
+ token, endpoint = self.auth_plugin.token_and_endpoint(
+ **filter_args)
+ if not (token and endpoint):
+ raise exceptions.AuthorizationFailure(
+ "Cannot find endpoint or token for request")
+
+ old_token_endpoint = (token, endpoint)
+ kwargs.setdefault("headers", {})["X-Auth-Token"] = token
+ self.cached_token = token
+ client.cached_endpoint = endpoint
+ # Perform the request once. If we get Unauthorized, then it
+ # might be because the auth token expired, so try to
+ # re-authenticate and try again. If it still fails, bail.
+ try:
+ return self.request(
+ method, self.concat_url(endpoint, url), **kwargs)
+ except exceptions.Unauthorized as unauth_ex:
+ if just_authenticated:
+ raise
+ self.cached_token = None
+ client.cached_endpoint = None
+ self.authenticate()
+ try:
+ token, endpoint = self.auth_plugin.token_and_endpoint(
+ **filter_args)
+ except exceptions.EndpointException:
+ raise unauth_ex
+ if (not (token and endpoint) or
+ old_token_endpoint == (token, endpoint)):
+ raise unauth_ex
+ self.cached_token = token
+ client.cached_endpoint = endpoint
+ kwargs["headers"]["X-Auth-Token"] = token
+ return self.request(
+ method, self.concat_url(endpoint, url), **kwargs)
+
+ def add_client(self, base_client_instance):
+ """Add a new instance of :class:`BaseClient` descendant.
+
+ `self` will store a reference to `base_client_instance`.
+
+ Example:
+
+ >>> def test_clients():
+ ... from keystoneclient.auth import keystone
+ ... from openstack.common.apiclient import client
+ ... auth = keystone.KeystoneAuthPlugin(
+ ... username="user", password="pass", tenant_name="tenant",
+ ... auth_url="http://auth:5000/v2.0")
+ ... openstack_client = client.HTTPClient(auth)
+ ... # create nova client
+ ... from novaclient.v1_1 import client
+ ... client.Client(openstack_client)
+ ... # create keystone client
+ ... from keystoneclient.v2_0 import client
+ ... client.Client(openstack_client)
+ ... # use them
+ ... openstack_client.identity.tenants.list()
+ ... openstack_client.compute.servers.list()
+ """
+ service_type = base_client_instance.service_type
+ if service_type and not hasattr(self, service_type):
+ setattr(self, service_type, base_client_instance)
+
+ def authenticate(self):
+ self.auth_plugin.authenticate(self)
+ # Store the authentication results in the keyring for later requests
+ if self.keyring_saver:
+ self.keyring_saver.save(self)
+
+
+class BaseClient(object):
+ """Top-level object to access the OpenStack API.
+
+ This client uses :class:`HTTPClient` to send requests. :class:`HTTPClient`
+ will handle a bunch of issues such as authentication.
+ """
+
+ service_type = None
+ endpoint_type = None # "publicURL" will be used
+ cached_endpoint = None
+
+ def __init__(self, http_client, extensions=None):
+ self.http_client = http_client
+ http_client.add_client(self)
+
+ # Add in any extensions...
+ if extensions:
+ for extension in extensions:
+ if extension.manager_class:
+ setattr(self, extension.name,
+ extension.manager_class(self))
+
+ def client_request(self, method, url, **kwargs):
+ return self.http_client.client_request(
+ self, method, url, **kwargs)
+
+ def head(self, url, **kwargs):
+ return self.client_request("HEAD", url, **kwargs)
+
+ def get(self, url, **kwargs):
+ return self.client_request("GET", url, **kwargs)
+
+ def post(self, url, **kwargs):
+ return self.client_request("POST", url, **kwargs)
+
+ def put(self, url, **kwargs):
+ return self.client_request("PUT", url, **kwargs)
+
+ def delete(self, url, **kwargs):
+ return self.client_request("DELETE", url, **kwargs)
+
+ def patch(self, url, **kwargs):
+ return self.client_request("PATCH", url, **kwargs)
+
+ @staticmethod
+ def get_class(api_name, version, version_map):
+ """Returns the client class for the requested API version
+
+ :param api_name: the name of the API, e.g. 'compute', 'image', etc
+ :param version: the requested API version
+ :param version_map: a dict of client classes keyed by version
+ :rtype: a client class for the requested API version
+ """
+ try:
+ client_path = version_map[str(version)]
+ except (KeyError, ValueError):
+ msg = "Invalid %s client version '%s'. must be one of: %s" % (
+ (api_name, version, ', '.join(version_map.keys())))
+ raise exceptions.UnsupportedVersion(msg)
+
+ return importutils.import_class(client_path)
diff --git a/openstack/common/apiclient/exceptions.py b/openstack/common/apiclient/exceptions.py
index e70d37a..b03def7 100644
--- a/openstack/common/apiclient/exceptions.py
+++ b/openstack/common/apiclient/exceptions.py
@@ -121,7 +121,7 @@ class HttpError(ClientException):
super(HttpError, self).__init__(formatted_string)
-class HttpClientError(HttpError):
+class HTTPClientError(HttpError):
"""Client-side HTTP error.
Exception for cases in which the client seems to have erred.
@@ -138,7 +138,7 @@ class HttpServerError(HttpError):
message = "HTTP Server Error"
-class BadRequest(HttpClientError):
+class BadRequest(HTTPClientError):
"""HTTP 400 - Bad Request.
The request cannot be fulfilled due to bad syntax.
@@ -147,7 +147,7 @@ class BadRequest(HttpClientError):
message = "Bad Request"
-class Unauthorized(HttpClientError):
+class Unauthorized(HTTPClientError):
"""HTTP 401 - Unauthorized.
Similar to 403 Forbidden, but specifically for use when authentication
@@ -157,7 +157,7 @@ class Unauthorized(HttpClientError):
message = "Unauthorized"
-class PaymentRequired(HttpClientError):
+class PaymentRequired(HTTPClientError):
"""HTTP 402 - Payment Required.
Reserved for future use.
@@ -166,7 +166,7 @@ class PaymentRequired(HttpClientError):
message = "Payment Required"
-class Forbidden(HttpClientError):
+class Forbidden(HTTPClientError):
"""HTTP 403 - Forbidden.
The request was a valid request, but the server is refusing to respond
@@ -176,7 +176,7 @@ class Forbidden(HttpClientError):
message = "Forbidden"
-class NotFound(HttpClientError):
+class NotFound(HTTPClientError):
"""HTTP 404 - Not Found.
The requested resource could not be found but may be available again
@@ -186,7 +186,7 @@ class NotFound(HttpClientError):
message = "Not Found"
-class MethodNotAllowed(HttpClientError):
+class MethodNotAllowed(HTTPClientError):
"""HTTP 405 - Method Not Allowed.
A request was made of a resource using a request method not supported
@@ -196,7 +196,7 @@ class MethodNotAllowed(HttpClientError):
message = "Method Not Allowed"
-class NotAcceptable(HttpClientError):
+class NotAcceptable(HTTPClientError):
"""HTTP 406 - Not Acceptable.
The requested resource is only capable of generating content not
@@ -206,7 +206,7 @@ class NotAcceptable(HttpClientError):
message = "Not Acceptable"
-class ProxyAuthenticationRequired(HttpClientError):
+class ProxyAuthenticationRequired(HTTPClientError):
"""HTTP 407 - Proxy Authentication Required.
The client must first authenticate itself with the proxy.
@@ -215,7 +215,7 @@ class ProxyAuthenticationRequired(HttpClientError):
message = "Proxy Authentication Required"
-class RequestTimeout(HttpClientError):
+class RequestTimeout(HTTPClientError):
"""HTTP 408 - Request Timeout.
The server timed out waiting for the request.
@@ -224,7 +224,7 @@ class RequestTimeout(HttpClientError):
message = "Request Timeout"
-class Conflict(HttpClientError):
+class Conflict(HTTPClientError):
"""HTTP 409 - Conflict.
Indicates that the request could not be processed because of conflict
@@ -234,7 +234,7 @@ class Conflict(HttpClientError):
message = "Conflict"
-class Gone(HttpClientError):
+class Gone(HTTPClientError):
"""HTTP 410 - Gone.
Indicates that the resource requested is no longer available and will
@@ -244,7 +244,7 @@ class Gone(HttpClientError):
message = "Gone"
-class LengthRequired(HttpClientError):
+class LengthRequired(HTTPClientError):
"""HTTP 411 - Length Required.
The request did not specify the length of its content, which is
@@ -254,7 +254,7 @@ class LengthRequired(HttpClientError):
message = "Length Required"
-class PreconditionFailed(HttpClientError):
+class PreconditionFailed(HTTPClientError):
"""HTTP 412 - Precondition Failed.
The server does not meet one of the preconditions that the requester
@@ -264,7 +264,7 @@ class PreconditionFailed(HttpClientError):
message = "Precondition Failed"
-class RequestEntityTooLarge(HttpClientError):
+class RequestEntityTooLarge(HTTPClientError):
"""HTTP 413 - Request Entity Too Large.
The request is larger than the server is willing or able to process.
@@ -281,7 +281,7 @@ class RequestEntityTooLarge(HttpClientError):
super(RequestEntityTooLarge, self).__init__(*args, **kwargs)
-class RequestUriTooLong(HttpClientError):
+class RequestUriTooLong(HTTPClientError):
"""HTTP 414 - Request-URI Too Long.
The URI provided was too long for the server to process.
@@ -290,7 +290,7 @@ class RequestUriTooLong(HttpClientError):
message = "Request-URI Too Long"
-class UnsupportedMediaType(HttpClientError):
+class UnsupportedMediaType(HTTPClientError):
"""HTTP 415 - Unsupported Media Type.
The request entity has a media type which the server or resource does
@@ -300,7 +300,7 @@ class UnsupportedMediaType(HttpClientError):
message = "Unsupported Media Type"
-class RequestedRangeNotSatisfiable(HttpClientError):
+class RequestedRangeNotSatisfiable(HTTPClientError):
"""HTTP 416 - Requested Range Not Satisfiable.
The client has asked for a portion of the file, but the server cannot
@@ -310,7 +310,7 @@ class RequestedRangeNotSatisfiable(HttpClientError):
message = "Requested Range Not Satisfiable"
-class ExpectationFailed(HttpClientError):
+class ExpectationFailed(HTTPClientError):
"""HTTP 417 - Expectation Failed.
The server cannot meet the requirements of the Expect request-header field.
@@ -319,7 +319,7 @@ class ExpectationFailed(HttpClientError):
message = "Expectation Failed"
-class UnprocessableEntity(HttpClientError):
+class UnprocessableEntity(HTTPClientError):
"""HTTP 422 - Unprocessable Entity.
The request was well-formed but was unable to be followed due to semantic
@@ -440,7 +440,7 @@ def from_response(response, method, url):
if 500 <= response.status_code < 600:
cls = HttpServerError
elif 400 <= response.status_code < 500:
- cls = HttpClientError
+ cls = HTTPClientError
else:
cls = HttpError
return cls(**kwargs)
diff --git a/openstack/common/apiclient/fake_client.py b/openstack/common/apiclient/fake_client.py
new file mode 100644
index 0000000..da125e2
--- /dev/null
+++ b/openstack/common/apiclient/fake_client.py
@@ -0,0 +1,172 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2013 OpenStack Foundation
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+A fake server that "responds" to API methods with pre-canned responses.
+
+All of these responses come from the spec, so if for some reason the spec's
+wrong the tests might raise AssertionError. I've indicated in comments the
+places where actual behavior differs from the spec.
+"""
+
+# W0102: Dangerous default value %s as argument
+# pylint: disable=W0102
+
+import json
+import urlparse
+
+import requests
+
+from openstack.common.apiclient import client
+
+
+def assert_has_keys(dct, required=[], optional=[]):
+ for k in required:
+ try:
+ assert k in dct
+ except AssertionError:
+ extra_keys = set(dct.keys()).difference(set(required + optional))
+ raise AssertionError("found unexpected keys: %s" %
+ list(extra_keys))
+
+
+class TestResponse(requests.Response):
+ """Wrap requests.Response and provide a convenient initialization.
+ """
+
+ def __init__(self, data):
+ super(TestResponse, self).__init__()
+ self._content_consumed = True
+ if isinstance(data, dict):
+ self.status_code = data.get('status_code', 200)
+ # Fake the text attribute to streamline Response creation
+ text = data.get('text', "")
+ if isinstance(text, (dict, list)):
+ self._content = json.dumps(text)
+ default_headers = {
+ "Content-Type": "application/json",
+ }
+ else:
+ self._content = text
+ default_headers = {}
+ self.headers = data.get('headers') or default_headers
+ else:
+ self.status_code = data
+
+ def __eq__(self, other):
+ return (self.status_code == other.status_code and
+ self.headers == other.headers and
+ self._content == other._content)
+
+
+class FakeHTTPClient(client.HTTPClient):
+
+ def __init__(self, *args, **kwargs):
+ self.callstack = []
+ self.fixtures = kwargs.pop("fixtures", None) or {}
+ if not args and not "auth_plugin" in kwargs:
+ args = (None, )
+ super(FakeHTTPClient, self).__init__(*args, **kwargs)
+
+ def assert_called(self, method, url, body=None, pos=-1):
+ """Assert than an API method was just called.
+ """
+ expected = (method, url)
+ called = self.callstack[pos][0:2]
+ assert self.callstack, \
+ "Expected %s %s but no calls were made." % expected
+
+ assert expected == called, 'Expected %s %s; got %s %s' % \
+ (expected + called)
+
+ if body is not None:
+ if self.callstack[pos][3] != body:
+ raise AssertionError('%r != %r' %
+ (self.callstack[pos][3], body))
+
+ def assert_called_anytime(self, method, url, body=None):
+ """Assert than an API method was called anytime in the test.
+ """
+ expected = (method, url)
+
+ assert self.callstack, \
+ "Expected %s %s but no calls were made." % expected
+
+ found = False
+ entry = None
+ for entry in self.callstack:
+ if expected == entry[0:2]:
+ found = True
+ break
+
+ assert found, 'Expected %s %s; got %s' % \
+ (method, url, self.callstack)
+ if body is not None:
+ assert entry[3] == body, "%s != %s" % (entry[3], body)
+
+ self.callstack = []
+
+ def clear_callstack(self):
+ self.callstack = []
+
+ def authenticate(self):
+ pass
+
+ def client_request(self, client, method, url, **kwargs):
+ # Check that certain things are called correctly
+ if method in ["GET", "DELETE"]:
+ assert "json" not in kwargs
+
+ # Note the call
+ self.callstack.append(
+ (method,
+ url,
+ kwargs.get("headers") or {},
+ kwargs.get("json") or kwargs.get("data")))
+ try:
+ fixture = self.fixtures[url][method]
+ except KeyError:
+ pass
+ else:
+ return TestResponse({"headers": fixture[0],
+ "text": fixture[1]})
+
+ # Call the method
+ args = urlparse.parse_qsl(urlparse.urlparse(url)[4])
+ kwargs.update(args)
+ munged_url = url.rsplit('?', 1)[0]
+ munged_url = munged_url.strip('/').replace('/', '_').replace('.', '_')
+ munged_url = munged_url.replace('-', '_')
+
+ callback = "%s_%s" % (method.lower(), munged_url)
+
+ if not hasattr(self, callback):
+ raise AssertionError('Called unknown API method: %s %s, '
+ 'expected fakes method name: %s' %
+ (method, url, callback))
+
+ resp = getattr(self, callback)(**kwargs)
+ if len(resp) == 3:
+ status, headers, body = resp
+ else:
+ status, body = resp
+ headers = {}
+ return TestResponse({
+ "status_code": status,
+ "text": body,
+ "headers": headers,
+ })
diff --git a/openstack/common/config/generator.py b/openstack/common/config/generator.py
index 3d2809b..1ae6c4d 100644
--- a/openstack/common/config/generator.py
+++ b/openstack/common/config/generator.py
@@ -50,7 +50,6 @@ OPT_TYPES = {
MULTISTROPT: 'multi valued',
}
-OPTION_COUNT = 0
OPTION_REGEX = re.compile(r"(%s)" % "|".join([STROPT, BOOLOPT, INTOPT,
FLOATOPT, LISTOPT,
MULTISTROPT]))
@@ -97,8 +96,6 @@ def generate(srcfiles):
for group, opts in opts_by_group.items():
print_group_opts(group, opts)
- print("# Total option count: %d" % OPTION_COUNT)
-
def _import_module(mod_str):
try:
@@ -163,9 +160,7 @@ def _list_opts(obj):
def print_group_opts(group, opts_by_module):
print("[%s]" % group)
print('')
- global OPTION_COUNT
for mod, opts in opts_by_module:
- OPTION_COUNT += len(opts)
print('#')
print('# Options defined in %s' % mod)
print('#')
@@ -186,24 +181,24 @@ def _get_my_ip():
return None
-def _sanitize_default(s):
+def _sanitize_default(name, value):
"""Set up a reasonably sensible default for pybasedir, my_ip and host."""
- if s.startswith(sys.prefix):
+ if value.startswith(sys.prefix):
# NOTE(jd) Don't use os.path.join, because it is likely to think the
# second part is an absolute pathname and therefore drop the first
# part.
- s = os.path.normpath("/usr/" + s[len(sys.prefix):])
- elif s.startswith(BASEDIR):
- return s.replace(BASEDIR, '/usr/lib/python/site-packages')
- elif BASEDIR in s:
- return s.replace(BASEDIR, '')
- elif s == _get_my_ip():
+ value = os.path.normpath("/usr/" + value[len(sys.prefix):])
+ elif value.startswith(BASEDIR):
+ return value.replace(BASEDIR, '/usr/lib/python/site-packages')
+ elif BASEDIR in value:
+ return value.replace(BASEDIR, '')
+ elif value == _get_my_ip():
return '10.0.0.1'
- elif s == socket.gethostname():
+ elif value == socket.gethostname() and 'host' in name:
return 'oslo'
- elif s.strip() != s:
- return '"%s"' % s
- return s
+ elif value.strip() != value:
+ return '"%s"' % value
+ return value
def _print_opt(opt):
@@ -224,7 +219,8 @@ def _print_opt(opt):
print('#%s=<None>' % opt_name)
elif opt_type == STROPT:
assert(isinstance(opt_default, basestring))
- print('#%s=%s' % (opt_name, _sanitize_default(opt_default)))
+ print('#%s=%s' % (opt_name, _sanitize_default(opt_name,
+ opt_default)))
elif opt_type == BOOLOPT:
assert(isinstance(opt_default, bool))
print('#%s=%s' % (opt_name, str(opt_default).lower()))
diff --git a/openstack/common/db/exception.py b/openstack/common/db/exception.py
index 69905da..3627de2 100644
--- a/openstack/common/db/exception.py
+++ b/openstack/common/db/exception.py
@@ -43,3 +43,9 @@ class DBDeadlock(DBError):
class DBInvalidUnicodeParameter(Exception):
message = _("Invalid Parameter: "
"Unicode is not supported by the current database.")
+
+
+class DbMigrationError(DBError):
+ """Wraps migration specific exception."""
+ def __init__(self, message=None):
+ super(DbMigrationError, self).__init__(str(message))
diff --git a/openstack/common/db/sqlalchemy/migration.py b/openstack/common/db/sqlalchemy/migration.py
index e643d8e..4154359 100644
--- a/openstack/common/db/sqlalchemy/migration.py
+++ b/openstack/common/db/sqlalchemy/migration.py
@@ -38,12 +38,53 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
+import distutils.version as dist_version
+import os
import re
+import migrate
from migrate.changeset import ansisql
from migrate.changeset.databases import sqlite
+from migrate.versioning import util as migrate_util
+import sqlalchemy
from sqlalchemy.schema import UniqueConstraint
+from openstack.common.db import exception
+from openstack.common.db.sqlalchemy import session as db_session
+from openstack.common.gettextutils import _ # noqa
+
+
+@migrate_util.decorator
+def patched_with_engine(f, *a, **kw):
+ url = a[0]
+ engine = migrate_util.construct_engine(url, **kw)
+
+ try:
+ kw['engine'] = engine
+ return f(*a, **kw)
+ finally:
+ if isinstance(engine, migrate_util.Engine) and engine is not url:
+ migrate_util.log.debug('Disposing SQLAlchemy engine %s', engine)
+ engine.dispose()
+
+
+# TODO(jkoelker) When migrate 0.7.3 is released and nova depends
+# on that version or higher, this can be removed
+MIN_PKG_VERSION = dist_version.StrictVersion('0.7.3')
+if (not hasattr(migrate, '__version__') or
+ dist_version.StrictVersion(migrate.__version__) < MIN_PKG_VERSION):
+ migrate_util.with_engine = patched_with_engine
+
+
+# NOTE(jkoelker) Delay importing migrate until we are patched
+from migrate import exceptions as versioning_exceptions
+from migrate.versioning import api as versioning_api
+from migrate.versioning.repository import Repository
+
+_REPOSITORY = None
+
+get_engine = db_session.get_engine
+
def _get_unique_constraints(self, table):
"""Retrieve information about existing unique constraints of the table
@@ -157,3 +198,81 @@ def patch_migrate():
_visit_migrate_unique_constraint
constraint_cls.__bases__ = (ansisql.ANSIColumnDropper,
sqlite.SQLiteConstraintGenerator)
+
+
+def db_sync(abs_path, version=None, init_version=0):
+ """Upgrade or downgrade a database.
+
+ Function runs the upgrade() or downgrade() functions in change scripts.
+
+ :param abs_path: Absolute path to migrate repository.
+ :param version: Database will upgrade/downgrade until this version.
+ If None - database will update to the latest
+ available version.
+ :param init_version: Initial database version
+ """
+ if version is not None:
+ try:
+ version = int(version)
+ except ValueError:
+ raise exception.DbMigrationError(
+ message=_("version should be an integer"))
+
+ current_version = db_version(abs_path, init_version)
+ repository = _find_migrate_repo(abs_path)
+ if version is None or version > current_version:
+ return versioning_api.upgrade(get_engine(), repository, version)
+ else:
+ return versioning_api.downgrade(get_engine(), repository,
+ version)
+
+
+def db_version(abs_path, init_version):
+ """Show the current version of the repository.
+
+ :param abs_path: Absolute path to migrate repository
+ :param version: Initial database version
+ """
+ repository = _find_migrate_repo(abs_path)
+ try:
+ return versioning_api.db_version(get_engine(), repository)
+ except versioning_exceptions.DatabaseNotControlledError:
+ meta = sqlalchemy.MetaData()
+ engine = get_engine()
+ meta.reflect(bind=engine)
+ tables = meta.tables
+ if len(tables) == 0:
+ db_version_control(abs_path, init_version)
+ return versioning_api.db_version(get_engine(), repository)
+ else:
+ # Some pre-Essex DB's may not be version controlled.
+ # Require them to upgrade using Essex first.
+ raise exception.DbMigrationError(
+ message=_("Upgrade DB using Essex release first."))
+
+
+def db_version_control(abs_path, version=None):
+ """Mark a database as under this repository's version control.
+
+ Once a database is under version control, schema changes should
+ only be done via change scripts in this repository.
+
+ :param abs_path: Absolute path to migrate repository
+ :param version: Initial database version
+ """
+ repository = _find_migrate_repo(abs_path)
+ versioning_api.version_control(get_engine(), repository, version)
+ return version
+
+
+def _find_migrate_repo(abs_path):
+ """Get the project's change script repository
+
+ :param abs_path: Absolute path to migrate repository
+ """
+ global _REPOSITORY
+ if not os.path.exists(abs_path):
+ raise exception.DbMigrationError("Path %s not found" % abs_path)
+ if _REPOSITORY is None:
+ _REPOSITORY = Repository(abs_path)
+ return _REPOSITORY
diff --git a/openstack/common/db/sqlalchemy/session.py b/openstack/common/db/sqlalchemy/session.py
index 59bcb90..236136e 100644
--- a/openstack/common/db/sqlalchemy/session.py
+++ b/openstack/common/db/sqlalchemy/session.py
@@ -279,13 +279,11 @@ database_opts = [
deprecated_opts=[cfg.DeprecatedOpt('sql_connection',
group='DEFAULT'),
cfg.DeprecatedOpt('sql_connection',
- group='DATABASE')],
- secret=True),
+ group='DATABASE')]),
cfg.StrOpt('slave_connection',
default='',
help='The SQLAlchemy connection string used to connect to the '
- 'slave database',
- secret=True),
+ 'slave database'),
cfg.IntOpt('idle_timeout',
default=3600,
deprecated_opts=[cfg.DeprecatedOpt('sql_idle_timeout',
@@ -478,6 +476,11 @@ def _raise_if_duplicate_entry_error(integrity_error, engine_name):
if engine_name not in ["mysql", "sqlite", "postgresql"]:
return
+ # FIXME(johannes): The usage of the .message attribute has been
+ # deprecated since Python 2.6. However, the exceptions raised by
+ # SQLAlchemy can differ when using unicode() and accessing .message.
+ # An audit across all three supported engines will be necessary to
+ # ensure there are no regressions.
m = _DUP_KEY_RE_DB[engine_name].match(integrity_error.message)
if not m:
return
@@ -510,6 +513,11 @@ def _raise_if_deadlock_error(operational_error, engine_name):
re = _DEADLOCK_RE_DB.get(engine_name)
if re is None:
return
+ # FIXME(johannes): The usage of the .message attribute has been
+ # deprecated since Python 2.6. However, the exceptions raised by
+ # SQLAlchemy can differ when using unicode() and accessing .message.
+ # An audit across all three supported engines will be necessary to
+ # ensure there are no regressions.
m = re.match(operational_error.message)
if not m:
return
diff --git a/openstack/common/db/sqlalchemy/utils.py b/openstack/common/db/sqlalchemy/utils.py
index 3ff7bdb..102f0e5 100755..100644
--- a/openstack/common/db/sqlalchemy/utils.py
+++ b/openstack/common/db/sqlalchemy/utils.py
@@ -18,6 +18,9 @@
# License for the specific language governing permissions and limitations
# under the License.
+import re
+
+from migrate.changeset import UniqueConstraint
import sqlalchemy
from sqlalchemy import Boolean
from sqlalchemy import CheckConstraint
@@ -37,13 +40,21 @@ from sqlalchemy.types import NullType
from openstack.common.gettextutils import _ # noqa
-from openstack.common import exception
from openstack.common import log as logging
from openstack.common import timeutils
LOG = logging.getLogger(__name__)
+_DBURL_REGEX = re.compile(r"[^:]+://([^:]+):([^@]+)@.+")
+
+
+def sanitize_db_url(url):
+ match = _DBURL_REGEX.match(url)
+ if match:
+ return '%s****:****%s' % (url[:match.start(1)], url[match.end(2):])
+ return url
+
class InvalidSortKey(Exception):
message = _("Sort key supplied was not valid.")
@@ -174,6 +185,10 @@ def visit_insert_from_select(element, compiler, **kw):
compiler.process(element.select))
+class ColumnError(Exception):
+ """Error raised when no column or an invalid column is found."""
+
+
def _get_not_supported_column(col_name_col_instance, column_name):
try:
column = col_name_col_instance[column_name]
@@ -181,16 +196,53 @@ def _get_not_supported_column(col_name_col_instance, column_name):
msg = _("Please specify column %s in col_name_col_instance "
"param. It is required because column has unsupported "
"type by sqlite).")
- raise exception.OpenstackException(message=msg % column_name)
+ raise ColumnError(msg % column_name)
if not isinstance(column, Column):
msg = _("col_name_col_instance param has wrong type of "
"column instance for column %s It should be instance "
"of sqlalchemy.Column.")
- raise exception.OpenstackException(message=msg % column_name)
+ raise ColumnError(msg % column_name)
return column
+def drop_unique_constraint(migrate_engine, table_name, uc_name, *columns,
+ **col_name_col_instance):
+ """Drop unique constraint from table.
+
+ This method drops UC from table and works for mysql, postgresql and sqlite.
+ In mysql and postgresql we are able to use "alter table" construction.
+ Sqlalchemy doesn't support some sqlite column types and replaces their
+ type with NullType in metadata. We process these columns and replace
+ NullType with the correct column type.
+
+ :param migrate_engine: sqlalchemy engine
+ :param table_name: name of table that contains uniq constraint.
+ :param uc_name: name of uniq constraint that will be dropped.
+ :param columns: columns that are in uniq constraint.
+ :param col_name_col_instance: contains pair column_name=column_instance.
+ column_instance is instance of Column. These params
+ are required only for columns that have unsupported
+ types by sqlite. For example BigInteger.
+ """
+
+ meta = MetaData()
+ meta.bind = migrate_engine
+ t = Table(table_name, meta, autoload=True)
+
+ if migrate_engine.name == "sqlite":
+ override_cols = [
+ _get_not_supported_column(col_name_col_instance, col.name)
+ for col in t.columns
+ if isinstance(col.type, NullType)
+ ]
+ for col in override_cols:
+ t.columns.replace(col)
+
+ uc = UniqueConstraint(*columns, table=t, name=uc_name)
+ uc.drop()
+
+
def drop_old_duplicate_entries_from_table(migrate_engine, table_name,
use_soft_delete, *uc_column_names):
"""Drop all old rows having the same values for columns in uc_columns.
@@ -248,8 +300,7 @@ def _get_default_deleted_value(table):
return 0
if isinstance(table.c.id.type, String):
return ""
- raise exception.OpenstackException(
- message=_("Unsupported id columns type"))
+ raise ColumnError(_("Unsupported id columns type"))
def _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes):
diff --git a/openstack/common/deprecated/wsgi.py b/openstack/common/deprecated/wsgi.py
index a9530b3..aff4b92 100644
--- a/openstack/common/deprecated/wsgi.py
+++ b/openstack/common/deprecated/wsgi.py
@@ -35,7 +35,6 @@ import webob.exc
from xml.dom import minidom
from xml.parsers import expat
-from openstack.common import exception
from openstack.common.gettextutils import _ # noqa
from openstack.common import jsonutils
from openstack.common import log as logging
@@ -59,6 +58,18 @@ CONF.register_opts(socket_opts)
LOG = logging.getLogger(__name__)
+class MalformedRequestBody(Exception):
+ def __init__(self, reason):
+ super(MalformedRequestBody, self).__init__(
+ "Malformed message body: %s", reason)
+
+
+class InvalidContentType(Exception):
+ def __init__(self, content_type):
+ super(InvalidContentType, self).__init__(
+ "Invalid content type %s", content_type)
+
+
def run_server(application, port, **kwargs):
"""Run a WSGI server with the given application."""
sock = eventlet.listen(('0.0.0.0', port))
@@ -255,7 +266,7 @@ class Request(webob.Request):
self.default_request_content_types)
if content_type not in allowed_content_types:
- raise exception.InvalidContentType(content_type=content_type)
+ raise InvalidContentType(content_type=content_type)
return content_type
@@ -294,10 +305,10 @@ class Resource(object):
try:
action, action_args, accept = self.deserialize_request(request)
- except exception.InvalidContentType:
+ except InvalidContentType:
msg = _("Unsupported Content-Type")
return webob.exc.HTTPUnsupportedMediaType(explanation=msg)
- except exception.MalformedRequestBody:
+ except MalformedRequestBody:
msg = _("Malformed request body")
return webob.exc.HTTPBadRequest(explanation=msg)
@@ -530,7 +541,7 @@ class ResponseSerializer(object):
try:
return self.body_serializers[content_type]
except (KeyError, TypeError):
- raise exception.InvalidContentType(content_type=content_type)
+ raise InvalidContentType(content_type=content_type)
class RequestHeadersDeserializer(ActionDispatcher):
@@ -589,7 +600,7 @@ class RequestDeserializer(object):
try:
content_type = request.get_content_type()
- except exception.InvalidContentType:
+ except InvalidContentType:
LOG.debug(_("Unrecognized Content-Type provided in request"))
raise
@@ -599,7 +610,7 @@ class RequestDeserializer(object):
try:
deserializer = self.get_body_deserializer(content_type)
- except exception.InvalidContentType:
+ except InvalidContentType:
LOG.debug(_("Unable to deserialize body as provided Content-Type"))
raise
@@ -609,7 +620,7 @@ class RequestDeserializer(object):
try:
return self.body_deserializers[content_type]
except (KeyError, TypeError):
- raise exception.InvalidContentType(content_type=content_type)
+ raise InvalidContentType(content_type=content_type)
def get_expected_content_type(self, request):
return request.best_match_content_type(self.supported_content_types)
@@ -651,7 +662,7 @@ class JSONDeserializer(TextDeserializer):
return jsonutils.loads(datastring)
except ValueError:
msg = _("cannot understand JSON")
- raise exception.MalformedRequestBody(reason=msg)
+ raise MalformedRequestBody(reason=msg)
def default(self, datastring):
return {'body': self._from_json(datastring)}
@@ -676,7 +687,7 @@ class XMLDeserializer(TextDeserializer):
return {node.nodeName: self._from_xml_node(node, plurals)}
except expat.ExpatError:
msg = _("cannot understand XML")
- raise exception.MalformedRequestBody(reason=msg)
+ raise MalformedRequestBody(reason=msg)
def _from_xml_node(self, node, listnames):
"""Convert a minidom node to a simple Python type.
diff --git a/openstack/common/exception.py b/openstack/common/exception.py
deleted file mode 100644
index 13c3bff..0000000
--- a/openstack/common/exception.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2011 OpenStack Foundation.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""
-Exceptions common to OpenStack projects
-"""
-
-import logging
-
-from openstack.common.gettextutils import _ # noqa
-
-_FATAL_EXCEPTION_FORMAT_ERRORS = False
-
-
-class Error(Exception):
- def __init__(self, message=None):
- super(Error, self).__init__(message)
-
-
-class ApiError(Error):
- def __init__(self, message='Unknown', code='Unknown'):
- self.api_message = message
- self.code = code
- super(ApiError, self).__init__('%s: %s' % (code, message))
-
-
-class NotFound(Error):
- pass
-
-
-class UnknownScheme(Error):
-
- msg_fmt = "Unknown scheme '%s' found in URI"
-
- def __init__(self, scheme):
- msg = self.msg_fmt % scheme
- super(UnknownScheme, self).__init__(msg)
-
-
-class BadStoreUri(Error):
-
- msg_fmt = "The Store URI %s was malformed. Reason: %s"
-
- def __init__(self, uri, reason):
- msg = self.msg_fmt % (uri, reason)
- super(BadStoreUri, self).__init__(msg)
-
-
-class Duplicate(Error):
- pass
-
-
-class NotAuthorized(Error):
- pass
-
-
-class NotEmpty(Error):
- pass
-
-
-class Invalid(Error):
- pass
-
-
-class BadInputError(Exception):
- """Error resulting from a client sending bad input to a server"""
- pass
-
-
-class MissingArgumentError(Error):
- pass
-
-
-class DatabaseMigrationError(Error):
- pass
-
-
-class ClientConnectionError(Exception):
- """Error resulting from a client connecting to a server"""
- pass
-
-
-def wrap_exception(f):
- def _wrap(*args, **kw):
- try:
- return f(*args, **kw)
- except Exception as e:
- if not isinstance(e, Error):
- logging.exception(_('Uncaught exception'))
- raise Error(str(e))
- raise
- _wrap.func_name = f.func_name
- return _wrap
-
-
-class OpenstackException(Exception):
- """Base Exception class.
-
- To correctly use this class, inherit from it and define
- a 'msg_fmt' property. That message will get printf'd
- with the keyword arguments provided to the constructor.
- """
- msg_fmt = "An unknown exception occurred"
-
- def __init__(self, **kwargs):
- try:
- self._error_string = self.msg_fmt % kwargs
-
- except Exception:
- if _FATAL_EXCEPTION_FORMAT_ERRORS:
- raise
- else:
- # at least get the core message out if something happened
- self._error_string = self.msg_fmt
-
- def __str__(self):
- return self._error_string
-
-
-class MalformedRequestBody(OpenstackException):
- msg_fmt = "Malformed message body: %(reason)s"
-
-
-class InvalidContentType(OpenstackException):
- msg_fmt = "Invalid content type %(content_type)s"
diff --git a/openstack/common/excutils.py b/openstack/common/excutils.py
index abe6f87..664b2e4 100644
--- a/openstack/common/excutils.py
+++ b/openstack/common/excutils.py
@@ -77,7 +77,8 @@ def forever_retry_uncaught_exceptions(infunc):
try:
return infunc(*args, **kwargs)
except Exception as exc:
- if exc.message == last_exc_message:
+ this_exc_message = unicode(exc)
+ if this_exc_message == last_exc_message:
exc_count += 1
else:
exc_count = 1
@@ -85,12 +86,12 @@ def forever_retry_uncaught_exceptions(infunc):
# the exception message changes
cur_time = int(time.time())
if (cur_time - last_log_time > 60 or
- exc.message != last_exc_message):
+ this_exc_message != last_exc_message):
logging.exception(
_('Unexpected exception occurred %d time(s)... '
'retrying.') % exc_count)
last_log_time = cur_time
- last_exc_message = exc.message
+ last_exc_message = this_exc_message
exc_count = 0
# This should be a very rare event. In case it isn't, do
# a sleep.
diff --git a/openstack/common/fixture/config.py b/openstack/common/fixture/config.py
new file mode 100644
index 0000000..cf52a66
--- /dev/null
+++ b/openstack/common/fixture/config.py
@@ -0,0 +1,45 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+#
+# Copyright 2013 Mirantis, Inc.
+# Copyright 2013 OpenStack Foundation
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+import fixtures
+from oslo.config import cfg
+
+
+class Config(fixtures.Fixture):
+ """Override some configuration values.
+
+ The keyword arguments are the names of configuration options to
+ override and their values.
+
+ If a group argument is supplied, the overrides are applied to
+ the specified configuration option group.
+
+ All overrides are automatically cleared at the end of the current
+ test by the reset() method, which is registred by addCleanup().
+ """
+
+ def __init__(self, conf=cfg.CONF):
+ self.conf = conf
+
+ def setUp(self):
+ super(Config, self).setUp()
+ self.addCleanup(self.conf.reset)
+
+ def config(self, **kw):
+ group = kw.pop('group', None)
+ for k, v in kw.iteritems():
+ self.conf.set_override(k, v, group)
diff --git a/openstack/common/gettextutils.py b/openstack/common/gettextutils.py
index 635a434..321fdd0 100644
--- a/openstack/common/gettextutils.py
+++ b/openstack/common/gettextutils.py
@@ -1,8 +1,8 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2012 Red Hat, Inc.
-# All Rights Reserved.
# Copyright 2013 IBM Corp.
+# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
@@ -31,17 +31,36 @@ import os
import re
import UserString
+from babel import localedata
import six
_localedir = os.environ.get('oslo'.upper() + '_LOCALEDIR')
_t = gettext.translation('oslo', localedir=_localedir, fallback=True)
+_AVAILABLE_LANGUAGES = []
+USE_LAZY = False
+
+
+def enable_lazy():
+ """Convenience function for configuring _() to use lazy gettext
+
+ Call this at the start of execution to enable the gettextutils._
+ function to use lazy gettext functionality. This is useful if
+ your project is importing _ directly instead of using the
+ gettextutils.install() way of importing the _ function.
+ """
+ global USE_LAZY
+ USE_LAZY = True
+
def _(msg):
- return _t.ugettext(msg)
+ if USE_LAZY:
+ return Message(msg, 'oslo')
+ else:
+ return _t.ugettext(msg)
-def install(domain):
+def install(domain, lazy=False):
"""Install a _() function using the given translation domain.
Given a translation domain, install a _() function using gettext's
@@ -51,41 +70,45 @@ def install(domain):
overriding the default localedir (e.g. /usr/share/locale) using
a translation-domain-specific environment variable (e.g.
NOVA_LOCALEDIR).
- """
- gettext.install(domain,
- localedir=os.environ.get(domain.upper() + '_LOCALEDIR'),
- unicode=True)
-
-
-"""
-Lazy gettext functionality.
-
-The following is an attempt to introduce a deferred way
-to do translations on messages in OpenStack. We attempt to
-override the standard _() function and % (format string) operation
-to build Message objects that can later be translated when we have
-more information. Also included is an example LogHandler that
-translates Messages to an associated locale, effectively allowing
-many logs, each with their own locale.
-"""
-
-
-def get_lazy_gettext(domain):
- """Assemble and return a lazy gettext function for a given domain.
- Factory method for a project/module to get a lazy gettext function
- for its own translation domain (i.e. nova, glance, cinder, etc.)
+ :param domain: the translation domain
+ :param lazy: indicates whether or not to install the lazy _() function.
+ The lazy _() introduces a way to do deferred translation
+ of messages by installing a _ that builds Message objects,
+ instead of strings, which can then be lazily translated into
+ any available locale.
"""
-
- def _lazy_gettext(msg):
- """Create and return a Message object.
-
- Message encapsulates a string so that we can translate it later when
- needed.
- """
- return Message(msg, domain)
-
- return _lazy_gettext
+ if lazy:
+ # NOTE(mrodden): Lazy gettext functionality.
+ #
+ # The following introduces a deferred way to do translations on
+ # messages in OpenStack. We override the standard _() function
+ # and % (format string) operation to build Message objects that can
+ # later be translated when we have more information.
+ #
+ # Also included below is an example LocaleHandler that translates
+ # Messages to an associated locale, effectively allowing many logs,
+ # each with their own locale.
+
+ def _lazy_gettext(msg):
+ """Create and return a Message object.
+
+ Lazy gettext function for a given domain, it is a factory method
+ for a project/module to get a lazy gettext function for its own
+ translation domain (i.e. nova, glance, cinder, etc.)
+
+ Message encapsulates a string so that we can translate
+ it later when needed.
+ """
+ return Message(msg, domain)
+
+ import __builtin__
+ __builtin__.__dict__['_'] = _lazy_gettext
+ else:
+ localedir = '%s_LOCALEDIR' % domain.upper()
+ gettext.install(domain,
+ localedir=os.environ.get(localedir),
+ unicode=True)
class Message(UserString.UserString, object):
@@ -130,7 +153,7 @@ class Message(UserString.UserString, object):
# look for %(blah) fields in string;
# ignore %% and deal with the
# case where % is first character on the line
- keys = re.findall('(?:[^%]|^)%\((\w*)\)[a-z]', full_msg)
+ keys = re.findall('(?:[^%]|^)?%\((\w*)\)[a-z]', full_msg)
# if we don't find any %(blah) blocks but have a %s
if not keys and re.findall('(?:[^%]|^)%[a-z]', full_msg):
@@ -232,6 +255,45 @@ class Message(UserString.UserString, object):
return UserString.UserString.__getattribute__(self, name)
+def get_available_languages(domain):
+ """Lists the available languages for the given translation domain.
+
+ :param domain: the domain to get languages for
+ """
+ if _AVAILABLE_LANGUAGES:
+ return _AVAILABLE_LANGUAGES
+
+ localedir = '%s_LOCALEDIR' % domain.upper()
+ find = lambda x: gettext.find(domain,
+ localedir=os.environ.get(localedir),
+ languages=[x])
+
+ # NOTE(mrodden): en_US should always be available (and first in case
+ # order matters) since our in-line message strings are en_US
+ _AVAILABLE_LANGUAGES.append('en_US')
+ # NOTE(luisg): Babel <1.0 used a function called list(), which was
+ # renamed to locale_identifiers() in >=1.0, the requirements master list
+ # requires >=0.9.6, uncapped, so defensively work with both. We can remove
+ # this check when the master list updates to >=1.0, and all projects udpate
+ list_identifiers = (getattr(localedata, 'list', None) or
+ getattr(localedata, 'locale_identifiers'))
+ locale_identifiers = list_identifiers()
+ for i in locale_identifiers:
+ if find(i) is not None:
+ _AVAILABLE_LANGUAGES.append(i)
+ return _AVAILABLE_LANGUAGES
+
+
+def get_localized_message(message, user_locale):
+ """Gets a localized version of the given message in the given locale."""
+ if (isinstance(message, Message)):
+ if user_locale:
+ message.locale = user_locale
+ return unicode(message)
+ else:
+ return message
+
+
class LocaleHandler(logging.Handler):
"""Handler that can have a locale associated to translate Messages.
diff --git a/openstack/common/local.py b/openstack/common/local.py
index f1bfc82..e82f17d 100644
--- a/openstack/common/local.py
+++ b/openstack/common/local.py
@@ -15,16 +15,15 @@
# License for the specific language governing permissions and limitations
# under the License.
-"""Greenthread local storage of variables using weak references"""
+"""Local storage of variables using weak references"""
+import threading
import weakref
-from eventlet import corolocal
-
-class WeakLocal(corolocal.local):
+class WeakLocal(threading.local):
def __getattribute__(self, attr):
- rval = corolocal.local.__getattribute__(self, attr)
+ rval = super(WeakLocal, self).__getattribute__(attr)
if rval:
# NOTE(mikal): this bit is confusing. What is stored is a weak
# reference, not the value itself. We therefore need to lookup
@@ -34,7 +33,7 @@ class WeakLocal(corolocal.local):
def __setattr__(self, attr, value):
value = weakref.ref(value)
- return corolocal.local.__setattr__(self, attr, value)
+ return super(WeakLocal, self).__setattr__(attr, value)
# NOTE(mikal): the name "store" should be deprecated in the future
@@ -45,4 +44,4 @@ store = WeakLocal()
# "strong" store will hold a reference to the object so that it never falls out
# of scope.
weak_store = WeakLocal()
-strong_store = corolocal.local
+strong_store = threading.local()
diff --git a/openstack/common/log.py b/openstack/common/log.py
index 465886b..c7e729c 100644
--- a/openstack/common/log.py
+++ b/openstack/common/log.py
@@ -29,8 +29,6 @@ It also allows setting of formatting information through conf.
"""
-import ConfigParser
-import cStringIO
import inspect
import itertools
import logging
@@ -41,6 +39,7 @@ import sys
import traceback
from oslo.config import cfg
+from six import moves
from openstack.common.gettextutils import _ # noqa
from openstack.common import importutils
@@ -348,7 +347,7 @@ class LogConfigError(Exception):
def _load_log_config(log_config):
try:
logging.config.fileConfig(log_config)
- except ConfigParser.Error as exc:
+ except moves.configparser.Error as exc:
raise LogConfigError(log_config, str(exc))
@@ -521,7 +520,7 @@ class ContextFormatter(logging.Formatter):
if not record:
return logging.Formatter.formatException(self, exc_info)
- stringbuffer = cStringIO.StringIO()
+ stringbuffer = moves.StringIO()
traceback.print_exception(exc_info[0], exc_info[1], exc_info[2],
None, stringbuffer)
lines = stringbuffer.getvalue().split('\n')
diff --git a/openstack/common/middleware/base.py b/openstack/common/middleware/base.py
index 7236731..2099549 100644
--- a/openstack/common/middleware/base.py
+++ b/openstack/common/middleware/base.py
@@ -28,11 +28,7 @@ class Middleware(object):
@classmethod
def factory(cls, global_conf, **local_conf):
"""Factory method for paste.deploy."""
-
- def filter(app):
- return cls(app)
-
- return filter
+ return cls
def __init__(self, application):
self.application = application
diff --git a/openstack/common/middleware/sizelimit.py b/openstack/common/middleware/sizelimit.py
index ecbdde1..23ba9b6 100644
--- a/openstack/common/middleware/sizelimit.py
+++ b/openstack/common/middleware/sizelimit.py
@@ -71,9 +71,6 @@ class LimitingReader(object):
class RequestBodySizeLimiter(base.Middleware):
"""Limit the size of incoming requests."""
- def __init__(self, *args, **kwargs):
- super(RequestBodySizeLimiter, self).__init__(*args, **kwargs)
-
@webob.dec.wsgify(RequestClass=wsgi.Request)
def __call__(self, req):
if req.content_length > CONF.max_request_body_size:
diff --git a/openstack/common/notifier/log_notifier.py b/openstack/common/notifier/log_notifier.py
index d3ef0ae..96072ed 100644
--- a/openstack/common/notifier/log_notifier.py
+++ b/openstack/common/notifier/log_notifier.py
@@ -25,7 +25,7 @@ CONF = cfg.CONF
def notify(_context, message):
"""Notifies the recipient of the desired event given the model.
- Log notifications using openstack's default logging system.
+ Log notifications using OpenStack's default logging system.
"""
priority = message.get('priority',
diff --git a/openstack/common/notifier/rpc_notifier.py b/openstack/common/notifier/rpc_notifier.py
index 6bfc333..db47a8a 100644
--- a/openstack/common/notifier/rpc_notifier.py
+++ b/openstack/common/notifier/rpc_notifier.py
@@ -24,7 +24,7 @@ LOG = logging.getLogger(__name__)
notification_topic_opt = cfg.ListOpt(
'notification_topics', default=['notifications', ],
- help='AMQP topic used for openstack notifications')
+ help='AMQP topic used for OpenStack notifications')
CONF = cfg.CONF
CONF.register_opt(notification_topic_opt)
diff --git a/openstack/common/notifier/rpc_notifier2.py b/openstack/common/notifier/rpc_notifier2.py
index 55dd780..505a94e 100644
--- a/openstack/common/notifier/rpc_notifier2.py
+++ b/openstack/common/notifier/rpc_notifier2.py
@@ -26,7 +26,7 @@ LOG = logging.getLogger(__name__)
notification_topic_opt = cfg.ListOpt(
'topics', default=['notifications', ],
- help='AMQP topic(s) used for openstack notifications')
+ help='AMQP topic(s) used for OpenStack notifications')
opt_group = cfg.OptGroup(name='rpc_notifier2',
title='Options for rpc_notifier2')
diff --git a/openstack/common/policy.py b/openstack/common/policy.py
index 00531e5..ffb8668 100644
--- a/openstack/common/policy.py
+++ b/openstack/common/policy.py
@@ -115,12 +115,18 @@ class Rules(dict):
def __missing__(self, key):
"""Implements the default rule handling."""
+ if isinstance(self.default_rule, dict):
+ raise KeyError(key)
+
# If the default rule isn't actually defined, do something
# reasonably intelligent
if not self.default_rule or self.default_rule not in self:
raise KeyError(key)
- return self[self.default_rule]
+ if isinstance(self.default_rule, BaseCheck):
+ return self.default_rule
+ elif isinstance(self.default_rule, six.string_types):
+ return self[self.default_rule]
def __str__(self):
"""Dumps a string representation of the rules."""
@@ -153,7 +159,7 @@ class Enforcer(object):
"""
def __init__(self, policy_file=None, rules=None, default_rule=None):
- self.rules = Rules(rules)
+ self.rules = Rules(rules, default_rule)
self.default_rule = default_rule or CONF.policy_default_rule
self.policy_path = None
@@ -172,13 +178,14 @@ class Enforcer(object):
"got %s instead") % type(rules))
if overwrite:
- self.rules = Rules(rules)
+ self.rules = Rules(rules, self.default_rule)
else:
- self.update(rules)
+ self.rules.update(rules)
def clear(self):
"""Clears Enforcer rules, policy's cache and policy's path."""
self.set_rules({})
+ self.default_rule = None
self.policy_path = None
def load_rules(self, force_reload=False):
@@ -194,8 +201,7 @@ class Enforcer(object):
reloaded, data = fileutils.read_cached_file(self.policy_path,
force_reload=force_reload)
-
- if reloaded:
+ if reloaded or not self.rules:
rules = Rules.load_json(data, self.default_rule)
self.set_rules(rules)
LOG.debug(_("Rules successfully reloaded"))
@@ -215,7 +221,7 @@ class Enforcer(object):
if policy_file:
return policy_file
- raise cfg.ConfigFilesNotFoundError(path=CONF.policy_file)
+ raise cfg.ConfigFilesNotFoundError((CONF.policy_file,))
def enforce(self, rule, target, creds, do_raise=False,
exc=None, *args, **kwargs):
@@ -398,7 +404,7 @@ class AndCheck(BaseCheck):
"""
for rule in self.rules:
- if not rule(target, cred):
+ if not rule(target, cred, enforcer):
return False
return True
@@ -441,7 +447,7 @@ class OrCheck(BaseCheck):
"""
for rule in self.rules:
- if rule(target, cred):
+ if rule(target, cred, enforcer):
return True
return False
diff --git a/openstack/common/processutils.py b/openstack/common/processutils.py
index 13f6222..06fe411 100644
--- a/openstack/common/processutils.py
+++ b/openstack/common/processutils.py
@@ -19,6 +19,7 @@
System-level utilities and helper functions.
"""
+import logging as stdlib_logging
import os
import random
import shlex
@@ -102,6 +103,9 @@ def execute(*cmd, **kwargs):
:param shell: whether or not there should be a shell used to
execute this command. Defaults to false.
:type shell: boolean
+ :param loglevel: log level for execute commands.
+ :type loglevel: int. (Should be stdlib_logging.DEBUG or
+ stdlib_logging.INFO)
:returns: (stdout, stderr) from process execution
:raises: :class:`UnknownArgumentError` on
receiving unknown arguments
@@ -116,6 +120,7 @@ def execute(*cmd, **kwargs):
run_as_root = kwargs.pop('run_as_root', False)
root_helper = kwargs.pop('root_helper', '')
shell = kwargs.pop('shell', False)
+ loglevel = kwargs.pop('loglevel', stdlib_logging.DEBUG)
if isinstance(check_exit_code, bool):
ignore_exit_code = not check_exit_code
@@ -139,7 +144,7 @@ def execute(*cmd, **kwargs):
while attempts > 0:
attempts -= 1
try:
- LOG.debug(_('Running cmd (subprocess): %s'), ' '.join(cmd))
+ LOG.log(loglevel, _('Running cmd (subprocess): %s'), ' '.join(cmd))
_PIPE = subprocess.PIPE # pylint: disable=E1101
if os.name == 'nt':
@@ -164,7 +169,7 @@ def execute(*cmd, **kwargs):
obj.stdin.close() # pylint: disable=E1101
_returncode = obj.returncode # pylint: disable=E1101
if _returncode:
- LOG.debug(_('Result was %s') % _returncode)
+ LOG.log(loglevel, _('Result was %s') % _returncode)
if not ignore_exit_code and _returncode not in check_exit_code:
(stdout, stderr) = result
raise ProcessExecutionError(exit_code=_returncode,
@@ -176,7 +181,7 @@ def execute(*cmd, **kwargs):
if not attempts:
raise
else:
- LOG.debug(_('%r failed. Retrying.'), cmd)
+ LOG.log(loglevel, _('%r failed. Retrying.'), cmd)
if delay_on_retry:
greenthread.sleep(random.randint(20, 200) / 100.0)
finally:
diff --git a/openstack/common/py3kcompat/__init__.py b/openstack/common/py3kcompat/__init__.py
new file mode 100644
index 0000000..be894cf
--- /dev/null
+++ b/openstack/common/py3kcompat/__init__.py
@@ -0,0 +1,17 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+#
+# Copyright 2013 Canonical Ltd.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
diff --git a/openstack/common/py3kcompat/urlutils.py b/openstack/common/py3kcompat/urlutils.py
new file mode 100644
index 0000000..04b3418
--- /dev/null
+++ b/openstack/common/py3kcompat/urlutils.py
@@ -0,0 +1,49 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+#
+# Copyright 2013 Canonical Ltd.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+
+"""
+Python2/Python3 compatibility layer for OpenStack
+"""
+
+import six
+
+if six.PY3:
+ # python3
+ import urllib.parse
+
+ urlencode = urllib.parse.urlencode
+ urljoin = urllib.parse.urljoin
+ quote = urllib.parse.quote
+ parse_qsl = urllib.parse.parse_qsl
+ urlparse = urllib.parse.urlparse
+ urlsplit = urllib.parse.urlsplit
+ urlunsplit = urllib.parse.urlunsplit
+else:
+ # python2
+ import urllib
+ import urlparse
+
+ urlencode = urllib.urlencode
+ quote = urllib.quote
+
+ parse = urlparse
+ parse_qsl = parse.parse_qsl
+ urljoin = parse.urljoin
+ urlparse = parse.urlparse
+ urlsplit = parse.urlsplit
+ urlunsplit = parse.urlunsplit
diff --git a/openstack/common/quota.py b/openstack/common/quota.py
new file mode 100644
index 0000000..43f8b01
--- /dev/null
+++ b/openstack/common/quota.py
@@ -0,0 +1,1175 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Common quotas"""
+
+import datetime
+
+from oslo.config import cfg
+
+from openstack.common.gettextutils import _ # noqa
+from openstack.common import importutils
+from openstack.common import log as logging
+from openstack.common import timeutils
+
+LOG = logging.getLogger(__name__)
+
+common_quota_opts = [
+ cfg.BoolOpt('use_default_quota_class',
+ default=True,
+ help='whether to use default quota class for default quota'),
+ cfg.StrOpt('quota_driver',
+ default='openstack.common.quota.DbQuotaDriver',
+ help='default driver to use for quota checks'),
+ cfg.IntOpt('until_refresh',
+ default=0,
+ help='count of reservations until usage is refreshed'),
+ cfg.IntOpt('max_age',
+ default=0,
+ help='number of seconds between subsequent usage refreshes'),
+ cfg.IntOpt('reservation_expire',
+ default=86400,
+ help='number of seconds until a reservation expires'),
+]
+
+CONF = cfg.CONF
+CONF.register_opts(common_quota_opts)
+
+
+class QuotaException(Exception):
+ """Base exception for quota.
+
+ To correctly use this class, inherit from it and define
+ a 'msg_fmt' property. That msg_fmt will get printf'd
+ with the keyword arguments provided to the constructor.
+
+ """
+ msg_fmt = _("Quota exception occurred.")
+ code = 500
+ headers = {}
+ safe = False
+
+ def __init__(self, message=None, **kwargs):
+ self.kwargs = {'code': self.code}
+ self.kwargs.update(kwargs)
+ if not message:
+ try:
+ message = self.msg_fmt % self.kwargs
+ except Exception:
+ # kwargs doesn't match a variable in the message
+ # log the issue and the kwargs
+ LOG.exception(_('Exception in string format operation'))
+ for name, value in kwargs.iteritems():
+ LOG.error("%s: %s" % (name, value))
+ # at least get the core message out if something happened
+ message = self.msg_fmt
+ super(QuotaException, self).__init__(message)
+
+ def format_message(self):
+ return unicode(self)
+
+
+class QuotaError(QuotaException):
+ msg_fmt = _("Quota exceeded") + ": code=%(code)s"
+ code = 413
+ headers = {'Retry-After': 0}
+ safe = True
+
+
+class InvalidQuotaValue(QuotaException):
+ msg_fmt = _("Change would make usage less than 0 for the following "
+ "resources: %(unders)s")
+
+
+class OverQuota(QuotaException):
+ msg_fmt = _("Quota exceeded for resources: %(overs)s")
+
+
+class QuotaExists(QuotaException):
+ message = _("Quota exists")
+
+
+class QuotaResourceUnknown(QuotaException):
+ msg_fmt = _("Unknown quota resources %(unknown)s.")
+
+
+class QuotaNotFound(QuotaException):
+ code = 404
+ message = _("Quota could not be found")
+
+
+class QuotaUsageNotFound(QuotaNotFound):
+ msg_fmt = _("Quota usage for project %(project_id)s could not be found.")
+
+
+class ProjectQuotaNotFound(QuotaNotFound):
+ msg_fmt = _("Quota for project %(project_id)s could not be found.")
+
+
+class ProjectUserQuotaNotFound(QuotaNotFound):
+ msg_fmt = _("Quota for user %(user_id)s in project %(project_id)s "
+ "could not be found.")
+
+
+class QuotaClassNotFound(QuotaNotFound):
+ msg_fmt = _("Quota class %(class_name)s could not be found.")
+
+
+class ReservationNotFound(QuotaNotFound):
+ msg_fmt = _("Quota reservation %(uuid)s could not be found.")
+
+
+class InvalidReservationExpiration(QuotaException):
+ code = 400
+ msg_fmt = _("Invalid reservation expiration %(expire)s.")
+
+
+class DbQuotaDriver(object):
+ """Database quota driver.
+
+ Driver to perform necessary checks to enforce quotas and obtain
+ quota information. The default driver utilizes the local
+ database.
+ """
+
+ def __init__(self, db):
+ self.db = db
+
+ def get_by_project_and_user(self, context, project_id, user_id, resource):
+ """Get a specific quota by project and user."""
+
+ return self.db.quota_get(context, project_id, user_id, resource)
+
+ def get_by_project(self, context, project_id, resource_name):
+ """Get a specific quota by project."""
+
+ return self.db.quota_get(context, project_id, resource_name)
+
+ def get_by_class(self, context, quota_class, resource_name):
+ """Get a specific quota by quota class."""
+
+ return self.db.quota_class_get(context, quota_class, resource_name)
+
+ def get_default(self, context, resource):
+ """Get a specific default quota for a resource."""
+
+ default_quotas = self.db.quota_class_get_default(context)
+ return default_quotas.get(resource.name, resource.default)
+
+ def get_defaults(self, context, resources):
+ """Given a list of resources, retrieve the default quotas.
+
+ Use the class quotas named `_DEFAULT_QUOTA_NAME` as default quotas,
+ if it exists.
+
+ :param context: The request context, for access checks.
+ :param resources: A dictionary of the registered resources.
+ """
+
+ quotas = {}
+ default_quotas = {}
+ if CONF.use_default_quota_class:
+ default_quotas = self.db.quota_class_get_default(context)
+ for resource in resources.values():
+ if resource.name not in default_quotas:
+ LOG.deprecated(_("Default quota for resource: %(res)s is set "
+ "by the default quota flag: quota_%(res)s, "
+ "it is now deprecated. Please use the "
+ "the default quota class for default "
+ "quota.") % {'res': resource.name})
+ quotas[resource.name] = default_quotas.get(resource.name,
+ resource.default)
+
+ return quotas
+
+ def get_class_quotas(self, context, resources, quota_class,
+ defaults=True):
+ """Given a list of resources, get quotas for the given quota class.
+
+ :param context: The request context, for access checks.
+ :param resources: A dictionary of the registered resources.
+ :param quota_class: The name of the quota class to return
+ quotas for.
+ :param defaults: If True, the default value will be reported
+ if there is no specific value for the
+ resource.
+ """
+
+ quotas = {}
+ default_quotas = {}
+ class_quotas = self.db.quota_class_get_all_by_name(context,
+ quota_class)
+ if defaults:
+ default_quotas = self.db.quota_class_get_default(context)
+ for resource in resources.values():
+ if resource.name in class_quotas:
+ quotas[resource.name] = class_quotas[resource.name]
+ continue
+
+ if defaults:
+ quotas[resource.name] = default_quotas.get(resource.name,
+ resource.default)
+
+ return quotas
+
+ def _process_quotas(self, context, resources, project_id, quotas,
+ quota_class=None, defaults=True, usages=None,
+ remains=False):
+ """Get the quotas for the appropriate class.
+
+ If the project ID matches the one in the context, we use the
+ quota_class from the context, otherwise, we use the provided
+ quota_class (if any)
+ """
+
+ modified_quotas = {}
+ if project_id == context.project_id:
+ quota_class = context.quota_class
+ if quota_class:
+ class_quotas = self.db.quota_class_get_all_by_name(context,
+ quota_class)
+ else:
+ class_quotas = {}
+
+ default_quotas = self.get_defaults(context, resources)
+
+ for resource in resources.values():
+ # Omit default/quota class values
+ if not defaults and resource.name not in quotas:
+ continue
+ class_quota = class_quotas.get(resource.name,
+ default_quotas[resource.name])
+ limit = quotas.get(resource.name, class_quota)
+ modified_quotas[resource.name] = dict(limit=limit)
+
+ # Include usages if desired. This is optional because one
+ # internal consumer of this interface wants to access the
+ # usages directly from inside a transaction.
+ if usages:
+ usage = usages.get(resource.name, {})
+ modified_quotas[resource.name].update(
+ in_use=usage.get('in_use', 0),
+ reserved=usage.get('reserved', 0),
+ )
+ # Initialize remains quotas.
+ if remains:
+ modified_quotas[resource.name].update(remains=limit)
+
+ if remains:
+ all_quotas = self.db.quota_get_all(context, project_id)
+ for quota in all_quotas:
+ if quota['resource'] in modified_quotas:
+ modified_quotas[quota['resource']]['remains'] -= \
+ quota['hard_limit']
+
+ return modified_quotas
+
+ def get_user_quotas(self, context, resources, project_id, user_id,
+ quota_class=None, defaults=True,
+ usages=True):
+ """Get user quotas for given user and project.
+
+ Given a list of resources, retrieve the quotas for the given
+ user and project.
+
+ :param context: The request context, for access checks.
+ :param resources: A dictionary of the registered resources.
+ :param project_id: The ID of the project to return quotas for.
+ :param user_id: The ID of the user to return quotas for.
+ :param quota_class: If project_id != context.project_id, the
+ quota class cannot be determined. This
+ parameter allows it to be specified. It
+ will be ignored if project_id ==
+ context.project_id.
+ :param defaults: If True, the quota class value (or the
+ default value, if there is no value from the
+ quota class) will be reported if there is no
+ specific value for the resource.
+ :param usages: If True, the current in_use and reserved counts
+ will also be returned.
+ """
+ user_quotas = self.db.quota_get_all_by_project_and_user(
+ context, project_id, user_id)
+ user_usages = None
+ if usages:
+ user_usages = self.db.quota_usage_get_all_by_project_and_user(
+ context, project_id, user_id)
+ return self._process_quotas(context, resources, project_id,
+ user_quotas, quota_class,
+ defaults=defaults, usages=user_usages)
+
+ def get_project_quotas(self, context, resources, project_id,
+ quota_class=None, defaults=True,
+ usages=True, remains=False):
+ """Given a list of resources, get the quotas for the given project.
+
+ :param context: The request context, for access checks.
+ :param resources: A dictionary of the registered resources.
+ :param project_id: The ID of the project to return quotas for.
+ :param quota_class: If project_id != context.project_id, the
+ quota class cannot be determined. This
+ parameter allows it to be specified. It
+ will be ignored if project_id ==
+ context.project_id.
+ :param defaults: If True, the quota class value (or the
+ default value, if there is no value from the
+ quota class) will be reported if there is no
+ specific value for the resource.
+ :param usages: If True, the current in_use and reserved counts
+ will also be returned.
+ :param remains: If True, the current remains of the project will
+ will be returned.
+ """
+ project_quotas = self.db.quota_get_all_by_project(
+ context, project_id)
+ project_usages = None
+ if usages:
+ project_usages = self.db.quota_usage_get_all_by_project(
+ context, project_id)
+ return self._process_quotas(context, resources, project_id,
+ project_quotas, quota_class,
+ defaults=defaults, usages=project_usages,
+ remains=remains)
+
+ def get_settable_quotas(self, context, resources, project_id,
+ user_id=None):
+ """Get settable quotas for given user and project.
+
+ Given a list of resources, retrieve the range of settable quotas for
+ the given user or project.
+
+ :param context: The request context, for access checks.
+ :param resources: A dictionary of the registered resources.
+ :param project_id: The ID of the project to return quotas for.
+ :param user_id: The ID of the user to return quotas for.
+ """
+ settable_quotas = {}
+ project_quotas = self.get_project_quotas(context, resources,
+ project_id, remains=True)
+ if user_id:
+ user_quotas = self.get_user_quotas(context, resources,
+ project_id, user_id)
+ setted_quotas = self.db.quota_get_all_by_project_and_user(
+ context, project_id, user_id)
+ for key, value in user_quotas.items():
+ maximum = project_quotas[key]['remains'] +\
+ setted_quotas.get(key, 0)
+ settable_quotas[key] = dict(
+ minimum=value['in_use'] + value['reserved'],
+ maximum=maximum,
+ )
+ else:
+ for key, value in project_quotas.items():
+ minimum = max(int(value['limit'] - value['remains']),
+ int(value['in_use'] + value['reserved']))
+ settable_quotas[key] = dict(minimum=minimum, maximum=-1)
+ return settable_quotas
+
+ def _get_quotas(self, context, resources, keys, has_sync, project_id=None,
+ user_id=None):
+ """Get guotas for resources identified by keys.
+
+ A helper method which retrieves the quotas for the specific
+ resources identified by keys, and which apply to the current
+ context.
+
+ :param context: The request context, for access checks.
+ :param resources: A dictionary of the registered resources.
+ :param keys: A list of the desired quotas to retrieve.
+ :param has_sync: If True, indicates that the resource must
+ have a sync attribute; if False, indicates
+ that the resource must NOT have a sync
+ attribute.
+ :param project_id: Specify the project_id if current context
+ is admin and admin wants to impact on
+ common user's tenant.
+ :param user_id: Specify the user_id if current context
+ is admin and admin wants to impact on
+ common user.
+ """
+
+ # Filter resources
+ if has_sync:
+ sync_filt = lambda x: hasattr(x, 'sync')
+ else:
+ sync_filt = lambda x: not hasattr(x, 'sync')
+ desired = set(keys)
+ sub_resources = dict((k, v) for k, v in resources.items()
+ if k in desired and sync_filt(v))
+
+ # Make sure we accounted for all of them...
+ if len(keys) != len(sub_resources):
+ unknown = desired - set(sub_resources.keys())
+ raise QuotaResourceUnknown(unknown=sorted(unknown))
+
+ if user_id:
+ # Grab and return the quotas (without usages)
+ quotas = self.get_user_quotas(context, sub_resources,
+ project_id, user_id,
+ context.quota_class, usages=False)
+ else:
+ # Grab and return the quotas (without usages)
+ quotas = self.get_project_quotas(context, sub_resources,
+ project_id,
+ context.quota_class,
+ usages=False)
+
+ return dict((k, v['limit']) for k, v in quotas.items())
+
+ def limit_check(self, context, resources, values, project_id=None,
+ user_id=None):
+ """Check simple quota limits.
+
+ For limits--those quotas for which there is no usage
+ synchronization function--this method checks that a set of
+ proposed values are permitted by the limit restriction.
+
+ This method will raise a QuotaResourceUnknown exception if a
+ given resource is unknown or if it is not a simple limit
+ resource.
+
+ If any of the proposed values is over the defined quota, an
+ OverQuota exception will be raised with the sorted list of the
+ resources which are too high. Otherwise, the method returns
+ nothing.
+
+ :param context: The request context, for access checks.
+ :param resources: A dictionary of the registered resources.
+ :param values: A dictionary of the values to check against the
+ quota.
+ :param project_id: Specify the project_id if current context
+ is admin and admin wants to impact on
+ common user's tenant.
+ :param user_id: Specify the user_id if current context
+ is admin and admin wants to impact on
+ common user.
+ """
+
+ # Ensure no value is less than zero
+ unders = [key for key, val in values.items() if val < 0]
+ if unders:
+ raise InvalidQuotaValue(unders=sorted(unders))
+
+ # If project_id is None, then we use the project_id in context
+ if project_id is None:
+ project_id = context.project_id
+ # If user id is None, then we use the user_id in context
+ if user_id is None:
+ user_id = context.user_id
+
+ # Get the applicable quotas
+ quotas = self._get_quotas(context, resources, values.keys(),
+ has_sync=False, project_id=project_id)
+ user_quotas = self._get_quotas(context, resources, values.keys(),
+ has_sync=False, project_id=project_id,
+ user_id=user_id)
+
+ # Check the quotas and construct a list of the resources that
+ # would be put over limit by the desired values
+ overs = [key for key, val in values.items()
+ if (quotas[key] >= 0 and quotas[key] < val) or
+ (user_quotas[key] >= 0 and user_quotas[key] < val)]
+ if overs:
+ raise OverQuota(overs=sorted(overs), quotas=quotas,
+ usages={})
+
+ def reserve(self, context, resources, deltas, expire=None,
+ project_id=None, user_id=None):
+ """Check quotas and reserve resources.
+
+ For counting quotas--those quotas for which there is a usage
+ synchronization function--this method checks quotas against
+ current usage and the desired deltas.
+
+ This method will raise a QuotaResourceUnknown exception if a
+ given resource is unknown or if it does not have a usage
+ synchronization function.
+
+ If any of the proposed values is over the defined quota, an
+ OverQuota exception will be raised with the sorted list of the
+ resources which are too high. Otherwise, the method returns a
+ list of reservation UUIDs which were created.
+
+ :param context: The request context, for access checks.
+ :param resources: A dictionary of the registered resources.
+ :param deltas: A dictionary of the proposed delta changes.
+ :param expire: An optional parameter specifying an expiration
+ time for the reservations. If it is a simple
+ number, it is interpreted as a number of
+ seconds and added to the current time; if it is
+ a datetime.timedelta object, it will also be
+ added to the current time. A datetime.datetime
+ object will be interpreted as the absolute
+ expiration time. If None is specified, the
+ default expiration time set by
+ --default-reservation-expire will be used (this
+ value will be treated as a number of seconds).
+ :param project_id: Specify the project_id if current context
+ is admin and admin wants to impact on
+ common user's tenant.
+ :param user_id: Specify the user_id if current context
+ is admin and admin wants to impact on
+ common user.
+ """
+
+ # Set up the reservation expiration
+ if expire is None:
+ expire = CONF.reservation_expire
+ if isinstance(expire, (int, long)):
+ expire = datetime.timedelta(seconds=expire)
+ if isinstance(expire, datetime.timedelta):
+ expire = timeutils.utcnow() + expire
+ if not isinstance(expire, datetime.datetime):
+ raise InvalidReservationExpiration(expire=expire)
+
+ # If project_id is None, then we use the project_id in context
+ if project_id is None:
+ project_id = context.project_id
+ # If user_id is None, then we use the project_id in context
+ if user_id is None:
+ user_id = context.user_id
+
+ # Get the applicable quotas.
+ # NOTE(Vek): We're not worried about races at this point.
+ # Yes, the admin may be in the process of reducing
+ # quotas, but that's a pretty rare thing.
+ quotas = self._get_quotas(context, resources, deltas.keys(),
+ has_sync=True, project_id=project_id)
+ user_quotas = self._get_quotas(context, resources, deltas.keys(),
+ has_sync=True, project_id=project_id,
+ user_id=user_id)
+
+ # NOTE(Vek): Most of the work here has to be done in the DB
+ # API, because we have to do it in a transaction,
+ # which means access to the session. Since the
+ # session isn't available outside the DBAPI, we
+ # have to do the work there.
+ return self.db.quota_reserve(context, resources, quotas, user_quotas,
+ deltas, expire,
+ CONF.until_refresh, CONF.max_age,
+ project_id=project_id, user_id=user_id)
+
+ def commit(self, context, reservations, project_id=None, user_id=None):
+ """Commit reservations.
+
+ :param context: The request context, for access checks.
+ :param reservations: A list of the reservation UUIDs, as
+ returned by the reserve() method.
+ :param project_id: Specify the project_id if current context
+ is admin and admin wants to impact on
+ common user's tenant.
+ :param user_id: Specify the user_id if current context
+ is admin and admin wants to impact on
+ common user.
+ """
+ # If project_id is None, then we use the project_id in context
+ if project_id is None:
+ project_id = context.project_id
+ # If user_id is None, then we use the user_id in context
+ if user_id is None:
+ user_id = context.user_id
+
+ self.db.reservation_commit(context, reservations,
+ project_id=project_id, user_id=user_id)
+
+ def rollback(self, context, reservations, project_id=None, user_id=None):
+ """Roll back reservations.
+
+ :param context: The request context, for access checks.
+ :param reservations: A list of the reservation UUIDs, as
+ returned by the reserve() method.
+ :param project_id: Specify the project_id if current context
+ is admin and admin wants to impact on
+ common user's tenant.
+ :param user_id: Specify the user_id if current context
+ is admin and admin wants to impact on
+ common user.
+ """
+ # If project_id is None, then we use the project_id in context
+ if project_id is None:
+ project_id = context.project_id
+ # If user_id is None, then we use the user_id in context
+ if user_id is None:
+ user_id = context.user_id
+
+ self.db.reservation_rollback(context, reservations,
+ project_id=project_id, user_id=user_id)
+
+ def usage_reset(self, context, resources):
+ """Reset the usage records.
+
+ Reset usages for a particular user on a list of resources.
+ This will force that user's usage records to be refreshed
+ the next time a reservation is made.
+
+ Note: this does not affect the currently outstanding
+ reservations the user has; those reservations must be
+ committed or rolled back (or expired).
+
+ :param context: The request context, for access checks.
+ :param resources: A list of the resource names for which the
+ usage must be reset.
+ """
+
+ # We need an elevated context for the calls to
+ # quota_usage_update()
+ elevated = context.elevated()
+
+ for resource in resources:
+ try:
+ # Reset the usage to -1, which will force it to be
+ # refreshed
+ self.db.quota_usage_update(elevated, context.project_id,
+ context.user_id,
+ resource, in_use=-1)
+ except QuotaUsageNotFound:
+ # That means it'll be refreshed anyway
+ pass
+
+ def destroy_all_by_project_and_user(self, context, project_id, user_id):
+ """Destroy objects by project and user.
+
+ Destroy all quotas, usages, and reservations associated with a
+ project and user.
+
+ :param context: The request context, for access checks.
+ :param project_id: The ID of the project being deleted.
+ :param user_id: The ID of the user being deleted.
+ """
+
+ self.db.quota_destroy_all_by_project_and_user(context, project_id,
+ user_id)
+
+ def destroy_all_by_project(self, context, project_id):
+ """Destroy quotas, usages, and reservations for given project.
+
+ :param context: The request context, for access checks.
+ :param project_id: The ID of the project being deleted.
+ """
+
+ self.db.quota_destroy_all_by_project(context, project_id)
+
+ def expire(self, context):
+ """Expire reservations.
+
+ Explores all currently existing reservations and rolls back
+ any that have expired.
+
+ :param context: The request context, for access checks.
+ """
+
+ self.db.reservation_expire(context)
+
+
+class BaseResource(object):
+ """Describe a single resource for quota checking."""
+
+ def __init__(self, name, flag=None):
+ """Initializes a Resource.
+
+ :param name: The name of the resource, i.e., "volumes".
+ :param flag: The name of the flag or configuration option
+ which specifies the default value of the quota
+ for this resource.
+ """
+
+ self.name = name
+ self.flag = flag
+
+ def quota(self, driver, context, **kwargs):
+ """Given a driver and context, obtain the quota for this resource.
+
+ :param driver: A quota driver.
+ :param context: The request context.
+ :param project_id: The project to obtain the quota value for.
+ If not provided, it is taken from the
+ context. If it is given as None, no
+ project-specific quota will be searched
+ for.
+ :param quota_class: The quota class corresponding to the
+ project, or for which the quota is to be
+ looked up. If not provided, it is taken
+ from the context. If it is given as None,
+ no quota class-specific quota will be
+ searched for. Note that the quota class
+ defaults to the value in the context,
+ which may not correspond to the project if
+ project_id is not the same as the one in
+ the context.
+ """
+
+ # Get the project ID
+ project_id = kwargs.get('project_id', context.project_id)
+
+ # Ditto for the quota class
+ quota_class = kwargs.get('quota_class', context.quota_class)
+
+ # Look up the quota for the project
+ if project_id:
+ try:
+ return driver.get_by_project(context, project_id, self.name)
+ except ProjectQuotaNotFound:
+ pass
+
+ # Try for the quota class
+ if quota_class:
+ try:
+ return driver.get_by_class(context, quota_class, self.name)
+ except QuotaClassNotFound:
+ pass
+
+ # OK, return the default
+ return driver.get_default(context, self)
+
+ @property
+ def default(self):
+ """Return the default value of the quota."""
+
+ return CONF[self.flag] if self.flag else -1
+
+
+class ReservableResource(BaseResource):
+ """Describe a reservable resource."""
+
+ def __init__(self, name, sync, flag=None):
+ """Initializes a ReservableResource.
+
+ Reservable resources are those resources which directly
+ correspond to objects in the database, i.e., instances,
+ cores, etc.
+
+ Usage synchronization function must be associated with each
+ object. This function will be called to determine the current
+ counts of one or more resources. This association is done in
+ database backend. See QUOTA_SYNC_FUNCTIONS in db/sqlalchemy/api.py.
+
+ The usage synchronization function will be passed three
+ arguments: an admin context, the project ID, and an opaque
+ session object, which should in turn be passed to the
+ underlying database function. Synchronization functions
+ should return a dictionary mapping resource names to the
+ current in_use count for those resources; more than one
+ resource and resource count may be returned. Note that
+ synchronization functions may be associated with more than one
+ ReservableResource.
+
+ :param name: The name of the resource, i.e., "volumes".
+ :param sync: A dbapi methods name which returns a dictionary
+ to resynchronize the in_use count for one or more
+ resources, as described above.
+ :param flag: The name of the flag or configuration option
+ which specifies the default value of the quota
+ for this resource.
+ """
+
+ super(ReservableResource, self).__init__(name, flag=flag)
+ self.sync = sync
+
+
+class AbsoluteResource(BaseResource):
+ """Describe a non-reservable resource."""
+
+ pass
+
+
+class CountableResource(AbsoluteResource):
+ """Countable resource.
+
+ Describe a resource where the counts aren't based solely on the
+ project ID.
+ """
+
+ def __init__(self, name, count, flag=None):
+ """Initializes a CountableResource.
+
+ Countable resources are those resources which directly
+ correspond to objects in the database, i.e., volumes, gigabytes,
+ etc., but for which a count by project ID is inappropriate. A
+ CountableResource must be constructed with a counting
+ function, which will be called to determine the current counts
+ of the resource.
+
+ The counting function will be passed the context, along with
+ the extra positional and keyword arguments that are passed to
+ Quota.count(). It should return an integer specifying the
+ count.
+
+ Note that this counting is not performed in a transaction-safe
+ manner. This resource class is a temporary measure to provide
+ required functionality, until a better approach to solving
+ this problem can be evolved.
+
+ :param name: The name of the resource, i.e., "volumes".
+ :param count: A callable which returns the count of the
+ resource. The arguments passed are as described
+ above.
+ :param flag: The name of the flag or configuration option
+ which specifies the default value of the quota
+ for this resource.
+ """
+
+ super(CountableResource, self).__init__(name, flag=flag)
+ self.count = count
+
+
+class QuotaEngine(object):
+ """Represent the set of recognized quotas."""
+
+ def __init__(self, db, quota_driver_class=None):
+ """Initialize a Quota object."""
+ self.db = db
+ self._resources = {}
+ self._driver_cls = quota_driver_class
+ self.__driver = None
+
+ @property
+ def _driver(self):
+ if self.__driver:
+ return self.__driver
+ if not self._driver_cls:
+ self._driver_cls = CONF.quota_driver
+ if isinstance(self._driver_cls, basestring):
+ self._driver_cls = importutils.import_object(self._driver_cls,
+ self.db)
+ self.__driver = self._driver_cls
+ return self.__driver
+
+ def __contains__(self, resource):
+ return resource in self._resources
+
+ def register_resource(self, resource):
+ """Register a resource."""
+
+ self._resources[resource.name] = resource
+
+ def register_resources(self, resources):
+ """Register a list of resources."""
+
+ for resource in resources:
+ self.register_resource(resource)
+
+ def get_by_project_and_user(self, context, project_id, user_id, resource):
+ """Get a specific quota by project and user."""
+
+ return self._driver.get_by_project_and_user(context, project_id,
+ user_id, resource)
+
+ def get_by_project(self, context, project_id, resource_name):
+ """Get a specific quota by project."""
+
+ return self._driver.get_by_project(context, project_id, resource_name)
+
+ def get_by_class(self, context, quota_class, resource_name):
+ """Get a specific quota by quota class."""
+
+ return self._driver.get_by_class(context, quota_class, resource_name)
+
+ def get_default(self, context, resource):
+ """Get a specific default quota for a resource."""
+
+ return self._driver.get_default(context, resource)
+
+ def get_defaults(self, context):
+ """Retrieve the default quotas.
+
+ :param context: The request context, for access checks.
+ """
+
+ return self._driver.get_defaults(context, self.resources)
+
+ def get_class_quotas(self, context, quota_class, defaults=True):
+ """Retrieve the quotas for the given quota class.
+
+ :param context: The request context, for access checks.
+ :param quota_class: The name of the quota class to return
+ quotas for.
+ :param defaults: If True, the default value will be reported
+ if there is no specific value for the
+ resource.
+ """
+
+ return self._driver.get_class_quotas(context, self.resources,
+ quota_class, defaults=defaults)
+
+ def get_user_quotas(self, context, project_id, user_id, quota_class=None,
+ defaults=True, usages=True):
+ """Retrieve the quotas for the given user and project.
+
+ :param context: The request context, for access checks.
+ :param project_id: The ID of the project to return quotas for.
+ :param user_id: The ID of the user to return quotas for.
+ :param quota_class: If project_id != context.project_id, the
+ quota class cannot be determined. This
+ parameter allows it to be specified.
+ :param defaults: If True, the quota class value (or the
+ default value, if there is no value from the
+ quota class) will be reported if there is no
+ specific value for the resource.
+ :param usages: If True, the current in_use and reserved counts
+ will also be returned.
+ """
+
+ return self._driver.get_user_quotas(context, self._resources,
+ project_id, user_id,
+ quota_class=quota_class,
+ defaults=defaults,
+ usages=usages)
+
+ def get_project_quotas(self, context, project_id, quota_class=None,
+ defaults=True, usages=True, remains=False):
+ """Retrieve the quotas for the given project.
+
+ :param context: The request context, for access checks.
+ :param project_id: The ID of the project to return quotas for.
+ :param quota_class: If project_id != context.project_id, the
+ quota class cannot be determined. This
+ parameter allows it to be specified.
+ :param defaults: If True, the quota class value (or the
+ default value, if there is no value from the
+ quota class) will be reported if there is no
+ specific value for the resource.
+ :param usages: If True, the current in_use and reserved counts
+ will also be returned.
+ :param remains: If True, the current remains of the project will
+ will be returned.
+ """
+
+ return self._driver.get_project_quotas(context, self._resources,
+ project_id,
+ quota_class=quota_class,
+ defaults=defaults,
+ usages=usages,
+ remains=remains)
+
+ def get_settable_quotas(self, context, project_id, user_id=None):
+ """Get settable quotas for given user and project.
+
+ Given a list of resources, retrieve the range of settable quotas for
+ the given user or project.
+
+ :param context: The request context, for access checks.
+ :param resources: A dictionary of the registered resources.
+ :param project_id: The ID of the project to return quotas for.
+ :param user_id: The ID of the user to return quotas for.
+ """
+
+ return self._driver.get_settable_quotas(context, self._resources,
+ project_id,
+ user_id=user_id)
+
+ def count(self, context, resource, *args, **kwargs):
+ """Count a resource.
+
+ For countable resources, invokes the count() function and
+ returns its result. Arguments following the context and
+ resource are passed directly to the count function declared by
+ the resource.
+
+ :param context: The request context, for access checks.
+ :param resource: The name of the resource, as a string.
+ """
+
+ # Get the resource
+ res = self.resources.get(resource)
+ if not res or not hasattr(res, 'count'):
+ raise QuotaResourceUnknown(unknown=[resource])
+
+ return res.count(context, *args, **kwargs)
+
+ def limit_check(self, context, project_id=None, user_id=None, **values):
+ """Check simple quota limits.
+
+ For limits--those quotas for which there is no usage
+ synchronization function--this method checks that a set of
+ proposed values are permitted by the limit restriction. The
+ values to check are given as keyword arguments, where the key
+ identifies the specific quota limit to check, and the value is
+ the proposed value.
+
+ This method will raise a QuotaResourceUnknown exception if a
+ given resource is unknown or if it is not a simple limit
+ resource.
+
+ If any of the proposed values is over the defined quota, an
+ OverQuota exception will be raised with the sorted list of the
+ resources which are too high. Otherwise, the method returns
+ nothing.
+
+ :param context: The request context, for access checks.
+ :param project_id: Specify the project_id if current context
+ is admin and admin wants to impact on
+ common user's tenant.
+ :param user_id: Specify the user_id if current context
+ is admin and admin wants to impact on
+ common user.
+ """
+
+ return self._driver.limit_check(context, self._resources, values,
+ project_id=project_id, user_id=user_id)
+
+ def reserve(self, context, expire=None, project_id=None, user_id=None,
+ **deltas):
+ """Check quotas and reserve resources.
+
+ For counting quotas--those quotas for which there is a usage
+ synchronization function--this method checks quotas against
+ current usage and the desired deltas. The deltas are given as
+ keyword arguments, and current usage and other reservations
+ are factored into the quota check.
+
+ This method will raise a QuotaResourceUnknown exception if a
+ given resource is unknown or if it does not have a usage
+ synchronization function.
+
+ If any of the proposed values is over the defined quota, an
+ OverQuota exception will be raised with the sorted list of the
+ resources which are too high. Otherwise, the method returns a
+ list of reservation UUIDs which were created.
+
+ :param context: The request context, for access checks.
+ :param expire: An optional parameter specifying an expiration
+ time for the reservations. If it is a simple
+ number, it is interpreted as a number of
+ seconds and added to the current time; if it is
+ a datetime.timedelta object, it will also be
+ added to the current time. A datetime.datetime
+ object will be interpreted as the absolute
+ expiration time. If None is specified, the
+ default expiration time set by
+ --default-reservation-expire will be used (this
+ value will be treated as a number of seconds).
+ :param project_id: Specify the project_id if current context
+ is admin and admin wants to impact on
+ common user's tenant.
+ """
+
+ reservations = self._driver.reserve(context, self._resources, deltas,
+ expire=expire,
+ project_id=project_id,
+ user_id=user_id)
+
+ LOG.debug(_("Created reservations %s"), reservations)
+
+ return reservations
+
+ def commit(self, context, reservations, project_id=None, user_id=None):
+ """Commit reservations.
+
+ :param context: The request context, for access checks.
+ :param reservations: A list of the reservation UUIDs, as
+ returned by the reserve() method.
+ :param project_id: Specify the project_id if current context
+ is admin and admin wants to impact on
+ common user's tenant.
+ """
+
+ try:
+ self._driver.commit(context, reservations, project_id=project_id,
+ user_id=user_id)
+ except Exception:
+ # NOTE(Vek): Ignoring exceptions here is safe, because the
+ # usage resynchronization and the reservation expiration
+ # mechanisms will resolve the issue. The exception is
+ # logged, however, because this is less than optimal.
+ LOG.exception(_("Failed to commit reservations %s"), reservations)
+ return
+ LOG.debug(_("Committed reservations %s"), reservations)
+
+ def rollback(self, context, reservations, project_id=None, user_id=None):
+ """Roll back reservations.
+
+ :param context: The request context, for access checks.
+ :param reservations: A list of the reservation UUIDs, as
+ returned by the reserve() method.
+ :param project_id: Specify the project_id if current context
+ is admin and admin wants to impact on
+ common user's tenant.
+ """
+
+ try:
+ self._driver.rollback(context, reservations, project_id=project_id,
+ user_id=user_id)
+ except Exception:
+ # NOTE(Vek): Ignoring exceptions here is safe, because the
+ # usage resynchronization and the reservation expiration
+ # mechanisms will resolve the issue. The exception is
+ # logged, however, because this is less than optimal.
+ LOG.exception(_("Failed to roll back reservations %s"),
+ reservations)
+ return
+ LOG.debug(_("Rolled back reservations %s"), reservations)
+
+ def usage_reset(self, context, resources):
+ """Reset the usage records.
+
+ Reset usages for a particular user on a list of resources.
+ This will force that user's usage records to be refreshed
+ the next time a reservation is made.
+
+ Note: this does not affect the currently outstanding
+ reservations the user has; those reservations must be
+ committed or rolled back (or expired).
+
+ :param context: The request context, for access checks.
+ :param resources: A list of the resource names for which the
+ usage must be reset.
+ """
+
+ self._driver.usage_reset(context, resources)
+
+ def destroy_all_by_project_and_user(self, context, project_id, user_id):
+ """Destroy all objects for given user and project.
+
+ Destroy all quotas, usages, and reservations associated with a
+ project and user.
+
+ :param context: The request context, for access checks.
+ :param project_id: The ID of the project being deleted.
+ :param user_id: The ID of the user being deleted.
+ """
+
+ self._driver.destroy_all_by_project_and_user(context,
+ project_id, user_id)
+
+ def destroy_all_by_project(self, context, project_id):
+ """Destroy all quotas, usages, and reservations for given project.
+
+ :param context: The request context, for access checks.
+ :param project_id: The ID of the project being deleted.
+ """
+
+ self._driver.destroy_all_by_project(context, project_id)
+
+ def expire(self, context):
+ """Expire reservations.
+
+ Explores all currently existing reservations and rolls back
+ any that have expired.
+
+ :param context: The request context, for access checks.
+ """
+
+ self._driver.expire(context)
+
+ @property
+ def resource_names(self):
+ return sorted(self.resources.keys())
+
+ @property
+ def resources(self):
+ return self._resources
diff --git a/openstack/common/rootwrap/cmd.py b/openstack/common/rootwrap/cmd.py
index 500f6c9..473dafb 100755..100644
--- a/openstack/common/rootwrap/cmd.py
+++ b/openstack/common/rootwrap/cmd.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (c) 2011 OpenStack Foundation.
diff --git a/openstack/common/rpc/__init__.py b/openstack/common/rpc/__init__.py
index e39f294..104b059 100644
--- a/openstack/common/rpc/__init__.py
+++ b/openstack/common/rpc/__init__.py
@@ -56,8 +56,7 @@ rpc_opts = [
help='Seconds to wait before a cast expires (TTL). '
'Only supported by impl_zmq.'),
cfg.ListOpt('allowed_rpc_exception_modules',
- default=['openstack.common.exception',
- 'nova.exception',
+ default=['nova.exception',
'cinder.exception',
'exceptions',
],
diff --git a/openstack/common/rpc/amqp.py b/openstack/common/rpc/amqp.py
index 1afd2ab..38f2515 100644
--- a/openstack/common/rpc/amqp.py
+++ b/openstack/common/rpc/amqp.py
@@ -300,8 +300,13 @@ def pack_context(msg, context):
for args at some point.
"""
- context_d = dict([('_context_%s' % key, value)
- for (key, value) in context.to_dict().iteritems()])
+ if isinstance(context, dict):
+ context_d = dict([('_context_%s' % key, value)
+ for (key, value) in context.iteritems()])
+ else:
+ context_d = dict([('_context_%s' % key, value)
+ for (key, value) in context.to_dict().iteritems()])
+
msg.update(context_d)
diff --git a/openstack/common/rpc/impl_kombu.py b/openstack/common/rpc/impl_kombu.py
index 6b1ae93..61ab415 100644
--- a/openstack/common/rpc/impl_kombu.py
+++ b/openstack/common/rpc/impl_kombu.py
@@ -490,12 +490,8 @@ class Connection(object):
# future with this?
ssl_params['cert_reqs'] = ssl.CERT_REQUIRED
- if not ssl_params:
- # Just have the default behavior
- return True
- else:
- # Return the extended behavior
- return ssl_params
+ # Return the extended behavior or just have the default behavior
+ return ssl_params or True
def _connect(self, params):
"""Connect to rabbit. Re-establish any queues that may have
diff --git a/openstack/common/rpc/impl_qpid.py b/openstack/common/rpc/impl_qpid.py
index 99c4619..e54beb4 100644
--- a/openstack/common/rpc/impl_qpid.py
+++ b/openstack/common/rpc/impl_qpid.py
@@ -320,7 +320,7 @@ class DirectPublisher(Publisher):
def __init__(self, conf, session, msg_id):
"""Init a 'direct' publisher."""
super(DirectPublisher, self).__init__(session, msg_id,
- {"type": "Direct"})
+ {"type": "direct"})
class TopicPublisher(Publisher):
diff --git a/openstack/common/rpc/matchmaker.py b/openstack/common/rpc/matchmaker.py
index e80ab37..a94f542 100644
--- a/openstack/common/rpc/matchmaker.py
+++ b/openstack/common/rpc/matchmaker.py
@@ -248,9 +248,7 @@ class DirectBinding(Binding):
that it maps directly to a host, thus direct.
"""
def test(self, key):
- if '.' in key:
- return True
- return False
+ return '.' in key
class TopicBinding(Binding):
@@ -262,17 +260,13 @@ class TopicBinding(Binding):
matches that of a direct exchange.
"""
def test(self, key):
- if '.' not in key:
- return True
- return False
+ return '.' not in key
class FanoutBinding(Binding):
"""Match on fanout keys, where key starts with 'fanout.' string."""
def test(self, key):
- if key.startswith('fanout~'):
- return True
- return False
+ return key.startswith('fanout~')
class StubExchange(Exchange):
diff --git a/openstack/common/rpc/matchmaker_ring.py b/openstack/common/rpc/matchmaker_ring.py
index 45a095f..6b488ce 100644
--- a/openstack/common/rpc/matchmaker_ring.py
+++ b/openstack/common/rpc/matchmaker_ring.py
@@ -63,9 +63,7 @@ class RingExchange(mm.Exchange):
self.ring0[k] = itertools.cycle(self.ring[k])
def _ring_has(self, key):
- if key in self.ring0:
- return True
- return False
+ return key in self.ring0
class RoundRobinRingExchange(RingExchange):
diff --git a/openstack/common/rpc/securemessage.py b/openstack/common/rpc/securemessage.py
new file mode 100644
index 0000000..c5530a6
--- /dev/null
+++ b/openstack/common/rpc/securemessage.py
@@ -0,0 +1,521 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2013 Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import base64
+import collections
+import os
+import struct
+import time
+
+import requests
+
+from oslo.config import cfg
+
+from openstack.common.crypto import utils as cryptoutils
+from openstack.common import jsonutils
+from openstack.common import log as logging
+
+secure_message_opts = [
+ cfg.BoolOpt('enabled', default=True,
+ help='Whether Secure Messaging (Signing) is enabled,'
+ ' defaults to enabled'),
+ cfg.BoolOpt('enforced', default=False,
+ help='Whether Secure Messaging (Signing) is enforced,'
+ ' defaults to not enforced'),
+ cfg.BoolOpt('encrypt', default=False,
+ help='Whether Secure Messaging (Encryption) is enabled,'
+ ' defaults to not enabled'),
+ cfg.StrOpt('secret_keys_file',
+ help='Path to the file containing the keys, takes precedence'
+ ' over secret_key'),
+ cfg.MultiStrOpt('secret_key',
+ help='A list of keys: (ex: name:<base64 encoded key>),'
+ ' ignored if secret_keys_file is set'),
+ cfg.StrOpt('kds_endpoint',
+ help='KDS endpoint (ex: http://kds.example.com:35357/v3)'),
+]
+secure_message_group = cfg.OptGroup('secure_messages',
+ title='Secure Messaging options')
+
+LOG = logging.getLogger(__name__)
+
+
+class SecureMessageException(Exception):
+ """Generic Exception for Secure Messages."""
+
+ msg = "An unknown Secure Message related exception occurred."
+
+ def __init__(self, msg=None):
+ if msg is None:
+ msg = self.msg
+ super(SecureMessageException, self).__init__(msg)
+
+
+class SharedKeyNotFound(SecureMessageException):
+ """No shared key was found and no other external authentication mechanism
+ is available.
+ """
+
+ msg = "Shared Key for [%s] Not Found. (%s)"
+
+ def __init__(self, name, errmsg):
+ super(SharedKeyNotFound, self).__init__(self.msg % (name, errmsg))
+
+
+class InvalidMetadata(SecureMessageException):
+ """The metadata is invalid."""
+
+ msg = "Invalid metadata: %s"
+
+ def __init__(self, err):
+ super(InvalidMetadata, self).__init__(self.msg % err)
+
+
+class InvalidSignature(SecureMessageException):
+ """Signature validation failed."""
+
+ msg = "Failed to validate signature (source=%s, destination=%s)"
+
+ def __init__(self, src, dst):
+ super(InvalidSignature, self).__init__(self.msg % (src, dst))
+
+
+class UnknownDestinationName(SecureMessageException):
+ """The Destination name is unknown to us."""
+
+ msg = "Invalid destination name (%s)"
+
+ def __init__(self, name):
+ super(UnknownDestinationName, self).__init__(self.msg % name)
+
+
+class InvalidEncryptedTicket(SecureMessageException):
+ """The Encrypted Ticket could not be successfully handled."""
+
+ msg = "Invalid Ticket (source=%s, destination=%s)"
+
+ def __init__(self, src, dst):
+ super(InvalidEncryptedTicket, self).__init__(self.msg % (src, dst))
+
+
+class InvalidExpiredTicket(SecureMessageException):
+ """The ticket received is already expired."""
+
+ msg = "Expired ticket (source=%s, destination=%s)"
+
+ def __init__(self, src, dst):
+ super(InvalidExpiredTicket, self).__init__(self.msg % (src, dst))
+
+
+class CommunicationError(SecureMessageException):
+ """The Communication with the KDS failed."""
+
+ msg = "Communication Error (target=%s): %s"
+
+ def __init__(self, target, errmsg):
+ super(CommunicationError, self).__init__(self.msg % (target, errmsg))
+
+
+class InvalidArgument(SecureMessageException):
+ """Bad initialization argument."""
+
+ msg = "Invalid argument: %s"
+
+ def __init__(self, errmsg):
+ super(InvalidArgument, self).__init__(self.msg % errmsg)
+
+
+Ticket = collections.namedtuple('Ticket', ['skey', 'ekey', 'esek'])
+
+
+class KeyStore(object):
+ """A storage class for Signing and Encryption Keys.
+
+ This class creates an object that holds Generic Keys like Signing
+ Keys, Encryption Keys, Encrypted SEK Tickets ...
+ """
+
+ def __init__(self):
+ self._kvps = dict()
+
+ def _get_key_name(self, source, target, ktype):
+ return (source, target, ktype)
+
+ def _put(self, src, dst, ktype, expiration, data):
+ name = self._get_key_name(src, dst, ktype)
+ self._kvps[name] = (expiration, data)
+
+ def _get(self, src, dst, ktype):
+ name = self._get_key_name(src, dst, ktype)
+ if name in self._kvps:
+ expiration, data = self._kvps[name]
+ if expiration > time.time():
+ return data
+ else:
+ del self._kvps[name]
+
+ return None
+
+ def clear(self):
+ """Wipes the store clear of all data."""
+ self._kvps.clear()
+
+ def put_ticket(self, source, target, skey, ekey, esek, expiration):
+ """Puts a sek pair in the cache.
+
+ :param source: Client name
+ :param target: Target name
+ :param skey: The Signing Key
+ :param ekey: The Encription Key
+ :param esek: The token encrypted with the target key
+ :param expiration: Expiration time in seconds since Epoch
+ """
+ keys = Ticket(skey, ekey, esek)
+ self._put(source, target, 'ticket', expiration, keys)
+
+ def get_ticket(self, source, target):
+ """Returns a Ticket (skey, ekey, esek) namedtuple for the
+ source/target pair.
+ """
+ return self._get(source, target, 'ticket')
+
+
+_KEY_STORE = KeyStore()
+
+
+class _KDSClient(object):
+
+ USER_AGENT = 'oslo-incubator/rpc'
+
+ def __init__(self, endpoint=None, timeout=None):
+ """A KDS Client class."""
+
+ self._endpoint = endpoint
+ if timeout is not None:
+ self.timeout = float(timeout)
+ else:
+ self.timeout = None
+
+ def _do_get(self, url, request):
+ req_kwargs = dict()
+ req_kwargs['headers'] = dict()
+ req_kwargs['headers']['User-Agent'] = self.USER_AGENT
+ req_kwargs['headers']['Content-Type'] = 'application/json'
+ req_kwargs['data'] = jsonutils.dumps({'request': request})
+ if self.timeout is not None:
+ req_kwargs['timeout'] = self.timeout
+
+ try:
+ resp = requests.get(url, **req_kwargs)
+ except requests.ConnectionError as e:
+ err = "Unable to establish connection. %s" % e
+ raise CommunicationError(url, err)
+
+ return resp
+
+ def _get_reply(self, url, resp):
+ if resp.text:
+ try:
+ body = jsonutils.loads(resp.text)
+ reply = body['reply']
+ except (KeyError, TypeError, ValueError):
+ msg = "Failed to decode reply: %s" % resp.text
+ raise CommunicationError(url, msg)
+ else:
+ msg = "No reply data was returned."
+ raise CommunicationError(url, msg)
+
+ return reply
+
+ def _get_ticket(self, request, url=None, redirects=10):
+ """Send an HTTP request.
+
+ Wraps around 'requests' to handle redirects and common errors.
+ """
+ if url is None:
+ if not self._endpoint:
+ raise CommunicationError(url, 'Endpoint not configured')
+ url = self._endpoint + '/kds/ticket'
+
+ while redirects:
+ resp = self._do_get(url, request)
+ if resp.status_code in (301, 302, 305):
+ # Redirected. Reissue the request to the new location.
+ url = resp.headers['location']
+ redirects -= 1
+ continue
+ elif resp.status_code != 200:
+ msg = "Request returned failure status: %s (%s)"
+ err = msg % (resp.status_code, resp.text)
+ raise CommunicationError(url, err)
+
+ return self._get_reply(url, resp)
+
+ raise CommunicationError(url, "Too many redirections, giving up!")
+
+ def get_ticket(self, source, target, crypto, key):
+
+ # prepare metadata
+ md = {'requestor': source,
+ 'target': target,
+ 'timestamp': time.time(),
+ 'nonce': struct.unpack('Q', os.urandom(8))[0]}
+ metadata = base64.b64encode(jsonutils.dumps(md))
+
+ # sign metadata
+ signature = crypto.sign(key, metadata)
+
+ # HTTP request
+ reply = self._get_ticket({'metadata': metadata,
+ 'signature': signature})
+
+ # verify reply
+ signature = crypto.sign(key, (reply['metadata'] + reply['ticket']))
+ if signature != reply['signature']:
+ raise InvalidEncryptedTicket(md['source'], md['destination'])
+ md = jsonutils.loads(base64.b64decode(reply['metadata']))
+ if ((md['source'] != source or
+ md['destination'] != target or
+ md['expiration'] < time.time())):
+ raise InvalidEncryptedTicket(md['source'], md['destination'])
+
+ # return ticket data
+ tkt = jsonutils.loads(crypto.decrypt(key, reply['ticket']))
+
+ return tkt, md['expiration']
+
+
+# we need to keep a global nonce, as this value should never repeat non
+# matter how many SecureMessage objects we create
+_NONCE = None
+
+
+def _get_nonce():
+ """We keep a single counter per instance, as it is so huge we can't
+ possibly cycle through within 1/100 of a second anyway.
+ """
+
+ global _NONCE
+ # Lazy initialize, for now get a random value, multiply by 2^32 and
+ # use it as the nonce base. The counter itself will rotate after
+ # 2^32 increments.
+ if _NONCE is None:
+ _NONCE = [struct.unpack('I', os.urandom(4))[0], 0]
+
+ # Increment counter and wrap at 2^32
+ _NONCE[1] += 1
+ if _NONCE[1] > 0xffffffff:
+ _NONCE[1] = 0
+
+ # Return base + counter
+ return long((_NONCE[0] * 0xffffffff)) + _NONCE[1]
+
+
+class SecureMessage(object):
+ """A Secure Message object.
+
+ This class creates a signing/encryption facility for RPC messages.
+ It encapsulates all the necessary crypto primitives to insulate
+ regular code from the intricacies of message authentication, validation
+ and optionally encryption.
+
+ :param topic: The topic name of the queue
+ :param host: The server name, together with the topic it forms a unique
+ name that is used to source signing keys, and verify
+ incoming messages.
+ :param conf: a ConfigOpts object
+ :param key: (optional) explicitly pass in endpoint private key.
+ If not provided it will be sourced from the service config
+ :param key_store: (optional) Storage class for local caching
+ :param encrypt: (defaults to False) Whether to encrypt messages
+ :param enctype: (defaults to AES) Cipher to use
+ :param hashtype: (defaults to SHA256) Hash function to use for signatures
+ """
+
+ def __init__(self, topic, host, conf, key=None, key_store=None,
+ encrypt=None, enctype='AES', hashtype='SHA256'):
+
+ conf.register_group(secure_message_group)
+ conf.register_opts(secure_message_opts, group='secure_messages')
+
+ self._name = '%s.%s' % (topic, host)
+ self._key = key
+ self._conf = conf.secure_messages
+ self._encrypt = self._conf.encrypt if (encrypt is None) else encrypt
+ self._crypto = cryptoutils.SymmetricCrypto(enctype, hashtype)
+ self._hkdf = cryptoutils.HKDF(hashtype)
+ self._kds = _KDSClient(self._conf.kds_endpoint)
+
+ if self._key is None:
+ self._key = self._init_key(topic, self._name)
+ if self._key is None:
+ err = "Secret Key (or key file) is missing or malformed"
+ raise SharedKeyNotFound(self._name, err)
+
+ self._key_store = key_store or _KEY_STORE
+
+ def _init_key(self, topic, name):
+ keys = None
+ if self._conf.secret_keys_file:
+ with open(self._conf.secret_keys_file, 'r') as f:
+ keys = f.readlines()
+ elif self._conf.secret_key:
+ keys = self._conf.secret_key
+
+ if keys is None:
+ return None
+
+ for k in keys:
+ if k[0] == '#':
+ continue
+ if ':' not in k:
+ break
+ svc, key = k.split(':', 1)
+ if svc == topic or svc == name:
+ return base64.b64decode(key)
+
+ return None
+
+ def _split_key(self, key, size):
+ sig_key = key[:size]
+ enc_key = key[size:]
+ return sig_key, enc_key
+
+ def _decode_esek(self, key, source, target, timestamp, esek):
+ """This function decrypts the esek buffer passed in and returns a
+ KeyStore to be used to check and decrypt the received message.
+
+ :param key: The key to use to decrypt the ticket (esek)
+ :param source: The name of the source service
+ :param traget: The name of the target service
+ :param timestamp: The incoming message timestamp
+ :param esek: a base64 encoded encrypted block containing a JSON string
+ """
+ rkey = None
+
+ try:
+ s = self._crypto.decrypt(key, esek)
+ j = jsonutils.loads(s)
+
+ rkey = base64.b64decode(j['key'])
+ expiration = j['timestamp'] + j['ttl']
+ if j['timestamp'] > timestamp or timestamp > expiration:
+ raise InvalidExpiredTicket(source, target)
+
+ except Exception:
+ raise InvalidEncryptedTicket(source, target)
+
+ info = '%s,%s,%s' % (source, target, str(j['timestamp']))
+
+ sek = self._hkdf.expand(rkey, info, len(key) * 2)
+
+ return self._split_key(sek, len(key))
+
+ def _get_ticket(self, target):
+ """This function will check if we already have a SEK for the specified
+ target in the cache, or will go and try to fetch a new SEK from the key
+ server.
+
+ :param target: The name of the target service
+ """
+ ticket = self._key_store.get_ticket(self._name, target)
+
+ if ticket is not None:
+ return ticket
+
+ tkt, expiration = self._kds.get_ticket(self._name, target,
+ self._crypto, self._key)
+
+ self._key_store.put_ticket(self._name, target,
+ base64.b64decode(tkt['skey']),
+ base64.b64decode(tkt['ekey']),
+ tkt['esek'], expiration)
+ return self._key_store.get_ticket(self._name, target)
+
+ def encode(self, version, target, json_msg):
+ """This is the main encoding function.
+
+ It takes a target and a message and returns a tuple consisting of a
+ JSON serialized metadata object, a JSON serialized (and optionally
+ encrypted) message, and a signature.
+
+ :param version: the current envelope version
+ :param target: The name of the target service (usually with hostname)
+ :param json_msg: a serialized json message object
+ """
+ ticket = self._get_ticket(target)
+
+ metadata = jsonutils.dumps({'source': self._name,
+ 'destination': target,
+ 'timestamp': time.time(),
+ 'nonce': _get_nonce(),
+ 'esek': ticket.esek,
+ 'encryption': self._encrypt})
+
+ message = json_msg
+ if self._encrypt:
+ message = self._crypto.encrypt(ticket.ekey, message)
+
+ signature = self._crypto.sign(ticket.skey,
+ version + metadata + message)
+
+ return (metadata, message, signature)
+
+ def decode(self, version, metadata, message, signature):
+ """This is the main decoding function.
+
+ It takes a version, metadata, message and signature strings and
+ returns a tuple with a (decrypted) message and metadata or raises
+ an exception in case of error.
+
+ :param version: the current envelope version
+ :param metadata: a JSON serialized object with metadata for validation
+ :param message: a JSON serialized (base64 encoded encrypted) message
+ :param signature: a base64 encoded signature
+ """
+ md = jsonutils.loads(metadata)
+
+ check_args = ('source', 'destination', 'timestamp',
+ 'nonce', 'esek', 'encryption')
+ for arg in check_args:
+ if arg not in md:
+ raise InvalidMetadata('Missing metadata "%s"' % arg)
+
+ if md['destination'] != self._name:
+ # TODO(simo) handle group keys by checking target
+ raise UnknownDestinationName(md['destination'])
+
+ try:
+ skey, ekey = self._decode_esek(self._key,
+ md['source'], md['destination'],
+ md['timestamp'], md['esek'])
+ except InvalidExpiredTicket:
+ raise
+ except Exception:
+ raise InvalidMetadata('Failed to decode ESEK for %s/%s' % (
+ md['source'], md['destination']))
+
+ sig = self._crypto.sign(skey, version + metadata + message)
+
+ if sig != signature:
+ raise InvalidSignature(md['source'], md['destination'])
+
+ if md['encryption'] is True:
+ msg = self._crypto.decrypt(ekey, message)
+ else:
+ msg = message
+
+ return (md, msg)
diff --git a/openstack/common/rpc/zmq_receiver.py b/openstack/common/rpc/zmq_receiver.py
index e74da22..000a7dd 100755..100644
--- a/openstack/common/rpc/zmq_receiver.py
+++ b/openstack/common/rpc/zmq_receiver.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2011 OpenStack Foundation
diff --git a/openstack/common/service.py b/openstack/common/service.py
index cb71af2..6da9751 100644
--- a/openstack/common/service.py
+++ b/openstack/common/service.py
@@ -81,6 +81,15 @@ class Launcher(object):
"""
self.services.wait()
+ def restart(self):
+ """Reload config files and restart service.
+
+ :returns: None
+
+ """
+ cfg.CONF.reload_config_files()
+ self.services.restart()
+
class SignalExit(SystemExit):
def __init__(self, signo, exccode=1):
@@ -93,24 +102,31 @@ class ServiceLauncher(Launcher):
# Allow the process to be killed again and die from natural causes
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGINT, signal.SIG_DFL)
+ signal.signal(signal.SIGHUP, signal.SIG_DFL)
raise SignalExit(signo)
- def wait(self):
+ def handle_signal(self):
signal.signal(signal.SIGTERM, self._handle_signal)
signal.signal(signal.SIGINT, self._handle_signal)
+ signal.signal(signal.SIGHUP, self._handle_signal)
+
+ def _wait_for_exit_or_signal(self):
+ status = None
+ signo = 0
LOG.debug(_('Full set of CONF:'))
CONF.log_opt_values(LOG, std_logging.DEBUG)
- status = None
try:
super(ServiceLauncher, self).wait()
except SignalExit as exc:
signame = {signal.SIGTERM: 'SIGTERM',
- signal.SIGINT: 'SIGINT'}[exc.signo]
+ signal.SIGINT: 'SIGINT',
+ signal.SIGHUP: 'SIGHUP'}[exc.signo]
LOG.info(_('Caught %s, exiting'), signame)
status = exc.code
+ signo = exc.signo
except SystemExit as exc:
status = exc.code
finally:
@@ -121,7 +137,16 @@ class ServiceLauncher(Launcher):
except Exception:
# We're shutting down, so it doesn't matter at this point.
LOG.exception(_('Exception during rpc cleanup.'))
- return status
+
+ return status, signo
+
+ def wait(self):
+ while True:
+ self.handle_signal()
+ status, signo = self._wait_for_exit_or_signal()
+ if signo != signal.SIGHUP:
+ return status
+ self.restart()
class ServiceWrapper(object):
@@ -139,9 +164,12 @@ class ProcessLauncher(object):
self.running = True
rfd, self.writepipe = os.pipe()
self.readpipe = eventlet.greenio.GreenPipe(rfd, 'r')
+ self.handle_signal()
+ def handle_signal(self):
signal.signal(signal.SIGTERM, self._handle_signal)
signal.signal(signal.SIGINT, self._handle_signal)
+ signal.signal(signal.SIGHUP, self._handle_signal)
def _handle_signal(self, signo, frame):
self.sigcaught = signo
@@ -150,6 +178,7 @@ class ProcessLauncher(object):
# Allow the process to be killed again and die from natural causes
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGINT, signal.SIG_DFL)
+ signal.signal(signal.SIGHUP, signal.SIG_DFL)
def _pipe_watcher(self):
# This will block until the write end is closed when the parent
@@ -160,16 +189,47 @@ class ProcessLauncher(object):
sys.exit(1)
- def _child_process(self, service):
+ def _child_process_handle_signal(self):
# Setup child signal handlers differently
def _sigterm(*args):
signal.signal(signal.SIGTERM, signal.SIG_DFL)
raise SignalExit(signal.SIGTERM)
+ def _sighup(*args):
+ signal.signal(signal.SIGHUP, signal.SIG_DFL)
+ raise SignalExit(signal.SIGHUP)
+
signal.signal(signal.SIGTERM, _sigterm)
+ signal.signal(signal.SIGHUP, _sighup)
# Block SIGINT and let the parent send us a SIGTERM
signal.signal(signal.SIGINT, signal.SIG_IGN)
+ def _child_wait_for_exit_or_signal(self, launcher):
+ status = None
+ signo = 0
+
+ try:
+ launcher.wait()
+ except SignalExit as exc:
+ signame = {signal.SIGTERM: 'SIGTERM',
+ signal.SIGINT: 'SIGINT',
+ signal.SIGHUP: 'SIGHUP'}[exc.signo]
+ LOG.info(_('Caught %s, exiting'), signame)
+ status = exc.code
+ signo = exc.signo
+ except SystemExit as exc:
+ status = exc.code
+ except BaseException:
+ LOG.exception(_('Unhandled exception'))
+ status = 2
+ finally:
+ launcher.stop()
+
+ return status, signo
+
+ def _child_process(self, service):
+ self._child_process_handle_signal()
+
# Reopen the eventlet hub to make sure we don't share an epoll
# fd with parent and/or siblings, which would be bad
eventlet.hubs.use_hub()
@@ -184,7 +244,7 @@ class ProcessLauncher(object):
launcher = Launcher()
launcher.launch_service(service)
- launcher.wait()
+ return launcher
def _start_child(self, wrap):
if len(wrap.forktimes) > wrap.workers:
@@ -205,21 +265,13 @@ class ProcessLauncher(object):
# NOTE(johannes): All exceptions are caught to ensure this
# doesn't fallback into the loop spawning children. It would
# be bad for a child to spawn more children.
- status = 0
- try:
- self._child_process(wrap.service)
- except SignalExit as exc:
- signame = {signal.SIGTERM: 'SIGTERM',
- signal.SIGINT: 'SIGINT'}[exc.signo]
- LOG.info(_('Caught %s, exiting'), signame)
- status = exc.code
- except SystemExit as exc:
- status = exc.code
- except BaseException:
- LOG.exception(_('Unhandled exception'))
- status = 2
- finally:
- wrap.service.stop()
+ launcher = self._child_process(wrap.service)
+ while True:
+ self._child_process_handle_signal()
+ status, signo = self._child_wait_for_exit_or_signal(launcher)
+ if signo != signal.SIGHUP:
+ break
+ launcher.restart()
os._exit(status)
@@ -265,12 +317,7 @@ class ProcessLauncher(object):
wrap.children.remove(pid)
return wrap
- def wait(self):
- """Loop waiting on children to die and respawning as necessary."""
-
- LOG.debug(_('Full set of CONF:'))
- CONF.log_opt_values(LOG, std_logging.DEBUG)
-
+ def _respawn_children(self):
while self.running:
wrap = self._wait_child()
if not wrap:
@@ -279,14 +326,30 @@ class ProcessLauncher(object):
# (see bug #1095346)
eventlet.greenthread.sleep(.01)
continue
-
while self.running and len(wrap.children) < wrap.workers:
self._start_child(wrap)
- if self.sigcaught:
- signame = {signal.SIGTERM: 'SIGTERM',
- signal.SIGINT: 'SIGINT'}[self.sigcaught]
- LOG.info(_('Caught %s, stopping children'), signame)
+ def wait(self):
+ """Loop waiting on children to die and respawning as necessary."""
+
+ LOG.debug(_('Full set of CONF:'))
+ CONF.log_opt_values(LOG, std_logging.DEBUG)
+
+ while True:
+ self.handle_signal()
+ self._respawn_children()
+ if self.sigcaught:
+ signame = {signal.SIGTERM: 'SIGTERM',
+ signal.SIGINT: 'SIGINT',
+ signal.SIGHUP: 'SIGHUP'}[self.sigcaught]
+ LOG.info(_('Caught %s, stopping children'), signame)
+ if self.sigcaught != signal.SIGHUP:
+ break
+
+ for pid in self.children:
+ os.kill(pid, signal.SIGHUP)
+ self.running = True
+ self.sigcaught = None
for pid in self.children:
try:
@@ -311,6 +374,10 @@ class Service(object):
# signal that the service is done shutting itself down:
self._done = event.Event()
+ def reset(self):
+ # NOTE(Fengqian): docs for Event.reset() recommend against using it
+ self._done = event.Event()
+
def start(self):
pass
@@ -353,6 +420,13 @@ class Services(object):
def wait(self):
self.tg.wait()
+ def restart(self):
+ self.stop()
+ self.done = event.Event()
+ for restart_service in self.services:
+ restart_service.reset()
+ self.tg.add_thread(self.run_service, restart_service, self.done)
+
@staticmethod
def run_service(service, done):
"""Service start wrapper.
diff --git a/openstack/common/test.py b/openstack/common/test.py
new file mode 100644
index 0000000..7f400e5
--- /dev/null
+++ b/openstack/common/test.py
@@ -0,0 +1,52 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010-2011 OpenStack Foundation
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Common utilities used in testing"""
+
+import os
+
+import fixtures
+import testtools
+
+
+class BaseTestCase(testtools.TestCase):
+
+ def setUp(self):
+ super(BaseTestCase, self).setUp()
+ self._set_timeout()
+ self._fake_output()
+ self.useFixture(fixtures.FakeLogger('openstack.common'))
+
+ def _set_timeout(self):
+ test_timeout = os.environ.get('OS_TEST_TIMEOUT', 0)
+ try:
+ test_timeout = int(test_timeout)
+ except ValueError:
+ # If timeout value is invalid do not set a timeout.
+ test_timeout = 0
+ if test_timeout > 0:
+ self.useFixture(fixtures.Timeout(test_timeout, gentle=True))
+
+ def _fake_output(self):
+ if (os.environ.get('OS_STDOUT_CAPTURE') == 'True' or
+ os.environ.get('OS_STDOUT_CAPTURE') == '1'):
+ stdout = self.useFixture(fixtures.StringStream('stdout')).stream
+ self.useFixture(fixtures.MonkeyPatch('sys.stdout', stdout))
+ if (os.environ.get('OS_STDERR_CAPTURE') == 'True' or
+ os.environ.get('OS_STDERR_CAPTURE') == '1'):
+ stderr = self.useFixture(fixtures.StringStream('stderr')).stream
+ self.useFixture(fixtures.MonkeyPatch('sys.stderr', stderr))
diff --git a/openstack/common/timeutils.py b/openstack/common/timeutils.py
index bd60489..aa9f708 100644
--- a/openstack/common/timeutils.py
+++ b/openstack/common/timeutils.py
@@ -49,9 +49,9 @@ def parse_isotime(timestr):
try:
return iso8601.parse_date(timestr)
except iso8601.ParseError as e:
- raise ValueError(e.message)
+ raise ValueError(unicode(e))
except TypeError as e:
- raise ValueError(e.message)
+ raise ValueError(unicode(e))
def strtime(at=None, fmt=PERFECT_TIME_FORMAT):
diff --git a/pypi/setup.py b/pypi/setup.py
index feb38c8..ea9a5be 100644
--- a/pypi/setup.py
+++ b/pypi/setup.py
@@ -33,7 +33,7 @@ setuptools.setup(
],
keywords='openstack',
author='OpenStack',
- author_email='openstack@lists.launchpad.net',
+ author_email='openstack@lists.openstack.org',
url='http://www.openstack.org/',
license='Apache Software License',
zip_safe=True,
diff --git a/requirements.txt b/requirements.txt
index 2c512f1..656241b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,6 +5,7 @@ WebOb==1.2.3
eventlet>=0.12.0
greenlet>=0.3.2
lxml
+requests>=1.1,<1.2.3
routes==1.12.3
iso8601>=0.1.4
anyjson>=0.3.3
@@ -17,6 +18,7 @@ qpid-python
six
netaddr
pycrypto>=2.6
+Babel>=0.9.6
-f http://tarballs.openstack.org/oslo.config/oslo.config-1.2.0a3.tar.gz#egg=oslo.config-1.2.0a3
oslo.config>=1.2.0a3
diff --git a/run_tests.sh b/run_tests.sh
new file mode 100755
index 0000000..bd850af
--- /dev/null
+++ b/run_tests.sh
@@ -0,0 +1,52 @@
+#!/bin/bash
+
+# Current scriipt is the simple wrapper on common tools/run_tests_common.sh
+# scrip. It pass project specific variables to common script.
+#
+# Optins list (from tools/run_tests_common.sh).
+# Use `./run_tests.sh -h` `./run_tests.sh --help` to get help message
+#
+# -V, --virtual-env Always use virtualenv. Install automatically if not present
+# -N, --no-virtual-env Don't use virtualenv. Run tests in local environment
+# -s, --no-site-packages Isolate the virtualenv from the global Python environment
+# -r, --recreate-db Recreate the test database (deprecated, as this is now the default).
+# -n, --no-recreate-db Don't recreate the test database.
+# -f, --force Force a clean re-build of the virtual environment. Useful when dependencies have been added.
+# -u, --update Update the virtual environment with any newer package versions
+# -p, --pep8 Just run PEP8 and HACKING compliance check
+# -P, --no-pep8 Don't run static code checks
+# -c, --coverage Generate coverage report
+# -d, --debug Run tests with testtools instead of testr. This allows you to use the debugger.
+# -h, --help Print this usage message
+# --hide-elapsed Don't print the elapsed time for each test along with slow test list
+# --virtual-env-path <path> Location of the virtualenv directory. Default: \$(pwd)
+# --virtual-env-name <name> Name of the virtualenv directory. Default: .venv
+# --tools-path <dir> Location of the tools directory. Default: \$(pwd)
+#
+# Note: with no options specified, the script will try to run the tests in a
+# virtual environment, if no virtualenv is found, the script will ask if
+# you would like to create one. If you prefer to run tests NOT in a
+# virtual environment, simply pass the -N option.
+
+
+# On Linux, testrepository will inspect /proc/cpuinfo to determine how many
+# CPUs are present in the machine, and run one worker per CPU.
+# Set workers_count=0 if you want to run one worker per CPU.
+# Make our paths available to run_tests_common.sh using `export` statement
+# export WORKERS_COUNT=0
+
+# there are no possibility to run some oslo tests with concurrency > 1
+# or separately due to dependencies between tests (see bug 1192207)
+export WORKERS_COUNT=1
+# option include {PROJECT_NAME}/* directory to coverage report if `-c` or
+# `--coverage` uses
+export PROJECT_NAME="openstack"
+# option exclude "${PROJECT_NAME}/openstack/common/*" from coverage report
+# if equals to 1
+export OMIT_OSLO_FROM_COVERAGE=0
+# path to directory with tests
+export TESTS_DIR="tests/"
+export EGG_INFO_FILE="openstack.common.egg-info/entry_points.txt"
+
+# run common test script
+tools/run_tests_common.sh $*
diff --git a/test-requirements.txt b/test-requirements.txt
index 3d88f90..f98d81a 100644
--- a/test-requirements.txt
+++ b/test-requirements.txt
@@ -2,7 +2,7 @@ coverage
discover
fixtures>=0.3.12
flake8==2.0
-hacking>=0.5.6,<0.7
+hacking>=0.5.6,<0.8
mock
mox==0.5.3
mysql-python
diff --git a/tests/unit/apiclient/test_auth.py b/tests/unit/apiclient/test_auth.py
new file mode 100644
index 0000000..15d9ef9
--- /dev/null
+++ b/tests/unit/apiclient/test_auth.py
@@ -0,0 +1,181 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012 OpenStack Foundation
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import argparse
+
+import fixtures
+import mock
+import requests
+
+from stevedore import extension
+
+try:
+ import json
+except ImportError:
+ import simplejson as json
+
+from openstack.common.apiclient import auth
+from openstack.common.apiclient import client
+from openstack.common.apiclient import fake_client
+from openstack.common import test
+
+
+TEST_REQUEST_BASE = {
+ 'verify': True,
+}
+
+
+def mock_http_request(resp=None):
+ """Mock an HTTP Request."""
+ if not resp:
+ resp = {
+ "access": {
+ "token": {
+ "expires": "12345",
+ "id": "FAKE_ID",
+ "tenant": {
+ "id": "FAKE_TENANT_ID",
+ }
+ },
+ "serviceCatalog": [
+ {
+ "type": "compute",
+ "endpoints": [
+ {
+ "region": "RegionOne",
+ "adminURL": "http://localhost:8774/v1.1",
+ "internalURL": "http://localhost:8774/v1.1",
+ "publicURL": "http://localhost:8774/v1.1/",
+ },
+ ],
+ },
+ ],
+ },
+ }
+
+ auth_response = fake_client.TestResponse({
+ "status_code": 200,
+ "text": json.dumps(resp),
+ })
+ return mock.Mock(return_value=(auth_response))
+
+
+def requested_headers(cs):
+ """Return requested passed headers."""
+ return {
+ 'User-Agent': cs.user_agent,
+ 'Content-Type': 'application/json',
+ }
+
+
+class BaseFakePlugin(auth.BaseAuthPlugin):
+ def _do_authenticate(self, http_client):
+ pass
+
+ def token_and_endpoint(self, endpoint_type, service_type):
+ pass
+
+
+class GlobalFunctionsTest(test.BaseTestCase):
+
+ def test_load_auth_system_opts(self):
+ self.useFixture(fixtures.MonkeyPatch(
+ "os.environ",
+ {"OS_TENANT_NAME": "fake-project",
+ "OS_USERNAME": "fake-username"}))
+ parser = argparse.ArgumentParser()
+ auth.load_auth_system_opts(parser)
+ options = parser.parse_args(
+ ["--os-auth-url=fake-url", "--os_auth_system=fake-system"])
+ self.assertTrue(options.os_tenant_name, "fake-project")
+ self.assertTrue(options.os_username, "fake-username")
+ self.assertTrue(options.os_auth_url, "fake-url")
+ self.assertTrue(options.os_auth_system, "fake-system")
+
+
+class MockEntrypoint(object):
+ def __init__(self, name, plugin):
+ self.name = name
+ self.plugin = plugin
+
+
+class AuthPluginTest(test.BaseTestCase):
+ @mock.patch.object(requests.Session, "request")
+ @mock.patch.object(extension.ExtensionManager, "map")
+ def test_auth_system_success(self, mock_mgr_map, mock_request):
+ """Test that we can authenticate using the auth system."""
+ class FakePlugin(BaseFakePlugin):
+ def authenticate(self, cls):
+ cls.request(
+ "POST", "http://auth/tokens",
+ json={"fake": "me"}, allow_redirects=True)
+
+ mock_mgr_map.side_effect = (
+ lambda func: func(MockEntrypoint("fake", FakePlugin)))
+
+ mock_request.side_effect = mock_http_request()
+
+ auth.discover_auth_systems()
+ plugin = auth.load_plugin("fake")
+ cs = client.HTTPClient(auth_plugin=plugin)
+ cs.authenticate()
+
+ headers = requested_headers(cs)
+
+ mock_request.assert_called_with(
+ "POST",
+ "http://auth/tokens",
+ headers=headers,
+ data='{"fake": "me"}',
+ allow_redirects=True,
+ **TEST_REQUEST_BASE)
+
+ @mock.patch.object(extension.ExtensionManager, "map")
+ def test_discover_auth_system_options(self, mock_mgr_map):
+ """Test that we can load the auth system options."""
+ class FakePlugin(BaseFakePlugin):
+ @classmethod
+ def add_opts(cls, parser):
+ parser.add_argument('--auth_system_opt',
+ default=False,
+ action='store_true',
+ help="Fake option")
+
+ mock_mgr_map.side_effect = (
+ lambda func: func(MockEntrypoint("fake", FakePlugin)))
+
+ parser = argparse.ArgumentParser()
+ auth.discover_auth_systems()
+ auth.load_auth_system_opts(parser)
+ opts, _args = parser.parse_known_args(['--auth_system_opt'])
+
+ self.assertTrue(opts.auth_system_opt)
+
+ @mock.patch.object(extension.ExtensionManager, "map")
+ def test_parse_auth_system_options(self, mock_mgr_map):
+ """Test that we can parse the auth system options."""
+ class FakePlugin(BaseFakePlugin):
+ opt_names = ["fake_argument"]
+
+ mock_mgr_map.side_effect = (
+ lambda func: func(MockEntrypoint("fake", FakePlugin)))
+
+ auth.discover_auth_systems()
+ plugin = auth.load_plugin("fake")
+
+ plugin.parse_opts([])
+ self.assertIn("fake_argument", plugin.opts)
diff --git a/tests/unit/apiclient/test_base.py b/tests/unit/apiclient/test_base.py
new file mode 100644
index 0000000..86f2941
--- /dev/null
+++ b/tests/unit/apiclient/test_base.py
@@ -0,0 +1,239 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2013 OpenStack Foundation
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from openstack.common.apiclient import base
+from openstack.common.apiclient import client
+from openstack.common.apiclient import exceptions
+from openstack.common.apiclient import fake_client
+from openstack.common import test
+
+
+class HumanResource(base.Resource):
+ HUMAN_ID = True
+
+
+class HumanResourceManager(base.ManagerWithFind):
+ resource_class = HumanResource
+
+ def list(self):
+ return self._list("/human_resources", "human_resources")
+
+ def get(self, human_resource):
+ return self._get(
+ "/human_resources/%s" % base.getid(human_resource),
+ "human_resource")
+
+ def update(self, human_resource, name):
+ body = {
+ "human_resource": {
+ "name": name,
+ },
+ }
+ return self._put(
+ "/human_resources/%s" % base.getid(human_resource),
+ body,
+ "human_resource")
+
+
+class CrudResource(base.Resource):
+ pass
+
+
+class CrudResourceManager(base.CrudManager):
+ """Manager class for manipulating Identity crud_resources."""
+ resource_class = CrudResource
+ collection_key = 'crud_resources'
+ key = 'crud_resource'
+
+ def get(self, crud_resource):
+ return super(CrudResourceManager, self).get(
+ crud_resource_id=base.getid(crud_resource))
+
+
+class FakeHTTPClient(fake_client.FakeHTTPClient):
+ crud_resource_json = {"id": "1", "domain_id": "my-domain"}
+
+ def get_human_resources(self, **kw):
+ return (200, {}, {'human_resources': [
+ {'id': 1, 'name': '256 MB Server'},
+ {'id': 2, 'name': '512 MB Server'},
+ {'id': 'aa1', 'name': '128 MB Server'}
+ ]})
+
+ def get_human_resources_1(self, **kw):
+ res = self.get_human_resources()[2]['human_resources'][0]
+ return (200, {}, {'human_resource': res})
+
+ def put_human_resources_1(self, **kw):
+ kw = kw["json"]["human_resource"].copy()
+ kw["id"] = "1"
+ return (200, {}, {'human_resource': kw})
+
+ def post_crud_resources(self, **kw):
+ return (200, {}, {"crud_resource": {"id": "1"}})
+
+ def get_crud_resources(self, **kw):
+ crud_resources = []
+ if kw.get("domain_id") == self.crud_resource_json["domain_id"]:
+ crud_resources = [self.crud_resource_json]
+ else:
+ crud_resources = []
+ return (200, {}, {"crud_resources": crud_resources})
+
+ def get_crud_resources_1(self, **kw):
+ return (200, {}, {"crud_resource": self.crud_resource_json})
+
+ def head_crud_resources_1(self, **kw):
+ return (204, {}, None)
+
+ def patch_crud_resources_1(self, **kw):
+ self.crud_resource_json.update(kw)
+ return (200, {}, {"crud_resource": self.crud_resource_json})
+
+ def delete_crud_resources_1(self, **kw):
+ return (202, {}, None)
+
+
+class TestClient(client.BaseClient):
+
+ service_type = "test"
+
+ def __init__(self, http_client, extensions=None):
+ super(TestClient, self).__init__(
+ http_client, extensions=extensions)
+
+ self.human_resources = HumanResourceManager(self)
+ self.crud_resources = CrudResourceManager(self)
+
+
+class ResourceTest(test.BaseTestCase):
+
+ def test_resource_repr(self):
+ r = base.Resource(None, dict(foo="bar", baz="spam"))
+ self.assertEqual(repr(r), "<Resource baz=spam, foo=bar>")
+
+ def test_getid(self):
+ class TmpObject(base.Resource):
+ id = "4"
+ self.assertEqual(base.getid(TmpObject(None, {})), "4")
+
+ def test_human_id(self):
+ r = base.Resource(None, {"name": "1"})
+ self.assertEqual(r.human_id, None)
+ r = HumanResource(None, {"name": "1"})
+ self.assertEqual(r.human_id, "1")
+
+
+class BaseManagerTest(test.BaseTestCase):
+
+ def setUp(self):
+ super(BaseManagerTest, self).setUp()
+ self.http_client = FakeHTTPClient()
+ self.tc = TestClient(self.http_client)
+
+ def test_resource_lazy_getattr(self):
+ f = HumanResource(self.tc.human_resources, {'id': 1})
+ self.assertEqual(f.name, '256 MB Server')
+ self.http_client.assert_called('GET', '/human_resources/1')
+
+ # Missing stuff still fails after a second get
+ self.assertRaises(AttributeError, getattr, f, 'blahblah')
+
+ def test_eq(self):
+ # Two resources of the same type with the same id: equal
+ r1 = base.Resource(None, {'id': 1, 'name': 'hi'})
+ r2 = base.Resource(None, {'id': 1, 'name': 'hello'})
+ self.assertEqual(r1, r2)
+
+ # Two resources of different types: never equal
+ r1 = base.Resource(None, {'id': 1})
+ r2 = HumanResource(None, {'id': 1})
+ self.assertNotEqual(r1, r2)
+
+ # Two resources with no ID: equal if their info is equal
+ r1 = base.Resource(None, {'name': 'joe', 'age': 12})
+ r2 = base.Resource(None, {'name': 'joe', 'age': 12})
+ self.assertEqual(r1, r2)
+
+ def test_findall_invalid_attribute(self):
+ # Make sure findall with an invalid attribute doesn't cause errors.
+ # The following should not raise an exception.
+ self.tc.human_resources.findall(vegetable='carrot')
+
+ # However, find() should raise an error
+ self.assertRaises(exceptions.NotFound,
+ self.tc.human_resources.find,
+ vegetable='carrot')
+
+ def test_update(self):
+ name = "new-name"
+ human_resource = self.tc.human_resources.update("1", name)
+ self.assertEqual(human_resource.id, "1")
+ self.assertEqual(human_resource.name, name)
+
+
+class CrudManagerTest(test.BaseTestCase):
+
+ domain_id = "my-domain"
+ crud_resource_id = "1"
+
+ def setUp(self):
+ super(CrudManagerTest, self).setUp()
+ self.http_client = FakeHTTPClient()
+ self.tc = TestClient(self.http_client)
+
+ def test_create(self):
+ crud_resource = self.tc.crud_resources.create()
+ self.assertEqual(crud_resource.id, self.crud_resource_id)
+
+ def test_list(self, domain=None, user=None):
+ crud_resources = self.tc.crud_resources.list(
+ base_url=None,
+ domain_id=self.domain_id)
+ self.assertEqual(len(crud_resources), 1)
+ self.assertEqual(crud_resources[0].id, self.crud_resource_id)
+ self.assertEqual(crud_resources[0].domain_id, self.domain_id)
+ crud_resources = self.tc.crud_resources.list(
+ base_url=None,
+ domain_id="another-domain",
+ another_attr=None)
+ self.assertEqual(len(crud_resources), 0)
+
+ def test_get(self):
+ crud_resource = self.tc.crud_resources.get(self.crud_resource_id)
+ self.assertEqual(crud_resource.id, self.crud_resource_id)
+ fake_client.assert_has_keys(
+ crud_resource._info,
+ required=["id", "domain_id"],
+ optional=["missing-attr"])
+
+ def test_update(self):
+ crud_resource = self.tc.crud_resources.update(
+ crud_resource_id=self.crud_resource_id,
+ domain_id=self.domain_id)
+ self.assertEqual(crud_resource.id, self.crud_resource_id)
+ self.assertEqual(crud_resource.domain_id, self.domain_id)
+
+ def test_delete(self):
+ resp = self.tc.crud_resources.delete(
+ crud_resource_id=self.crud_resource_id)
+ self.assertEqual(resp.status_code, 202)
+
+ def test_head(self):
+ ret = self.tc.crud_resources.head(
+ crud_resource_id=self.crud_resource_id)
+ self.assertTrue(ret)
diff --git a/tests/unit/apiclient/test_client.py b/tests/unit/apiclient/test_client.py
new file mode 100644
index 0000000..b48ba12
--- /dev/null
+++ b/tests/unit/apiclient/test_client.py
@@ -0,0 +1,137 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012 OpenStack Foundation
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+
+import mock
+import requests
+
+from openstack.common.apiclient import auth
+from openstack.common.apiclient import client
+from openstack.common.apiclient import exceptions
+from openstack.common import test
+
+
+class TestClient(client.BaseClient):
+ service_type = "test"
+
+
+class FakeAuthPlugin(auth.BaseAuthPlugin):
+ auth_system = "fake"
+ attempt = -1
+
+ def _do_authenticate(self, http_client):
+ pass
+
+ def token_and_endpoint(self, endpoint_type, service_type):
+ self.attempt = self.attempt + 1
+ return ("token-%s" % self.attempt, "/endpoint-%s" % self.attempt)
+
+
+class ClientTest(test.BaseTestCase):
+
+ def test_client_with_timeout(self):
+ http_client = client.HTTPClient(None, timeout=2)
+ self.assertEqual(http_client.timeout, 2)
+ mock_request = mock.Mock()
+ mock_request.return_value = requests.Response()
+ mock_request.return_value.status_code = 200
+ with mock.patch("requests.Session.request", mock_request):
+ http_client.request("GET", "/", json={"1": "2"})
+ requests.Session.request.assert_called_with(
+ "GET",
+ "/",
+ timeout=2,
+ headers=mock.ANY,
+ verify=mock.ANY,
+ data=mock.ANY)
+
+ def test_concat_url(self):
+ self.assertEqual(client.HTTPClient.concat_url("/a", "/b"), "/a/b")
+ self.assertEqual(client.HTTPClient.concat_url("/a", "b"), "/a/b")
+ self.assertEqual(client.HTTPClient.concat_url("/a/", "/b"), "/a/b")
+
+ def test_client_request(self):
+ http_client = client.HTTPClient(FakeAuthPlugin())
+ mock_request = mock.Mock()
+ mock_request.return_value = requests.Response()
+ mock_request.return_value.status_code = 200
+ with mock.patch("requests.Session.request", mock_request):
+ http_client.client_request(
+ TestClient(http_client), "GET", "/resource", json={"1": "2"})
+ requests.Session.request.assert_called_with(
+ "GET",
+ "/endpoint-0/resource",
+ headers={
+ "User-Agent": http_client.user_agent,
+ "Content-Type": "application/json",
+ "X-Auth-Token": "token-0"
+ },
+ data='{"1": "2"}',
+ verify=True)
+
+ def test_client_request_reissue(self):
+ reject_token = None
+
+ def fake_request(method, url, **kwargs):
+ if kwargs["headers"]["X-Auth-Token"] == reject_token:
+ raise exceptions.Unauthorized(method=method, url=url)
+ return "%s %s" % (method, url)
+
+ http_client = client.HTTPClient(FakeAuthPlugin())
+ test_client = TestClient(http_client)
+ http_client.request = fake_request
+
+ self.assertEqual(
+ http_client.client_request(
+ test_client, "GET", "/resource"),
+ "GET /endpoint-0/resource")
+ reject_token = "token-0"
+ self.assertEqual(
+ http_client.client_request(
+ test_client, "GET", "/resource"),
+ "GET /endpoint-1/resource")
+
+
+class FakeClient1(object):
+ pass
+
+
+class FakeClient21(object):
+ pass
+
+
+class GetClientClassTestCase(test.BaseTestCase):
+ version_map = {
+ "1": "%s.FakeClient1" % __name__,
+ "2.1": "%s.FakeClient21" % __name__,
+ }
+
+ def test_get_int(self):
+ self.assertEqual(
+ client.BaseClient.get_class("fake", 1, self.version_map),
+ FakeClient1)
+
+ def test_get_str(self):
+ self.assertEqual(
+ client.BaseClient.get_class("fake", "2.1", self.version_map),
+ FakeClient21)
+
+ def test_unsupported_version(self):
+ self.assertRaises(
+ exceptions.UnsupportedVersion,
+ client.BaseClient.get_class,
+ "fake", "7", self.version_map)
diff --git a/tests/unit/apiclient/test_exceptions.py b/tests/unit/apiclient/test_exceptions.py
index 34cae73..39753ec 100644
--- a/tests/unit/apiclient/test_exceptions.py
+++ b/tests/unit/apiclient/test_exceptions.py
@@ -13,9 +13,8 @@
# License for the specific language governing permissions and limitations
# under the License.
-from tests import utils
-
from openstack.common.apiclient import exceptions
+from openstack.common import test
class FakeResponse(object):
@@ -29,7 +28,7 @@ class FakeResponse(object):
return self.json_data
-class ExceptionsArgsTest(utils.BaseTestCase):
+class ExceptionsArgsTest(test.BaseTestCase):
def assert_exception(self, ex_cls, method, url, status_code, json_data):
ex = exceptions.from_response(
@@ -61,7 +60,7 @@ class ExceptionsArgsTest(utils.BaseTestCase):
json_data = {"error": {"message": "fake unknown message",
"details": "fake unknown details"}}
self.assert_exception(
- exceptions.HttpClientError, method, url, status_code, json_data)
+ exceptions.HTTPClientError, method, url, status_code, json_data)
status_code = 600
self.assert_exception(
exceptions.HttpError, method, url, status_code, json_data)
diff --git a/tests/unit/crypto/test_utils.py b/tests/unit/crypto/test_utils.py
index 3a39100..a6cb6a2 100644
--- a/tests/unit/crypto/test_utils.py
+++ b/tests/unit/crypto/test_utils.py
@@ -18,10 +18,10 @@ Unit Tests for crypto utils.
"""
from openstack.common.crypto import utils as cryptoutils
-from tests import utils as test_utils
+from openstack.common import test
-class CryptoUtilsTestCase(test_utils.BaseTestCase):
+class CryptoUtilsTestCase(test.BaseTestCase):
# Uses Tests from RFC5869
def _test_HKDF(self, ikm, prk, okm, length,
diff --git a/tests/unit/db/sqlalchemy/test_migrate.py b/tests/unit/db/sqlalchemy/test_migrate.py
index 6724b5c..3e74a88 100644
--- a/tests/unit/db/sqlalchemy/test_migrate.py
+++ b/tests/unit/db/sqlalchemy/test_migrate.py
@@ -28,7 +28,7 @@ def uniques(*constraints):
Convert a sequence of UniqueConstraint instances into a set of
tuples of form (constraint_name, (constraint_columns)) so that
- assertEquals() will be able to compare sets of unique constraints
+ assertEqual() will be able to compare sets of unique constraints
"""
@@ -70,7 +70,7 @@ class TestSqliteUniqueConstraints(test_base.DbTestCase):
sa.UniqueConstraint(table.c.a, table.c.b, name='unique_a_b'),
sa.UniqueConstraint(table.c.b, table.c.c, name='unique_b_c'),
)
- self.assertEquals(should_be, existing)
+ self.assertEqual(should_be, existing)
def test_add_unique_constraint(self):
table = self.reflected_table
@@ -82,7 +82,7 @@ class TestSqliteUniqueConstraints(test_base.DbTestCase):
sa.UniqueConstraint(table.c.b, table.c.c, name='unique_b_c'),
sa.UniqueConstraint(table.c.a, table.c.c, name='unique_a_c'),
)
- self.assertEquals(should_be, existing)
+ self.assertEqual(should_be, existing)
def test_drop_unique_constraint(self):
table = self.reflected_table
@@ -92,4 +92,4 @@ class TestSqliteUniqueConstraints(test_base.DbTestCase):
should_be = uniques(
sa.UniqueConstraint(table.c.b, table.c.c, name='unique_b_c'),
)
- self.assertEquals(should_be, existing)
+ self.assertEqual(should_be, existing)
diff --git a/tests/unit/db/sqlalchemy/test_migration_common.py b/tests/unit/db/sqlalchemy/test_migration_common.py
new file mode 100644
index 0000000..bb27212
--- /dev/null
+++ b/tests/unit/db/sqlalchemy/test_migration_common.py
@@ -0,0 +1,154 @@
+# Copyright 2013 Mirantis Inc.
+# All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+
+import contextlib
+import os
+import tempfile
+
+from migrate import exceptions as migrate_exception
+from migrate.versioning import api as versioning_api
+import mock
+import sqlalchemy
+
+from openstack.common.db import exception as db_exception
+from openstack.common.db.sqlalchemy import migration
+from openstack.common.db.sqlalchemy import session as db_session
+from tests.unit.db.sqlalchemy import base
+
+
+class TestMigrationCommon(base.DbTestCase):
+ def setUp(self):
+ super(TestMigrationCommon, self).setUp()
+
+ migration._REPOSITORY = None
+ self.path = tempfile.mkdtemp('test_migration')
+ self.return_value = '/home/openstack/migrations'
+ self.init_version = 1
+ self.test_version = 123
+
+ self.patcher_repo = mock.patch.object(migration, 'Repository')
+ self.repository = self.patcher_repo.start()
+ self.repository.return_value = self.return_value
+
+ self.mock_api_db = mock.patch.object(versioning_api, 'db_version')
+ self.mock_api_db_version = self.mock_api_db.start()
+ self.mock_api_db_version.return_value = self.test_version
+
+ def tearDown(self):
+ os.rmdir(self.path)
+ self.mock_api_db.stop()
+ self.patcher_repo.stop()
+ super(TestMigrationCommon, self).tearDown()
+
+ def test_find_migrate_repo_path_not_found(self):
+ self.assertRaises(
+ db_exception.DbMigrationError,
+ migration._find_migrate_repo,
+ "/foo/bar/",
+ )
+ self.assertIsNone(migration._REPOSITORY)
+
+ def test_find_migrate_repo_called_once(self):
+ my_repository = migration._find_migrate_repo(self.path)
+
+ self.repository.assert_called_once_with(self.path)
+ self.assertEqual(migration._REPOSITORY, self.return_value)
+ self.assertEqual(my_repository, self.return_value)
+
+ def test_find_migrate_repo_called_few_times(self):
+ repository1 = migration._find_migrate_repo(self.path)
+ repository2 = migration._find_migrate_repo(self.path)
+
+ self.repository.assert_called_once_with(self.path)
+ self.assertEqual(migration._REPOSITORY, self.return_value)
+ self.assertEqual(repository1, self.return_value)
+ self.assertEqual(repository2, self.return_value)
+
+ def test_db_version_control(self):
+ with contextlib.nested(
+ mock.patch.object(migration, '_find_migrate_repo'),
+ mock.patch.object(versioning_api, 'version_control'),
+ ) as (mock_find_repo, mock_version_control):
+ mock_find_repo.return_value = self.return_value
+
+ version = migration.db_version_control(
+ self.path, self.test_version)
+
+ self.assertEqual(version, self.test_version)
+ mock_version_control.assert_called_once_with(
+ db_session.get_engine(), self.return_value, self.test_version)
+
+ def test_db_version_return(self):
+ ret_val = migration.db_version(self.path, self.init_version)
+ self.assertEqual(ret_val, self.test_version)
+
+ def test_db_version_raise_not_controlled_error_first(self):
+ with mock.patch.object(migration, 'db_version_control') as mock_ver:
+
+ self.mock_api_db_version.side_effect = [
+ migrate_exception.DatabaseNotControlledError('oups'),
+ self.test_version]
+
+ ret_val = migration.db_version(self.path, self.init_version)
+ self.assertEqual(ret_val, self.test_version)
+ mock_ver.assert_called_once_with(self.path, self.init_version)
+
+ def test_db_version_raise_not_controlled_error_no_tables(self):
+ with mock.patch.object(sqlalchemy, 'MetaData') as mock_meta:
+ self.mock_api_db_version.side_effect = \
+ migrate_exception.DatabaseNotControlledError('oups')
+ my_meta = mock.MagicMock()
+ my_meta.tables = {'a': 1, 'b': 2}
+ mock_meta.return_value = my_meta
+
+ self.assertRaises(
+ db_exception.DbMigrationError, migration.db_version,
+ self.path, self.init_version)
+
+ def test_db_sync_wrong_version(self):
+ self.assertRaises(
+ db_exception.DbMigrationError, migration.db_sync, self.path, 'foo')
+
+ def test_db_sync_upgrade(self):
+ init_ver = 55
+ with contextlib.nested(
+ mock.patch.object(migration, '_find_migrate_repo'),
+ mock.patch.object(versioning_api, 'upgrade')
+ ) as (mock_find_repo, mock_upgrade):
+
+ mock_find_repo.return_value = self.return_value
+ self.mock_api_db_version.return_value = self.test_version - 1
+
+ migration.db_sync(self.path, self.test_version, init_ver)
+
+ mock_upgrade.assert_called_once_with(
+ db_session.get_engine(), self.return_value, self.test_version)
+
+ def test_db_sync_downgrade(self):
+ with contextlib.nested(
+ mock.patch.object(migration, '_find_migrate_repo'),
+ mock.patch.object(versioning_api, 'downgrade')
+ ) as (mock_find_repo, mock_downgrade):
+
+ mock_find_repo.return_value = self.return_value
+ self.mock_api_db_version.return_value = self.test_version + 1
+
+ migration.db_sync(self.path, self.test_version)
+
+ mock_downgrade.assert_called_once_with(
+ db_session.get_engine(), self.return_value, self.test_version)
diff --git a/tests/unit/db/sqlalchemy/test_migrations.py b/tests/unit/db/sqlalchemy/test_migrations.py
index 8428b1c..ee2929c 100644
--- a/tests/unit/db/sqlalchemy/test_migrations.py
+++ b/tests/unit/db/sqlalchemy/test_migrations.py
@@ -355,10 +355,10 @@ class TestWalkVersions(test_utils.BaseTestCase, WalkVersionsMixin):
versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1)
upgraded = [mock.call(None, v, with_data=True) for v in versions]
- self.assertEquals(self._migrate_up.call_args_list, upgraded)
+ self.assertEqual(self._migrate_up.call_args_list, upgraded)
downgraded = [mock.call(None, v - 1) for v in reversed(versions)]
- self.assertEquals(self._migrate_down.call_args_list, downgraded)
+ self.assertEqual(self._migrate_down.call_args_list, downgraded)
@mock.patch.object(WalkVersionsMixin, '_migrate_up')
@mock.patch.object(WalkVersionsMixin, '_migrate_down')
@@ -376,7 +376,7 @@ class TestWalkVersions(test_utils.BaseTestCase, WalkVersionsMixin):
upgraded.extend(
[mock.call(self.engine, v) for v in reversed(versions)]
)
- self.assertEquals(upgraded, self._migrate_up.call_args_list)
+ self.assertEqual(upgraded, self._migrate_up.call_args_list)
downgraded_1 = [
mock.call(self.engine, v - 1, with_data=True) for v in versions
@@ -386,7 +386,7 @@ class TestWalkVersions(test_utils.BaseTestCase, WalkVersionsMixin):
downgraded_2.append(mock.call(self.engine, v - 1))
downgraded_2.append(mock.call(self.engine, v - 1))
downgraded = downgraded_1 + downgraded_2
- self.assertEquals(self._migrate_down.call_args_list, downgraded)
+ self.assertEqual(self._migrate_down.call_args_list, downgraded)
@mock.patch.object(WalkVersionsMixin, '_migrate_up')
@mock.patch.object(WalkVersionsMixin, '_migrate_down')
@@ -402,12 +402,12 @@ class TestWalkVersions(test_utils.BaseTestCase, WalkVersionsMixin):
for v in versions:
upgraded.append(mock.call(self.engine, v, with_data=True))
upgraded.append(mock.call(self.engine, v))
- self.assertEquals(upgraded, self._migrate_up.call_args_list)
+ self.assertEqual(upgraded, self._migrate_up.call_args_list)
downgraded = [
mock.call(self.engine, v - 1, with_data=True) for v in versions
]
- self.assertEquals(self._migrate_down.call_args_list, downgraded)
+ self.assertEqual(self._migrate_down.call_args_list, downgraded)
@mock.patch.object(WalkVersionsMixin, '_migrate_up')
@mock.patch.object(WalkVersionsMixin, '_migrate_down')
@@ -422,4 +422,4 @@ class TestWalkVersions(test_utils.BaseTestCase, WalkVersionsMixin):
upgraded = [
mock.call(self.engine, v, with_data=True) for v in versions
]
- self.assertEquals(upgraded, self._migrate_up.call_args_list)
+ self.assertEqual(upgraded, self._migrate_up.call_args_list)
diff --git a/tests/unit/db/sqlalchemy/test_models.py b/tests/unit/db/sqlalchemy/test_models.py
index 04905a6..89bf83f 100644
--- a/tests/unit/db/sqlalchemy/test_models.py
+++ b/tests/unit/db/sqlalchemy/test_models.py
@@ -16,10 +16,10 @@
# under the License.
from openstack.common.db.sqlalchemy import models
-from tests import utils as test_utils
+from openstack.common import test
-class ModelBaseTest(test_utils.BaseTestCase):
+class ModelBaseTest(test.BaseTestCase):
def test_modelbase_has_dict_methods(self):
dict_methods = ('__getitem__',
@@ -73,7 +73,7 @@ class ModelBaseTest(test_utils.BaseTestCase):
self.assertEqual(min_items, found_items)
-class TimestampMixinTest(test_utils.BaseTestCase):
+class TimestampMixinTest(test.BaseTestCase):
def test_timestampmixin_attr(self):
diff --git a/tests/unit/db/sqlalchemy/test_sqlalchemy.py b/tests/unit/db/sqlalchemy/test_sqlalchemy.py
index 48d6cf7..01141e3 100644
--- a/tests/unit/db/sqlalchemy/test_sqlalchemy.py
+++ b/tests/unit/db/sqlalchemy/test_sqlalchemy.py
@@ -53,14 +53,14 @@ sql_connection_debug=60
sql_connection_trace=True
""")])
self.conf(['--config-file', paths[0]])
- self.assertEquals(self.conf.database.connection, 'x://y.z')
- self.assertEquals(self.conf.database.min_pool_size, 10)
- self.assertEquals(self.conf.database.max_pool_size, 20)
- self.assertEquals(self.conf.database.max_retries, 30)
- self.assertEquals(self.conf.database.retry_interval, 40)
- self.assertEquals(self.conf.database.max_overflow, 50)
- self.assertEquals(self.conf.database.connection_debug, 60)
- self.assertEquals(self.conf.database.connection_trace, True)
+ self.assertEqual(self.conf.database.connection, 'x://y.z')
+ self.assertEqual(self.conf.database.min_pool_size, 10)
+ self.assertEqual(self.conf.database.max_pool_size, 20)
+ self.assertEqual(self.conf.database.max_retries, 30)
+ self.assertEqual(self.conf.database.retry_interval, 40)
+ self.assertEqual(self.conf.database.max_overflow, 50)
+ self.assertEqual(self.conf.database.connection_debug, 60)
+ self.assertEqual(self.conf.database.connection_trace, True)
def test_session_parameters(self):
paths = self.create_tempfiles([('test', """[database]
@@ -75,15 +75,15 @@ connection_trace=True
pool_timeout=7
""")])
self.conf(['--config-file', paths[0]])
- self.assertEquals(self.conf.database.connection, 'x://y.z')
- self.assertEquals(self.conf.database.min_pool_size, 10)
- self.assertEquals(self.conf.database.max_pool_size, 20)
- self.assertEquals(self.conf.database.max_retries, 30)
- self.assertEquals(self.conf.database.retry_interval, 40)
- self.assertEquals(self.conf.database.max_overflow, 50)
- self.assertEquals(self.conf.database.connection_debug, 60)
- self.assertEquals(self.conf.database.connection_trace, True)
- self.assertEquals(self.conf.database.pool_timeout, 7)
+ self.assertEqual(self.conf.database.connection, 'x://y.z')
+ self.assertEqual(self.conf.database.min_pool_size, 10)
+ self.assertEqual(self.conf.database.max_pool_size, 20)
+ self.assertEqual(self.conf.database.max_retries, 30)
+ self.assertEqual(self.conf.database.retry_interval, 40)
+ self.assertEqual(self.conf.database.max_overflow, 50)
+ self.assertEqual(self.conf.database.connection_debug, 60)
+ self.assertEqual(self.conf.database.connection_trace, True)
+ self.assertEqual(self.conf.database.pool_timeout, 7)
def test_dbapi_database_deprecated_parameters(self):
paths = self.create_tempfiles([('test',
@@ -98,15 +98,14 @@ pool_timeout=7
'sqlalchemy_pool_timeout=5\n'
)])
self.conf(['--config-file', paths[0]])
- self.assertEquals(self.conf.database.connection,
- 'fake_connection')
- self.assertEquals(self.conf.database.idle_timeout, 100)
- self.assertEquals(self.conf.database.min_pool_size, 99)
- self.assertEquals(self.conf.database.max_pool_size, 199)
- self.assertEquals(self.conf.database.max_retries, 22)
- self.assertEquals(self.conf.database.retry_interval, 17)
- self.assertEquals(self.conf.database.max_overflow, 101)
- self.assertEquals(self.conf.database.pool_timeout, 5)
+ self.assertEqual(self.conf.database.connection, 'fake_connection')
+ self.assertEqual(self.conf.database.idle_timeout, 100)
+ self.assertEqual(self.conf.database.min_pool_size, 99)
+ self.assertEqual(self.conf.database.max_pool_size, 199)
+ self.assertEqual(self.conf.database.max_retries, 22)
+ self.assertEqual(self.conf.database.retry_interval, 17)
+ self.assertEqual(self.conf.database.max_overflow, 101)
+ self.assertEqual(self.conf.database.pool_timeout, 5)
class SessionErrorWrapperTestCase(test_base.DbTestCase):
diff --git a/tests/unit/db/sqlalchemy/test_utils.py b/tests/unit/db/sqlalchemy/test_utils.py
index 15a1e25..0d8d87c 100644
--- a/tests/unit/db/sqlalchemy/test_utils.py
+++ b/tests/unit/db/sqlalchemy/test_utils.py
@@ -15,20 +15,37 @@
# License for the specific language governing permissions and limitations
# under the License.
+import warnings
+
+from migrate.changeset import UniqueConstraint
import sqlalchemy
from sqlalchemy.dialects import mysql
from sqlalchemy import Boolean, Index, Integer, DateTime, String
-from sqlalchemy import MetaData, Table, Column
+from sqlalchemy import MetaData, Table, Column, ForeignKey
from sqlalchemy.engine import reflection
+from sqlalchemy.exc import SAWarning
from sqlalchemy.sql import select
from sqlalchemy.types import UserDefinedType, NullType
from openstack.common.db.sqlalchemy import utils
-from openstack.common import exception
from tests.unit.db.sqlalchemy import test_migrations
from tests import utils as testutils
+class TestSanitizeDbUrl(testutils.BaseTestCase):
+
+ def test_url_with_cred(self):
+ db_url = 'myproto://johndoe:secret@localhost/myschema'
+ expected = 'myproto://****:****@localhost/myschema'
+ actual = utils.sanitize_db_url(db_url)
+ self.assertEqual(expected, actual)
+
+ def test_url_with_no_cred(self):
+ db_url = 'sqlite:///mysqlitefile'
+ actual = utils.sanitize_db_url(db_url)
+ self.assertEqual(db_url, actual)
+
+
class CustomType(UserDefinedType):
"""Dummy column type for testing unsupported types."""
def get_col_spec(self):
@@ -317,7 +334,7 @@ class TestMigrationUtils(test_migrations.BaseMigrationTestCase):
Column('deleted', Boolean))
table.create()
- self.assertRaises(exception.OpenstackException,
+ self.assertRaises(utils.ColumnError,
utils.change_deleted_column_type_to_id_type,
engine, table_name)
@@ -358,7 +375,7 @@ class TestMigrationUtils(test_migrations.BaseMigrationTestCase):
Column('deleted', Integer))
table.create()
- self.assertRaises(exception.OpenstackException,
+ self.assertRaises(utils.ColumnError,
utils.change_deleted_column_type_to_boolean,
engine, table_name)
@@ -371,3 +388,131 @@ class TestMigrationUtils(test_migrations.BaseMigrationTestCase):
# but sqlalchemy will set it to NullType.
self.assertTrue(isinstance(table.c.foo.type, NullType))
self.assertTrue(isinstance(table.c.deleted.type, Boolean))
+
+ def test_utils_drop_unique_constraint(self):
+ table_name = "__test_tmp_table__"
+ uc_name = 'uniq_foo'
+ values = [
+ {'id': 1, 'a': 3, 'foo': 10},
+ {'id': 2, 'a': 2, 'foo': 20},
+ {'id': 3, 'a': 1, 'foo': 30},
+ ]
+ for key, engine in self.engines.items():
+ meta = MetaData()
+ meta.bind = engine
+ test_table = Table(
+ table_name, meta,
+ Column('id', Integer, primary_key=True, nullable=False),
+ Column('a', Integer),
+ Column('foo', Integer),
+ UniqueConstraint('a', name='uniq_a'),
+ UniqueConstraint('foo', name=uc_name),
+ )
+ test_table.create()
+
+ engine.execute(test_table.insert(), values)
+ # NOTE(boris-42): This method is generic UC dropper.
+ utils.drop_unique_constraint(engine, table_name, uc_name, 'foo')
+
+ s = test_table.select().order_by(test_table.c.id)
+ rows = engine.execute(s).fetchall()
+
+ for i in xrange(0, len(values)):
+ v = values[i]
+ self.assertEqual((v['id'], v['a'], v['foo']), rows[i])
+
+ # NOTE(boris-42): Update data about Table from DB.
+ meta = MetaData()
+ meta.bind = engine
+ test_table = Table(table_name, meta, autoload=True)
+ constraints = filter(
+ lambda c: c.name == uc_name, test_table.constraints)
+ self.assertEqual(len(constraints), 0)
+ self.assertEqual(len(test_table.constraints), 1)
+
+ test_table.drop()
+
+ def test_util_drop_unique_constraint_with_not_supported_sqlite_type(self):
+ table_name = "__test_tmp_table__"
+ uc_name = 'uniq_foo'
+ values = [
+ {'id': 1, 'a': 3, 'foo': 10},
+ {'id': 2, 'a': 2, 'foo': 20},
+ {'id': 3, 'a': 1, 'foo': 30}
+ ]
+
+ engine = self.engines['sqlite']
+ meta = MetaData(bind=engine)
+
+ test_table = Table(
+ table_name, meta,
+ Column('id', Integer, primary_key=True, nullable=False),
+ Column('a', Integer),
+ Column('foo', CustomType, default=0),
+ UniqueConstraint('a', name='uniq_a'),
+ UniqueConstraint('foo', name=uc_name),
+ )
+ test_table.create()
+
+ engine.execute(test_table.insert(), values)
+ warnings.simplefilter("ignore", SAWarning)
+ # NOTE(boris-42): Missing info about column `foo` that has
+ # unsupported type CustomType.
+ self.assertRaises(utils.ColumnError,
+ utils.drop_unique_constraint,
+ engine, table_name, uc_name, 'foo')
+
+ # NOTE(boris-42): Wrong type of foo instance. it should be
+ # instance of sqlalchemy.Column.
+ self.assertRaises(utils.ColumnError,
+ utils.drop_unique_constraint,
+ engine, table_name, uc_name, 'foo', foo=Integer())
+
+ foo = Column('foo', CustomType, default=0)
+ utils.drop_unique_constraint(
+ engine, table_name, uc_name, 'foo', foo=foo)
+
+ s = test_table.select().order_by(test_table.c.id)
+ rows = engine.execute(s).fetchall()
+
+ for i in xrange(0, len(values)):
+ v = values[i]
+ self.assertEqual((v['id'], v['a'], v['foo']), rows[i])
+
+ # NOTE(boris-42): Update data about Table from DB.
+ meta = MetaData(bind=engine)
+ test_table = Table(table_name, meta, autoload=True)
+ constraints = filter(
+ lambda c: c.name == uc_name, test_table.constraints)
+ self.assertEqual(len(constraints), 0)
+ self.assertEqual(len(test_table.constraints), 1)
+ test_table.drop()
+
+ def test_drop_unique_constraint_in_sqlite_fk_recreate(self):
+ engine = self.engines['sqlite']
+ meta = MetaData()
+ meta.bind = engine
+ parent_table = Table(
+ 'table0', meta,
+ Column('id', Integer, primary_key=True),
+ Column('foo', Integer),
+ )
+ parent_table.create()
+ table_name = 'table1'
+ table = Table(
+ table_name, meta,
+ Column('id', Integer, primary_key=True),
+ Column('baz', Integer),
+ Column('bar', Integer, ForeignKey("table0.id")),
+ UniqueConstraint('baz', name='constr1')
+ )
+ table.create()
+ utils.drop_unique_constraint(engine, table_name, 'constr1', 'baz')
+
+ insp = reflection.Inspector.from_engine(engine)
+ f_keys = insp.get_foreign_keys(table_name)
+ self.assertEqual(len(f_keys), 1)
+ f_key = f_keys[0]
+ self.assertEqual(f_key['referred_table'], 'table0')
+ self.assertEqual(f_key['referred_columns'], ['id'])
+ self.assertEqual(f_key['constrained_columns'], ['bar'])
diff --git a/tests/unit/db/test_api.py b/tests/unit/db/test_api.py
index 2a8db3b..dab2198 100644
--- a/tests/unit/db/test_api.py
+++ b/tests/unit/db/test_api.py
@@ -41,8 +41,8 @@ class DBAPITestCase(test_utils.BaseTestCase):
)])
self.conf(['--config-file', paths[0]])
- self.assertEquals(self.conf.database.backend, 'test_123')
- self.assertEquals(self.conf.database.use_tpool, True)
+ self.assertEqual(self.conf.database.backend, 'test_123')
+ self.assertEqual(self.conf.database.use_tpool, True)
def test_dbapi_parameters(self):
paths = self.create_tempfiles([('test',
@@ -52,8 +52,8 @@ class DBAPITestCase(test_utils.BaseTestCase):
)])
self.conf(['--config-file', paths[0]])
- self.assertEquals(self.conf.database.backend, 'test_123')
- self.assertEquals(self.conf.database.use_tpool, True)
+ self.assertEqual(self.conf.database.backend, 'test_123')
+ self.assertEqual(self.conf.database.use_tpool, True)
def test_dbapi_api_class_method_and_tpool_false(self):
backend_mapping = {'test_known': 'tests.unit.db.test_api'}
diff --git a/tests/unit/deprecated/test_wsgi.py b/tests/unit/deprecated/test_wsgi.py
index 72aeae7..7e71837 100644
--- a/tests/unit/deprecated/test_wsgi.py
+++ b/tests/unit/deprecated/test_wsgi.py
@@ -25,7 +25,6 @@ import six
import webob
from openstack.common.deprecated import wsgi
-from openstack.common import exception
from tests import utils
@@ -44,7 +43,7 @@ class RequestTest(utils.BaseTestCase):
request = wsgi.Request.blank('/tests/123', method='POST')
request.headers["Content-Type"] = "text/html"
request.body = "asdf<br />"
- self.assertRaises(exception.InvalidContentType,
+ self.assertRaises(wsgi.InvalidContentType,
request.get_content_type)
def test_content_type_with_charset(self):
@@ -311,7 +310,7 @@ class ResponseSerializerTest(utils.BaseTestCase):
self.body_serializers[ctype])
def test_get_serializer_unknown_content_type(self):
- self.assertRaises(exception.InvalidContentType,
+ self.assertRaises(wsgi.InvalidContentType,
self.serializer.get_body_serializer,
'application/unknown')
@@ -335,7 +334,7 @@ class ResponseSerializerTest(utils.BaseTestCase):
self.assertEqual(response.status_int, 404)
def test_serialize_response_dict_to_unknown_content_type(self):
- self.assertRaises(exception.InvalidContentType,
+ self.assertRaises(wsgi.InvalidContentType,
self.serializer.serialize,
{}, 'application/unknown')
@@ -371,7 +370,7 @@ class RequestDeserializerTest(utils.BaseTestCase):
self.body_deserializers['application/xml'])
def test_get_deserializer_unknown_content_type(self):
- self.assertRaises(exception.InvalidContentType,
+ self.assertRaises(wsgi.InvalidContentType,
self.deserializer.get_body_deserializer,
'application/unknown')
@@ -523,7 +522,7 @@ class WSGIServerTest(utils.BaseTestCase):
server.start()
response = urllib2.urlopen('http://127.0.0.1:%d/' % server.port)
- self.assertEquals(greetings, response.read())
+ self.assertEqual(greetings, response.read())
server.stop()
@@ -541,7 +540,7 @@ class WSGIServerTest(utils.BaseTestCase):
server.start()
response = urllib2.urlopen('http://127.0.0.1:%d/v1.0/' % server.port)
- self.assertEquals(greetings, response.read())
+ self.assertEqual(greetings, response.read())
server.stop()
@@ -595,7 +594,7 @@ class WSGIServerWithSSLTest(utils.BaseTestCase):
server.start()
response = urllib2.urlopen('https://127.0.0.1:%d/v1.0/' % server.port)
- self.assertEquals(greetings, response.read())
+ self.assertEqual(greetings, response.read())
server.stop()
@@ -618,6 +617,6 @@ class WSGIServerWithSSLTest(utils.BaseTestCase):
server.start()
response = urllib2.urlopen('https://[::1]:%d/v1.0/' % server.port)
- self.assertEquals(greetings, response.read())
+ self.assertEqual(greetings, response.read())
server.stop()
diff --git a/tests/unit/fixture/__init__.py b/tests/unit/fixture/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/unit/fixture/__init__.py
diff --git a/tests/unit/fixture/test_config.py b/tests/unit/fixture/test_config.py
new file mode 100644
index 0000000..3368272
--- /dev/null
+++ b/tests/unit/fixture/test_config.py
@@ -0,0 +1,45 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+#
+# Copyright 2013 Mirantis, Inc.
+# Copyright 2013 OpenStack Foundation
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+from oslo.config import cfg
+
+from openstack.common.fixture import config
+from tests.utils import BaseTestCase
+
+conf = cfg.CONF
+
+
+class ConfigTestCase(BaseTestCase):
+ def setUp(self):
+ super(ConfigTestCase, self).setUp()
+ self.config = self.useFixture(config.Config(conf)).config
+ self.config_fixture = config.Config(conf)
+ conf.register_opt(cfg.StrOpt(
+ 'testing_option', default='initial_value'))
+
+ def test_overriden_value(self):
+ self.assertEqual(conf.get('testing_option'), 'initial_value')
+ self.config(testing_option='changed_value')
+ self.assertEqual(conf.get('testing_option'),
+ self.config_fixture.conf.get('testing_option'))
+
+ def test_cleanup(self):
+ self.config(testing_option='changed_value')
+ self.assertEqual(self.config_fixture.conf.get('testing_option'),
+ 'changed_value')
+ self.config_fixture.conf.reset()
+ self.assertEqual(conf.get('testing_option'), 'initial_value')
diff --git a/tests/unit/middleware/test_context.py b/tests/unit/middleware/test_context.py
index 5bf5ae4..daffa60 100644
--- a/tests/unit/middleware/test_context.py
+++ b/tests/unit/middleware/test_context.py
@@ -20,10 +20,10 @@ import mock
import openstack.common.context
from openstack.common.middleware import context
-from tests import utils
+from openstack.common import test
-class ContextMiddlewareTest(utils.BaseTestCase):
+class ContextMiddlewareTest(test.BaseTestCase):
def test_process_request(self):
req = mock.Mock()
@@ -59,7 +59,7 @@ class ContextMiddlewareTest(utils.BaseTestCase):
import_class.assert_called_with(mock.sentinel.arg)
-class FilterFactoryTest(utils.BaseTestCase):
+class FilterFactoryTest(test.BaseTestCase):
def test_filter_factory(self):
global_conf = dict(sentinel=mock.sentinel.global_conf)
diff --git a/tests/unit/middleware/test_correlation_id.py b/tests/unit/middleware/test_correlation_id.py
index dc83cc7..11939e9 100644
--- a/tests/unit/middleware/test_correlation_id.py
+++ b/tests/unit/middleware/test_correlation_id.py
@@ -36,7 +36,7 @@ class CorrelationIdMiddlewareTest(utils.BaseTestCase):
middleware = correlation_id.CorrelationIdMiddleware(app)
middleware(req)
- self.assertEquals(req.headers.get("X_CORRELATION_ID"), "fake_uuid")
+ self.assertEqual(req.headers.get("X_CORRELATION_ID"), "fake_uuid")
def test_process_request_should_not_regenerate_correlation_id(self):
app = mock.Mock()
@@ -46,5 +46,4 @@ class CorrelationIdMiddlewareTest(utils.BaseTestCase):
middleware = correlation_id.CorrelationIdMiddleware(app)
middleware(req)
- self.assertEquals(req.headers.get("X_CORRELATION_ID"),
- "correlation_id")
+ self.assertEqual(req.headers.get("X_CORRELATION_ID"), "correlation_id")
diff --git a/tests/unit/middleware/test_sizelimit.py b/tests/unit/middleware/test_sizelimit.py
index 7579659..3666b54 100644
--- a/tests/unit/middleware/test_sizelimit.py
+++ b/tests/unit/middleware/test_sizelimit.py
@@ -32,7 +32,7 @@ class TestLimitingReader(utils.BaseTestCase):
for chunk in sizelimit.LimitingReader(data, BYTES):
bytes_read += len(chunk)
- self.assertEquals(bytes_read, BYTES)
+ self.assertEqual(bytes_read, BYTES)
bytes_read = 0
data = six.StringIO("*" * BYTES)
@@ -42,7 +42,7 @@ class TestLimitingReader(utils.BaseTestCase):
bytes_read += 1
byte = reader.read(1)
- self.assertEquals(bytes_read, BYTES)
+ self.assertEqual(bytes_read, BYTES)
def test_limiting_reader_fails(self):
BYTES = 1024
diff --git a/tests/unit/rpc/amqp.py b/tests/unit/rpc/amqp.py
index 83713c7..76d6946 100644
--- a/tests/unit/rpc/amqp.py
+++ b/tests/unit/rpc/amqp.py
@@ -22,6 +22,7 @@ Unit Tests for AMQP-based remote procedure calls
import logging
from eventlet import greenthread
+import mock
from oslo.config import cfg
from openstack.common import jsonutils
@@ -177,3 +178,13 @@ class BaseRpcAMQPTestCase(common.BaseRpcTestCase):
conn.close()
self.assertTrue(self.exc_raised)
+
+ def test_context_dict_type_check(self):
+ """Test that context is handled properly depending on the type."""
+ fake_context = {'fake': 'context'}
+ mock_msg = mock.MagicMock()
+ rpc_amqp.pack_context(mock_msg, fake_context)
+
+ # assert first arg in args was a dict type
+ args = mock_msg.update.call_args[0]
+ self.assertIsInstance(args[0], dict)
diff --git a/tests/unit/rpc/common.py b/tests/unit/rpc/common.py
index 2c343fe..4cc279c 100644
--- a/tests/unit/rpc/common.py
+++ b/tests/unit/rpc/common.py
@@ -26,7 +26,6 @@ import time
import eventlet
from oslo.config import cfg
-from openstack.common import exception
from openstack.common.gettextutils import _ # noqa
from openstack.common.rpc import common as rpc_common
from openstack.common.rpc import dispatcher as rpc_dispatcher
@@ -37,6 +36,13 @@ FLAGS = cfg.CONF
LOG = logging.getLogger(__name__)
+class ApiError(Exception):
+ def __init__(self, message='Unknown', code='Unknown'):
+ self.api_message = message
+ self.code = code
+ super(ApiError, self).__init__('%s: %s' % (code, message))
+
+
class BaseRpcTestCase(test_utils.BaseTestCase):
def setUp(self, supports_timeouts=True, topic='test',
@@ -424,7 +430,7 @@ class TestReceiver(object):
@staticmethod
def fail_converted(context, value):
"""Raises an exception with the value sent in."""
- raise exception.ApiError(message=value, code='500')
+ raise ApiError(message=value, code='500')
@staticmethod
def block(context, value):
diff --git a/tests/unit/rpc/test_common.py b/tests/unit/rpc/test_common.py
index 291f823..f37f4b0 100644
--- a/tests/unit/rpc/test_common.py
+++ b/tests/unit/rpc/test_common.py
@@ -23,7 +23,6 @@ import sys
from oslo.config import cfg
import six
-from openstack.common import exception
from openstack.common import importutils
from openstack.common import jsonutils
from openstack.common import rpc
@@ -35,10 +34,6 @@ FLAGS = cfg.CONF
LOG = logging.getLogger(__name__)
-def raise_exception():
- raise Exception("test")
-
-
class FakeUserDefinedException(Exception):
def __init__(self, *args, **kwargs):
super(FakeUserDefinedException, self).__init__(*args)
@@ -54,7 +49,7 @@ class RpcCommonTestCase(test_utils.BaseTestCase):
}
try:
- raise_exception()
+ raise Exception("test")
except Exception:
failure = rpc_common.serialize_remote_exception(sys.exc_info())
@@ -64,17 +59,14 @@ class RpcCommonTestCase(test_utils.BaseTestCase):
self.assertEqual(expected['message'], failure['message'])
def test_serialize_remote_custom_exception(self):
- def raise_custom_exception():
- raise exception.MalformedRequestBody(reason='test')
-
expected = {
- 'class': 'MalformedRequestBody',
- 'module': 'openstack.common.exception',
- 'message': str(exception.MalformedRequestBody(reason='test')),
+ 'class': 'FakeUserDefinedException',
+ 'module': self.__class__.__module__,
+ 'message': 'test',
}
try:
- raise_custom_exception()
+ raise FakeUserDefinedException('test')
except Exception:
failure = rpc_common.serialize_remote_exception(sys.exc_info())
@@ -88,14 +80,14 @@ class RpcCommonTestCase(test_utils.BaseTestCase):
# module, when being re-serialized, so that through any amount of cell
# hops up, it can pop out with the right type
expected = {
- 'class': 'OpenstackException',
- 'module': 'openstack.common.exception',
- 'message': exception.OpenstackException.msg_fmt,
+ 'class': 'FakeUserDefinedException',
+ 'module': self.__class__.__module__,
+ 'message': 'foobar',
}
def raise_remote_exception():
try:
- raise exception.OpenstackException()
+ raise FakeUserDefinedException('foobar')
except Exception as e:
ex_type = type(e)
message = str(e)
@@ -132,26 +124,6 @@ class RpcCommonTestCase(test_utils.BaseTestCase):
self.assertTrue('raise NotImplementedError' in
six.text_type(after_exc))
- def test_deserialize_remote_custom_exception(self):
- failure = {
- 'class': 'OpenstackException',
- 'module': 'openstack.common.exception',
- 'message': exception.OpenstackException.msg_fmt,
- 'tb': ['raise OpenstackException'],
- }
- serialized = jsonutils.dumps(failure)
-
- after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized)
- self.assertTrue(isinstance(after_exc, exception.OpenstackException))
- self.assertTrue('An unknown' in six.text_type(after_exc))
- #assure the traceback was added
- self.assertTrue('raise OpenstackException' in six.text_type(after_exc))
- self.assertEqual('OpenstackException_Remote',
- after_exc.__class__.__name__)
- self.assertEqual('openstack.common.exception_Remote',
- after_exc.__class__.__module__)
- self.assertTrue(isinstance(after_exc, exception.OpenstackException))
-
def test_deserialize_remote_exception_bad_module(self):
failure = {
'class': 'popen2',
@@ -169,6 +141,7 @@ class RpcCommonTestCase(test_utils.BaseTestCase):
self.config(allowed_rpc_exception_modules=[self.__class__.__module__])
failure = {
'class': 'FakeUserDefinedException',
+ 'message': 'foobar',
'module': self.__class__.__module__,
'tb': ['raise FakeUserDefinedException'],
}
@@ -176,6 +149,7 @@ class RpcCommonTestCase(test_utils.BaseTestCase):
after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized)
self.assertTrue(isinstance(after_exc, FakeUserDefinedException))
+ self.assertTrue('foobar' in six.text_type(after_exc))
#assure the traceback was added
self.assertTrue('raise FakeUserDefinedException' in
six.text_type(after_exc))
@@ -327,8 +301,8 @@ class RpcCommonTestCase(test_utils.BaseTestCase):
def test_safe_log_sanitizes_globals(self):
def logger_method(msg, data):
- self.assertEquals('<SANITIZED>', data['_context_auth_token'])
- self.assertEquals('<SANITIZED>', data['auth_token'])
+ self.assertEqual('<SANITIZED>', data['_context_auth_token'])
+ self.assertEqual('<SANITIZED>', data['auth_token'])
data = {'_context_auth_token': 'banana',
'auth_token': 'cheese'}
@@ -336,7 +310,7 @@ class RpcCommonTestCase(test_utils.BaseTestCase):
def test_safe_log_sanitizes_set_admin_password(self):
def logger_method(msg, data):
- self.assertEquals('<SANITIZED>', data['args']['new_pass'])
+ self.assertEqual('<SANITIZED>', data['args']['new_pass'])
data = {'_context_auth_token': 'banana',
'auth_token': 'cheese',
@@ -346,7 +320,7 @@ class RpcCommonTestCase(test_utils.BaseTestCase):
def test_safe_log_sanitizes_run_instance(self):
def logger_method(msg, data):
- self.assertEquals('<SANITIZED>', data['args']['admin_password'])
+ self.assertEqual('<SANITIZED>', data['args']['admin_password'])
data = {'_context_auth_token': 'banana',
'auth_token': 'cheese',
@@ -356,8 +330,8 @@ class RpcCommonTestCase(test_utils.BaseTestCase):
def test_safe_log_sanitizes_any_password_in_context(self):
def logger_method(msg, data):
- self.assertEquals('<SANITIZED>', data['_context_password'])
- self.assertEquals('<SANITIZED>', data['password'])
+ self.assertEqual('<SANITIZED>', data['_context_password'])
+ self.assertEqual('<SANITIZED>', data['password'])
data = {'_context_auth_token': 'banana',
'auth_token': 'cheese',
@@ -369,7 +343,7 @@ class RpcCommonTestCase(test_utils.BaseTestCase):
def test_safe_log_sanitizes_cells_route_message(self):
def logger_method(msg, data):
vals = data['args']['message']['args']['method_info']
- self.assertEquals('<SANITIZED>', vals['method_kwargs']['password'])
+ self.assertEqual('<SANITIZED>', vals['method_kwargs']['password'])
meth_info = {'method_args': ['aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee'],
'method': 'set_admin_password',
diff --git a/tests/unit/rpc/test_kombu.py b/tests/unit/rpc/test_kombu.py
index cbe948d..695e701 100644
--- a/tests/unit/rpc/test_kombu.py
+++ b/tests/unit/rpc/test_kombu.py
@@ -33,10 +33,10 @@ from oslo.config import cfg
import six
import time
-from openstack.common import exception
from openstack.common.rpc import amqp as rpc_amqp
from openstack.common.rpc import common as rpc_common
from tests.unit.rpc import amqp
+from tests.unit.rpc import common
from tests import utils
try:
@@ -596,7 +596,8 @@ class RpcKombuTestCase(amqp.BaseRpcAMQPTestCase):
"""
value = "This is the exception message"
# The use of ApiError is an arbitrary choice here ...
- self.assertRaises(exception.ApiError,
+ self.config(allowed_rpc_exception_modules=[common.__name__])
+ self.assertRaises(common.ApiError,
self.rpc.call,
FLAGS,
self.context,
@@ -609,10 +610,10 @@ class RpcKombuTestCase(amqp.BaseRpcAMQPTestCase):
{"method": "fail_converted",
"args": {"value": value}})
self.fail("should have thrown Exception")
- except exception.ApiError as exc:
+ except common.ApiError as exc:
self.assertTrue(value in six.text_type(exc))
#Traceback should be included in exception message
- self.assertTrue('exception.ApiError' in six.text_type(exc))
+ self.assertTrue('ApiError' in six.text_type(exc))
def test_create_worker(self):
meth = 'declare_topic_consumer'
diff --git a/tests/unit/rpc/test_qpid.py b/tests/unit/rpc/test_qpid.py
index a4bce1c..910e292 100644
--- a/tests/unit/rpc/test_qpid.py
+++ b/tests/unit/rpc/test_qpid.py
@@ -438,9 +438,9 @@ class RpcQpidTestCase(utils.BaseTestCase):
{"method": "test_method", "args": {}})
if multi:
- self.assertEquals(list(res), ["foo", "bar", "baz"])
+ self.assertEqual(list(res), ["foo", "bar", "baz"])
else:
- self.assertEquals(res, "foo")
+ self.assertEqual(res, "foo")
finally:
impl_qpid.cleanup()
self.uuid4 = uuid.uuid4()
@@ -488,7 +488,7 @@ class RpcQpidTestCase(utils.BaseTestCase):
else:
res = method(FLAGS, ctx, "impl_qpid_test",
{"method": "test_method", "args": {}}, timeout)
- self.assertEquals(res, "foo")
+ self.assertEqual(res, "foo")
finally:
impl_qpid.cleanup()
self.uuid4 = uuid.uuid4()
diff --git a/tests/unit/rpc/test_securemessage.py b/tests/unit/rpc/test_securemessage.py
new file mode 100644
index 0000000..8c07df1
--- /dev/null
+++ b/tests/unit/rpc/test_securemessage.py
@@ -0,0 +1,134 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2013 Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+Unit Tests for rpc 'securemessage' functions.
+"""
+
+import logging
+
+from oslo.config import cfg
+
+from openstack.common import jsonutils
+from openstack.common.rpc import common as rpc_common
+from openstack.common.rpc import securemessage as rpc_secmsg
+from tests import utils as test_utils
+
+
+CONF = cfg.CONF
+LOG = logging.getLogger(__name__)
+
+
+class RpcCryptoTestCase(test_utils.BaseTestCase):
+
+ def test_KeyStore(self):
+ store = rpc_secmsg.KeyStore()
+
+ # check empty cache returns noting
+ keys = store.get_ticket('foo', 'bar')
+ self.assertIsNone(keys)
+
+ ticket = rpc_secmsg.Ticket('skey', 'ekey', 'esek')
+
+ #add entry in the cache
+ store.put_ticket('foo', 'bar', 'skey', 'ekey', 'esek', 2000000000)
+
+ #chck it returns the object
+ keys = store.get_ticket('foo', 'bar')
+ self.assertEqual(keys, ticket)
+
+ #check inverted source/target returns nothing
+ keys = store.get_ticket('bar', 'foo')
+ self.assertIsNone(keys)
+
+ #add expired entry in the cache
+ store.put_ticket('foo', 'bar', 'skey', 'ekey', 'skey', 1000000000)
+
+ #check expired entries are not returned
+ keys = store.get_ticket('foo', 'bar')
+ self.assertIsNone(keys)
+
+ def _test_secure_message(self, data, encrypt):
+ msg = {'message': 'body'}
+
+ # Use a fresh store for each test
+ store = rpc_secmsg.KeyStore()
+
+ send = rpc_secmsg.SecureMessage(data['source'][0], data['source'][1],
+ CONF, data['send_key'],
+ store, encrypt, enctype=data['cipher'],
+ hashtype=data['hash'])
+ recv = rpc_secmsg.SecureMessage(data['target'][0], data['target'][1],
+ CONF, data['recv_key'],
+ store, encrypt, enctype=data['cipher'],
+ hashtype=data['hash'])
+
+ source = '%s.%s' % data['source']
+ target = '%s.%s' % data['target']
+ # Adds test keys in cache, we do it twice, once for client side use,
+ # then for server side use as we run both in the same process
+ store.put_ticket(source, target,
+ data['skey'], data['ekey'], data['esek'], 2000000000)
+
+ pkt = send.encode(rpc_common._RPC_ENVELOPE_VERSION,
+ target, jsonutils.dumps(msg))
+
+ out = recv.decode(rpc_common._RPC_ENVELOPE_VERSION,
+ pkt[0], pkt[1], pkt[2])
+ rmsg = jsonutils.loads(out[1])
+
+ self.assertEqual(len(msg),
+ len(set(msg.items()) & set(rmsg.items())))
+
+ def test_secure_message_sha256_aes(self):
+ foo_to_bar_sha256_aes = {
+ 'source': ('foo', 'host.example.com'),
+ 'target': ('bar', 'host.example.com'),
+ 'send_key': '\x0b' * 16,
+ 'recv_key': '\x0b' * 16,
+ 'hash': 'SHA256',
+ 'cipher': 'AES',
+ 'skey': "\xaf\xab\x81\x14'\xdd\x1ck\xd1\xb4[\x84MZ\xf5\r",
+ 'ekey': '\x98\x06\x1bW\x1e\xc1z\xdd\xe2\xb1h\xa5\xb7;\x14\n',
+ 'esek': ('IehVCF684xJVN0sHc/zngsCAZWQkKSueK4I+ycRhxDGYsqYaAw+nECnZ'
+ 'mgA3R+DM8halM5TEwwI/uuPqExu8p+fW4CqSMh8oEtLGGqrx85GromaH'
+ '/YVqK1GpIfUSIQSZrXhAzITN9MeYfeLhD0w2ENUG6AyAk3D56W6l9zJw'
+ 'ZsI=')
+ }
+ # Test signing only first
+ self._test_secure_message(foo_to_bar_sha256_aes, False)
+ # Test encryption too
+ self._test_secure_message(foo_to_bar_sha256_aes, True)
+
+ def test_secure_message_md5_des(self):
+ foo_to_baz_md5_des = {
+ 'source': ('foo', 'host.example.com'),
+ 'target': ('bar', 'host.example.com'),
+ 'send_key': '????????',
+ 'recv_key': '????????',
+ 'hash': 'MD5',
+ 'cipher': 'DES',
+ 'skey': 'N<\xeb\x98\x9f$\xa9\xa8',
+ 'ekey': '\x8c\xd2\x02\x89\xbb6\xd0\xdd',
+ 'esek': ('CyVMteHe5LiYWFcRnodPv4t8UJ14QztJCC0p/olib9vq50/wua0LY6sk'
+ 'WWe0GGcvEdzaoZAuH6eBh00CdAVT2LqlK0nBE3Szj93jmVIJxMM+ydxZ'
+ '2VCvEZohhKeenMiI')
+ }
+ # Test signing only first
+ self._test_secure_message(foo_to_baz_md5_des, False)
+ # Test encryption too
+ self._test_secure_message(foo_to_baz_md5_des, True)
+
+ #TODO(simo): test fetching key from file
diff --git a/tests/unit/rpc/test_zmq.py b/tests/unit/rpc/test_zmq.py
index 55e3eb1..888dbb0 100644
--- a/tests/unit/rpc/test_zmq.py
+++ b/tests/unit/rpc/test_zmq.py
@@ -106,9 +106,7 @@ class _RpcZmqBaseTestCase(common.BaseRpcTestCase):
for char in badchars:
self.topic_nested = char.join(('hello', 'world'))
try:
- # TODO(ewindisch): Determine which exception is raised.
- # pending bug #1121348
- self.assertRaises(Exception, self._test_cast,
+ self.assertRaises(AssertionError, self._test_cast,
common.TestReceiver.echo, 42, {"value": 42},
fanout=False)
finally:
diff --git a/tests/unit/scheduler/test_weights.py b/tests/unit/scheduler/test_weights.py
index 21d3f3e..fcb25c6 100644
--- a/tests/unit/scheduler/test_weights.py
+++ b/tests/unit/scheduler/test_weights.py
@@ -17,11 +17,11 @@ Tests For Scheduler weights.
"""
from openstack.common.scheduler import base_weight
+from openstack.common import test
from tests.unit import fakes
-from tests import utils
-class TestWeightHandler(utils.BaseTestCase):
+class TestWeightHandler(test.BaseTestCase):
def test_get_all_classes(self):
namespace = "openstack.common.tests.fakes.weights"
handler = base_weight.BaseWeightHandler(
diff --git a/tests/unit/test_authutils.py b/tests/unit/test_authutils.py
index 3596df9..82433ad 100644
--- a/tests/unit/test_authutils.py
+++ b/tests/unit/test_authutils.py
@@ -16,10 +16,10 @@
# under the License.
from openstack.common import authutils
-from tests import utils
+from openstack.common import test
-class AuthUtilsTest(utils.BaseTestCase):
+class AuthUtilsTest(test.BaseTestCase):
def test_auth_str_equal(self):
self.assertTrue(authutils.auth_str_equal('abc123', 'abc123'))
diff --git a/tests/unit/test_cfgfilter.py b/tests/unit/test_cfgfilter.py
index 58b4e2d..1270758 100644
--- a/tests/unit/test_cfgfilter.py
+++ b/tests/unit/test_cfgfilter.py
@@ -30,11 +30,11 @@ class ConfigFilterTestCase(utils.BaseTestCase):
def test_register_opt_default(self):
self.fconf.register_opt(cfg.StrOpt('foo', default='bar'))
- self.assertEquals(self.fconf.foo, 'bar')
- self.assertEquals(self.fconf['foo'], 'bar')
+ self.assertEqual(self.fconf.foo, 'bar')
+ self.assertEqual(self.fconf['foo'], 'bar')
self.assertTrue('foo' in self.fconf)
- self.assertEquals(list(self.fconf), ['foo'])
- self.assertEquals(len(self.fconf), 1)
+ self.assertEqual(list(self.fconf), ['foo'])
+ self.assertEqual(len(self.fconf), 1)
def test_register_opt_none_default(self):
self.fconf.register_opt(cfg.StrOpt('foo'))
@@ -42,21 +42,21 @@ class ConfigFilterTestCase(utils.BaseTestCase):
self.assertTrue(self.fconf.foo is None)
self.assertTrue(self.fconf['foo'] is None)
self.assertTrue('foo' in self.fconf)
- self.assertEquals(list(self.fconf), ['foo'])
- self.assertEquals(len(self.fconf), 1)
+ self.assertEqual(list(self.fconf), ['foo'])
+ self.assertEqual(len(self.fconf), 1)
def test_register_grouped_opt_default(self):
self.fconf.register_opt(cfg.StrOpt('foo', default='bar'),
group='blaa')
- self.assertEquals(self.fconf.blaa.foo, 'bar')
- self.assertEquals(self.fconf['blaa']['foo'], 'bar')
+ self.assertEqual(self.fconf.blaa.foo, 'bar')
+ self.assertEqual(self.fconf['blaa']['foo'], 'bar')
self.assertTrue('blaa' in self.fconf)
self.assertTrue('foo' in self.fconf.blaa)
- self.assertEquals(list(self.fconf), ['blaa'])
- self.assertEquals(list(self.fconf.blaa), ['foo'])
- self.assertEquals(len(self.fconf), 1)
- self.assertEquals(len(self.fconf.blaa), 1)
+ self.assertEqual(list(self.fconf), ['blaa'])
+ self.assertEqual(list(self.fconf.blaa), ['foo'])
+ self.assertEqual(len(self.fconf), 1)
+ self.assertEqual(len(self.fconf.blaa), 1)
def test_register_grouped_opt_none_default(self):
self.fconf.register_opt(cfg.StrOpt('foo'), group='blaa')
@@ -65,10 +65,10 @@ class ConfigFilterTestCase(utils.BaseTestCase):
self.assertTrue(self.fconf['blaa']['foo'] is None)
self.assertTrue('blaa' in self.fconf)
self.assertTrue('foo' in self.fconf.blaa)
- self.assertEquals(list(self.fconf), ['blaa'])
- self.assertEquals(list(self.fconf.blaa), ['foo'])
- self.assertEquals(len(self.fconf), 1)
- self.assertEquals(len(self.fconf.blaa), 1)
+ self.assertEqual(list(self.fconf), ['blaa'])
+ self.assertEqual(list(self.fconf.blaa), ['foo'])
+ self.assertEqual(len(self.fconf), 1)
+ self.assertEqual(len(self.fconf.blaa), 1)
def test_register_group(self):
group = cfg.OptGroup('blaa')
@@ -79,24 +79,24 @@ class ConfigFilterTestCase(utils.BaseTestCase):
self.assertTrue(self.fconf['blaa']['foo'] is None)
self.assertTrue('blaa' in self.fconf)
self.assertTrue('foo' in self.fconf.blaa)
- self.assertEquals(list(self.fconf), ['blaa'])
- self.assertEquals(list(self.fconf.blaa), ['foo'])
- self.assertEquals(len(self.fconf), 1)
- self.assertEquals(len(self.fconf.blaa), 1)
+ self.assertEqual(list(self.fconf), ['blaa'])
+ self.assertEqual(list(self.fconf.blaa), ['foo'])
+ self.assertEqual(len(self.fconf), 1)
+ self.assertEqual(len(self.fconf.blaa), 1)
def test_unknown_opt(self):
self.assertFalse('foo' in self.fconf)
- self.assertEquals(len(self.fconf), 0)
+ self.assertEqual(len(self.fconf), 0)
self.assertRaises(cfg.NoSuchOptError, getattr, self.fconf, 'foo')
def test_blocked_opt(self):
self.conf.register_opt(cfg.StrOpt('foo'))
self.assertTrue('foo' in self.conf)
- self.assertEquals(len(self.conf), 1)
+ self.assertEqual(len(self.conf), 1)
self.assertTrue(self.conf.foo is None)
self.assertFalse('foo' in self.fconf)
- self.assertEquals(len(self.fconf), 0)
+ self.assertEqual(len(self.fconf), 0)
self.assertRaises(cfg.NoSuchOptError, getattr, self.fconf, 'foo')
def test_import_opt(self):
diff --git a/tests/unit/test_cliutils.py b/tests/unit/test_cliutils.py
index 13f954c..fada65f 100644
--- a/tests/unit/test_cliutils.py
+++ b/tests/unit/test_cliutils.py
@@ -15,10 +15,10 @@
# under the License.
from openstack.common import cliutils
-from tests import utils
+from openstack.common import test
-class ValidateArgsTest(utils.BaseTestCase):
+class ValidateArgsTest(test.BaseTestCase):
def test_lambda_no_args(self):
cliutils.validate_args(lambda: None)
diff --git a/tests/unit/test_compat.py b/tests/unit/test_compat.py
new file mode 100644
index 0000000..f71e130
--- /dev/null
+++ b/tests/unit/test_compat.py
@@ -0,0 +1,53 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+#
+# Copyright 2013 Canonical Ltd.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+
+from openstack.common.py3kcompat import urlutils
+from tests import utils
+
+
+class CompatTestCase(utils.BaseTestCase):
+ def test_urlencode(self):
+ fake = 'fake'
+ result = urlutils.urlencode({'Fake': fake})
+ self.assertEqual(result, 'Fake=fake')
+
+ def test_urljoin(self):
+ root_url = "http://yahoo.com/"
+ url2 = "faq.html"
+ result = urlutils.urljoin(root_url, url2)
+ self.assertEqual(result, "http://yahoo.com/faq.html")
+
+ def test_urlquote(self):
+ url = "/~fake"
+ result = urlutils.quote(url)
+ self.assertEqual(result, '/%7Efake')
+
+ def test_urlparse(self):
+ url = 'http://www.yahoo.com'
+ result = urlutils.urlparse(url)
+ self.assertEqual(result.scheme, 'http')
+
+ def test_urlsplit(self):
+ url = 'http://www.yahoo.com'
+ result = urlutils.urlsplit(url)
+ self.assertEqual(result.scheme, 'http')
+
+ def test_urlunsplit(self):
+ url = "http://www.yahoo.com"
+ result = urlutils.urlunsplit(urlutils.urlsplit(url))
+ self.assertEqual(result, 'http://www.yahoo.com')
diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py
index 2f9a3de..ae79b31 100644
--- a/tests/unit/test_context.py
+++ b/tests/unit/test_context.py
@@ -16,10 +16,10 @@
# under the License.
from openstack.common import context
-from tests import utils
+from openstack.common import test
-class ContextTest(utils.BaseTestCase):
+class ContextTest(test.BaseTestCase):
def test_context(self):
ctx = context.RequestContext()
diff --git a/tests/unit/test_exception.py b/tests/unit/test_exception.py
deleted file mode 100644
index 407e68d..0000000
--- a/tests/unit/test_exception.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2011 OpenStack Foundation.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-from openstack.common import exception
-from tests import utils
-
-
-def good_function():
- return "Is Bueno!"
-
-
-def bad_function_error():
- raise exception.Error()
-
-
-def bad_function_exception():
- raise Exception()
-
-
-class WrapExceptionTest(utils.BaseTestCase):
-
- def test_wrap_exception_good_return(self):
- wrapped = exception.wrap_exception
- self.assertEquals(good_function(), wrapped(good_function)())
-
- def test_wrap_exception_throws_error(self):
- wrapped = exception.wrap_exception
- self.assertRaises(exception.Error, wrapped(bad_function_error))
-
- def test_wrap_exception_throws_exception(self):
- wrapped = exception.wrap_exception
- self.assertRaises(Exception, wrapped(bad_function_exception))
-
-
-class ApiErrorTest(utils.BaseTestCase):
-
- def test_without_code(self):
- err = exception.ApiError('fake error')
- self.assertEqual(str(err), 'Unknown: fake error')
- self.assertEqual(err.code, 'Unknown')
- self.assertEqual(err.api_message, 'fake error')
-
- def test_with_code(self):
- err = exception.ApiError('fake error', 'blah code')
- self.assertEqual(str(err), 'blah code: fake error')
- self.assertEqual(err.code, 'blah code')
- self.assertEqual(err.api_message, 'fake error')
-
-
-class BadStoreUriTest(utils.BaseTestCase):
-
- def test(self):
- uri = 'http:///etc/passwd'
- reason = 'Permission DENIED!'
- err = exception.BadStoreUri(uri, reason)
- self.assertTrue(uri in str(err))
- self.assertTrue(reason in str(err))
-
-
-class UnknownSchemeTest(utils.BaseTestCase):
-
- def test(self):
- scheme = 'http'
- err = exception.UnknownScheme(scheme)
- self.assertTrue(scheme in str(err))
-
-
-class OpenstackExceptionTest(utils.BaseTestCase):
- class TestException(exception.OpenstackException):
- msg_fmt = '%(test)s'
-
- def test_format_error_string(self):
- test_message = 'Know Your Meme'
- err = self.TestException(test=test_message)
- self.assertEqual(err._error_string, test_message)
-
- def test_error_formating_error_string(self):
- self.stubs.Set(exception, '_FATAL_EXCEPTION_FORMAT_ERRORS', False)
- err = self.TestException(lol='U mad brah')
- self.assertEqual(err._error_string, self.TestException.msg_fmt)
-
- def test_str(self):
- message = 'Y u no fail'
- err = self.TestException(test=message)
- self.assertEqual(str(err), message)
diff --git a/tests/unit/test_fileutils.py b/tests/unit/test_fileutils.py
index 4214e83..139d478 100644
--- a/tests/unit/test_fileutils.py
+++ b/tests/unit/test_fileutils.py
@@ -25,10 +25,10 @@ import mock
import mox
from openstack.common import fileutils
-from tests import utils
+from openstack.common import test
-class EnsureTree(utils.BaseTestCase):
+class EnsureTree(test.BaseTestCase):
def test_ensure_tree(self):
tmpdir = tempfile.mkdtemp()
try:
@@ -41,7 +41,7 @@ class EnsureTree(utils.BaseTestCase):
shutil.rmtree(tmpdir)
-class TestCachedFile(utils.BaseTestCase):
+class TestCachedFile(test.BaseTestCase):
def setUp(self):
super(TestCachedFile, self).setUp()
@@ -86,7 +86,7 @@ class TestCachedFile(utils.BaseTestCase):
self.assertTrue(fresh)
-class DeleteIfExists(utils.BaseTestCase):
+class DeleteIfExists(test.BaseTestCase):
def test_file_present(self):
tmpfile = tempfile.mktemp()
@@ -113,7 +113,7 @@ class DeleteIfExists(utils.BaseTestCase):
self.assertRaises(OSError, fileutils.delete_if_exists, tmpfile)
-class RemovePathOnError(utils.BaseTestCase):
+class RemovePathOnError(test.BaseTestCase):
def test_error(self):
tmpfile = tempfile.mktemp()
open(tmpfile, 'w')
@@ -134,7 +134,7 @@ class RemovePathOnError(utils.BaseTestCase):
os.unlink(tmpfile)
-class UtilsTestCase(utils.BaseTestCase):
+class UtilsTestCase(test.BaseTestCase):
def test_file_open(self):
dst_fd, dst_path = tempfile.mkstemp()
try:
@@ -142,6 +142,6 @@ class UtilsTestCase(utils.BaseTestCase):
with open(dst_path, 'w') as f:
f.write('hello')
with fileutils.file_open(dst_path, 'r') as fp:
- self.assertEquals(fp.read(), 'hello')
+ self.assertEqual(fp.read(), 'hello')
finally:
os.unlink(dst_path)
diff --git a/tests/unit/test_funcutils.py b/tests/unit/test_funcutils.py
index 439d825..1cf665b 100644
--- a/tests/unit/test_funcutils.py
+++ b/tests/unit/test_funcutils.py
@@ -20,11 +20,10 @@
import functools
from openstack.common import funcutils
+from openstack.common import test
-from tests import utils
-
-class FuncutilsTestCase(utils.BaseTestCase):
+class FuncutilsTestCase(test.BaseTestCase):
def _test_func(self, instance, red=None, blue=None):
pass
diff --git a/tests/unit/test_gettext.py b/tests/unit/test_gettext.py
index e21297f..ddd8b84 100644
--- a/tests/unit/test_gettext.py
+++ b/tests/unit/test_gettext.py
@@ -1,8 +1,8 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2012 Red Hat, Inc.
-# All Rights Reserved.
# Copyright 2013 IBM Corp.
+# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
@@ -16,6 +16,7 @@
# License for the specific language governing permissions and limitations
# under the License.
+from babel import localedata
import copy
import gettext
import logging.handlers
@@ -32,9 +33,53 @@ LOG = logging.getLogger(__name__)
class GettextTest(utils.BaseTestCase):
+ def setUp(self):
+ super(GettextTest, self).setUp()
+ # remember so we can reset to it later
+ self._USE_LAZY = gettextutils.USE_LAZY
+
+ def tearDown(self):
+ # reset to value before test
+ gettextutils.USE_LAZY = self._USE_LAZY
+ super(GettextTest, self).tearDown()
+
+ def test_enable_lazy(self):
+ gettextutils.USE_LAZY = False
+
+ gettextutils.enable_lazy()
+ # assert now enabled
+ self.assertTrue(gettextutils.USE_LAZY)
+
+ def test_underscore_non_lazy(self):
+ # set lazy off
+ gettextutils.USE_LAZY = False
+
+ self.mox.StubOutWithMock(gettextutils._t, 'ugettext')
+ gettextutils._t.ugettext('blah').AndReturn('translated blah')
+ self.mox.ReplayAll()
+
+ result = gettextutils._('blah')
+ self.assertEqual('translated blah', result)
+
+ def test_underscore_lazy(self):
+ # set lazy off
+ gettextutils.USE_LAZY = False
+
+ gettextutils.enable_lazy()
+ result = gettextutils._('blah')
+ self.assertIsInstance(result, gettextutils.Message)
+
def test_gettext_does_not_blow_up(self):
LOG.info(gettextutils._('test'))
+ def test_gettextutils_install(self):
+ gettextutils.install('blaa')
+ self.assertTrue(isinstance(_('A String'), unicode)) # noqa
+
+ gettextutils.install('blaa', lazy=True)
+ self.assertTrue(isinstance(_('A Message'), # noqa
+ gettextutils.Message))
+
def test_gettext_install_looks_up_localedir(self):
with mock.patch('os.environ.get') as environ_get:
with mock.patch('gettext.install') as gettext_install:
@@ -47,13 +92,83 @@ class GettextTest(utils.BaseTestCase):
localedir='/foo/bar',
unicode=True)
+ def test_get_localized_message(self):
+ non_message = 'Non-translatable Message'
+ en_message = 'A message in the default locale'
+ es_translation = 'A message in Spanish'
+ zh_translation = 'A message in Chinese'
+ message = gettextutils.Message(en_message, 'test_domain')
+
+ # In the Message class the translation ultimately occurs when the
+ # message is turned into a string, and that is what we mock here
+ def _mock_translation_and_unicode(self):
+ if self.locale == 'es':
+ return es_translation
+ if self.locale == 'zh':
+ return zh_translation
+ return self.data
+
+ self.stubs.Set(gettextutils.Message,
+ '__unicode__', _mock_translation_and_unicode)
+
+ self.assertEqual(es_translation,
+ gettextutils.get_localized_message(message, 'es'))
+ self.assertEqual(zh_translation,
+ gettextutils.get_localized_message(message, 'zh'))
+ self.assertEqual(en_message,
+ gettextutils.get_localized_message(message, 'en'))
+ self.assertEqual(en_message,
+ gettextutils.get_localized_message(message, 'XX'))
+ self.assertEqual(en_message,
+ gettextutils.get_localized_message(message, None))
+ self.assertEqual(non_message,
+ gettextutils.get_localized_message(non_message, 'A'))
+
+ def test_get_available_languages(self):
+ # All the available languages for which locale data is available
+ def _mock_locale_identifiers():
+ return ['zh', 'es', 'nl', 'fr']
+
+ self.stubs.Set(localedata,
+ 'list' if hasattr(localedata, 'list')
+ else 'locale_identifiers',
+ _mock_locale_identifiers)
+
+ # Only the languages available for a specific translation domain
+ def _mock_gettext_find(domain, localedir=None, languages=[], all=0):
+ if domain == 'test_domain':
+ return 'translation-file' if any(x in ['zh', 'es']
+ for x in languages) else None
+ return None
+ self.stubs.Set(gettext, 'find', _mock_gettext_find)
+
+ domain_languages = gettextutils.get_available_languages('test_domain')
+ # en_US should always be available no matter the domain
+ # en_US should also always be the first element since order matters
+ # finally only the domain languages should be included after en_US
+ self.assertTrue('en_US', domain_languages)
+ self.assertEqual(3, len(domain_languages))
+ self.assertEqual('en_US', domain_languages[0])
+ self.assertTrue('zh' in domain_languages)
+ self.assertTrue('es' in domain_languages)
+
+ # Clear languages to test an unknown domain
+ gettextutils._AVAILABLE_LANGUAGES = []
+ unknown_domain_languages = gettextutils.get_available_languages('huh')
+ self.assertEqual(1, len(unknown_domain_languages))
+ self.assertTrue('en_US' in unknown_domain_languages)
+
class MessageTestCase(utils.BaseTestCase):
"""Unit tests for locale Message class."""
def setUp(self):
super(MessageTestCase, self).setUp()
- self._lazy_gettext = gettextutils.get_lazy_gettext('oslo')
+
+ def _message_with_domain(msg):
+ return gettextutils.Message(msg, 'oslo')
+
+ self._lazy_gettext = _message_with_domain
def tearDown(self):
# need to clean up stubs early since they interfere
@@ -127,6 +242,19 @@ class MessageTestCase(utils.BaseTestCase):
self.assertEqual(result, msgid % params)
+ def test_regex_find_named_parameters_no_space(self):
+ msgid = ("Request: %(method)s http://%(server)s:"
+ "%(port)s%(url)s with headers %(headers)s")
+ params = {'method': 'POST',
+ 'server': 'test1',
+ 'port': 1234,
+ 'url': 'test2',
+ 'headers': {'h1': 'val1'}}
+
+ result = self._lazy_gettext(msgid) % params
+
+ self.assertEqual(result, msgid % params)
+
def test_regex_dict_is_parameter(self):
msgid = ("Test that we can inject a dictionary %s")
params = {'description': 'test1',
@@ -395,7 +523,11 @@ class LocaleHandlerTestCase(utils.BaseTestCase):
def setUp(self):
super(LocaleHandlerTestCase, self).setUp()
- self._lazy_gettext = gettextutils.get_lazy_gettext('oslo')
+
+ def _message_with_domain(msg):
+ return gettextutils.Message(msg, 'oslo')
+
+ self._lazy_gettext = _message_with_domain
self.buffer_handler = logging.handlers.BufferingHandler(40)
self.locale_handler = gettextutils.LocaleHandler(
'zh_CN', self.buffer_handler)
diff --git a/tests/unit/test_importutils.py b/tests/unit/test_importutils.py
index 372bb6d..d716929 100644
--- a/tests/unit/test_importutils.py
+++ b/tests/unit/test_importutils.py
@@ -19,10 +19,10 @@ import datetime
import sys
from openstack.common import importutils
-from tests import utils
+from openstack.common import test
-class ImportUtilsTest(utils.BaseTestCase):
+class ImportUtilsTest(test.BaseTestCase):
# NOTE(jkoelker) There has GOT to be a way to test this. But mocking
# __import__ is the devil. Right now we just make
diff --git a/tests/unit/test_jsonutils.py b/tests/unit/test_jsonutils.py
index 5dc23f7..0d22e7e 100644
--- a/tests/unit/test_jsonutils.py
+++ b/tests/unit/test_jsonutils.py
@@ -22,10 +22,10 @@ import netaddr
import six
from openstack.common import jsonutils
-from tests import utils
+from openstack.common import test
-class JSONUtilsTestCase(utils.BaseTestCase):
+class JSONUtilsTestCase(test.BaseTestCase):
def test_dumps(self):
self.assertEqual(jsonutils.dumps({'a': 'b'}), '{"a": "b"}')
@@ -38,37 +38,37 @@ class JSONUtilsTestCase(utils.BaseTestCase):
self.assertEqual(jsonutils.load(x), {'a': 'b'})
-class ToPrimitiveTestCase(utils.BaseTestCase):
+class ToPrimitiveTestCase(test.BaseTestCase):
def test_list(self):
- self.assertEquals(jsonutils.to_primitive([1, 2, 3]), [1, 2, 3])
+ self.assertEqual(jsonutils.to_primitive([1, 2, 3]), [1, 2, 3])
def test_empty_list(self):
- self.assertEquals(jsonutils.to_primitive([]), [])
+ self.assertEqual(jsonutils.to_primitive([]), [])
def test_tuple(self):
- self.assertEquals(jsonutils.to_primitive((1, 2, 3)), [1, 2, 3])
+ self.assertEqual(jsonutils.to_primitive((1, 2, 3)), [1, 2, 3])
def test_dict(self):
- self.assertEquals(jsonutils.to_primitive(dict(a=1, b=2, c=3)),
- dict(a=1, b=2, c=3))
+ self.assertEqual(jsonutils.to_primitive(dict(a=1, b=2, c=3)),
+ dict(a=1, b=2, c=3))
def test_empty_dict(self):
- self.assertEquals(jsonutils.to_primitive({}), {})
+ self.assertEqual(jsonutils.to_primitive({}), {})
def test_datetime(self):
x = datetime.datetime(1920, 2, 3, 4, 5, 6, 7)
- self.assertEquals(jsonutils.to_primitive(x),
- '1920-02-03T04:05:06.000007')
+ self.assertEqual(jsonutils.to_primitive(x),
+ '1920-02-03T04:05:06.000007')
def test_datetime_preserve(self):
x = datetime.datetime(1920, 2, 3, 4, 5, 6, 7)
- self.assertEquals(jsonutils.to_primitive(x, convert_datetime=False), x)
+ self.assertEqual(jsonutils.to_primitive(x, convert_datetime=False), x)
def test_DateTime(self):
x = xmlrpclib.DateTime()
x.decode("19710203T04:05:06")
- self.assertEquals(jsonutils.to_primitive(x),
- '1971-02-03T04:05:06.000000')
+ self.assertEqual(jsonutils.to_primitive(x),
+ '1971-02-03T04:05:06.000000')
def test_iter(self):
class IterClass(object):
@@ -86,7 +86,7 @@ class ToPrimitiveTestCase(utils.BaseTestCase):
return self.data[self.index - 1]
x = IterClass()
- self.assertEquals(jsonutils.to_primitive(x), [1, 2, 3, 4, 5])
+ self.assertEqual(jsonutils.to_primitive(x), [1, 2, 3, 4, 5])
def test_iteritems(self):
class IterItemsClass(object):
@@ -99,7 +99,7 @@ class ToPrimitiveTestCase(utils.BaseTestCase):
x = IterItemsClass()
p = jsonutils.to_primitive(x)
- self.assertEquals(p, {'a': 1, 'b': 2, 'c': 3})
+ self.assertEqual(p, {'a': 1, 'b': 2, 'c': 3})
def test_iteritems_with_cycle(self):
class IterItemsClass(object):
@@ -127,24 +127,24 @@ class ToPrimitiveTestCase(utils.BaseTestCase):
self.b = 1
x = MysteryClass()
- self.assertEquals(jsonutils.to_primitive(x, convert_instances=True),
- dict(b=1))
+ self.assertEqual(jsonutils.to_primitive(x, convert_instances=True),
+ dict(b=1))
- self.assertEquals(jsonutils.to_primitive(x), x)
+ self.assertEqual(jsonutils.to_primitive(x), x)
def test_typeerror(self):
x = bytearray # Class, not instance
- self.assertEquals(jsonutils.to_primitive(x), u"<type 'bytearray'>")
+ self.assertEqual(jsonutils.to_primitive(x), u"<type 'bytearray'>")
def test_nasties(self):
def foo():
pass
x = [datetime, foo, dir]
ret = jsonutils.to_primitive(x)
- self.assertEquals(len(ret), 3)
+ self.assertEqual(len(ret), 3)
self.assertTrue(ret[0].startswith(u"<module 'datetime' from "))
self.assertTrue(ret[1].startswith('<function foo at 0x'))
- self.assertEquals(ret[2], '<built-in function dir>')
+ self.assertEqual(ret[2], '<built-in function dir>')
def test_depth(self):
class LevelsGenerator(object):
@@ -164,15 +164,15 @@ class ToPrimitiveTestCase(utils.BaseTestCase):
json_l4 = {0: {0: {0: {0: '?'}}}}
ret = jsonutils.to_primitive(l4_obj, max_depth=2)
- self.assertEquals(ret, json_l2)
+ self.assertEqual(ret, json_l2)
ret = jsonutils.to_primitive(l4_obj, max_depth=3)
- self.assertEquals(ret, json_l3)
+ self.assertEqual(ret, json_l3)
ret = jsonutils.to_primitive(l4_obj, max_depth=4)
- self.assertEquals(ret, json_l4)
+ self.assertEqual(ret, json_l4)
def test_ipaddr(self):
thing = {'ip_addr': netaddr.IPAddress('1.2.3.4')}
ret = jsonutils.to_primitive(thing)
- self.assertEquals({'ip_addr': '1.2.3.4'}, ret)
+ self.assertEqual({'ip_addr': '1.2.3.4'}, ret)
diff --git a/tests/unit/test_local.py b/tests/unit/test_local.py
index 37e5798..a8c9ab6 100644
--- a/tests/unit/test_local.py
+++ b/tests/unit/test_local.py
@@ -15,10 +15,10 @@
# License for the specific language governing permissions and limitations
# under the License.
-import eventlet
+import threading
from openstack.common import local
-from tests import utils
+from openstack.common import test
class Dict(dict):
@@ -26,11 +26,20 @@ class Dict(dict):
pass
-class LocalStoreTestCase(utils.BaseTestCase):
+class LocalStoreTestCase(test.BaseTestCase):
v1 = Dict(a='1')
v2 = Dict(a='2')
v3 = Dict(a='3')
+ def setUp(self):
+ super(LocalStoreTestCase, self).setUp()
+ # NOTE(mrodden): we need to make sure that local store
+ # gets imported in the current python context we are
+ # testing in (eventlet vs normal python threading) so
+ # we test the correct type of local store for the current
+ # threading model
+ reload(local)
+
def test_thread_unique_storage(self):
"""Make sure local store holds thread specific values."""
expected_set = []
@@ -44,8 +53,13 @@ class LocalStoreTestCase(utils.BaseTestCase):
local.store.a = self.v3
expected_set.append(getattr(local.store, 'a'))
- eventlet.spawn(do_something).wait()
- eventlet.spawn(do_something2).wait()
+ t1 = threading.Thread(target=do_something)
+ t2 = threading.Thread(target=do_something2)
+ t1.start()
+ t2.start()
+ t1.join()
+ t2.join()
+
expected_set.append(getattr(local.store, 'a'))
self.assertTrue(self.v1 in expected_set)
diff --git a/tests/unit/test_lockutils.py b/tests/unit/test_lockutils.py
index b5783b8..670d289 100644
--- a/tests/unit/test_lockutils.py
+++ b/tests/unit/test_lockutils.py
@@ -73,10 +73,10 @@ class LockTestCase(utils.BaseTestCase):
"""Bar"""
pass
- self.assertEquals(foo.__doc__, 'Bar', "Wrapped function's docstring "
- "got lost")
- self.assertEquals(foo.__name__, 'foo', "Wrapped function's name "
- "got mangled")
+ self.assertEqual(foo.__doc__, 'Bar', "Wrapped function's docstring "
+ "got lost")
+ self.assertEqual(foo.__name__, 'foo', "Wrapped function's name "
+ "got mangled")
def test_lock_internally(self):
"""We can lock across multiple green threads."""
@@ -97,13 +97,13 @@ class LockTestCase(utils.BaseTestCase):
for thread in threads:
thread.wait()
- self.assertEquals(len(seen_threads), 100)
+ self.assertEqual(len(seen_threads), 100)
# Looking at the seen threads, split it into chunks of 10, and verify
# that the last 9 match the first in each chunk.
for i in range(10):
for j in range(9):
- self.assertEquals(seen_threads[i * 10],
- seen_threads[i * 10 + 1 + j])
+ self.assertEqual(seen_threads[i * 10],
+ seen_threads[i * 10 + 1 + j])
self.assertEqual(saved_sem_num, len(lockutils._semaphores),
"Semaphore leak detected")
diff --git a/tests/unit/test_log.py b/tests/unit/test_log.py
index 8f40f75..39a0cce 100644
--- a/tests/unit/test_log.py
+++ b/tests/unit/test_log.py
@@ -108,13 +108,13 @@ class LazyLoggerTestCase(CommonLoggerTestsMixIn, test_utils.BaseTestCase):
class LogHandlerTestCase(test_utils.BaseTestCase):
def test_log_path_logdir(self):
self.config(log_dir='/some/path', log_file=None)
- self.assertEquals(log._get_log_file_path(binary='foo-bar'),
- '/some/path/foo-bar.log')
+ self.assertEqual(log._get_log_file_path(binary='foo-bar'),
+ '/some/path/foo-bar.log')
def test_log_path_logfile(self):
self.config(log_file='/some/path/foo-bar.log')
- self.assertEquals(log._get_log_file_path(binary='foo-bar'),
- '/some/path/foo-bar.log')
+ self.assertEqual(log._get_log_file_path(binary='foo-bar'),
+ '/some/path/foo-bar.log')
def test_log_path_none(self):
self.config(log_dir=None, log_file=None)
@@ -123,8 +123,8 @@ class LogHandlerTestCase(test_utils.BaseTestCase):
def test_log_path_logfile_overrides_logdir(self):
self.config(log_dir='/some/other/path',
log_file='/some/path/foo-bar.log')
- self.assertEquals(log._get_log_file_path(binary='foo-bar'),
- '/some/path/foo-bar.log')
+ self.assertEqual(log._get_log_file_path(binary='foo-bar'),
+ '/some/path/foo-bar.log')
class PublishErrorsHandlerTestCase(test_utils.BaseTestCase):
@@ -354,7 +354,7 @@ class SetDefaultsTestCase(test_utils.BaseTestCase):
def test_default_to_none(self):
log.set_defaults(logging_context_format_string=None)
self.conf([])
- self.assertEquals(self.conf.logging_context_format_string, None)
+ self.assertEqual(self.conf.logging_context_format_string, None)
def test_change_default(self):
my_default = '%(asctime)s %(levelname)s %(name)s [%(request_id)s '\
@@ -362,7 +362,7 @@ class SetDefaultsTestCase(test_utils.BaseTestCase):
'%(message)s'
log.set_defaults(logging_context_format_string=my_default)
self.conf([])
- self.assertEquals(self.conf.logging_context_format_string, my_default)
+ self.assertEqual(self.conf.logging_context_format_string, my_default)
class LogConfigOptsTestCase(test_utils.BaseTestCase):
@@ -379,8 +379,8 @@ class LogConfigOptsTestCase(test_utils.BaseTestCase):
def test_debug_verbose(self):
CONF(['--debug', '--verbose'])
- self.assertEquals(CONF.debug, True)
- self.assertEquals(CONF.verbose, True)
+ self.assertEqual(CONF.debug, True)
+ self.assertEqual(CONF.verbose, True)
def test_logging_opts(self):
CONF([])
@@ -390,29 +390,29 @@ class LogConfigOptsTestCase(test_utils.BaseTestCase):
self.assertTrue(CONF.log_dir is None)
self.assertTrue(CONF.log_format is None)
- self.assertEquals(CONF.log_date_format, log._DEFAULT_LOG_DATE_FORMAT)
+ self.assertEqual(CONF.log_date_format, log._DEFAULT_LOG_DATE_FORMAT)
- self.assertEquals(CONF.use_syslog, False)
+ self.assertEqual(CONF.use_syslog, False)
def test_log_file(self):
log_file = '/some/path/foo-bar.log'
CONF(['--log-file', log_file])
- self.assertEquals(CONF.log_file, log_file)
+ self.assertEqual(CONF.log_file, log_file)
def test_logfile_deprecated(self):
logfile = '/some/other/path/foo-bar.log'
CONF(['--logfile', logfile])
- self.assertEquals(CONF.log_file, logfile)
+ self.assertEqual(CONF.log_file, logfile)
def test_log_dir(self):
log_dir = '/some/path/'
CONF(['--log-dir', log_dir])
- self.assertEquals(CONF.log_dir, log_dir)
+ self.assertEqual(CONF.log_dir, log_dir)
def test_logdir_deprecated(self):
logdir = '/some/other/path/'
CONF(['--logdir', logdir])
- self.assertEquals(CONF.log_dir, logdir)
+ self.assertEqual(CONF.log_dir, logdir)
def test_log_format_overrides_formatter(self):
CONF(['--log-format', '[Any format]'])
diff --git a/tests/unit/test_loopingcall.py b/tests/unit/test_loopingcall.py
index 89cf336..9746b88 100644
--- a/tests/unit/test_loopingcall.py
+++ b/tests/unit/test_loopingcall.py
@@ -20,11 +20,11 @@ from eventlet import greenthread
import mox
from openstack.common import loopingcall
+from openstack.common import test
from openstack.common import timeutils
-from tests import utils
-class LoopingCallTestCase(utils.BaseTestCase):
+class LoopingCallTestCase(test.BaseTestCase):
def setUp(self):
super(LoopingCallTestCase, self).setUp()
diff --git a/tests/unit/test_memorycache.py b/tests/unit/test_memorycache.py
index 48a36b5..0d176da 100644
--- a/tests/unit/test_memorycache.py
+++ b/tests/unit/test_memorycache.py
@@ -18,11 +18,11 @@
import datetime
from openstack.common import memorycache
+from openstack.common import test
from openstack.common import timeutils
-from tests import utils
-class MemorycacheTest(utils.BaseTestCase):
+class MemorycacheTest(test.BaseTestCase):
def setUp(self):
self.client = memorycache.get_client()
super(MemorycacheTest, self).setUp()
diff --git a/tests/unit/test_network_utils.py b/tests/unit/test_network_utils.py
index 4ac0222..a4b9042 100644
--- a/tests/unit/test_network_utils.py
+++ b/tests/unit/test_network_utils.py
@@ -16,10 +16,10 @@
# under the License.
from openstack.common import network_utils
-from tests import utils
+from openstack.common import test
-class NetworkUtilsTest(utils.BaseTestCase):
+class NetworkUtilsTest(test.BaseTestCase):
def test_parse_host_port(self):
self.assertEqual(('server01', 80),
diff --git a/tests/unit/test_pastedeploy.py b/tests/unit/test_pastedeploy.py
index 3dd02d7..f8a2c57 100644
--- a/tests/unit/test_pastedeploy.py
+++ b/tests/unit/test_pastedeploy.py
@@ -20,7 +20,7 @@ import tempfile
import fixtures
from openstack.common import pastedeploy
-from tests import utils
+from openstack.common import test
class App(object):
@@ -43,7 +43,7 @@ class Filter(object):
self.data = data
-class PasteTestCase(utils.BaseTestCase):
+class PasteTestCase(test.BaseTestCase):
def setUp(self):
super(PasteTestCase, self).setUp()
@@ -67,7 +67,7 @@ openstack.app_factory = tests.unit.test_pastedeploy:App
""")
app = pastedeploy.paste_deploy_app(paste_conf, 'myfoo', data)
- self.assertEquals(app.data, data)
+ self.assertEqual(app.data, data)
def test_app_factory_with_local_conf(self):
data = 'test_app_factory_with_local_conf'
@@ -80,8 +80,8 @@ foo = bar
""")
app = pastedeploy.paste_deploy_app(paste_conf, 'myfoo', data)
- self.assertEquals(app.data, data)
- self.assertEquals(app.foo, 'bar')
+ self.assertEqual(app.data, data)
+ self.assertEqual(app.foo, 'bar')
def test_filter_factory(self):
data = 'test_filter_factory'
@@ -100,5 +100,5 @@ openstack.app_factory = tests.unit.test_pastedeploy:App
""")
app = pastedeploy.paste_deploy_app(paste_conf, 'myfoo', data)
- self.assertEquals(app.data, data)
- self.assertEquals(app.app.data, data)
+ self.assertEqual(app.data, data)
+ self.assertEqual(app.app.data, data)
diff --git a/tests/unit/test_periodic.py b/tests/unit/test_periodic.py
index d663f8b..500e8f4 100644
--- a/tests/unit/test_periodic.py
+++ b/tests/unit/test_periodic.py
@@ -28,6 +28,10 @@ from tests import utils
from testtools import matchers
+class AnException(Exception):
+ pass
+
+
class AService(periodic_task.PeriodicTasks):
def __init__(self):
@@ -40,7 +44,7 @@ class AService(periodic_task.PeriodicTasks):
@periodic_task.periodic_task
def crashit(self, context):
self.called['urg'] = True
- raise Exception('urg')
+ raise AnException('urg')
@periodic_task.periodic_task(spacing=10)
def doit_with_kwargs_odd(self, context):
@@ -66,7 +70,7 @@ class PeriodicTasksTestCase(utils.BaseTestCase):
def test_raises(self):
serv = AService()
- self.assertRaises(Exception,
+ self.assertRaises(AnException,
serv.run_periodic_tasks,
None, raise_on_error=True)
diff --git a/tests/unit/test_policy.py b/tests/unit/test_policy.py
index b7d38a3..143f56c 100644
--- a/tests/unit/test_policy.py
+++ b/tests/unit/test_policy.py
@@ -170,6 +170,49 @@ class EnforcerTest(PolicyBaseTestCase):
creds = {'roles': ''}
self.assertEqual(self.enforcer.enforce(action, {}, creds), True)
+ def test_enforcer_with_default_rule(self):
+ rules_json = """{
+ "deny_stack_user": "not role:stack_user",
+ "cloudwatch:PutMetricData": ""
+ }"""
+ rules = policy.Rules.load_json(rules_json)
+ default_rule = policy.TrueCheck()
+ enforcer = policy.Enforcer(default_rule=default_rule)
+ enforcer.set_rules(rules)
+ action = "cloudwatch:PutMetricData"
+ creds = {'roles': ''}
+ self.assertEqual(enforcer.enforce(action, {}, creds), True)
+
+ def test_enforcer_force_reload_true(self):
+ self.enforcer.set_rules({'test': 'test'})
+ self.enforcer.load_rules(force_reload=True)
+ self.assertNotIn({'test': 'test'}, self.enforcer.rules)
+ self.assertIn('default', self.enforcer.rules)
+ self.assertIn('admin', self.enforcer.rules)
+
+ def test_enforcer_force_reload_false(self):
+ self.enforcer.set_rules({'test': 'test'})
+ self.enforcer.load_rules(force_reload=False)
+ self.assertIn('test', self.enforcer.rules)
+ self.assertNotIn('default', self.enforcer.rules)
+ self.assertNotIn('admin', self.enforcer.rules)
+
+ def test_enforcer_overwrite_rules(self):
+ self.enforcer.set_rules({'test': 'test'})
+ self.enforcer.set_rules({'test': 'test1'}, overwrite=True)
+ self.assertEqual(self.enforcer.rules, {'test': 'test1'})
+
+ def test_enforcer_update_rules(self):
+ self.enforcer.set_rules({'test': 'test'})
+ self.enforcer.set_rules({'test1': 'test1'}, overwrite=False)
+ self.assertEqual(self.enforcer.rules, {'test': 'test',
+ 'test1': 'test1'})
+
+ def test_get_policy_path_raises_exc(self):
+ enforcer = policy.Enforcer(policy_file='raise_error.json')
+ self.assertRaises(cfg.ConfigFilesNotFoundError,
+ enforcer._get_policy_path)
+
class FakeCheck(policy.BaseCheck):
def __init__(self, result=None):
@@ -187,24 +230,15 @@ class FakeCheck(policy.BaseCheck):
class CheckFunctionTestCase(PolicyBaseTestCase):
def test_check_explicit(self):
- self.enforcer.load_rules()
- self.enforcer.rules = None
rule = FakeCheck()
result = self.enforcer.enforce(rule, "target", "creds")
-
self.assertEqual(result, ("target", "creds", self.enforcer))
- self.assertEqual(self.enforcer.rules, None)
def test_check_no_rules(self):
- self.enforcer.load_rules()
- self.enforcer.rules = None
result = self.enforcer.enforce('rule', "target", "creds")
-
self.assertEqual(result, False)
- self.assertEqual(self.enforcer.rules, None)
def test_check_missing_rule(self):
- self.enforcer.rules = {}
result = self.enforcer.enforce('rule', 'target', 'creds')
self.assertEqual(result, False)
@@ -298,6 +332,48 @@ class NotCheckTestCase(utils.BaseTestCase):
rule.assert_called_once_with('target', 'cred', None)
+class AndCheckTestCase(utils.BaseTestCase):
+ def test_init(self):
+ check = policy.AndCheck(['rule1', 'rule2'])
+
+ self.assertEqual(check.rules, ['rule1', 'rule2'])
+
+ def test_add_check(self):
+ check = policy.AndCheck(['rule1', 'rule2'])
+ check.add_check('rule3')
+
+ self.assertEqual(check.rules, ['rule1', 'rule2', 'rule3'])
+
+ def test_str(self):
+ check = policy.AndCheck(['rule1', 'rule2'])
+
+ self.assertEqual(str(check), '(rule1 and rule2)')
+
+ def test_call_all_false(self):
+ rules = [mock.Mock(return_value=False), mock.Mock(return_value=False)]
+ check = policy.AndCheck(rules)
+
+ self.assertEqual(check('target', 'cred', None), False)
+ rules[0].assert_called_once_with('target', 'cred', None)
+ self.assertFalse(rules[1].called)
+
+ def test_call_first_true(self):
+ rules = [mock.Mock(return_value=True), mock.Mock(return_value=False)]
+ check = policy.AndCheck(rules)
+
+ self.assertFalse(check('target', 'cred', None))
+ rules[0].assert_called_once_with('target', 'cred', None)
+ rules[1].assert_called_once_with('target', 'cred', None)
+
+ def test_call_second_true(self):
+ rules = [mock.Mock(return_value=False), mock.Mock(return_value=True)]
+ check = policy.AndCheck(rules)
+
+ self.assertFalse(check('target', 'cred', None))
+ rules[0].assert_called_once_with('target', 'cred', None)
+ self.assertFalse(rules[1].called)
+
+
class OrCheckTestCase(utils.BaseTestCase):
def test_init(self):
check = policy.OrCheck(['rule1', 'rule2'])
@@ -320,15 +396,15 @@ class OrCheckTestCase(utils.BaseTestCase):
check = policy.OrCheck(rules)
self.assertEqual(check('target', 'cred', None), False)
- rules[0].assert_called_once_with('target', 'cred')
- rules[1].assert_called_once_with('target', 'cred')
+ rules[0].assert_called_once_with('target', 'cred', None)
+ rules[1].assert_called_once_with('target', 'cred', None)
def test_call_first_true(self):
rules = [mock.Mock(return_value=True), mock.Mock(return_value=False)]
check = policy.OrCheck(rules)
self.assertEqual(check('target', 'cred', None), True)
- rules[0].assert_called_once_with('target', 'cred')
+ rules[0].assert_called_once_with('target', 'cred', None)
self.assertFalse(rules[1].called)
def test_call_second_true(self):
@@ -336,8 +412,8 @@ class OrCheckTestCase(utils.BaseTestCase):
check = policy.OrCheck(rules)
self.assertEqual(check('target', 'cred', None), True)
- rules[0].assert_called_once_with('target', 'cred')
- rules[1].assert_called_once_with('target', 'cred')
+ rules[0].assert_called_once_with('target', 'cred', None)
+ rules[1].assert_called_once_with('target', 'cred', None)
class ParseCheckTestCase(utils.BaseTestCase):
diff --git a/tests/unit/test_processutils.py b/tests/unit/test_processutils.py
index 7c6e11c..0f9289c 100644
--- a/tests/unit/test_processutils.py
+++ b/tests/unit/test_processutils.py
@@ -24,10 +24,10 @@ import tempfile
import six
from openstack.common import processutils
-from tests import utils
+from openstack.common import test
-class UtilsTest(utils.BaseTestCase):
+class UtilsTest(test.BaseTestCase):
# NOTE(jkoelker) Moar tests from nova need to be ported. But they
# need to be mock'd out. Currently they requre actually
# running code.
@@ -37,27 +37,27 @@ class UtilsTest(utils.BaseTestCase):
hozer=True)
-class ProcessExecutionErrorTest(utils.BaseTestCase):
+class ProcessExecutionErrorTest(test.BaseTestCase):
def test_defaults(self):
err = processutils.ProcessExecutionError()
- self.assertTrue('None\n' in err.message)
- self.assertTrue('code: -\n' in err.message)
+ self.assertTrue('None\n' in unicode(err))
+ self.assertTrue('code: -\n' in unicode(err))
def test_with_description(self):
description = 'The Narwhal Bacons at Midnight'
err = processutils.ProcessExecutionError(description=description)
- self.assertTrue(description in err.message)
+ self.assertTrue(description in unicode(err))
def test_with_exit_code(self):
exit_code = 0
err = processutils.ProcessExecutionError(exit_code=exit_code)
- self.assertTrue(str(exit_code) in err.message)
+ self.assertTrue(str(exit_code) in unicode(err))
def test_with_cmd(self):
cmd = 'telinit'
err = processutils.ProcessExecutionError(cmd=cmd)
- self.assertTrue(cmd in err.message)
+ self.assertTrue(cmd in unicode(err))
def test_with_stdout(self):
stdout = """
@@ -80,13 +80,13 @@ class ProcessExecutionErrorTest(utils.BaseTestCase):
the Wielder of Wonder, with world's renown.
""".strip()
err = processutils.ProcessExecutionError(stdout=stdout)
- print(err.message)
- self.assertTrue('people-kings' in err.message)
+ print(unicode(err))
+ self.assertTrue('people-kings' in unicode(err))
def test_with_stderr(self):
stderr = 'Cottonian library'
err = processutils.ProcessExecutionError(stderr=stderr)
- self.assertTrue(stderr in str(err.message))
+ self.assertTrue(stderr in unicode(err))
def test_retry_on_failure(self):
fd, tmpfilename = tempfile.mkstemp()
@@ -127,8 +127,7 @@ exit 1
'always get passed '
'correctly')
runs = int(runs.strip())
- self.assertEquals(runs, 10,
- 'Ran %d times instead of 10.' % (runs,))
+ self.assertEqual(runs, 10, 'Ran %d times instead of 10.' % (runs,))
finally:
os.unlink(tmpfilename)
os.unlink(tmpfilename2)
@@ -181,7 +180,7 @@ def fake_execute_raises(*cmd, **kwargs):
'command'])
-class TryCmdTestCase(utils.BaseTestCase):
+class TryCmdTestCase(test.BaseTestCase):
def test_keep_warnings(self):
self.useFixture(fixtures.MonkeyPatch(
'openstack.common.processutils.execute', fake_execute))
@@ -231,7 +230,7 @@ class FakeSshConnection(object):
six.StringIO('stderr'))
-class SshExecuteTestCase(utils.BaseTestCase):
+class SshExecuteTestCase(test.BaseTestCase):
def test_invalid_addl_env(self):
self.assertRaises(processutils.InvalidArgumentError,
processutils.ssh_execute,
diff --git a/tests/unit/test_quota.py b/tests/unit/test_quota.py
new file mode 100644
index 0000000..8c26369
--- /dev/null
+++ b/tests/unit/test_quota.py
@@ -0,0 +1,441 @@
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+
+import datetime
+import mock
+from openstack.common import quota
+from oslo.config import cfg
+from tests import utils
+
+CONF = cfg.CONF
+
+
+class FakeContext(object):
+ project_id = 'p1'
+ user_id = 'u1'
+ quota_class = 'QuotaClass_'
+
+ def elevated(self):
+ return self
+
+
+class ExceptionTestCase(utils.BaseTestCase):
+
+ def _get_raised_exception(self, exception, *args, **kwargs):
+ try:
+ raise exception(*args, **kwargs)
+ except Exception as e:
+ return e
+
+ def test_quota_exception_format(self):
+
+ class TestException(quota.QuotaException):
+ msg_fmt = "Test format %(string)s"
+
+ e = self._get_raised_exception(TestException)
+ self.assertEqual(e.message, e.msg_fmt)
+
+ e = self._get_raised_exception(TestException, number=42)
+ self.assertEqual(e.message, e.msg_fmt)
+
+ e = self._get_raised_exception(TestException, string="test")
+ self.assertEqual(e.message, e.msg_fmt % {"string": "test"})
+
+
+class DbQuotaDriverTestCase(utils.BaseTestCase):
+
+ def setUp(self):
+ self.sample_resources = {'r1': quota.BaseResource('r1'),
+ 'r2': quota.BaseResource('r2')}
+
+ dbapi = mock.Mock()
+ dbapi.quota_usage_get_all_by_project_and_user = mock.Mock(
+ return_value={'project_id': 'p1', 'user_id': 'u1',
+ 'r1': {'reserved': 1, 'in_use': 2},
+ 'r2': {'reserved': 2, 'in_use': 3}})
+ dbapi.quota_get_all_by_project_and_user = mock.Mock(
+ return_value={'project_id': 'p1', 'user_id': 'u1',
+ 'r1': 5, 'r2': 6})
+ dbapi.quota_get = mock.Mock(return_value='quota_get')
+ dbapi.quota_reserve = mock.Mock(return_value='quota_reserve')
+ dbapi.quota_class_get = mock.Mock(return_value='quota_class_get')
+ dbapi.quota_class_reserve = mock.Mock(
+ return_value='quota_class_reserve')
+ dbapi.quota_class_get_default = mock.Mock(
+ return_value={'r1': 1, 'r2': 2})
+ dbapi.quota_class_get_all_by_name = mock.Mock(return_value={'r1': 9})
+ dbapi.quota_get_all_by_project = mock.Mock(
+ return_value=dict([('r%d' % i, i) for i in range(3)]))
+ dbapi.quota_get_all = mock.Mock(
+ return_value=[{'resource': 'r1', 'hard_limit': 3},
+ {'resource': 'r2', 'hard_limit': 4}])
+ dbapi.quota_usage_get_all_by_project = mock.Mock(
+ return_value=dict([('r%d' % i, {'in_use': i, 'reserved': i + 1})
+ for i in range(3)]))
+ self.driver = quota.DbQuotaDriver(dbapi)
+ self.ctxt = FakeContext()
+ return super(DbQuotaDriverTestCase, self).setUp()
+
+ def test_get_by_project(self):
+ args = ['p1', 'resource']
+ self.assertEqual('quota_get',
+ self.driver.get_by_project(self.ctxt, *args))
+ self.driver.db.quota_get.assert_called_once_with(self.ctxt, *args)
+
+ def test_get_by_project_and_user(self):
+ args = ['p1', 'u1', 'resource']
+ self.assertEqual('quota_get',
+ self.driver.get_by_project_and_user(self.ctxt, *args))
+ self.driver.db.quota_get.assert_called_once_with(self.ctxt, *args)
+
+ def test_get_by_class(self):
+ args = ['class', 'resource']
+ self.assertEqual('quota_class_get',
+ self.driver.get_by_class(self.ctxt, *args))
+ self.driver.db.quota_class_get.assert_called_once_with(self.ctxt,
+ *args)
+
+ def test_get_defaults(self):
+ defaults = self.driver.get_defaults(self.ctxt, self.sample_resources)
+ self.assertEqual(defaults, {'r1': 1, 'r2': 2})
+ self.sample_resources.pop('r1')
+ defaults = self.driver.get_defaults(self.ctxt, self.sample_resources)
+ self.assertEqual(defaults, {'r2': 2})
+
+ def test_get_class_quotas(self):
+ quotas = self.driver.get_class_quotas(self.ctxt,
+ self.sample_resources,
+ 'ClassName')
+ self.assertEqual(quotas, {'r1': 9, 'r2': 2})
+
+ def test_get_user_quotas(self):
+ actual = self.driver.get_user_quotas(
+ self.ctxt, self.sample_resources.copy(), 'p1', 'u1')
+ expected = {'r1': {'in_use': 2, 'limit': 5, 'reserved': 1},
+ 'r2': {'in_use': 3, 'limit': 6, 'reserved': 2}}
+ self.assertEqual(actual, expected)
+
+ def test_get_settable_quotas(self):
+ actual = self.driver.get_settable_quotas(self.ctxt,
+ self.sample_resources, 'p1')
+ expected = {'r1': {'maximum': -1, 'minimum': 3},
+ 'r2': {'maximum': -1, 'minimum': 5}}
+ self.assertEqual(actual, expected)
+
+ def test_get_settable_quotas_with_user_id(self):
+ actual = self.driver.get_settable_quotas(
+ self.ctxt, self.sample_resources, 'p1', user_id='u1')
+ expected = {'r1': {'maximum': 3, 'minimum': 3},
+ 'r2': {'maximum': 4, 'minimum': 5}}
+ self.assertEqual(actual, expected)
+
+ def test_get_project_quotas(self):
+ self.ctxt.quota_class = 'ClassName'
+ expected = {'r1': {'limit': 1, 'in_use': 1, 'reserved': 2},
+ 'r2': {'limit': 2, 'in_use': 2, 'reserved': 3}}
+ quotas = self.driver.get_project_quotas(self.ctxt,
+ self.sample_resources, 'p1')
+ self.assertEqual(quotas, expected)
+
+ def test_get_project_quotas_project_id_differs(self):
+ self.ctxt.project_id = 'p2'
+ expected = {'r1': {'limit': 1, 'in_use': 1, 'reserved': 2},
+ 'r2': {'limit': 2, 'in_use': 2, 'reserved': 3}}
+ quotas = self.driver.get_project_quotas(self.ctxt,
+ self.sample_resources, 'p1')
+ self.assertEqual(quotas, expected)
+
+ def test_get_project_quotas_omit_default_quota_class(self):
+ self.sample_resources['r3'] = quota.BaseResource('r3')
+ quotas = self.driver.get_project_quotas(
+ self.ctxt, self.sample_resources, 'p1', defaults=False)
+ expected = {'r1': {'limit': 1, 'in_use': 1, 'reserved': 2},
+ 'r2': {'limit': 2, 'in_use': 2, 'reserved': 3}}
+ self.assertEqual(quotas, expected)
+
+ def test_limit_check_invalid_quota_value(self):
+ self.assertRaises(quota.InvalidQuotaValue,
+ self.driver.limit_check, self.ctxt, [], {'r1': -1})
+
+ def test_limit_check_quota_resource_unknown(self):
+ self.assertRaises(quota.QuotaResourceUnknown,
+ self.driver.limit_check,
+ self.ctxt,
+ {'r1': quota.ReservableResource('r1', 'r1')},
+ {'r1': 42})
+
+ def test_limit_check_over_quota(self):
+ self.assertRaises(quota.OverQuota,
+ self.driver.limit_check,
+ self.ctxt,
+ {'r1': quota.BaseResource('r1')},
+ {'r1': 2})
+
+ def test_limit_check(self):
+ self.assertIsNone(self.driver.limit_check(
+ self.ctxt, {'r1': quota.BaseResource('r1')}, {'r1': 1}))
+
+ def test_quota_reserve(self):
+ now = datetime.datetime.utcnow()
+
+ class FakeTimeutils(object):
+ @staticmethod
+ def utcnow():
+ return now
+
+ self.stubs.Set(quota, "timeutils", FakeTimeutils)
+
+ expected = [self.ctxt, self.sample_resources, {}, {}, {}, None,
+ CONF.until_refresh, CONF.max_age]
+
+ # expire as None
+ self.assertEqual('quota_reserve', self.driver.reserve(
+ self.ctxt, self.sample_resources, {}, None, 'p1'))
+ expected[5] = now + datetime.timedelta(
+ seconds=CONF.reservation_expire)
+ self.driver.db.quota_reserve.assert_called_once_with(*expected,
+ project_id='p1',
+ user_id='u1')
+ self.driver.db.reset_mock()
+ # expire as seconds
+ self.assertEqual('quota_reserve', self.driver.reserve(
+ self.ctxt, self.sample_resources, {}, 42, 'p1'))
+ expected[5] = now + datetime.timedelta(seconds=42)
+ self.driver.db.quota_reserve.assert_called_once_with(*expected,
+ project_id='p1',
+ user_id='u1')
+ self.driver.db.reset_mock()
+ # expire as absolute
+ expected[5] = now + datetime.timedelta(hours=1)
+ self.assertEqual('quota_reserve', self.driver.reserve(
+ self.ctxt, self.sample_resources, {},
+ now + datetime.timedelta(hours=1), 'p1'))
+ self.driver.db.quota_reserve.assert_called_once_with(*expected,
+ project_id='p1',
+ user_id='u1')
+ self.driver.db.reset_mock()
+ # InvalidReservationExpiration
+ self.assertRaises(quota.InvalidReservationExpiration,
+ self.driver.reserve, self.ctxt,
+ self.sample_resources, {}, (), 'p1')
+ self.driver.db.reset_mock()
+ # project_id is None
+ self.assertEqual('quota_reserve', self.driver.reserve(
+ self.ctxt, self.sample_resources, {},
+ now + datetime.timedelta(hours=1)))
+ self.driver.db.quota_reserve.assert_called_once_with(*expected,
+ project_id='p1',
+ user_id='u1')
+
+ def test_commit(self):
+ self.assertIsNone(self.driver.commit(self.ctxt, 'reservations',
+ project_id='p1'))
+ self.driver.db.reservation_commit.assert_called_once_with(
+ self.ctxt, 'reservations', project_id='p1', user_id='u1')
+
+ def test_commit_project_id_none(self):
+ self.assertIsNone(self.driver.commit(self.ctxt, 'reservations'))
+ self.driver.db.reservation_commit.assert_called_once_with(
+ self.ctxt, 'reservations', project_id='p1', user_id='u1')
+
+ def test_rollback(self):
+ self.assertIsNone(self.driver.rollback(self.ctxt, 'reservations',
+ project_id='p1'))
+ self.driver.db.reservation_rollback.assert_called_once_with(
+ self.ctxt, 'reservations', project_id='p1', user_id='u1')
+
+ def test_rollback_project_id_none(self):
+ self.assertIsNone(self.driver.rollback(self.ctxt, 'reservations'))
+ self.driver.db.reservation_rollback.assert_called_once_with(
+ self.ctxt, 'reservations', project_id='p1', user_id='u1')
+
+ def test_usage_reset(self):
+ resource = self.sample_resources['r1']
+ self.assertIsNone(self.driver.usage_reset(self.ctxt, [resource]))
+ self.driver.db.quota_usage_update.assert_called_once_with(
+ self.ctxt, 'p1', 'u1', resource, in_use=-1)
+
+ def test_usage_reset_quota_usage_not_found(self):
+ resource = self.sample_resources['r1']
+ self.driver.db.quota_usage_update = mock.Mock(
+ side_effect=quota.QuotaUsageNotFound)
+ self.assertIsNone(self.driver.usage_reset(self.ctxt, [resource]))
+ self.driver.db.quota_usage_update.assert_called_once_with(
+ self.ctxt, 'p1', 'u1', resource, in_use=-1)
+
+ def test_destroy_all_by_project_and_user(self):
+ self.assertIsNone(self.driver.destroy_all_by_project_and_user(
+ self.ctxt, 'p1', 'u1'))
+ method = self.driver.db.quota_destroy_all_by_project_and_user
+ method.assert_called_once_with(self.ctxt, 'p1', 'u1')
+
+ def test_destroy_all_by_project(self):
+ self.assertIsNone(self.driver.destroy_all_by_project(self.ctxt, 'p1'))
+ self.driver.db.quota_destroy_all_by_project.assert_called_once_with(
+ self.ctxt, 'p1')
+
+ def test_expire(self):
+ self.assertIsNone(self.driver.expire(self.ctxt))
+ self.driver.db.reservation_expire.assert_called_once_with(self.ctxt)
+
+
+class BaseResourceTestCase(utils.BaseTestCase):
+
+ def setUp(self):
+ self.ctxt = FakeContext()
+ self.dbapi = mock.Mock()
+ self.dbapi.quota_get = mock.Mock(return_value='quota_get')
+ self.dbapi.quota_class_get = mock.Mock(
+ return_value='quota_class_get')
+ self.dbapi.quota_class_get_default = mock.Mock(
+ return_value={'r1': 1})
+ self.driver = quota.DbQuotaDriver(self.dbapi)
+ super(BaseResourceTestCase, self).setUp()
+
+ def test_quota(self):
+ resource = quota.BaseResource('r1')
+ self.assertEqual('quota_get', resource.quota(self.driver, self.ctxt))
+
+ def test_quota_no_project_id(self):
+ self.ctxt.project_id = None
+ resource = quota.BaseResource('r1')
+ self.assertEqual('quota_class_get',
+ resource.quota(self.driver, self.ctxt))
+
+ def test_quota_project_quota_not_found(self):
+ self.dbapi.quota_get = mock.Mock(
+ side_effect=quota.ProjectQuotaNotFound())
+ resource = quota.BaseResource('r1')
+ self.assertEqual('quota_class_get',
+ resource.quota(self.driver, self.ctxt))
+
+ def test_quota_quota_class_not_found(self):
+ self.dbapi.quota_get = mock.Mock(
+ side_effect=quota.ProjectQuotaNotFound(project_id='p1'))
+ self.dbapi.quota_class_get = mock.Mock(
+ side_effect=quota.QuotaClassNotFound(class_name='ClassName'))
+ resource = quota.BaseResource('r1')
+ self.assertEqual(1, resource.quota(self.driver, self.ctxt))
+
+
+class CountableResourceTestCase(utils.BaseTestCase):
+
+ def test_init(self):
+ resource = quota.CountableResource('r1', 42)
+ self.assertEqual('r1', resource.name)
+ self.assertEqual(42, resource.count)
+
+
+class QuotaEngineTestCase(utils.BaseTestCase):
+
+ def setUp(self):
+ self.ctxt = FakeContext()
+ self.dbapi = mock.Mock()
+ self.quota_driver = mock.Mock()
+ self.engine = quota.QuotaEngine(self.dbapi, self.quota_driver)
+ self.r1 = quota.BaseResource('r1')
+ self.r2 = quota.BaseResource('r2')
+ self.engine.register_resources([self.r1, self.r2])
+ super(QuotaEngineTestCase, self).setUp()
+
+ def assertProxyMethod(self, method, *args, **kwargs):
+ if 'retval' in kwargs:
+ retval = kwargs.pop('retval')
+ else:
+ retval = method
+ setattr(self.quota_driver, method, mock.Mock(return_value=method))
+ actual = getattr(self.engine, method)(self.ctxt, *args, **kwargs)
+ getattr(self.quota_driver, method).assert_called_once_with(self.ctxt,
+ *args,
+ **kwargs)
+ self.assertEqual(actual, retval)
+
+ def assertMethod(self, method, args, kwargs, called_args,
+ called_kwargs, retval):
+ setattr(self.quota_driver, method, mock.Mock(return_value=method))
+ actual = getattr(self.engine, method)(self.ctxt, *args, **kwargs)
+ getattr(self.quota_driver, method).assert_called_once_with(
+ self.ctxt, *called_args, **called_kwargs)
+ self.assertEqual(actual, retval)
+
+ def test_proxy_methods(self):
+ self.assertProxyMethod('get_by_project', 'p1', 'resname')
+ self.assertProxyMethod('get_by_project_and_user', 'p1', 'u1', 'res')
+ self.assertProxyMethod('get_by_class', 'quota_class', 'resname')
+ self.assertProxyMethod('get_default', 'resource')
+ self.assertProxyMethod('expire', retval=None)
+ self.assertProxyMethod('usage_reset', 'resources', retval=None)
+ self.assertProxyMethod('destroy_all_by_project', 'p1', retval=None)
+ self.assertProxyMethod('destroy_all_by_project_and_user', 'p1',
+ 'u1', retval=None)
+ self.assertProxyMethod('commit', 'reservations', project_id='p1',
+ user_id='u1', retval=None)
+ self.assertProxyMethod('rollback', 'reservations', project_id='p1',
+ user_id=None, retval=None)
+
+ self.assertMethod('get_settable_quotas', ['p1'], {'user_id': 'u1'},
+ [self.engine.resources, 'p1'], {'user_id': 'u1'},
+ 'get_settable_quotas')
+ self.assertMethod('get_defaults', [], {},
+ [self.engine.resources], {}, 'get_defaults')
+ self.assertMethod('get_project_quotas', ['p1', 'quotaclass'],
+ {'defaults': 'defaults', 'usages': 'usages'},
+ [self.engine.resources, 'p1'],
+ {'quota_class': 'quotaclass', 'defaults': 'defaults',
+ 'usages': 'usages', 'remains': False},
+ 'get_project_quotas')
+ self.assertMethod('reserve', [],
+ {'expire': 'expire', 'project_id': 'p1',
+ 'user_id': 'u1', 'deltas': 'd1'},
+ [self.engine.resources, {'deltas': 'd1'}],
+ {'expire': 'expire',
+ 'project_id': 'p1', 'user_id': 'u1'}, 'reserve')
+ self.assertMethod('get_class_quotas',
+ ['quota_class'], {'defaults': 'defaults'},
+ [self.engine.resources, 'quota_class'],
+ {'defaults': 'defaults'}, 'get_class_quotas')
+ self.assertMethod('get_user_quotas', ['project_id', 'user_id'],
+ {'quota_class': 'qc', 'defaults': 'de',
+ 'usages': 'us'},
+ [self.engine.resources, 'project_id', 'user_id'],
+ {'quota_class': 'qc', 'defaults': 'de',
+ 'usages': 'us'},
+ 'get_user_quotas')
+ self.assertMethod('limit_check',
+ [], {'project_id': 'p1', 'user_id': 'u1',
+ 'val1': 'val1'},
+ [self.engine.resources, {'val1': 'val1'}],
+ {'project_id': 'p1', 'user_id': 'u1'}, 'limit_check')
+
+ def test_resource_names(self):
+ self.assertEqual(['r1', 'r2'], self.engine.resource_names)
+
+ def test_contains(self):
+ self.assertTrue(self.r1.name in self.engine)
+ self.assertTrue(self.r2.name in self.engine)
+ self.assertFalse('r3' in self.engine)
+
+ def test_count(self):
+ count = mock.Mock(return_value=42)
+ r = quota.CountableResource('r1', count)
+ self.engine.register_resource(r)
+ actual = self.engine.count(self.ctxt, 'r1')
+ self.assertEqual(42, actual)
+ count.assert_called_once_with(self.ctxt)
+ self.assertRaises(quota.QuotaResourceUnknown,
+ self.engine.count, self.ctxt, 'r2')
+
+ def test_init(self):
+ engine = quota.QuotaEngine(self.dbapi)
+ self.assertIsInstance(engine._driver, quota.DbQuotaDriver)
diff --git a/tests/unit/test_service.py b/tests/unit/test_service.py
index 4f742ce..c7d18f6 100644
--- a/tests/unit/test_service.py
+++ b/tests/unit/test_service.py
@@ -67,7 +67,41 @@ class ServiceWithTimer(service.Service):
self.timer_fired = self.timer_fired + 1
-class ServiceLauncherTest(utils.BaseTestCase):
+class ServiceTestBase(utils.BaseTestCase):
+ """A base class for ServiceLauncherTest and ServiceRestartTest."""
+
+ def _wait(self, cond, timeout):
+ start = time.time()
+ while not cond():
+ if time.time() - start > timeout:
+ break
+ time.sleep(.1)
+
+ def setUp(self):
+ super(ServiceTestBase, self).setUp()
+ # FIXME(markmc): Ugly hack to workaround bug #1073732
+ CONF.unregister_opts(notifier_api.notifier_opts)
+ # NOTE(markmc): ConfigOpts.log_opt_values() uses CONF.config-file
+ CONF(args=[], default_config_files=[])
+ self.addCleanup(CONF.reset)
+ self.addCleanup(CONF.register_opts, notifier_api.notifier_opts)
+ self.addCleanup(self._reap_pid)
+
+ def _reap_pid(self):
+ if self.pid:
+ # Make sure all processes are stopped
+ os.kill(self.pid, signal.SIGTERM)
+
+ # Make sure we reap our test process
+ self._reap_test()
+
+ def _reap_test(self):
+ pid, status = os.waitpid(self.pid, 0)
+ self.pid = None
+ return status
+
+
+class ServiceLauncherTest(ServiceTestBase):
"""Originally from nova/tests/integrated/test_multiprocess_api.py."""
def _spawn(self):
@@ -111,38 +145,6 @@ class ServiceLauncherTest(utils.BaseTestCase):
self.assertEqual(len(workers), self.workers)
return workers
- def _wait(self, cond, timeout):
- start = time.time()
- while True:
- if cond():
- break
- if time.time() - start > timeout:
- break
- time.sleep(.1)
-
- def setUp(self):
- super(ServiceLauncherTest, self).setUp()
- # FIXME(markmc): Ugly hack to workaround bug #1073732
- CONF.unregister_opts(notifier_api.notifier_opts)
- # NOTE(markmc): ConfigOpts.log_opt_values() uses CONF.config-file
- CONF(args=[], default_config_files=[])
- self.addCleanup(CONF.reset)
- self.addCleanup(CONF.register_opts, notifier_api.notifier_opts)
- self.addCleanup(self._reap_pid)
-
- def _reap_pid(self):
- if self.pid:
- # Make sure all processes are stopped
- os.kill(self.pid, signal.SIGTERM)
-
- # Make sure we reap our test process
- self._reap_test()
-
- def _reap_test(self):
- pid, status = os.waitpid(self.pid, 0)
- self.pid = None
- return status
-
def _get_workers(self):
f = os.popen('ps ax -o pid,ppid,command')
# Skip ps header
@@ -195,6 +197,89 @@ class ServiceLauncherTest(utils.BaseTestCase):
self.assertTrue(os.WIFEXITED(status))
self.assertEqual(os.WEXITSTATUS(status), 0)
+ def test_child_signal_sighup(self):
+ start_workers = self._spawn()
+
+ os.kill(start_workers[0], signal.SIGHUP)
+ # Wait at most 5 seconds to respawn a worker
+ cond = lambda: start_workers == self._get_workers()
+ timeout = 5
+ self._wait(cond, timeout)
+
+ # Make sure worker pids match
+ end_workers = self._get_workers()
+ LOG.info('workers: %r' % end_workers)
+ self.assertEqual(start_workers, end_workers)
+
+ def test_parent_signal_sighup(self):
+ start_workers = self._spawn()
+
+ os.kill(self.pid, signal.SIGHUP)
+ # Wait at most 5 seconds to respawn a worker
+ cond = lambda: start_workers == self._get_workers()
+ timeout = 5
+ self._wait(cond, timeout)
+
+ # Make sure worker pids match
+ end_workers = self._get_workers()
+ LOG.info('workers: %r' % end_workers)
+ self.assertEqual(start_workers, end_workers)
+
+
+class ServiceRestartTest(ServiceTestBase):
+
+ def _check_process_alive(self):
+ f = os.popen('ps ax -o pid,stat,cmd')
+ f.readline()
+ pid_stat = [tuple(p for p in line.strip().split()[:2])
+ for line in f.readlines()]
+ for p, stat in pid_stat:
+ if int(p) == self.pid:
+ return stat not in ['Z', 'T', 'Z+']
+ return False
+
+ def _spawn_service(self):
+ pid = os.fork()
+ status = 0
+ if pid == 0:
+ try:
+ serv = ServiceWithTimer()
+ launcher = service.ServiceLauncher()
+ launcher.launch_service(serv)
+ launcher.wait()
+ except SystemExit as exc:
+ status = exc.code
+ os._exit(status)
+ self.pid = pid
+
+ def test_service_restart(self):
+ self._spawn_service()
+
+ cond = self._check_process_alive
+ timeout = 5
+ self._wait(cond, timeout)
+
+ ret = self._check_process_alive()
+ self.assertTrue(ret)
+
+ os.kill(self.pid, signal.SIGHUP)
+ self._wait(cond, timeout)
+
+ ret_restart = self._check_process_alive()
+ self.assertTrue(ret_restart)
+
+ def test_terminate_sigterm(self):
+ self._spawn_service()
+ cond = self._check_process_alive
+ timeout = 5
+ self._wait(cond, timeout)
+
+ os.kill(self.pid, signal.SIGTERM)
+
+ status = self._reap_test()
+ self.assertTrue(os.WIFEXITED(status))
+ self.assertEqual(os.WEXITSTATUS(status), 0)
+
class _Service(service.Service):
def __init__(self):
diff --git a/tests/unit/test_sslutils.py b/tests/unit/test_sslutils.py
index 4c0646e..2095d3c 100644
--- a/tests/unit/test_sslutils.py
+++ b/tests/unit/test_sslutils.py
@@ -17,39 +17,38 @@
import ssl
from openstack.common import sslutils
-from tests import utils
+from openstack.common import test
-class SSLUtilsTest(utils.BaseTestCase):
+class SSLUtilsTest(test.BaseTestCase):
def test_valid_versions(self):
- self.assertEquals(sslutils.validate_ssl_version("SSLv3"),
- ssl.PROTOCOL_SSLv3)
- self.assertEquals(sslutils.validate_ssl_version("SSLv23"),
- ssl.PROTOCOL_SSLv23)
- self.assertEquals(sslutils.validate_ssl_version("TLSv1"),
- ssl.PROTOCOL_TLSv1)
+ self.assertEqual(sslutils.validate_ssl_version("SSLv3"),
+ ssl.PROTOCOL_SSLv3)
+ self.assertEqual(sslutils.validate_ssl_version("SSLv23"),
+ ssl.PROTOCOL_SSLv23)
+ self.assertEqual(sslutils.validate_ssl_version("TLSv1"),
+ ssl.PROTOCOL_TLSv1)
try:
protocol = ssl.PROTOCOL_SSLv2
except AttributeError:
pass
else:
- self.assertEquals(sslutils.validate_ssl_version("SSLv2"),
- protocol)
+ self.assertEqual(sslutils.validate_ssl_version("SSLv2"), protocol)
def test_lowercase_valid_versions(self):
- self.assertEquals(sslutils.validate_ssl_version("sslv3"),
- ssl.PROTOCOL_SSLv3)
- self.assertEquals(sslutils.validate_ssl_version("sslv23"),
- ssl.PROTOCOL_SSLv23)
- self.assertEquals(sslutils.validate_ssl_version("tlsv1"),
- ssl.PROTOCOL_TLSv1)
+ self.assertEqual(sslutils.validate_ssl_version("sslv3"),
+ ssl.PROTOCOL_SSLv3)
+ self.assertEqual(sslutils.validate_ssl_version("sslv23"),
+ ssl.PROTOCOL_SSLv23)
+ self.assertEqual(sslutils.validate_ssl_version("tlsv1"),
+ ssl.PROTOCOL_TLSv1)
try:
protocol = ssl.PROTOCOL_SSLv2
except AttributeError:
pass
else:
- self.assertEquals(sslutils.validate_ssl_version("sslv2"),
- protocol)
+ self.assertEqual(sslutils.validate_ssl_version("sslv2"),
+ protocol)
def test_invalid_version(self):
self.assertRaises(RuntimeError,
diff --git a/tests/unit/test_strutils.py b/tests/unit/test_strutils.py
index a8d8462..22aafd9 100644
--- a/tests/unit/test_strutils.py
+++ b/tests/unit/test_strutils.py
@@ -20,10 +20,10 @@ import mock
import six
from openstack.common import strutils
-from tests import utils
+from openstack.common import test
-class StrUtilsTest(utils.BaseTestCase):
+class StrUtilsTest(test.BaseTestCase):
def test_bool_bool_from_string(self):
self.assertTrue(strutils.bool_from_string(True))
@@ -188,11 +188,11 @@ class StrUtilsTest(utils.BaseTestCase):
}
for (in_value, expected_value) in working_examples.items():
b_value = strutils.to_bytes(in_value)
- self.assertEquals(expected_value, b_value)
+ self.assertEqual(expected_value, b_value)
if in_value:
in_value = "-" + in_value
b_value = strutils.to_bytes(in_value)
- self.assertEquals(expected_value * -1, b_value)
+ self.assertEqual(expected_value * -1, b_value)
breaking_examples = [
'junk1KB', '1023BBBB',
]
diff --git a/tests/unit/test_threadgroup.py b/tests/unit/test_threadgroup.py
index 5af6653..2273800 100644
--- a/tests/unit/test_threadgroup.py
+++ b/tests/unit/test_threadgroup.py
@@ -20,13 +20,13 @@ Unit Tests for thread groups
"""
from openstack.common import log as logging
+from openstack.common import test
from openstack.common import threadgroup
-from tests import utils
LOG = logging.getLogger(__name__)
-class ThreadGroupTestCase(utils.BaseTestCase):
+class ThreadGroupTestCase(test.BaseTestCase):
"""Test cases for thread group."""
def setUp(self):
super(ThreadGroupTestCase, self).setUp()
diff --git a/tests/unit/test_timeutils.py b/tests/unit/test_timeutils.py
index bfab278..792c0aa 100644
--- a/tests/unit/test_timeutils.py
+++ b/tests/unit/test_timeutils.py
@@ -21,11 +21,11 @@ import datetime
import iso8601
import mock
+from openstack.common import test
from openstack.common import timeutils
-from tests import utils
-class TimeUtilsTest(utils.BaseTestCase):
+class TimeUtilsTest(test.BaseTestCase):
def setUp(self):
super(TimeUtilsTest, self).setUp()
@@ -181,16 +181,16 @@ class TimeUtilsTest(utils.BaseTestCase):
self.assertTrue(timeutils.is_soon(expires, 0))
-class TestIso8601Time(utils.BaseTestCase):
+class TestIso8601Time(test.BaseTestCase):
def _instaneous(self, timestamp, yr, mon, day, hr, min, sec, micro):
- self.assertEquals(timestamp.year, yr)
- self.assertEquals(timestamp.month, mon)
- self.assertEquals(timestamp.day, day)
- self.assertEquals(timestamp.hour, hr)
- self.assertEquals(timestamp.minute, min)
- self.assertEquals(timestamp.second, sec)
- self.assertEquals(timestamp.microsecond, micro)
+ self.assertEqual(timestamp.year, yr)
+ self.assertEqual(timestamp.month, mon)
+ self.assertEqual(timestamp.day, day)
+ self.assertEqual(timestamp.hour, hr)
+ self.assertEqual(timestamp.minute, min)
+ self.assertEqual(timestamp.second, sec)
+ self.assertEqual(timestamp.microsecond, micro)
def _do_test(self, str, yr, mon, day, hr, min, sec, micro, shift):
DAY_SECONDS = 24 * 60 * 60
@@ -246,26 +246,26 @@ class TestIso8601Time(utils.BaseTestCase):
def test_zulu_roundtrip(self):
str = '2012-02-14T20:53:07Z'
zulu = timeutils.parse_isotime(str)
- self.assertEquals(zulu.tzinfo, iso8601.iso8601.UTC)
- self.assertEquals(timeutils.isotime(zulu), str)
+ self.assertEqual(zulu.tzinfo, iso8601.iso8601.UTC)
+ self.assertEqual(timeutils.isotime(zulu), str)
def test_east_roundtrip(self):
str = '2012-02-14T20:53:07-07:00'
east = timeutils.parse_isotime(str)
- self.assertEquals(east.tzinfo.tzname(None), '-07:00')
- self.assertEquals(timeutils.isotime(east), str)
+ self.assertEqual(east.tzinfo.tzname(None), '-07:00')
+ self.assertEqual(timeutils.isotime(east), str)
def test_west_roundtrip(self):
str = '2012-02-14T20:53:07+11:30'
west = timeutils.parse_isotime(str)
- self.assertEquals(west.tzinfo.tzname(None), '+11:30')
- self.assertEquals(timeutils.isotime(west), str)
+ self.assertEqual(west.tzinfo.tzname(None), '+11:30')
+ self.assertEqual(timeutils.isotime(west), str)
def test_now_roundtrip(self):
str = timeutils.isotime()
now = timeutils.parse_isotime(str)
- self.assertEquals(now.tzinfo, iso8601.iso8601.UTC)
- self.assertEquals(timeutils.isotime(now), str)
+ self.assertEqual(now.tzinfo, iso8601.iso8601.UTC)
+ self.assertEqual(timeutils.isotime(now), str)
def test_zulu_normalize(self):
str = '2012-02-14T20:53:07Z'
diff --git a/tests/unit/test_uuidutils.py b/tests/unit/test_uuidutils.py
index e9348e2..4af5a8e 100644
--- a/tests/unit/test_uuidutils.py
+++ b/tests/unit/test_uuidutils.py
@@ -17,11 +17,11 @@
import uuid
+from openstack.common import test
from openstack.common import uuidutils
-from tests import utils
-class UUIDUtilsTest(utils.BaseTestCase):
+class UUIDUtilsTest(test.BaseTestCase):
def test_generate_uuid(self):
uuid_string = uuidutils.generate_uuid()
diff --git a/tests/unit/test_xmlutils.py b/tests/unit/test_xmlutils.py
index 5d2bd05..a5bd29f 100644
--- a/tests/unit/test_xmlutils.py
+++ b/tests/unit/test_xmlutils.py
@@ -16,11 +16,11 @@
from xml.dom import minidom
+from openstack.common import test
from openstack.common import xmlutils
-from tests import utils
-class XMLUtilsTestCase(utils.BaseTestCase):
+class XMLUtilsTestCase(test.BaseTestCase):
def test_safe_parse_xml(self):
normal_body = ("""
@@ -55,7 +55,7 @@ class XMLUtilsTestCase(utils.BaseTestCase):
killer_body())
-class SafeParserTestCase(utils.BaseTestCase):
+class SafeParserTestCase(test.BaseTestCase):
def test_external_dtd(self):
xml_string = ("""<?xml version="1.0" encoding="utf-8"?>
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
diff --git a/tests/utils.py b/tests/utils.py
index e93c278..b0770a9 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -24,7 +24,6 @@ import fixtures
from oslo.config import cfg
import testtools
-from openstack.common import exception
from openstack.common.fixture import moxstubout
@@ -38,8 +37,24 @@ class BaseTestCase(testtools.TestCase):
self.conf = conf
self.addCleanup(self.conf.reset)
self.useFixture(fixtures.FakeLogger('openstack.common'))
- self.useFixture(fixtures.Timeout(30, True))
- self.stubs.Set(exception, '_FATAL_EXCEPTION_FORMAT_ERRORS', True)
+
+ test_timeout = os.environ.get('OS_TEST_TIMEOUT', 0)
+ try:
+ test_timeout = int(test_timeout)
+ except ValueError:
+ # If timeout value is invalid do not set a timeout.
+ test_timeout = 0
+ if test_timeout > 0:
+ self.useFixture(fixtures.Timeout(test_timeout, gentle=True))
+ if (os.environ.get('OS_STDOUT_CAPTURE') == 'True' or
+ os.environ.get('OS_STDOUT_CAPTURE') == '1'):
+ stdout = self.useFixture(fixtures.StringStream('stdout')).stream
+ self.useFixture(fixtures.MonkeyPatch('sys.stdout', stdout))
+ if (os.environ.get('OS_STDERR_CAPTURE') == 'True' or
+ os.environ.get('OS_STDERR_CAPTURE') == '1'):
+ stderr = self.useFixture(fixtures.StringStream('stderr')).stream
+ self.useFixture(fixtures.MonkeyPatch('sys.stderr', stderr))
+
self.useFixture(fixtures.NestedTempfile())
self.tempdirs = []
diff --git a/tools/colorizer.py b/tools/colorizer.py
new file mode 100755
index 0000000..13364ba
--- /dev/null
+++ b/tools/colorizer.py
@@ -0,0 +1,333 @@
+#!/usr/bin/env python
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright (c) 2013, Nebula, Inc.
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# Colorizer Code is borrowed from Twisted:
+# Copyright (c) 2001-2010 Twisted Matrix Laboratories.
+#
+# Permission is hereby granted, free of charge, to any person obtaining
+# a copy of this software and associated documentation files (the
+# "Software"), to deal in the Software without restriction, including
+# without limitation the rights to use, copy, modify, merge, publish,
+# distribute, sublicense, and/or sell copies of the Software, and to
+# permit persons to whom the Software is furnished to do so, subject to
+# the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
+# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
+# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+"""Display a subunit stream through a colorized unittest test runner."""
+
+import heapq
+import subunit
+import sys
+import unittest
+
+import testtools
+
+
+class _AnsiColorizer(object):
+ """Colorizer allows callers to write text in a particular color.
+
+ A colorizer is an object that loosely wraps around a stream, allowing
+ callers to write text to the stream in a particular color.
+
+ Colorizer classes must implement C{supported()} and C{write(text, color)}.
+ """
+ _colors = dict(black=30, red=31, green=32, yellow=33,
+ blue=34, magenta=35, cyan=36, white=37)
+
+ def __init__(self, stream):
+ self.stream = stream
+
+ def supported(cls, stream=sys.stdout):
+ """Check is the current platform supports coloring terminal output.
+
+ A class method that returns True if the current platform supports
+ coloring terminal output using this method. Returns False otherwise.
+ """
+ if not stream.isatty():
+ return False # auto color only on TTYs
+ try:
+ import curses
+ except ImportError:
+ return False
+ else:
+ try:
+ try:
+ return curses.tigetnum("colors") > 2
+ except curses.error:
+ curses.setupterm()
+ return curses.tigetnum("colors") > 2
+ except Exception:
+ # guess false in case of error
+ return False
+ supported = classmethod(supported)
+
+ def write(self, text, color):
+ """Write the given text to the stream in the given color.
+
+ @param text: Text to be written to the stream.
+
+ @param color: A string label for a color. e.g. 'red', 'white'.
+ """
+ color = self._colors[color]
+ self.stream.write('\x1b[%s;1m%s\x1b[0m' % (color, text))
+
+
+class _Win32Colorizer(object):
+ """See _AnsiColorizer docstring."""
+ def __init__(self, stream):
+ import win32console
+ red, green, blue, bold = (win32console.FOREGROUND_RED,
+ win32console.FOREGROUND_GREEN,
+ win32console.FOREGROUND_BLUE,
+ win32console.FOREGROUND_INTENSITY)
+ self.stream = stream
+ self.screenBuffer = win32console.GetStdHandle(
+ win32console.STD_OUT_HANDLE)
+ self._colors = {
+ 'normal': red | green | blue,
+ 'red': red | bold,
+ 'green': green | bold,
+ 'blue': blue | bold,
+ 'yellow': red | green | bold,
+ 'magenta': red | blue | bold,
+ 'cyan': green | blue | bold,
+ 'white': red | green | blue | bold,
+ }
+
+ def supported(cls, stream=sys.stdout):
+ try:
+ import win32console
+ screenBuffer = win32console.GetStdHandle(
+ win32console.STD_OUT_HANDLE)
+ except ImportError:
+ return False
+ import pywintypes
+ try:
+ screenBuffer.SetConsoleTextAttribute(
+ win32console.FOREGROUND_RED |
+ win32console.FOREGROUND_GREEN |
+ win32console.FOREGROUND_BLUE)
+ except pywintypes.error:
+ return False
+ else:
+ return True
+ supported = classmethod(supported)
+
+ def write(self, text, color):
+ color = self._colors[color]
+ self.screenBuffer.SetConsoleTextAttribute(color)
+ self.stream.write(text)
+ self.screenBuffer.SetConsoleTextAttribute(self._colors['normal'])
+
+
+class _NullColorizer(object):
+ """See _AnsiColorizer docstring."""
+ def __init__(self, stream):
+ self.stream = stream
+
+ def supported(cls, stream=sys.stdout):
+ return True
+ supported = classmethod(supported)
+
+ def write(self, text, color):
+ self.stream.write(text)
+
+
+def get_elapsed_time_color(elapsed_time):
+ if elapsed_time > 1.0:
+ return 'red'
+ elif elapsed_time > 0.25:
+ return 'yellow'
+ else:
+ return 'green'
+
+
+class OpenStackTestResult(testtools.TestResult):
+ def __init__(self, stream, descriptions, verbosity):
+ super(OpenStackTestResult, self).__init__()
+ self.stream = stream
+ self.showAll = verbosity > 1
+ self.num_slow_tests = 10
+ self.slow_tests = [] # this is a fixed-sized heap
+ self.colorizer = None
+ # NOTE(vish): reset stdout for the terminal check
+ stdout = sys.stdout
+ sys.stdout = sys.__stdout__
+ for colorizer in [_Win32Colorizer, _AnsiColorizer, _NullColorizer]:
+ if colorizer.supported():
+ self.colorizer = colorizer(self.stream)
+ break
+ sys.stdout = stdout
+ self.start_time = None
+ self.last_time = {}
+ self.results = {}
+ self.last_written = None
+
+ def _writeElapsedTime(self, elapsed):
+ color = get_elapsed_time_color(elapsed)
+ self.colorizer.write(" %.2f" % elapsed, color)
+
+ def _addResult(self, test, *args):
+ try:
+ name = test.id()
+ except AttributeError:
+ name = 'Unknown.unknown'
+ test_class, test_name = name.rsplit('.', 1)
+
+ elapsed = (self._now() - self.start_time).total_seconds()
+ item = (elapsed, test_class, test_name)
+ if len(self.slow_tests) >= self.num_slow_tests:
+ heapq.heappushpop(self.slow_tests, item)
+ else:
+ heapq.heappush(self.slow_tests, item)
+
+ self.results.setdefault(test_class, [])
+ self.results[test_class].append((test_name, elapsed) + args)
+ self.last_time[test_class] = self._now()
+ self.writeTests()
+
+ def _writeResult(self, test_name, elapsed, long_result, color,
+ short_result, success):
+ if self.showAll:
+ self.stream.write(' %s' % str(test_name).ljust(66))
+ self.colorizer.write(long_result, color)
+ if success:
+ self._writeElapsedTime(elapsed)
+ self.stream.writeln()
+ else:
+ self.colorizer.write(short_result, color)
+
+ def addSuccess(self, test):
+ super(OpenStackTestResult, self).addSuccess(test)
+ self._addResult(test, 'OK', 'green', '.', True)
+
+ def addFailure(self, test, err):
+ if test.id() == 'process-returncode':
+ return
+ super(OpenStackTestResult, self).addFailure(test, err)
+ self._addResult(test, 'FAIL', 'red', 'F', False)
+
+ def addError(self, test, err):
+ super(OpenStackTestResult, self).addFailure(test, err)
+ self._addResult(test, 'ERROR', 'red', 'E', False)
+
+ def addSkip(self, test, reason=None, details=None):
+ super(OpenStackTestResult, self).addSkip(test, reason, details)
+ self._addResult(test, 'SKIP', 'blue', 'S', True)
+
+ def startTest(self, test):
+ self.start_time = self._now()
+ super(OpenStackTestResult, self).startTest(test)
+
+ def writeTestCase(self, cls):
+ if not self.results.get(cls):
+ return
+ if cls != self.last_written:
+ self.colorizer.write(cls, 'white')
+ self.stream.writeln()
+ for result in self.results[cls]:
+ self._writeResult(*result)
+ del self.results[cls]
+ self.stream.flush()
+ self.last_written = cls
+
+ def writeTests(self):
+ time = self.last_time.get(self.last_written, self._now())
+ if not self.last_written or (self._now() - time).total_seconds() > 2.0:
+ diff = 3.0
+ while diff > 2.0:
+ classes = self.results.keys()
+ oldest = min(classes, key=lambda x: self.last_time[x])
+ diff = (self._now() - self.last_time[oldest]).total_seconds()
+ self.writeTestCase(oldest)
+ else:
+ self.writeTestCase(self.last_written)
+
+ def done(self):
+ self.stopTestRun()
+
+ def stopTestRun(self):
+ for cls in list(self.results.iterkeys()):
+ self.writeTestCase(cls)
+ self.stream.writeln()
+ self.writeSlowTests()
+
+ def writeSlowTests(self):
+ # Pare out 'fast' tests
+ slow_tests = [item for item in self.slow_tests
+ if get_elapsed_time_color(item[0]) != 'green']
+ if slow_tests:
+ slow_total_time = sum(item[0] for item in slow_tests)
+ slow = ("Slowest %i tests took %.2f secs:"
+ % (len(slow_tests), slow_total_time))
+ self.colorizer.write(slow, 'yellow')
+ self.stream.writeln()
+ last_cls = None
+ # sort by name
+ for elapsed, cls, name in sorted(slow_tests,
+ key=lambda x: x[1] + x[2]):
+ if cls != last_cls:
+ self.colorizer.write(cls, 'white')
+ self.stream.writeln()
+ last_cls = cls
+ self.stream.write(' %s' % str(name).ljust(68))
+ self._writeElapsedTime(elapsed)
+ self.stream.writeln()
+
+ def printErrors(self):
+ if self.showAll:
+ self.stream.writeln()
+ self.printErrorList('ERROR', self.errors)
+ self.printErrorList('FAIL', self.failures)
+
+ def printErrorList(self, flavor, errors):
+ for test, err in errors:
+ self.colorizer.write("=" * 70, 'red')
+ self.stream.writeln()
+ self.colorizer.write(flavor, 'red')
+ self.stream.writeln(": %s" % test.id())
+ self.colorizer.write("-" * 70, 'red')
+ self.stream.writeln()
+ self.stream.writeln("%s" % err)
+
+
+test = subunit.ProtocolTestCase(sys.stdin, passthrough=None)
+
+if sys.version_info[0:2] <= (2, 6):
+ runner = unittest.TextTestRunner(verbosity=2)
+else:
+ runner = unittest.TextTestRunner(verbosity=2,
+ resultclass=OpenStackTestResult)
+
+if runner.run(test).wasSuccessful():
+ exit_code = 0
+else:
+ exit_code = 1
+sys.exit(exit_code)
diff --git a/tools/config/generate_sample.sh b/tools/config/generate_sample.sh
index 26f02dd..f5e5a67 100755
--- a/tools/config/generate_sample.sh
+++ b/tools/config/generate_sample.sh
@@ -64,6 +64,9 @@ FILES=$(find $BASEDIR/$PACKAGENAME -type f -name "*.py" ! -path "*/tests/*" \
export EVENTLET_NO_GREENDNS=yes
+OS_VARS=$(set | sed -n '/^OS_/s/=[^=]*$//gp' | xargs)
+[ "$OS_VARS" ] && eval "unset \$OS_VARS"
+
MODULEPATH=openstack.common.config.generator
OUTPUTFILE=$OUTPUTDIR/$PACKAGENAME.conf.sample
python -m $MODULEPATH $FILES > $OUTPUTFILE
diff --git a/tools/install_venv.py b/tools/install_venv.py
new file mode 100644
index 0000000..1d4e7e0
--- /dev/null
+++ b/tools/install_venv.py
@@ -0,0 +1,74 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Copyright 2010 OpenStack Foundation
+# Copyright 2013 IBM Corp.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import os
+import sys
+
+import install_venv_common as install_venv # noqa
+
+
+def print_help(venv, root):
+ help = """
+ Openstack development environment setup is complete.
+
+ Openstack development uses virtualenv to track and manage Python
+ dependencies while in development and testing.
+
+ To activate the Openstack virtualenv for the extent of your current shell
+ session you can run:
+
+ $ source %s/bin/activate
+
+ Or, if you prefer, you can run commands in the virtualenv on a case by case
+ basis by running:
+
+ $ %s/tools/with_venv.sh <your command>
+
+ Also, make test will automatically use the virtualenv.
+ """
+ print(help % (venv, root))
+
+
+def main(argv):
+ root = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+
+ if os.environ.get('tools_path'):
+ root = os.environ['tools_path']
+ venv = os.path.join(root, '.venv')
+ if os.environ.get('venv'):
+ venv = os.environ['venv']
+
+ pip_requires = os.path.join(root, 'requirements.txt')
+ test_requires = os.path.join(root, 'test-requirements.txt')
+ py_version = "python%s.%s" % (sys.version_info[0], sys.version_info[1])
+ project = 'Openstack'
+ install = install_venv.InstallVenv(root, venv, pip_requires, test_requires,
+ py_version, project)
+ options = install.parse_args(argv)
+ install.check_python_version()
+ install.check_dependencies()
+ install.create_virtualenv(no_site_packages=options.no_site_packages)
+ install.install_dependencies()
+ install.post_process()
+ print_help(venv, root)
+
+if __name__ == '__main__':
+ main(sys.argv)
diff --git a/tools/install_venv_common.py b/tools/install_venv_common.py
index f428c1e..0999e2c 100644
--- a/tools/install_venv_common.py
+++ b/tools/install_venv_common.py
@@ -114,9 +114,10 @@ class InstallVenv(object):
print('Installing dependencies with pip (this can take a while)...')
# First things first, make sure our venv has the latest pip and
- # setuptools.
- self.pip_install('pip>=1.3')
+ # setuptools and pbr
+ self.pip_install('pip>=1.4')
self.pip_install('setuptools')
+ self.pip_install('pbr')
self.pip_install('-r', self.requirements)
self.pip_install('-r', self.test_requirements)
@@ -201,12 +202,13 @@ class Fedora(Distro):
RHEL: https://bugzilla.redhat.com/958868
"""
- # Install "patch" program if it's not there
- if not self.check_pkg('patch'):
- self.die("Please install 'patch'.")
+ if os.path.exists('contrib/redhat-eventlet.patch'):
+ # Install "patch" program if it's not there
+ if not self.check_pkg('patch'):
+ self.die("Please install 'patch'.")
- # Apply the eventlet patch
- self.apply_patch(os.path.join(self.venv, 'lib', self.py_version,
- 'site-packages',
- 'eventlet/green/subprocess.py'),
- 'contrib/redhat-eventlet.patch')
+ # Apply the eventlet patch
+ self.apply_patch(os.path.join(self.venv, 'lib', self.py_version,
+ 'site-packages',
+ 'eventlet/green/subprocess.py'),
+ 'contrib/redhat-eventlet.patch')
diff --git a/tools/run_tests_common.sh b/tools/run_tests_common.sh
new file mode 100755
index 0000000..1cc4e6f
--- /dev/null
+++ b/tools/run_tests_common.sh
@@ -0,0 +1,253 @@
+#!/bin/bash
+
+set -eu
+
+function usage {
+ echo "Usage: $0 [OPTION]..."
+ echo "Run project's test suite(s)"
+ echo ""
+ echo " -V, --virtual-env Always use virtualenv. Install automatically if not present"
+ echo " -N, --no-virtual-env Don't use virtualenv. Run tests in local environment"
+ echo " -s, --no-site-packages Isolate the virtualenv from the global Python environment"
+ echo " -r, --recreate-db Recreate the test database (deprecated, as this is now the default)."
+ echo " -n, --no-recreate-db Don't recreate the test database."
+ echo " -f, --force Force a clean re-build of the virtual environment. Useful when dependencies have been added."
+ echo " -u, --update Update the virtual environment with any newer package versions"
+ echo " -p, --pep8 Just run PEP8 and HACKING compliance check"
+ echo " -P, --no-pep8 Don't run static code checks"
+ echo " -c, --coverage Generate coverage report"
+ echo " -d, --debug Run tests with testtools instead of testr. This allows you to use the debugger."
+ echo " -h, --help Print this usage message"
+ echo " --hide-elapsed Don't print the elapsed time for each test along with slow test list"
+ echo " --virtual-env-path <path> Location of the virtualenv directory"
+ echo " Default: \$(pwd)"
+ echo " --virtual-env-name <name> Name of the virtualenv directory"
+ echo " Default: .venv"
+ echo " --tools-path <dir> Location of the tools directory"
+ echo " Default: \$(pwd)"
+ echo ""
+ echo "Note: with no options specified, the script will try to run the tests in a virtual environment,"
+ echo " If no virtualenv is found, the script will ask if you would like to create one. If you "
+ echo " prefer to run tests NOT in a virtual environment, simply pass the -N option."
+ exit
+}
+
+function process_options {
+ i=1
+ while [ $i -le $# ]; do
+ case "${!i}" in
+ -h|--help) usage;;
+ -V|--virtual-env) ALWAYS_VENV=1; NEVER_VENV=0;;
+ -N|--no-virtual-env) ALWAYS_VENV=0; NEVER_VENV=1;;
+ -s|--no-site-packages) NO_SITE_PACKAGES=1;;
+ -r|--recreate-db) RECREATE_DB=1;;
+ -n|--no-recreate-db) RECREATE_DB=0;;
+ -f|--force) FORCE=1;;
+ -u|--update) UPDATE=1;;
+ -p|--pep8) JUST_PEP8=1;;
+ -P|--no-pep8) NO_PEP8=1;;
+ -c|--coverage) COVERAGE=1;;
+ -d|--debug) DEBUG=1;;
+ --virtual-env-path)
+ (( i++ ))
+ VENV_PATH=${!i}
+ ;;
+ --virtual-env-name)
+ (( i++ ))
+ VENV_DIR=${!i}
+ ;;
+ --tools-path)
+ (( i++ ))
+ TOOLS_PATH=${!i}
+ ;;
+ -*) TESTOPTS="$TESTOPTS ${!i}";;
+ *) TESTRARGS="$TESTRARGS ${!i}"
+ esac
+ (( i++ ))
+ done
+}
+
+
+TOOLS_PATH=${TOOLS_PATH:-${PWD}}
+VENV_PATH=${VENV_PATH:-${PWD}}
+VENV_DIR=${VENV_NAME:-.venv}
+WITH_VENV=${TOOLS_PATH}/tools/with_venv.sh
+
+ALWAYS_VENV=0
+NEVER_VENV=0
+FORCE=0
+NO_SITE_PACKAGES=1
+INSTALLVENVOPTS=
+TESTRARGS=
+TESTOPTS=
+WRAPPER=""
+JUST_PEP8=0
+NO_PEP8=0
+COVERAGE=0
+DEBUG=0
+RECREATE_DB=1
+UPDATE=0
+
+LANG=en_US.UTF-8
+LANGUAGE=en_US:en
+LC_ALL=C
+
+process_options $@
+# Make our paths available to other scripts we call
+export VENV_PATH
+export TOOLS_PATH
+export VENV_DIR
+export VENV_NAME
+export WITH_VENV
+export VENV=${VENV_PATH}/${VENV_DIR}
+
+function init_testr {
+ if [ ! -d .testrepository ]; then
+ ${WRAPPER} testr init
+ fi
+}
+
+function run_tests {
+ # Cleanup *pyc
+ ${WRAPPER} find . -type f -name "*.pyc" -delete
+
+ if [ ${DEBUG} -eq 1 ]; then
+ if [ "${TESTOPTS}" = "" ] && [ "${TESTRARGS}" = "" ]; then
+ # Default to running all tests if specific test is not
+ # provided.
+ TESTRARGS="discover ./${TESTS_DIR}"
+ fi
+ ${WRAPPER} python -m testtools.run ${TESTOPTS} ${TESTRARGS}
+
+ # Short circuit because all of the testr and coverage stuff
+ # below does not make sense when running testtools.run for
+ # debugging purposes.
+ return $?
+ fi
+
+ if [ ${COVERAGE} -eq 1 ]; then
+ TESTRTESTS="${TESTRTESTS} --coverage"
+ else
+ TESTRTESTS="${TESTRTESTS}"
+ fi
+
+ # Just run the test suites in current environment
+ set +e
+ TESTRARGS=`echo "${TESTRARGS}" | sed -e's/^\s*\(.*\)\s*$/\1/'`
+
+ if [ ${WORKERS_COUNT} -ne 0 ]; then
+ TESTRTESTS="${TESTRTESTS} --testr-args='--concurrency=${WORKERS_COUNT} --subunit ${TESTOPTS} ${TESTRARGS}'"
+ else
+ TESTRTESTS="${TESTRTESTS} --testr-args='--subunit ${TESTOPTS} ${TESTRARGS}'"
+ fi
+
+ if [ setup.cfg -nt ${EGG_INFO_FILE} ]; then
+ ${WRAPPER} python setup.py egg_info
+ fi
+
+ echo "Running \`${WRAPPER} ${TESTRTESTS}\`"
+ if ${WRAPPER} which subunit-2to1 2>&1 > /dev/null; then
+ # subunit-2to1 is present, testr subunit stream should be in version 2
+ # format. Convert to version one before colorizing.
+ bash -c "${WRAPPER} ${TESTRTESTS} | ${WRAPPER} subunit-2to1 | ${WRAPPER} ${TOOLS_PATH}/tools/colorizer.py"
+ else
+ bash -c "${WRAPPER} ${TESTRTESTS} | ${WRAPPER} ${TOOLS_PATH}/tools/colorizer.py"
+ fi
+ RESULT=$?
+ set -e
+
+ copy_subunit_log
+
+ if [ $COVERAGE -eq 1 ]; then
+ echo "Generating coverage report in covhtml/"
+ ${WRAPPER} coverage combine
+ # Don't compute coverage for common code, which is tested elsewhere
+ # if we are not in `oslo-incubator` project
+ if [ ${OMIT_OSLO_FROM_COVERAGE} -eq 0 ]; then
+ OMIT_OSLO=""
+ else
+ OMIT_OSLO="--omit='${PROJECT_NAME}/openstack/common/*'"
+ fi
+ ${WRAPPER} coverage html --include='${PROJECT_NAME}/*' ${OMIT_OSLO} -d covhtml -i
+ fi
+
+ return ${RESULT}
+}
+
+function copy_subunit_log {
+ LOGNAME=`cat .testrepository/next-stream`
+ LOGNAME=$((${LOGNAME} - 1))
+ LOGNAME=".testrepository/${LOGNAME}"
+ cp ${LOGNAME} subunit.log
+}
+
+function run_pep8 {
+ echo "Running flake8 ..."
+ bash -c "${WRAPPER} flake8"
+}
+
+
+TESTRTESTS="python setup.py testr"
+
+if [ ${NO_SITE_PACKAGES} -eq 1 ]; then
+ INSTALLVENVOPTS="--no-site-packages"
+fi
+
+if [ ${NEVER_VENV} -eq 0 ]; then
+ # Remove the virtual environment if -f or --force used
+ if [ ${FORCE} -eq 1 ]; then
+ echo "Cleaning virtualenv..."
+ rm -rf ${VENV}
+ fi
+
+ # Update the virtual environment if -u or --update used
+ if [ ${UPDATE} -eq 1 ]; then
+ echo "Updating virtualenv..."
+ python ${TOOLS_PATH}/tools/install_venv.py ${INSTALLVENVOPTS}
+ fi
+
+ if [ -e ${VENV} ]; then
+ WRAPPER="${WITH_VENV}"
+ else
+ if [ ${ALWAYS_VENV} -eq 1 ]; then
+ # Automatically install the virtualenv
+ python ${TOOLS_PATH}/tools/install_venv.py ${INSTALLVENVOPTS}
+ WRAPPER="${WITH_VENV}"
+ else
+ echo -e "No virtual environment found...create one? (Y/n) \c"
+ read USE_VENV
+ if [ "x${USE_VENV}" = "xY" -o "x${USE_VENV}" = "x" -o "x${USE_VENV}" = "xy" ]; then
+ # Install the virtualenv and run the test suite in it
+ python ${TOOLS_PATH}/tools/install_venv.py ${INSTALLVENVOPTS}
+ WRAPPER=${WITH_VENV}
+ fi
+ fi
+ fi
+fi
+
+# Delete old coverage data from previous runs
+if [ ${COVERAGE} -eq 1 ]; then
+ ${WRAPPER} coverage erase
+fi
+
+if [ ${JUST_PEP8} -eq 1 ]; then
+ run_pep8
+ exit
+fi
+
+if [ ${RECREATE_DB} -eq 1 ]; then
+ rm -f tests.sqlite
+fi
+
+init_testr
+run_tests
+
+# NOTE(sirp): we only want to run pep8 when we're running the full-test suite,
+# not when we're running tests individually. To handle this, we need to
+# distinguish between options (testropts), which begin with a '-', and
+# arguments (testrargs).
+if [ -z "${TESTRARGS}" ]; then
+ if [ ${NO_PEP8} -eq 0 ]; then
+ run_pep8
+ fi
+fi
diff --git a/tools/with_venv.sh b/tools/with_venv.sh
new file mode 100755
index 0000000..7303990
--- /dev/null
+++ b/tools/with_venv.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+TOOLS_PATH=${TOOLS_PATH:-$(dirname $0)}
+VENV_PATH=${VENV_PATH:-${TOOLS_PATH}}
+VENV_DIR=${VENV_NAME:-/../.venv}
+TOOLS=${TOOLS_PATH}
+VENV=${VENV:-${VENV_PATH}/${VENV_DIR}}
+source ${VENV}/bin/activate && "$@"