diff options
107 files changed, 6564 insertions, 801 deletions
@@ -2,3 +2,4 @@ # <preferred e-mail> <other e-mail 1> # <preferred e-mail> <other e-mail 2> Zhongyue Luo <zhongyue.nah@intel.com> <lzyeval@gmail.com> +<yufang521247@gmail.com> <zhangyufang@360.cn> diff --git a/MAINTAINERS b/MAINTAINERS index e069373..a611a5c 100644 --- a/MAINTAINERS +++ b/MAINTAINERS @@ -78,6 +78,12 @@ M: S: Orphan F: cliutils.py +== compat == + +M: Chuck Short <chuck.short@canonical.com> +S: Maintained +F: py3kcompat + == context == M: @@ -109,12 +115,6 @@ M: S: Orphan F: eventlet_backdoor.py -== exception == - -M: -S: Obsolete -F: exception.py - == excutils == M: @@ -223,6 +223,12 @@ M: Michael Still <mikal@stillhq.com> S: Maintained F: processutils.py +== quota == + +M: Sergey Skripnick <sskripnick@mirantis.com> +S: Maintained +F: quota.py + == redhat-eventlet.patch == M: Mark McLoughlin <markmc@redhat.com> diff --git a/openstack/common/apiclient/auth.py b/openstack/common/apiclient/auth.py new file mode 100644 index 0000000..1744228 --- /dev/null +++ b/openstack/common/apiclient/auth.py @@ -0,0 +1,227 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 OpenStack Foundation +# Copyright 2013 Spanish National Research Council. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +# E0202: An attribute inherited from %s hide this method +# pylint: disable=E0202 + +import abc +import argparse +import logging +import os + +from stevedore import extension + +from openstack.common.apiclient import exceptions + + +logger = logging.getLogger(__name__) + + +_discovered_plugins = {} + + +def discover_auth_systems(): + """Discover the available auth-systems. + + This won't take into account the old style auth-systems. + """ + global _discovered_plugins + _discovered_plugins = {} + + def add_plugin(ext): + _discovered_plugins[ext.name] = ext.plugin + + ep_namespace = "openstack.common.apiclient.auth" + mgr = extension.ExtensionManager(ep_namespace) + mgr.map(add_plugin) + + +def load_auth_system_opts(parser): + """Load options needed by the available auth-systems into a parser. + + This function will try to populate the parser with options from the + available plugins. + """ + group = parser.add_argument_group("Common auth options") + BaseAuthPlugin.add_common_opts(group) + for name, auth_plugin in _discovered_plugins.iteritems(): + group = parser.add_argument_group( + "Auth-system '%s' options" % name, + conflict_handler="resolve") + auth_plugin.add_opts(group) + + +def load_plugin(auth_system): + try: + plugin_class = _discovered_plugins[auth_system] + except KeyError: + raise exceptions.AuthSystemNotFound(auth_system) + return plugin_class(auth_system=auth_system) + + +def load_plugin_from_args(args): + """Load requred plugin and populate it with options. + + Try to guess auth system if it is not specified. Systems are tried in + alphabetical order. + + :type args: argparse.Namespace + :raises: AuthorizationFailure + """ + auth_system = args.os_auth_system + if auth_system: + plugin = load_plugin(auth_system) + plugin.parse_opts(args) + plugin.sufficient_options() + return plugin + + for plugin_auth_system in sorted(_discovered_plugins.iterkeys()): + plugin_class = _discovered_plugins[plugin_auth_system] + plugin = plugin_class() + plugin.parse_opts(args) + try: + plugin.sufficient_options() + except exceptions.AuthPluginOptionsMissing: + continue + return plugin + raise exceptions.AuthPluginOptionsMissing(["auth_system"]) + + +class BaseAuthPlugin(object): + """Base class for authentication plugins. + + An authentication plugin needs to override at least the authenticate + method to be a valid plugin. + """ + + __metaclass__ = abc.ABCMeta + + auth_system = None + opt_names = [] + common_opt_names = [ + "auth_system", + "username", + "password", + "tenant_name", + "token", + "auth_url", + ] + + def __init__(self, auth_system=None, **kwargs): + self.auth_system = auth_system or self.auth_system + self.opts = dict((name, kwargs.get(name)) + for name in self.opt_names) + + @staticmethod + def _parser_add_opt(parser, opt): + """Add an option to parser in two variants. + + :param opt: option name (with underscores) + """ + dashed_opt = opt.replace("_", "-") + env_var = "OS_%s" % opt.upper() + arg_default = os.environ.get(env_var, "") + arg_help = "Defaults to env[%s]." % env_var + parser.add_argument( + "--os-%s" % dashed_opt, + metavar="<%s>" % dashed_opt, + default=arg_default, + help=arg_help) + parser.add_argument( + "--os_%s" % opt, + metavar="<%s>" % dashed_opt, + help=argparse.SUPPRESS) + + @classmethod + def add_opts(cls, parser): + """Populate the parser with the options for this plugin. + """ + for opt in cls.opt_names: + # use `BaseAuthPlugin.common_opt_names` since it is never + # changed in child classes + if opt not in BaseAuthPlugin.common_opt_names: + cls._parser_add_opt(parser, opt) + + @classmethod + def add_common_opts(cls, parser): + """Add options that are common for several plugins. + """ + for opt in cls.common_opt_names: + cls._parser_add_opt(parser, opt) + + @staticmethod + def get_opt(opt_name, args): + """Return option name and value. + + :param opt_name: name of the option, e.g., "username" + :param args: parsed arguments + """ + return (opt_name, getattr(args, "os_%s" % opt_name, None)) + + def parse_opts(self, args): + """Parse the actual auth-system options if any. + + This method is expected to populate the attribute `self.opts` with a + dict containing the options and values needed to make authentication. + """ + self.opts.update(dict(self.get_opt(opt_name, args) + for opt_name in self.opt_names)) + + def authenticate(self, http_client): + """Authenticate using plugin defined method. + + The method usually analyses `self.opts` and performs + a request to authentication server. + + :param http_client: client object that needs authentication + :type http_client: HTTPClient + :raises: AuthorizationFailure + """ + self.sufficient_options() + self._do_authenticate(http_client) + + @abc.abstractmethod + def _do_authenticate(self, http_client): + """Protected method for authentication. + """ + + def sufficient_options(self): + """Check if all required options are present. + + :raises: AuthPluginOptionsMissing + """ + missing = [opt + for opt in self.opt_names + if not self.opts.get(opt)] + if missing: + raise exceptions.AuthPluginOptionsMissing(missing) + + @abc.abstractmethod + def token_and_endpoint(self, endpoint_type, service_type): + """Return token and endpoint. + + :param service_type: Service type of the endpoint + :type service_type: string + :param endpoint_type: Type of endpoint. + Possible values: public or publicURL, + internal or internalURL, + admin or adminURL + :type endpoint_type: string + :returns: tuple of token and endpoint strings + :raises: EndpointException + """ diff --git a/openstack/common/apiclient/base.py b/openstack/common/apiclient/base.py new file mode 100644 index 0000000..0ecf1f5 --- /dev/null +++ b/openstack/common/apiclient/base.py @@ -0,0 +1,492 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 Jacob Kaplan-Moss +# Copyright 2011 OpenStack LLC +# Copyright 2012 Grid Dynamics +# Copyright 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Base utilities to build API operation managers and objects on top of. +""" + +# E1102: %s is not callable +# pylint: disable=E1102 + +import abc +import urllib + +from openstack.common.apiclient import exceptions +from openstack.common import strutils + + +def getid(obj): + """Return id if argument is a Resource. + + Abstracts the common pattern of allowing both an object or an object's ID + (UUID) as a parameter when dealing with relationships. + """ + try: + if obj.uuid: + return obj.uuid + except AttributeError: + pass + try: + return obj.id + except AttributeError: + return obj + + +# TODO(aababilov): call run_hooks() in HookableMixin's child classes +class HookableMixin(object): + """Mixin so classes can register and run hooks.""" + _hooks_map = {} + + @classmethod + def add_hook(cls, hook_type, hook_func): + """Add a new hook of specified type. + + :param cls: class that registers hooks + :param hook_type: hook type, e.g., '__pre_parse_args__' + :param hook_func: hook function + """ + if hook_type not in cls._hooks_map: + cls._hooks_map[hook_type] = [] + + cls._hooks_map[hook_type].append(hook_func) + + @classmethod + def run_hooks(cls, hook_type, *args, **kwargs): + """Run all hooks of specified type. + + :param cls: class that registers hooks + :param hook_type: hook type, e.g., '__pre_parse_args__' + :param **args: args to be passed to every hook function + :param **kwargs: kwargs to be passed to every hook function + """ + hook_funcs = cls._hooks_map.get(hook_type) or [] + for hook_func in hook_funcs: + hook_func(*args, **kwargs) + + +class BaseManager(HookableMixin): + """Basic manager type providing common operations. + + Managers interact with a particular type of API (servers, flavors, images, + etc.) and provide CRUD operations for them. + """ + resource_class = None + + def __init__(self, client): + """Initializes BaseManager with `client`. + + :param client: instance of BaseClient descendant for HTTP requests + """ + super(BaseManager, self).__init__() + self.client = client + + def _list(self, url, response_key, obj_class=None, json=None): + """List the collection. + + :param url: a partial URL, e.g., '/servers' + :param response_key: the key to be looked up in response dictionary, + e.g., 'servers' + :param obj_class: class for constructing the returned objects + (self.resource_class will be used by default) + :param json: data that will be encoded as JSON and passed in POST + request (GET will be sent by default) + """ + if json: + body = self.client.post(url, json=json).json() + else: + body = self.client.get(url).json() + + if obj_class is None: + obj_class = self.resource_class + + data = body[response_key] + # NOTE(ja): keystone returns values as list as {'values': [ ... ]} + # unlike other services which just return the list... + try: + data = data['values'] + except (KeyError, TypeError): + pass + + return [obj_class(self, res, loaded=True) for res in data if res] + + def _get(self, url, response_key): + """Get an object from collection. + + :param url: a partial URL, e.g., '/servers' + :param response_key: the key to be looked up in response dictionary, + e.g., 'server' + """ + body = self.client.get(url).json() + return self.resource_class(self, body[response_key], loaded=True) + + def _head(self, url): + """Retrieve request headers for an object. + + :param url: a partial URL, e.g., '/servers' + """ + resp = self.client.head(url) + return resp.status_code == 204 + + def _post(self, url, json, response_key, return_raw=False): + """Create an object. + + :param url: a partial URL, e.g., '/servers' + :param json: data that will be encoded as JSON and passed in POST + request (GET will be sent by default) + :param response_key: the key to be looked up in response dictionary, + e.g., 'servers' + :param return_raw: flag to force returning raw JSON instead of + Python object of self.resource_class + """ + body = self.client.post(url, json=json).json() + if return_raw: + return body[response_key] + return self.resource_class(self, body[response_key]) + + def _put(self, url, json=None, response_key=None): + """Update an object with PUT method. + + :param url: a partial URL, e.g., '/servers' + :param json: data that will be encoded as JSON and passed in POST + request (GET will be sent by default) + :param response_key: the key to be looked up in response dictionary, + e.g., 'servers' + """ + resp = self.client.put(url, json=json) + # PUT requests may not return a body + if resp.content: + body = resp.json() + if response_key is not None: + return self.resource_class(self, body[response_key]) + else: + return self.resource_class(self, body) + + def _patch(self, url, json=None, response_key=None): + """Update an object with PATCH method. + + :param url: a partial URL, e.g., '/servers' + :param json: data that will be encoded as JSON and passed in POST + request (GET will be sent by default) + :param response_key: the key to be looked up in response dictionary, + e.g., 'servers' + """ + body = self.client.patch(url, json=json).json() + if response_key is not None: + return self.resource_class(self, body[response_key]) + else: + return self.resource_class(self, body) + + def _delete(self, url): + """Delete an object. + + :param url: a partial URL, e.g., '/servers/my-server' + """ + return self.client.delete(url) + + +class ManagerWithFind(BaseManager): + """Manager with additional `find()`/`findall()` methods.""" + + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def list(self): + pass + + def find(self, **kwargs): + """Find a single item with attributes matching ``**kwargs``. + + This isn't very efficient: it loads the entire list then filters on + the Python side. + """ + matches = self.findall(**kwargs) + num_matches = len(matches) + if num_matches == 0: + msg = "No %s matching %s." % (self.resource_class.__name__, kwargs) + raise exceptions.NotFound(msg) + elif num_matches > 1: + raise exceptions.NoUniqueMatch() + else: + return matches[0] + + def findall(self, **kwargs): + """Find all items with attributes matching ``**kwargs``. + + This isn't very efficient: it loads the entire list then filters on + the Python side. + """ + found = [] + searches = kwargs.items() + + for obj in self.list(): + try: + if all(getattr(obj, attr) == value + for (attr, value) in searches): + found.append(obj) + except AttributeError: + continue + + return found + + +class CrudManager(BaseManager): + """Base manager class for manipulating entities. + + Children of this class are expected to define a `collection_key` and `key`. + + - `collection_key`: Usually a plural noun by convention (e.g. `entities`); + used to refer collections in both URL's (e.g. `/v3/entities`) and JSON + objects containing a list of member resources (e.g. `{'entities': [{}, + {}, {}]}`). + - `key`: Usually a singular noun by convention (e.g. `entity`); used to + refer to an individual member of the collection. + + """ + collection_key = None + key = None + + def build_url(self, base_url=None, **kwargs): + """Builds a resource URL for the given kwargs. + + Given an example collection where `collection_key = 'entities'` and + `key = 'entity'`, the following URL's could be generated. + + By default, the URL will represent a collection of entities, e.g.:: + + /entities + + If kwargs contains an `entity_id`, then the URL will represent a + specific member, e.g.:: + + /entities/{entity_id} + + :param base_url: if provided, the generated URL will be appended to it + """ + url = base_url if base_url is not None else '' + + url += '/%s' % self.collection_key + + # do we have a specific entity? + entity_id = kwargs.get('%s_id' % self.key) + if entity_id is not None: + url += '/%s' % entity_id + + return url + + def _filter_kwargs(self, kwargs): + """Drop null values and handle ids.""" + for key, ref in kwargs.copy().iteritems(): + if ref is None: + kwargs.pop(key) + else: + if isinstance(ref, Resource): + kwargs.pop(key) + kwargs['%s_id' % key] = getid(ref) + return kwargs + + def create(self, **kwargs): + kwargs = self._filter_kwargs(kwargs) + return self._post( + self.build_url(**kwargs), + {self.key: kwargs}, + self.key) + + def get(self, **kwargs): + kwargs = self._filter_kwargs(kwargs) + return self._get( + self.build_url(**kwargs), + self.key) + + def head(self, **kwargs): + kwargs = self._filter_kwargs(kwargs) + return self._head(self.build_url(**kwargs)) + + def list(self, base_url=None, **kwargs): + """List the collection. + + :param base_url: if provided, the generated URL will be appended to it + """ + kwargs = self._filter_kwargs(kwargs) + + return self._list( + '%(base_url)s%(query)s' % { + 'base_url': self.build_url(base_url=base_url, **kwargs), + 'query': '?%s' % urllib.urlencode(kwargs) if kwargs else '', + }, + self.collection_key) + + def put(self, base_url=None, **kwargs): + """Update an element. + + :param base_url: if provided, the generated URL will be appended to it + """ + kwargs = self._filter_kwargs(kwargs) + + return self._put(self.build_url(base_url=base_url, **kwargs)) + + def update(self, **kwargs): + kwargs = self._filter_kwargs(kwargs) + params = kwargs.copy() + params.pop('%s_id' % self.key) + + return self._patch( + self.build_url(**kwargs), + {self.key: params}, + self.key) + + def delete(self, **kwargs): + kwargs = self._filter_kwargs(kwargs) + + return self._delete( + self.build_url(**kwargs)) + + def find(self, base_url=None, **kwargs): + """Find a single item with attributes matching ``**kwargs``. + + :param base_url: if provided, the generated URL will be appended to it + """ + kwargs = self._filter_kwargs(kwargs) + + rl = self._list( + '%(base_url)s%(query)s' % { + 'base_url': self.build_url(base_url=base_url, **kwargs), + 'query': '?%s' % urllib.urlencode(kwargs) if kwargs else '', + }, + self.collection_key) + num = len(rl) + + if num == 0: + msg = "No %s matching %s." % (self.resource_class.__name__, kwargs) + raise exceptions.NotFound(404, msg) + elif num > 1: + raise exceptions.NoUniqueMatch + else: + return rl[0] + + +class Extension(HookableMixin): + """Extension descriptor.""" + + SUPPORTED_HOOKS = ('__pre_parse_args__', '__post_parse_args__') + manager_class = None + + def __init__(self, name, module): + super(Extension, self).__init__() + self.name = name + self.module = module + self._parse_extension_module() + + def _parse_extension_module(self): + self.manager_class = None + for attr_name, attr_value in self.module.__dict__.items(): + if attr_name in self.SUPPORTED_HOOKS: + self.add_hook(attr_name, attr_value) + else: + try: + if issubclass(attr_value, BaseManager): + self.manager_class = attr_value + except TypeError: + pass + + def __repr__(self): + return "<Extension '%s'>" % self.name + + +class Resource(object): + """Base class for OpenStack resources (tenant, user, etc.). + + This is pretty much just a bag for attributes. + """ + + HUMAN_ID = False + NAME_ATTR = 'name' + + def __init__(self, manager, info, loaded=False): + """Populate and bind to a manager. + + :param manager: BaseManager object + :param info: dictionary representing resource attributes + :param loaded: prevent lazy-loading if set to True + """ + self.manager = manager + self._info = info + self._add_details(info) + self._loaded = loaded + + def __repr__(self): + reprkeys = sorted(k + for k in self.__dict__.keys() + if k[0] != '_' and k != 'manager') + info = ", ".join("%s=%s" % (k, getattr(self, k)) for k in reprkeys) + return "<%s %s>" % (self.__class__.__name__, info) + + @property + def human_id(self): + """Human-readable ID which can be used for bash completion. + """ + if self.NAME_ATTR in self.__dict__ and self.HUMAN_ID: + return strutils.to_slug(getattr(self, self.NAME_ATTR)) + return None + + def _add_details(self, info): + for (k, v) in info.iteritems(): + try: + setattr(self, k, v) + self._info[k] = v + except AttributeError: + # In this case we already defined the attribute on the class + pass + + def __getattr__(self, k): + if k not in self.__dict__: + #NOTE(bcwaldon): disallow lazy-loading if already loaded once + if not self.is_loaded(): + self.get() + return self.__getattr__(k) + + raise AttributeError(k) + else: + return self.__dict__[k] + + def get(self): + # set_loaded() first ... so if we have to bail, we know we tried. + self.set_loaded(True) + if not hasattr(self.manager, 'get'): + return + + new = self.manager.get(self.id) + if new: + self._add_details(new._info) + + def __eq__(self, other): + if not isinstance(other, Resource): + return NotImplemented + # two resources of different types are not equal + if not isinstance(other, self.__class__): + return False + if hasattr(self, 'id') and hasattr(other, 'id'): + return self.id == other.id + return self._info == other._info + + def is_loaded(self): + return self._loaded + + def set_loaded(self, val): + self._loaded = val diff --git a/openstack/common/apiclient/client.py b/openstack/common/apiclient/client.py new file mode 100644 index 0000000..cb3e1d7 --- /dev/null +++ b/openstack/common/apiclient/client.py @@ -0,0 +1,360 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 Jacob Kaplan-Moss +# Copyright 2011 OpenStack LLC +# Copyright 2011 Piston Cloud Computing, Inc. +# Copyright 2013 Alessio Ababilov +# Copyright 2013 Grid Dynamics +# Copyright 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +OpenStack Client interface. Handles the REST calls and responses. +""" + +# E0202: An attribute inherited from %s hide this method +# pylint: disable=E0202 + +import logging +import time + +try: + import simplejson as json +except ImportError: + import json + +import requests + +from openstack.common.apiclient import exceptions +from openstack.common import importutils + + +_logger = logging.getLogger(__name__) + + +class HTTPClient(object): + """This client handles sending HTTP requests to OpenStack servers. + + Features: + - share authentication information between several clients to different + services (e.g., for compute and image clients); + - reissue authentication request for expired tokens; + - encode/decode JSON bodies; + - raise exeptions on HTTP errors; + - pluggable authentication; + - store authentication information in a keyring; + - store time spent for requests; + - register clients for particular services, so one can use + `http_client.identity` or `http_client.compute`; + - log requests and responses in a format that is easy to copy-and-paste + into terminal and send the same request with curl. + """ + + user_agent = "openstack.common.apiclient" + + def __init__(self, + auth_plugin, + region_name=None, + endpoint_type="publicURL", + original_ip=None, + verify=True, + cert=None, + timeout=None, + timings=False, + keyring_saver=None, + debug=False, + user_agent=None, + http=None): + self.auth_plugin = auth_plugin + + self.endpoint_type = endpoint_type + self.region_name = region_name + + self.original_ip = original_ip + self.timeout = timeout + self.verify = verify + self.cert = cert + + self.keyring_saver = keyring_saver + self.debug = debug + self.user_agent = user_agent or self.user_agent + + self.times = [] # [("item", starttime, endtime), ...] + self.timings = timings + + # requests within the same session can reuse TCP connections from pool + self.http = http or requests.Session() + + self.cached_token = None + + def _http_log_req(self, method, url, kwargs): + if not self.debug: + return + + string_parts = [ + "curl -i", + "-X '%s'" % method, + "'%s'" % url, + ] + + for element in kwargs['headers']: + header = "-H '%s: %s'" % (element, kwargs['headers'][element]) + string_parts.append(header) + + _logger.debug("REQ: %s" % " ".join(string_parts)) + if 'data' in kwargs: + _logger.debug("REQ BODY: %s\n" % (kwargs['data'])) + + def _http_log_resp(self, resp): + if not self.debug: + return + _logger.debug( + "RESP: [%s] %s\n", + resp.status_code, + resp.headers) + if resp._content_consumed: + _logger.debug( + "RESP BODY: %s\n", + resp.text) + + def serialize(self, kwargs): + if kwargs.get('json') is not None: + kwargs['headers']['Content-Type'] = 'application/json' + kwargs['data'] = json.dumps(kwargs['json']) + try: + del kwargs['json'] + except KeyError: + pass + + def get_timings(self): + return self.times + + def reset_timings(self): + self.times = [] + + def request(self, method, url, **kwargs): + """Send an http request with the specified characteristics. + + Wrapper around `requests.Session.request` to handle tasks such as + setting headers, JSON encoding/decoding, and error handling. + + :param method: method of HTTP request + :param url: URL of HTTP request + :param kwargs: any other parameter that can be passed to +' requests.Session.request (such as `headers`) or `json` + that will be encoded as JSON and used as `data` argument + """ + kwargs.setdefault("headers", kwargs.get("headers", {})) + kwargs["headers"]["User-Agent"] = self.user_agent + if self.original_ip: + kwargs["headers"]["Forwarded"] = "for=%s;by=%s" % ( + self.original_ip, self.user_agent) + if self.timeout is not None: + kwargs.setdefault("timeout", self.timeout) + kwargs.setdefault("verify", self.verify) + if self.cert is not None: + kwargs.setdefault("cert", self.cert) + self.serialize(kwargs) + + self._http_log_req(method, url, kwargs) + if self.timings: + start_time = time.time() + resp = self.http.request(method, url, **kwargs) + if self.timings: + self.times.append(("%s %s" % (method, url), + start_time, time.time())) + self._http_log_resp(resp) + + if resp.status_code >= 400: + _logger.debug( + "Request returned failure status: %s", + resp.status_code) + raise exceptions.from_response(resp, method, url) + + return resp + + @staticmethod + def concat_url(endpoint, url): + """Concatenate endpoint and final URL. + + E.g., "http://keystone/v2.0/" and "/tokens" are concatenated to + "http://keystone/v2.0/tokens". + + :param endpoint: the base URL + :param url: the final URL + """ + return "%s/%s" % (endpoint.rstrip("/"), url.strip("/")) + + def client_request(self, client, method, url, **kwargs): + """Send an http request using `client`'s endpoint and specified `url`. + + If request was rejected as unauthorized (possibly because the token is + expired), issue one authorization attempt and send the request once + again. + + :param client: instance of BaseClient descendant + :param method: method of HTTP request + :param url: URL of HTTP request + :param kwargs: any other parameter that can be passed to +' `HTTPClient.request` + """ + + filter_args = { + "endpoint_type": client.endpoint_type or self.endpoint_type, + "service_type": client.service_type, + } + token, endpoint = (self.cached_token, client.cached_endpoint) + just_authenticated = False + if not (token and endpoint): + try: + token, endpoint = self.auth_plugin.token_and_endpoint( + **filter_args) + except exceptions.EndpointException: + pass + if not (token and endpoint): + self.authenticate() + just_authenticated = True + token, endpoint = self.auth_plugin.token_and_endpoint( + **filter_args) + if not (token and endpoint): + raise exceptions.AuthorizationFailure( + "Cannot find endpoint or token for request") + + old_token_endpoint = (token, endpoint) + kwargs.setdefault("headers", {})["X-Auth-Token"] = token + self.cached_token = token + client.cached_endpoint = endpoint + # Perform the request once. If we get Unauthorized, then it + # might be because the auth token expired, so try to + # re-authenticate and try again. If it still fails, bail. + try: + return self.request( + method, self.concat_url(endpoint, url), **kwargs) + except exceptions.Unauthorized as unauth_ex: + if just_authenticated: + raise + self.cached_token = None + client.cached_endpoint = None + self.authenticate() + try: + token, endpoint = self.auth_plugin.token_and_endpoint( + **filter_args) + except exceptions.EndpointException: + raise unauth_ex + if (not (token and endpoint) or + old_token_endpoint == (token, endpoint)): + raise unauth_ex + self.cached_token = token + client.cached_endpoint = endpoint + kwargs["headers"]["X-Auth-Token"] = token + return self.request( + method, self.concat_url(endpoint, url), **kwargs) + + def add_client(self, base_client_instance): + """Add a new instance of :class:`BaseClient` descendant. + + `self` will store a reference to `base_client_instance`. + + Example: + + >>> def test_clients(): + ... from keystoneclient.auth import keystone + ... from openstack.common.apiclient import client + ... auth = keystone.KeystoneAuthPlugin( + ... username="user", password="pass", tenant_name="tenant", + ... auth_url="http://auth:5000/v2.0") + ... openstack_client = client.HTTPClient(auth) + ... # create nova client + ... from novaclient.v1_1 import client + ... client.Client(openstack_client) + ... # create keystone client + ... from keystoneclient.v2_0 import client + ... client.Client(openstack_client) + ... # use them + ... openstack_client.identity.tenants.list() + ... openstack_client.compute.servers.list() + """ + service_type = base_client_instance.service_type + if service_type and not hasattr(self, service_type): + setattr(self, service_type, base_client_instance) + + def authenticate(self): + self.auth_plugin.authenticate(self) + # Store the authentication results in the keyring for later requests + if self.keyring_saver: + self.keyring_saver.save(self) + + +class BaseClient(object): + """Top-level object to access the OpenStack API. + + This client uses :class:`HTTPClient` to send requests. :class:`HTTPClient` + will handle a bunch of issues such as authentication. + """ + + service_type = None + endpoint_type = None # "publicURL" will be used + cached_endpoint = None + + def __init__(self, http_client, extensions=None): + self.http_client = http_client + http_client.add_client(self) + + # Add in any extensions... + if extensions: + for extension in extensions: + if extension.manager_class: + setattr(self, extension.name, + extension.manager_class(self)) + + def client_request(self, method, url, **kwargs): + return self.http_client.client_request( + self, method, url, **kwargs) + + def head(self, url, **kwargs): + return self.client_request("HEAD", url, **kwargs) + + def get(self, url, **kwargs): + return self.client_request("GET", url, **kwargs) + + def post(self, url, **kwargs): + return self.client_request("POST", url, **kwargs) + + def put(self, url, **kwargs): + return self.client_request("PUT", url, **kwargs) + + def delete(self, url, **kwargs): + return self.client_request("DELETE", url, **kwargs) + + def patch(self, url, **kwargs): + return self.client_request("PATCH", url, **kwargs) + + @staticmethod + def get_class(api_name, version, version_map): + """Returns the client class for the requested API version + + :param api_name: the name of the API, e.g. 'compute', 'image', etc + :param version: the requested API version + :param version_map: a dict of client classes keyed by version + :rtype: a client class for the requested API version + """ + try: + client_path = version_map[str(version)] + except (KeyError, ValueError): + msg = "Invalid %s client version '%s'. must be one of: %s" % ( + (api_name, version, ', '.join(version_map.keys()))) + raise exceptions.UnsupportedVersion(msg) + + return importutils.import_class(client_path) diff --git a/openstack/common/apiclient/exceptions.py b/openstack/common/apiclient/exceptions.py index e70d37a..b03def7 100644 --- a/openstack/common/apiclient/exceptions.py +++ b/openstack/common/apiclient/exceptions.py @@ -121,7 +121,7 @@ class HttpError(ClientException): super(HttpError, self).__init__(formatted_string) -class HttpClientError(HttpError): +class HTTPClientError(HttpError): """Client-side HTTP error. Exception for cases in which the client seems to have erred. @@ -138,7 +138,7 @@ class HttpServerError(HttpError): message = "HTTP Server Error" -class BadRequest(HttpClientError): +class BadRequest(HTTPClientError): """HTTP 400 - Bad Request. The request cannot be fulfilled due to bad syntax. @@ -147,7 +147,7 @@ class BadRequest(HttpClientError): message = "Bad Request" -class Unauthorized(HttpClientError): +class Unauthorized(HTTPClientError): """HTTP 401 - Unauthorized. Similar to 403 Forbidden, but specifically for use when authentication @@ -157,7 +157,7 @@ class Unauthorized(HttpClientError): message = "Unauthorized" -class PaymentRequired(HttpClientError): +class PaymentRequired(HTTPClientError): """HTTP 402 - Payment Required. Reserved for future use. @@ -166,7 +166,7 @@ class PaymentRequired(HttpClientError): message = "Payment Required" -class Forbidden(HttpClientError): +class Forbidden(HTTPClientError): """HTTP 403 - Forbidden. The request was a valid request, but the server is refusing to respond @@ -176,7 +176,7 @@ class Forbidden(HttpClientError): message = "Forbidden" -class NotFound(HttpClientError): +class NotFound(HTTPClientError): """HTTP 404 - Not Found. The requested resource could not be found but may be available again @@ -186,7 +186,7 @@ class NotFound(HttpClientError): message = "Not Found" -class MethodNotAllowed(HttpClientError): +class MethodNotAllowed(HTTPClientError): """HTTP 405 - Method Not Allowed. A request was made of a resource using a request method not supported @@ -196,7 +196,7 @@ class MethodNotAllowed(HttpClientError): message = "Method Not Allowed" -class NotAcceptable(HttpClientError): +class NotAcceptable(HTTPClientError): """HTTP 406 - Not Acceptable. The requested resource is only capable of generating content not @@ -206,7 +206,7 @@ class NotAcceptable(HttpClientError): message = "Not Acceptable" -class ProxyAuthenticationRequired(HttpClientError): +class ProxyAuthenticationRequired(HTTPClientError): """HTTP 407 - Proxy Authentication Required. The client must first authenticate itself with the proxy. @@ -215,7 +215,7 @@ class ProxyAuthenticationRequired(HttpClientError): message = "Proxy Authentication Required" -class RequestTimeout(HttpClientError): +class RequestTimeout(HTTPClientError): """HTTP 408 - Request Timeout. The server timed out waiting for the request. @@ -224,7 +224,7 @@ class RequestTimeout(HttpClientError): message = "Request Timeout" -class Conflict(HttpClientError): +class Conflict(HTTPClientError): """HTTP 409 - Conflict. Indicates that the request could not be processed because of conflict @@ -234,7 +234,7 @@ class Conflict(HttpClientError): message = "Conflict" -class Gone(HttpClientError): +class Gone(HTTPClientError): """HTTP 410 - Gone. Indicates that the resource requested is no longer available and will @@ -244,7 +244,7 @@ class Gone(HttpClientError): message = "Gone" -class LengthRequired(HttpClientError): +class LengthRequired(HTTPClientError): """HTTP 411 - Length Required. The request did not specify the length of its content, which is @@ -254,7 +254,7 @@ class LengthRequired(HttpClientError): message = "Length Required" -class PreconditionFailed(HttpClientError): +class PreconditionFailed(HTTPClientError): """HTTP 412 - Precondition Failed. The server does not meet one of the preconditions that the requester @@ -264,7 +264,7 @@ class PreconditionFailed(HttpClientError): message = "Precondition Failed" -class RequestEntityTooLarge(HttpClientError): +class RequestEntityTooLarge(HTTPClientError): """HTTP 413 - Request Entity Too Large. The request is larger than the server is willing or able to process. @@ -281,7 +281,7 @@ class RequestEntityTooLarge(HttpClientError): super(RequestEntityTooLarge, self).__init__(*args, **kwargs) -class RequestUriTooLong(HttpClientError): +class RequestUriTooLong(HTTPClientError): """HTTP 414 - Request-URI Too Long. The URI provided was too long for the server to process. @@ -290,7 +290,7 @@ class RequestUriTooLong(HttpClientError): message = "Request-URI Too Long" -class UnsupportedMediaType(HttpClientError): +class UnsupportedMediaType(HTTPClientError): """HTTP 415 - Unsupported Media Type. The request entity has a media type which the server or resource does @@ -300,7 +300,7 @@ class UnsupportedMediaType(HttpClientError): message = "Unsupported Media Type" -class RequestedRangeNotSatisfiable(HttpClientError): +class RequestedRangeNotSatisfiable(HTTPClientError): """HTTP 416 - Requested Range Not Satisfiable. The client has asked for a portion of the file, but the server cannot @@ -310,7 +310,7 @@ class RequestedRangeNotSatisfiable(HttpClientError): message = "Requested Range Not Satisfiable" -class ExpectationFailed(HttpClientError): +class ExpectationFailed(HTTPClientError): """HTTP 417 - Expectation Failed. The server cannot meet the requirements of the Expect request-header field. @@ -319,7 +319,7 @@ class ExpectationFailed(HttpClientError): message = "Expectation Failed" -class UnprocessableEntity(HttpClientError): +class UnprocessableEntity(HTTPClientError): """HTTP 422 - Unprocessable Entity. The request was well-formed but was unable to be followed due to semantic @@ -440,7 +440,7 @@ def from_response(response, method, url): if 500 <= response.status_code < 600: cls = HttpServerError elif 400 <= response.status_code < 500: - cls = HttpClientError + cls = HTTPClientError else: cls = HttpError return cls(**kwargs) diff --git a/openstack/common/apiclient/fake_client.py b/openstack/common/apiclient/fake_client.py new file mode 100644 index 0000000..da125e2 --- /dev/null +++ b/openstack/common/apiclient/fake_client.py @@ -0,0 +1,172 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +A fake server that "responds" to API methods with pre-canned responses. + +All of these responses come from the spec, so if for some reason the spec's +wrong the tests might raise AssertionError. I've indicated in comments the +places where actual behavior differs from the spec. +""" + +# W0102: Dangerous default value %s as argument +# pylint: disable=W0102 + +import json +import urlparse + +import requests + +from openstack.common.apiclient import client + + +def assert_has_keys(dct, required=[], optional=[]): + for k in required: + try: + assert k in dct + except AssertionError: + extra_keys = set(dct.keys()).difference(set(required + optional)) + raise AssertionError("found unexpected keys: %s" % + list(extra_keys)) + + +class TestResponse(requests.Response): + """Wrap requests.Response and provide a convenient initialization. + """ + + def __init__(self, data): + super(TestResponse, self).__init__() + self._content_consumed = True + if isinstance(data, dict): + self.status_code = data.get('status_code', 200) + # Fake the text attribute to streamline Response creation + text = data.get('text', "") + if isinstance(text, (dict, list)): + self._content = json.dumps(text) + default_headers = { + "Content-Type": "application/json", + } + else: + self._content = text + default_headers = {} + self.headers = data.get('headers') or default_headers + else: + self.status_code = data + + def __eq__(self, other): + return (self.status_code == other.status_code and + self.headers == other.headers and + self._content == other._content) + + +class FakeHTTPClient(client.HTTPClient): + + def __init__(self, *args, **kwargs): + self.callstack = [] + self.fixtures = kwargs.pop("fixtures", None) or {} + if not args and not "auth_plugin" in kwargs: + args = (None, ) + super(FakeHTTPClient, self).__init__(*args, **kwargs) + + def assert_called(self, method, url, body=None, pos=-1): + """Assert than an API method was just called. + """ + expected = (method, url) + called = self.callstack[pos][0:2] + assert self.callstack, \ + "Expected %s %s but no calls were made." % expected + + assert expected == called, 'Expected %s %s; got %s %s' % \ + (expected + called) + + if body is not None: + if self.callstack[pos][3] != body: + raise AssertionError('%r != %r' % + (self.callstack[pos][3], body)) + + def assert_called_anytime(self, method, url, body=None): + """Assert than an API method was called anytime in the test. + """ + expected = (method, url) + + assert self.callstack, \ + "Expected %s %s but no calls were made." % expected + + found = False + entry = None + for entry in self.callstack: + if expected == entry[0:2]: + found = True + break + + assert found, 'Expected %s %s; got %s' % \ + (method, url, self.callstack) + if body is not None: + assert entry[3] == body, "%s != %s" % (entry[3], body) + + self.callstack = [] + + def clear_callstack(self): + self.callstack = [] + + def authenticate(self): + pass + + def client_request(self, client, method, url, **kwargs): + # Check that certain things are called correctly + if method in ["GET", "DELETE"]: + assert "json" not in kwargs + + # Note the call + self.callstack.append( + (method, + url, + kwargs.get("headers") or {}, + kwargs.get("json") or kwargs.get("data"))) + try: + fixture = self.fixtures[url][method] + except KeyError: + pass + else: + return TestResponse({"headers": fixture[0], + "text": fixture[1]}) + + # Call the method + args = urlparse.parse_qsl(urlparse.urlparse(url)[4]) + kwargs.update(args) + munged_url = url.rsplit('?', 1)[0] + munged_url = munged_url.strip('/').replace('/', '_').replace('.', '_') + munged_url = munged_url.replace('-', '_') + + callback = "%s_%s" % (method.lower(), munged_url) + + if not hasattr(self, callback): + raise AssertionError('Called unknown API method: %s %s, ' + 'expected fakes method name: %s' % + (method, url, callback)) + + resp = getattr(self, callback)(**kwargs) + if len(resp) == 3: + status, headers, body = resp + else: + status, body = resp + headers = {} + return TestResponse({ + "status_code": status, + "text": body, + "headers": headers, + }) diff --git a/openstack/common/config/generator.py b/openstack/common/config/generator.py index 3d2809b..1ae6c4d 100644 --- a/openstack/common/config/generator.py +++ b/openstack/common/config/generator.py @@ -50,7 +50,6 @@ OPT_TYPES = { MULTISTROPT: 'multi valued', } -OPTION_COUNT = 0 OPTION_REGEX = re.compile(r"(%s)" % "|".join([STROPT, BOOLOPT, INTOPT, FLOATOPT, LISTOPT, MULTISTROPT])) @@ -97,8 +96,6 @@ def generate(srcfiles): for group, opts in opts_by_group.items(): print_group_opts(group, opts) - print("# Total option count: %d" % OPTION_COUNT) - def _import_module(mod_str): try: @@ -163,9 +160,7 @@ def _list_opts(obj): def print_group_opts(group, opts_by_module): print("[%s]" % group) print('') - global OPTION_COUNT for mod, opts in opts_by_module: - OPTION_COUNT += len(opts) print('#') print('# Options defined in %s' % mod) print('#') @@ -186,24 +181,24 @@ def _get_my_ip(): return None -def _sanitize_default(s): +def _sanitize_default(name, value): """Set up a reasonably sensible default for pybasedir, my_ip and host.""" - if s.startswith(sys.prefix): + if value.startswith(sys.prefix): # NOTE(jd) Don't use os.path.join, because it is likely to think the # second part is an absolute pathname and therefore drop the first # part. - s = os.path.normpath("/usr/" + s[len(sys.prefix):]) - elif s.startswith(BASEDIR): - return s.replace(BASEDIR, '/usr/lib/python/site-packages') - elif BASEDIR in s: - return s.replace(BASEDIR, '') - elif s == _get_my_ip(): + value = os.path.normpath("/usr/" + value[len(sys.prefix):]) + elif value.startswith(BASEDIR): + return value.replace(BASEDIR, '/usr/lib/python/site-packages') + elif BASEDIR in value: + return value.replace(BASEDIR, '') + elif value == _get_my_ip(): return '10.0.0.1' - elif s == socket.gethostname(): + elif value == socket.gethostname() and 'host' in name: return 'oslo' - elif s.strip() != s: - return '"%s"' % s - return s + elif value.strip() != value: + return '"%s"' % value + return value def _print_opt(opt): @@ -224,7 +219,8 @@ def _print_opt(opt): print('#%s=<None>' % opt_name) elif opt_type == STROPT: assert(isinstance(opt_default, basestring)) - print('#%s=%s' % (opt_name, _sanitize_default(opt_default))) + print('#%s=%s' % (opt_name, _sanitize_default(opt_name, + opt_default))) elif opt_type == BOOLOPT: assert(isinstance(opt_default, bool)) print('#%s=%s' % (opt_name, str(opt_default).lower())) diff --git a/openstack/common/db/exception.py b/openstack/common/db/exception.py index 69905da..3627de2 100644 --- a/openstack/common/db/exception.py +++ b/openstack/common/db/exception.py @@ -43,3 +43,9 @@ class DBDeadlock(DBError): class DBInvalidUnicodeParameter(Exception): message = _("Invalid Parameter: " "Unicode is not supported by the current database.") + + +class DbMigrationError(DBError): + """Wraps migration specific exception.""" + def __init__(self, message=None): + super(DbMigrationError, self).__init__(str(message)) diff --git a/openstack/common/db/sqlalchemy/migration.py b/openstack/common/db/sqlalchemy/migration.py index e643d8e..4154359 100644 --- a/openstack/common/db/sqlalchemy/migration.py +++ b/openstack/common/db/sqlalchemy/migration.py @@ -38,12 +38,53 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE +import distutils.version as dist_version +import os import re +import migrate from migrate.changeset import ansisql from migrate.changeset.databases import sqlite +from migrate.versioning import util as migrate_util +import sqlalchemy from sqlalchemy.schema import UniqueConstraint +from openstack.common.db import exception +from openstack.common.db.sqlalchemy import session as db_session +from openstack.common.gettextutils import _ # noqa + + +@migrate_util.decorator +def patched_with_engine(f, *a, **kw): + url = a[0] + engine = migrate_util.construct_engine(url, **kw) + + try: + kw['engine'] = engine + return f(*a, **kw) + finally: + if isinstance(engine, migrate_util.Engine) and engine is not url: + migrate_util.log.debug('Disposing SQLAlchemy engine %s', engine) + engine.dispose() + + +# TODO(jkoelker) When migrate 0.7.3 is released and nova depends +# on that version or higher, this can be removed +MIN_PKG_VERSION = dist_version.StrictVersion('0.7.3') +if (not hasattr(migrate, '__version__') or + dist_version.StrictVersion(migrate.__version__) < MIN_PKG_VERSION): + migrate_util.with_engine = patched_with_engine + + +# NOTE(jkoelker) Delay importing migrate until we are patched +from migrate import exceptions as versioning_exceptions +from migrate.versioning import api as versioning_api +from migrate.versioning.repository import Repository + +_REPOSITORY = None + +get_engine = db_session.get_engine + def _get_unique_constraints(self, table): """Retrieve information about existing unique constraints of the table @@ -157,3 +198,81 @@ def patch_migrate(): _visit_migrate_unique_constraint constraint_cls.__bases__ = (ansisql.ANSIColumnDropper, sqlite.SQLiteConstraintGenerator) + + +def db_sync(abs_path, version=None, init_version=0): + """Upgrade or downgrade a database. + + Function runs the upgrade() or downgrade() functions in change scripts. + + :param abs_path: Absolute path to migrate repository. + :param version: Database will upgrade/downgrade until this version. + If None - database will update to the latest + available version. + :param init_version: Initial database version + """ + if version is not None: + try: + version = int(version) + except ValueError: + raise exception.DbMigrationError( + message=_("version should be an integer")) + + current_version = db_version(abs_path, init_version) + repository = _find_migrate_repo(abs_path) + if version is None or version > current_version: + return versioning_api.upgrade(get_engine(), repository, version) + else: + return versioning_api.downgrade(get_engine(), repository, + version) + + +def db_version(abs_path, init_version): + """Show the current version of the repository. + + :param abs_path: Absolute path to migrate repository + :param version: Initial database version + """ + repository = _find_migrate_repo(abs_path) + try: + return versioning_api.db_version(get_engine(), repository) + except versioning_exceptions.DatabaseNotControlledError: + meta = sqlalchemy.MetaData() + engine = get_engine() + meta.reflect(bind=engine) + tables = meta.tables + if len(tables) == 0: + db_version_control(abs_path, init_version) + return versioning_api.db_version(get_engine(), repository) + else: + # Some pre-Essex DB's may not be version controlled. + # Require them to upgrade using Essex first. + raise exception.DbMigrationError( + message=_("Upgrade DB using Essex release first.")) + + +def db_version_control(abs_path, version=None): + """Mark a database as under this repository's version control. + + Once a database is under version control, schema changes should + only be done via change scripts in this repository. + + :param abs_path: Absolute path to migrate repository + :param version: Initial database version + """ + repository = _find_migrate_repo(abs_path) + versioning_api.version_control(get_engine(), repository, version) + return version + + +def _find_migrate_repo(abs_path): + """Get the project's change script repository + + :param abs_path: Absolute path to migrate repository + """ + global _REPOSITORY + if not os.path.exists(abs_path): + raise exception.DbMigrationError("Path %s not found" % abs_path) + if _REPOSITORY is None: + _REPOSITORY = Repository(abs_path) + return _REPOSITORY diff --git a/openstack/common/db/sqlalchemy/session.py b/openstack/common/db/sqlalchemy/session.py index 59bcb90..236136e 100644 --- a/openstack/common/db/sqlalchemy/session.py +++ b/openstack/common/db/sqlalchemy/session.py @@ -279,13 +279,11 @@ database_opts = [ deprecated_opts=[cfg.DeprecatedOpt('sql_connection', group='DEFAULT'), cfg.DeprecatedOpt('sql_connection', - group='DATABASE')], - secret=True), + group='DATABASE')]), cfg.StrOpt('slave_connection', default='', help='The SQLAlchemy connection string used to connect to the ' - 'slave database', - secret=True), + 'slave database'), cfg.IntOpt('idle_timeout', default=3600, deprecated_opts=[cfg.DeprecatedOpt('sql_idle_timeout', @@ -478,6 +476,11 @@ def _raise_if_duplicate_entry_error(integrity_error, engine_name): if engine_name not in ["mysql", "sqlite", "postgresql"]: return + # FIXME(johannes): The usage of the .message attribute has been + # deprecated since Python 2.6. However, the exceptions raised by + # SQLAlchemy can differ when using unicode() and accessing .message. + # An audit across all three supported engines will be necessary to + # ensure there are no regressions. m = _DUP_KEY_RE_DB[engine_name].match(integrity_error.message) if not m: return @@ -510,6 +513,11 @@ def _raise_if_deadlock_error(operational_error, engine_name): re = _DEADLOCK_RE_DB.get(engine_name) if re is None: return + # FIXME(johannes): The usage of the .message attribute has been + # deprecated since Python 2.6. However, the exceptions raised by + # SQLAlchemy can differ when using unicode() and accessing .message. + # An audit across all three supported engines will be necessary to + # ensure there are no regressions. m = re.match(operational_error.message) if not m: return diff --git a/openstack/common/db/sqlalchemy/utils.py b/openstack/common/db/sqlalchemy/utils.py index 3ff7bdb..102f0e5 100755..100644 --- a/openstack/common/db/sqlalchemy/utils.py +++ b/openstack/common/db/sqlalchemy/utils.py @@ -18,6 +18,9 @@ # License for the specific language governing permissions and limitations # under the License. +import re + +from migrate.changeset import UniqueConstraint import sqlalchemy from sqlalchemy import Boolean from sqlalchemy import CheckConstraint @@ -37,13 +40,21 @@ from sqlalchemy.types import NullType from openstack.common.gettextutils import _ # noqa -from openstack.common import exception from openstack.common import log as logging from openstack.common import timeutils LOG = logging.getLogger(__name__) +_DBURL_REGEX = re.compile(r"[^:]+://([^:]+):([^@]+)@.+") + + +def sanitize_db_url(url): + match = _DBURL_REGEX.match(url) + if match: + return '%s****:****%s' % (url[:match.start(1)], url[match.end(2):]) + return url + class InvalidSortKey(Exception): message = _("Sort key supplied was not valid.") @@ -174,6 +185,10 @@ def visit_insert_from_select(element, compiler, **kw): compiler.process(element.select)) +class ColumnError(Exception): + """Error raised when no column or an invalid column is found.""" + + def _get_not_supported_column(col_name_col_instance, column_name): try: column = col_name_col_instance[column_name] @@ -181,16 +196,53 @@ def _get_not_supported_column(col_name_col_instance, column_name): msg = _("Please specify column %s in col_name_col_instance " "param. It is required because column has unsupported " "type by sqlite).") - raise exception.OpenstackException(message=msg % column_name) + raise ColumnError(msg % column_name) if not isinstance(column, Column): msg = _("col_name_col_instance param has wrong type of " "column instance for column %s It should be instance " "of sqlalchemy.Column.") - raise exception.OpenstackException(message=msg % column_name) + raise ColumnError(msg % column_name) return column +def drop_unique_constraint(migrate_engine, table_name, uc_name, *columns, + **col_name_col_instance): + """Drop unique constraint from table. + + This method drops UC from table and works for mysql, postgresql and sqlite. + In mysql and postgresql we are able to use "alter table" construction. + Sqlalchemy doesn't support some sqlite column types and replaces their + type with NullType in metadata. We process these columns and replace + NullType with the correct column type. + + :param migrate_engine: sqlalchemy engine + :param table_name: name of table that contains uniq constraint. + :param uc_name: name of uniq constraint that will be dropped. + :param columns: columns that are in uniq constraint. + :param col_name_col_instance: contains pair column_name=column_instance. + column_instance is instance of Column. These params + are required only for columns that have unsupported + types by sqlite. For example BigInteger. + """ + + meta = MetaData() + meta.bind = migrate_engine + t = Table(table_name, meta, autoload=True) + + if migrate_engine.name == "sqlite": + override_cols = [ + _get_not_supported_column(col_name_col_instance, col.name) + for col in t.columns + if isinstance(col.type, NullType) + ] + for col in override_cols: + t.columns.replace(col) + + uc = UniqueConstraint(*columns, table=t, name=uc_name) + uc.drop() + + def drop_old_duplicate_entries_from_table(migrate_engine, table_name, use_soft_delete, *uc_column_names): """Drop all old rows having the same values for columns in uc_columns. @@ -248,8 +300,7 @@ def _get_default_deleted_value(table): return 0 if isinstance(table.c.id.type, String): return "" - raise exception.OpenstackException( - message=_("Unsupported id columns type")) + raise ColumnError(_("Unsupported id columns type")) def _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes): diff --git a/openstack/common/deprecated/wsgi.py b/openstack/common/deprecated/wsgi.py index a9530b3..aff4b92 100644 --- a/openstack/common/deprecated/wsgi.py +++ b/openstack/common/deprecated/wsgi.py @@ -35,7 +35,6 @@ import webob.exc from xml.dom import minidom from xml.parsers import expat -from openstack.common import exception from openstack.common.gettextutils import _ # noqa from openstack.common import jsonutils from openstack.common import log as logging @@ -59,6 +58,18 @@ CONF.register_opts(socket_opts) LOG = logging.getLogger(__name__) +class MalformedRequestBody(Exception): + def __init__(self, reason): + super(MalformedRequestBody, self).__init__( + "Malformed message body: %s", reason) + + +class InvalidContentType(Exception): + def __init__(self, content_type): + super(InvalidContentType, self).__init__( + "Invalid content type %s", content_type) + + def run_server(application, port, **kwargs): """Run a WSGI server with the given application.""" sock = eventlet.listen(('0.0.0.0', port)) @@ -255,7 +266,7 @@ class Request(webob.Request): self.default_request_content_types) if content_type not in allowed_content_types: - raise exception.InvalidContentType(content_type=content_type) + raise InvalidContentType(content_type=content_type) return content_type @@ -294,10 +305,10 @@ class Resource(object): try: action, action_args, accept = self.deserialize_request(request) - except exception.InvalidContentType: + except InvalidContentType: msg = _("Unsupported Content-Type") return webob.exc.HTTPUnsupportedMediaType(explanation=msg) - except exception.MalformedRequestBody: + except MalformedRequestBody: msg = _("Malformed request body") return webob.exc.HTTPBadRequest(explanation=msg) @@ -530,7 +541,7 @@ class ResponseSerializer(object): try: return self.body_serializers[content_type] except (KeyError, TypeError): - raise exception.InvalidContentType(content_type=content_type) + raise InvalidContentType(content_type=content_type) class RequestHeadersDeserializer(ActionDispatcher): @@ -589,7 +600,7 @@ class RequestDeserializer(object): try: content_type = request.get_content_type() - except exception.InvalidContentType: + except InvalidContentType: LOG.debug(_("Unrecognized Content-Type provided in request")) raise @@ -599,7 +610,7 @@ class RequestDeserializer(object): try: deserializer = self.get_body_deserializer(content_type) - except exception.InvalidContentType: + except InvalidContentType: LOG.debug(_("Unable to deserialize body as provided Content-Type")) raise @@ -609,7 +620,7 @@ class RequestDeserializer(object): try: return self.body_deserializers[content_type] except (KeyError, TypeError): - raise exception.InvalidContentType(content_type=content_type) + raise InvalidContentType(content_type=content_type) def get_expected_content_type(self, request): return request.best_match_content_type(self.supported_content_types) @@ -651,7 +662,7 @@ class JSONDeserializer(TextDeserializer): return jsonutils.loads(datastring) except ValueError: msg = _("cannot understand JSON") - raise exception.MalformedRequestBody(reason=msg) + raise MalformedRequestBody(reason=msg) def default(self, datastring): return {'body': self._from_json(datastring)} @@ -676,7 +687,7 @@ class XMLDeserializer(TextDeserializer): return {node.nodeName: self._from_xml_node(node, plurals)} except expat.ExpatError: msg = _("cannot understand XML") - raise exception.MalformedRequestBody(reason=msg) + raise MalformedRequestBody(reason=msg) def _from_xml_node(self, node, listnames): """Convert a minidom node to a simple Python type. diff --git a/openstack/common/exception.py b/openstack/common/exception.py deleted file mode 100644 index 13c3bff..0000000 --- a/openstack/common/exception.py +++ /dev/null @@ -1,139 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2011 OpenStack Foundation. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -""" -Exceptions common to OpenStack projects -""" - -import logging - -from openstack.common.gettextutils import _ # noqa - -_FATAL_EXCEPTION_FORMAT_ERRORS = False - - -class Error(Exception): - def __init__(self, message=None): - super(Error, self).__init__(message) - - -class ApiError(Error): - def __init__(self, message='Unknown', code='Unknown'): - self.api_message = message - self.code = code - super(ApiError, self).__init__('%s: %s' % (code, message)) - - -class NotFound(Error): - pass - - -class UnknownScheme(Error): - - msg_fmt = "Unknown scheme '%s' found in URI" - - def __init__(self, scheme): - msg = self.msg_fmt % scheme - super(UnknownScheme, self).__init__(msg) - - -class BadStoreUri(Error): - - msg_fmt = "The Store URI %s was malformed. Reason: %s" - - def __init__(self, uri, reason): - msg = self.msg_fmt % (uri, reason) - super(BadStoreUri, self).__init__(msg) - - -class Duplicate(Error): - pass - - -class NotAuthorized(Error): - pass - - -class NotEmpty(Error): - pass - - -class Invalid(Error): - pass - - -class BadInputError(Exception): - """Error resulting from a client sending bad input to a server""" - pass - - -class MissingArgumentError(Error): - pass - - -class DatabaseMigrationError(Error): - pass - - -class ClientConnectionError(Exception): - """Error resulting from a client connecting to a server""" - pass - - -def wrap_exception(f): - def _wrap(*args, **kw): - try: - return f(*args, **kw) - except Exception as e: - if not isinstance(e, Error): - logging.exception(_('Uncaught exception')) - raise Error(str(e)) - raise - _wrap.func_name = f.func_name - return _wrap - - -class OpenstackException(Exception): - """Base Exception class. - - To correctly use this class, inherit from it and define - a 'msg_fmt' property. That message will get printf'd - with the keyword arguments provided to the constructor. - """ - msg_fmt = "An unknown exception occurred" - - def __init__(self, **kwargs): - try: - self._error_string = self.msg_fmt % kwargs - - except Exception: - if _FATAL_EXCEPTION_FORMAT_ERRORS: - raise - else: - # at least get the core message out if something happened - self._error_string = self.msg_fmt - - def __str__(self): - return self._error_string - - -class MalformedRequestBody(OpenstackException): - msg_fmt = "Malformed message body: %(reason)s" - - -class InvalidContentType(OpenstackException): - msg_fmt = "Invalid content type %(content_type)s" diff --git a/openstack/common/excutils.py b/openstack/common/excutils.py index abe6f87..664b2e4 100644 --- a/openstack/common/excutils.py +++ b/openstack/common/excutils.py @@ -77,7 +77,8 @@ def forever_retry_uncaught_exceptions(infunc): try: return infunc(*args, **kwargs) except Exception as exc: - if exc.message == last_exc_message: + this_exc_message = unicode(exc) + if this_exc_message == last_exc_message: exc_count += 1 else: exc_count = 1 @@ -85,12 +86,12 @@ def forever_retry_uncaught_exceptions(infunc): # the exception message changes cur_time = int(time.time()) if (cur_time - last_log_time > 60 or - exc.message != last_exc_message): + this_exc_message != last_exc_message): logging.exception( _('Unexpected exception occurred %d time(s)... ' 'retrying.') % exc_count) last_log_time = cur_time - last_exc_message = exc.message + last_exc_message = this_exc_message exc_count = 0 # This should be a very rare event. In case it isn't, do # a sleep. diff --git a/openstack/common/fixture/config.py b/openstack/common/fixture/config.py new file mode 100644 index 0000000..cf52a66 --- /dev/null +++ b/openstack/common/fixture/config.py @@ -0,0 +1,45 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 +# +# Copyright 2013 Mirantis, Inc. +# Copyright 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import fixtures +from oslo.config import cfg + + +class Config(fixtures.Fixture): + """Override some configuration values. + + The keyword arguments are the names of configuration options to + override and their values. + + If a group argument is supplied, the overrides are applied to + the specified configuration option group. + + All overrides are automatically cleared at the end of the current + test by the reset() method, which is registred by addCleanup(). + """ + + def __init__(self, conf=cfg.CONF): + self.conf = conf + + def setUp(self): + super(Config, self).setUp() + self.addCleanup(self.conf.reset) + + def config(self, **kw): + group = kw.pop('group', None) + for k, v in kw.iteritems(): + self.conf.set_override(k, v, group) diff --git a/openstack/common/gettextutils.py b/openstack/common/gettextutils.py index 635a434..321fdd0 100644 --- a/openstack/common/gettextutils.py +++ b/openstack/common/gettextutils.py @@ -1,8 +1,8 @@ # vim: tabstop=4 shiftwidth=4 softtabstop=4 # Copyright 2012 Red Hat, Inc. -# All Rights Reserved. # Copyright 2013 IBM Corp. +# All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain @@ -31,17 +31,36 @@ import os import re import UserString +from babel import localedata import six _localedir = os.environ.get('oslo'.upper() + '_LOCALEDIR') _t = gettext.translation('oslo', localedir=_localedir, fallback=True) +_AVAILABLE_LANGUAGES = [] +USE_LAZY = False + + +def enable_lazy(): + """Convenience function for configuring _() to use lazy gettext + + Call this at the start of execution to enable the gettextutils._ + function to use lazy gettext functionality. This is useful if + your project is importing _ directly instead of using the + gettextutils.install() way of importing the _ function. + """ + global USE_LAZY + USE_LAZY = True + def _(msg): - return _t.ugettext(msg) + if USE_LAZY: + return Message(msg, 'oslo') + else: + return _t.ugettext(msg) -def install(domain): +def install(domain, lazy=False): """Install a _() function using the given translation domain. Given a translation domain, install a _() function using gettext's @@ -51,41 +70,45 @@ def install(domain): overriding the default localedir (e.g. /usr/share/locale) using a translation-domain-specific environment variable (e.g. NOVA_LOCALEDIR). - """ - gettext.install(domain, - localedir=os.environ.get(domain.upper() + '_LOCALEDIR'), - unicode=True) - - -""" -Lazy gettext functionality. - -The following is an attempt to introduce a deferred way -to do translations on messages in OpenStack. We attempt to -override the standard _() function and % (format string) operation -to build Message objects that can later be translated when we have -more information. Also included is an example LogHandler that -translates Messages to an associated locale, effectively allowing -many logs, each with their own locale. -""" - - -def get_lazy_gettext(domain): - """Assemble and return a lazy gettext function for a given domain. - Factory method for a project/module to get a lazy gettext function - for its own translation domain (i.e. nova, glance, cinder, etc.) + :param domain: the translation domain + :param lazy: indicates whether or not to install the lazy _() function. + The lazy _() introduces a way to do deferred translation + of messages by installing a _ that builds Message objects, + instead of strings, which can then be lazily translated into + any available locale. """ - - def _lazy_gettext(msg): - """Create and return a Message object. - - Message encapsulates a string so that we can translate it later when - needed. - """ - return Message(msg, domain) - - return _lazy_gettext + if lazy: + # NOTE(mrodden): Lazy gettext functionality. + # + # The following introduces a deferred way to do translations on + # messages in OpenStack. We override the standard _() function + # and % (format string) operation to build Message objects that can + # later be translated when we have more information. + # + # Also included below is an example LocaleHandler that translates + # Messages to an associated locale, effectively allowing many logs, + # each with their own locale. + + def _lazy_gettext(msg): + """Create and return a Message object. + + Lazy gettext function for a given domain, it is a factory method + for a project/module to get a lazy gettext function for its own + translation domain (i.e. nova, glance, cinder, etc.) + + Message encapsulates a string so that we can translate + it later when needed. + """ + return Message(msg, domain) + + import __builtin__ + __builtin__.__dict__['_'] = _lazy_gettext + else: + localedir = '%s_LOCALEDIR' % domain.upper() + gettext.install(domain, + localedir=os.environ.get(localedir), + unicode=True) class Message(UserString.UserString, object): @@ -130,7 +153,7 @@ class Message(UserString.UserString, object): # look for %(blah) fields in string; # ignore %% and deal with the # case where % is first character on the line - keys = re.findall('(?:[^%]|^)%\((\w*)\)[a-z]', full_msg) + keys = re.findall('(?:[^%]|^)?%\((\w*)\)[a-z]', full_msg) # if we don't find any %(blah) blocks but have a %s if not keys and re.findall('(?:[^%]|^)%[a-z]', full_msg): @@ -232,6 +255,45 @@ class Message(UserString.UserString, object): return UserString.UserString.__getattribute__(self, name) +def get_available_languages(domain): + """Lists the available languages for the given translation domain. + + :param domain: the domain to get languages for + """ + if _AVAILABLE_LANGUAGES: + return _AVAILABLE_LANGUAGES + + localedir = '%s_LOCALEDIR' % domain.upper() + find = lambda x: gettext.find(domain, + localedir=os.environ.get(localedir), + languages=[x]) + + # NOTE(mrodden): en_US should always be available (and first in case + # order matters) since our in-line message strings are en_US + _AVAILABLE_LANGUAGES.append('en_US') + # NOTE(luisg): Babel <1.0 used a function called list(), which was + # renamed to locale_identifiers() in >=1.0, the requirements master list + # requires >=0.9.6, uncapped, so defensively work with both. We can remove + # this check when the master list updates to >=1.0, and all projects udpate + list_identifiers = (getattr(localedata, 'list', None) or + getattr(localedata, 'locale_identifiers')) + locale_identifiers = list_identifiers() + for i in locale_identifiers: + if find(i) is not None: + _AVAILABLE_LANGUAGES.append(i) + return _AVAILABLE_LANGUAGES + + +def get_localized_message(message, user_locale): + """Gets a localized version of the given message in the given locale.""" + if (isinstance(message, Message)): + if user_locale: + message.locale = user_locale + return unicode(message) + else: + return message + + class LocaleHandler(logging.Handler): """Handler that can have a locale associated to translate Messages. diff --git a/openstack/common/local.py b/openstack/common/local.py index f1bfc82..e82f17d 100644 --- a/openstack/common/local.py +++ b/openstack/common/local.py @@ -15,16 +15,15 @@ # License for the specific language governing permissions and limitations # under the License. -"""Greenthread local storage of variables using weak references""" +"""Local storage of variables using weak references""" +import threading import weakref -from eventlet import corolocal - -class WeakLocal(corolocal.local): +class WeakLocal(threading.local): def __getattribute__(self, attr): - rval = corolocal.local.__getattribute__(self, attr) + rval = super(WeakLocal, self).__getattribute__(attr) if rval: # NOTE(mikal): this bit is confusing. What is stored is a weak # reference, not the value itself. We therefore need to lookup @@ -34,7 +33,7 @@ class WeakLocal(corolocal.local): def __setattr__(self, attr, value): value = weakref.ref(value) - return corolocal.local.__setattr__(self, attr, value) + return super(WeakLocal, self).__setattr__(attr, value) # NOTE(mikal): the name "store" should be deprecated in the future @@ -45,4 +44,4 @@ store = WeakLocal() # "strong" store will hold a reference to the object so that it never falls out # of scope. weak_store = WeakLocal() -strong_store = corolocal.local +strong_store = threading.local() diff --git a/openstack/common/log.py b/openstack/common/log.py index 465886b..c7e729c 100644 --- a/openstack/common/log.py +++ b/openstack/common/log.py @@ -29,8 +29,6 @@ It also allows setting of formatting information through conf. """ -import ConfigParser -import cStringIO import inspect import itertools import logging @@ -41,6 +39,7 @@ import sys import traceback from oslo.config import cfg +from six import moves from openstack.common.gettextutils import _ # noqa from openstack.common import importutils @@ -348,7 +347,7 @@ class LogConfigError(Exception): def _load_log_config(log_config): try: logging.config.fileConfig(log_config) - except ConfigParser.Error as exc: + except moves.configparser.Error as exc: raise LogConfigError(log_config, str(exc)) @@ -521,7 +520,7 @@ class ContextFormatter(logging.Formatter): if not record: return logging.Formatter.formatException(self, exc_info) - stringbuffer = cStringIO.StringIO() + stringbuffer = moves.StringIO() traceback.print_exception(exc_info[0], exc_info[1], exc_info[2], None, stringbuffer) lines = stringbuffer.getvalue().split('\n') diff --git a/openstack/common/middleware/base.py b/openstack/common/middleware/base.py index 7236731..2099549 100644 --- a/openstack/common/middleware/base.py +++ b/openstack/common/middleware/base.py @@ -28,11 +28,7 @@ class Middleware(object): @classmethod def factory(cls, global_conf, **local_conf): """Factory method for paste.deploy.""" - - def filter(app): - return cls(app) - - return filter + return cls def __init__(self, application): self.application = application diff --git a/openstack/common/middleware/sizelimit.py b/openstack/common/middleware/sizelimit.py index ecbdde1..23ba9b6 100644 --- a/openstack/common/middleware/sizelimit.py +++ b/openstack/common/middleware/sizelimit.py @@ -71,9 +71,6 @@ class LimitingReader(object): class RequestBodySizeLimiter(base.Middleware): """Limit the size of incoming requests.""" - def __init__(self, *args, **kwargs): - super(RequestBodySizeLimiter, self).__init__(*args, **kwargs) - @webob.dec.wsgify(RequestClass=wsgi.Request) def __call__(self, req): if req.content_length > CONF.max_request_body_size: diff --git a/openstack/common/notifier/log_notifier.py b/openstack/common/notifier/log_notifier.py index d3ef0ae..96072ed 100644 --- a/openstack/common/notifier/log_notifier.py +++ b/openstack/common/notifier/log_notifier.py @@ -25,7 +25,7 @@ CONF = cfg.CONF def notify(_context, message): """Notifies the recipient of the desired event given the model. - Log notifications using openstack's default logging system. + Log notifications using OpenStack's default logging system. """ priority = message.get('priority', diff --git a/openstack/common/notifier/rpc_notifier.py b/openstack/common/notifier/rpc_notifier.py index 6bfc333..db47a8a 100644 --- a/openstack/common/notifier/rpc_notifier.py +++ b/openstack/common/notifier/rpc_notifier.py @@ -24,7 +24,7 @@ LOG = logging.getLogger(__name__) notification_topic_opt = cfg.ListOpt( 'notification_topics', default=['notifications', ], - help='AMQP topic used for openstack notifications') + help='AMQP topic used for OpenStack notifications') CONF = cfg.CONF CONF.register_opt(notification_topic_opt) diff --git a/openstack/common/notifier/rpc_notifier2.py b/openstack/common/notifier/rpc_notifier2.py index 55dd780..505a94e 100644 --- a/openstack/common/notifier/rpc_notifier2.py +++ b/openstack/common/notifier/rpc_notifier2.py @@ -26,7 +26,7 @@ LOG = logging.getLogger(__name__) notification_topic_opt = cfg.ListOpt( 'topics', default=['notifications', ], - help='AMQP topic(s) used for openstack notifications') + help='AMQP topic(s) used for OpenStack notifications') opt_group = cfg.OptGroup(name='rpc_notifier2', title='Options for rpc_notifier2') diff --git a/openstack/common/policy.py b/openstack/common/policy.py index 00531e5..ffb8668 100644 --- a/openstack/common/policy.py +++ b/openstack/common/policy.py @@ -115,12 +115,18 @@ class Rules(dict): def __missing__(self, key): """Implements the default rule handling.""" + if isinstance(self.default_rule, dict): + raise KeyError(key) + # If the default rule isn't actually defined, do something # reasonably intelligent if not self.default_rule or self.default_rule not in self: raise KeyError(key) - return self[self.default_rule] + if isinstance(self.default_rule, BaseCheck): + return self.default_rule + elif isinstance(self.default_rule, six.string_types): + return self[self.default_rule] def __str__(self): """Dumps a string representation of the rules.""" @@ -153,7 +159,7 @@ class Enforcer(object): """ def __init__(self, policy_file=None, rules=None, default_rule=None): - self.rules = Rules(rules) + self.rules = Rules(rules, default_rule) self.default_rule = default_rule or CONF.policy_default_rule self.policy_path = None @@ -172,13 +178,14 @@ class Enforcer(object): "got %s instead") % type(rules)) if overwrite: - self.rules = Rules(rules) + self.rules = Rules(rules, self.default_rule) else: - self.update(rules) + self.rules.update(rules) def clear(self): """Clears Enforcer rules, policy's cache and policy's path.""" self.set_rules({}) + self.default_rule = None self.policy_path = None def load_rules(self, force_reload=False): @@ -194,8 +201,7 @@ class Enforcer(object): reloaded, data = fileutils.read_cached_file(self.policy_path, force_reload=force_reload) - - if reloaded: + if reloaded or not self.rules: rules = Rules.load_json(data, self.default_rule) self.set_rules(rules) LOG.debug(_("Rules successfully reloaded")) @@ -215,7 +221,7 @@ class Enforcer(object): if policy_file: return policy_file - raise cfg.ConfigFilesNotFoundError(path=CONF.policy_file) + raise cfg.ConfigFilesNotFoundError((CONF.policy_file,)) def enforce(self, rule, target, creds, do_raise=False, exc=None, *args, **kwargs): @@ -398,7 +404,7 @@ class AndCheck(BaseCheck): """ for rule in self.rules: - if not rule(target, cred): + if not rule(target, cred, enforcer): return False return True @@ -441,7 +447,7 @@ class OrCheck(BaseCheck): """ for rule in self.rules: - if rule(target, cred): + if rule(target, cred, enforcer): return True return False diff --git a/openstack/common/processutils.py b/openstack/common/processutils.py index 13f6222..06fe411 100644 --- a/openstack/common/processutils.py +++ b/openstack/common/processutils.py @@ -19,6 +19,7 @@ System-level utilities and helper functions. """ +import logging as stdlib_logging import os import random import shlex @@ -102,6 +103,9 @@ def execute(*cmd, **kwargs): :param shell: whether or not there should be a shell used to execute this command. Defaults to false. :type shell: boolean + :param loglevel: log level for execute commands. + :type loglevel: int. (Should be stdlib_logging.DEBUG or + stdlib_logging.INFO) :returns: (stdout, stderr) from process execution :raises: :class:`UnknownArgumentError` on receiving unknown arguments @@ -116,6 +120,7 @@ def execute(*cmd, **kwargs): run_as_root = kwargs.pop('run_as_root', False) root_helper = kwargs.pop('root_helper', '') shell = kwargs.pop('shell', False) + loglevel = kwargs.pop('loglevel', stdlib_logging.DEBUG) if isinstance(check_exit_code, bool): ignore_exit_code = not check_exit_code @@ -139,7 +144,7 @@ def execute(*cmd, **kwargs): while attempts > 0: attempts -= 1 try: - LOG.debug(_('Running cmd (subprocess): %s'), ' '.join(cmd)) + LOG.log(loglevel, _('Running cmd (subprocess): %s'), ' '.join(cmd)) _PIPE = subprocess.PIPE # pylint: disable=E1101 if os.name == 'nt': @@ -164,7 +169,7 @@ def execute(*cmd, **kwargs): obj.stdin.close() # pylint: disable=E1101 _returncode = obj.returncode # pylint: disable=E1101 if _returncode: - LOG.debug(_('Result was %s') % _returncode) + LOG.log(loglevel, _('Result was %s') % _returncode) if not ignore_exit_code and _returncode not in check_exit_code: (stdout, stderr) = result raise ProcessExecutionError(exit_code=_returncode, @@ -176,7 +181,7 @@ def execute(*cmd, **kwargs): if not attempts: raise else: - LOG.debug(_('%r failed. Retrying.'), cmd) + LOG.log(loglevel, _('%r failed. Retrying.'), cmd) if delay_on_retry: greenthread.sleep(random.randint(20, 200) / 100.0) finally: diff --git a/openstack/common/py3kcompat/__init__.py b/openstack/common/py3kcompat/__init__.py new file mode 100644 index 0000000..be894cf --- /dev/null +++ b/openstack/common/py3kcompat/__init__.py @@ -0,0 +1,17 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 +# +# Copyright 2013 Canonical Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# diff --git a/openstack/common/py3kcompat/urlutils.py b/openstack/common/py3kcompat/urlutils.py new file mode 100644 index 0000000..04b3418 --- /dev/null +++ b/openstack/common/py3kcompat/urlutils.py @@ -0,0 +1,49 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 +# +# Copyright 2013 Canonical Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# + +""" +Python2/Python3 compatibility layer for OpenStack +""" + +import six + +if six.PY3: + # python3 + import urllib.parse + + urlencode = urllib.parse.urlencode + urljoin = urllib.parse.urljoin + quote = urllib.parse.quote + parse_qsl = urllib.parse.parse_qsl + urlparse = urllib.parse.urlparse + urlsplit = urllib.parse.urlsplit + urlunsplit = urllib.parse.urlunsplit +else: + # python2 + import urllib + import urlparse + + urlencode = urllib.urlencode + quote = urllib.quote + + parse = urlparse + parse_qsl = parse.parse_qsl + urljoin = parse.urljoin + urlparse = parse.urlparse + urlsplit = parse.urlsplit + urlunsplit = parse.urlunsplit diff --git a/openstack/common/quota.py b/openstack/common/quota.py new file mode 100644 index 0000000..43f8b01 --- /dev/null +++ b/openstack/common/quota.py @@ -0,0 +1,1175 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Common quotas""" + +import datetime + +from oslo.config import cfg + +from openstack.common.gettextutils import _ # noqa +from openstack.common import importutils +from openstack.common import log as logging +from openstack.common import timeutils + +LOG = logging.getLogger(__name__) + +common_quota_opts = [ + cfg.BoolOpt('use_default_quota_class', + default=True, + help='whether to use default quota class for default quota'), + cfg.StrOpt('quota_driver', + default='openstack.common.quota.DbQuotaDriver', + help='default driver to use for quota checks'), + cfg.IntOpt('until_refresh', + default=0, + help='count of reservations until usage is refreshed'), + cfg.IntOpt('max_age', + default=0, + help='number of seconds between subsequent usage refreshes'), + cfg.IntOpt('reservation_expire', + default=86400, + help='number of seconds until a reservation expires'), +] + +CONF = cfg.CONF +CONF.register_opts(common_quota_opts) + + +class QuotaException(Exception): + """Base exception for quota. + + To correctly use this class, inherit from it and define + a 'msg_fmt' property. That msg_fmt will get printf'd + with the keyword arguments provided to the constructor. + + """ + msg_fmt = _("Quota exception occurred.") + code = 500 + headers = {} + safe = False + + def __init__(self, message=None, **kwargs): + self.kwargs = {'code': self.code} + self.kwargs.update(kwargs) + if not message: + try: + message = self.msg_fmt % self.kwargs + except Exception: + # kwargs doesn't match a variable in the message + # log the issue and the kwargs + LOG.exception(_('Exception in string format operation')) + for name, value in kwargs.iteritems(): + LOG.error("%s: %s" % (name, value)) + # at least get the core message out if something happened + message = self.msg_fmt + super(QuotaException, self).__init__(message) + + def format_message(self): + return unicode(self) + + +class QuotaError(QuotaException): + msg_fmt = _("Quota exceeded") + ": code=%(code)s" + code = 413 + headers = {'Retry-After': 0} + safe = True + + +class InvalidQuotaValue(QuotaException): + msg_fmt = _("Change would make usage less than 0 for the following " + "resources: %(unders)s") + + +class OverQuota(QuotaException): + msg_fmt = _("Quota exceeded for resources: %(overs)s") + + +class QuotaExists(QuotaException): + message = _("Quota exists") + + +class QuotaResourceUnknown(QuotaException): + msg_fmt = _("Unknown quota resources %(unknown)s.") + + +class QuotaNotFound(QuotaException): + code = 404 + message = _("Quota could not be found") + + +class QuotaUsageNotFound(QuotaNotFound): + msg_fmt = _("Quota usage for project %(project_id)s could not be found.") + + +class ProjectQuotaNotFound(QuotaNotFound): + msg_fmt = _("Quota for project %(project_id)s could not be found.") + + +class ProjectUserQuotaNotFound(QuotaNotFound): + msg_fmt = _("Quota for user %(user_id)s in project %(project_id)s " + "could not be found.") + + +class QuotaClassNotFound(QuotaNotFound): + msg_fmt = _("Quota class %(class_name)s could not be found.") + + +class ReservationNotFound(QuotaNotFound): + msg_fmt = _("Quota reservation %(uuid)s could not be found.") + + +class InvalidReservationExpiration(QuotaException): + code = 400 + msg_fmt = _("Invalid reservation expiration %(expire)s.") + + +class DbQuotaDriver(object): + """Database quota driver. + + Driver to perform necessary checks to enforce quotas and obtain + quota information. The default driver utilizes the local + database. + """ + + def __init__(self, db): + self.db = db + + def get_by_project_and_user(self, context, project_id, user_id, resource): + """Get a specific quota by project and user.""" + + return self.db.quota_get(context, project_id, user_id, resource) + + def get_by_project(self, context, project_id, resource_name): + """Get a specific quota by project.""" + + return self.db.quota_get(context, project_id, resource_name) + + def get_by_class(self, context, quota_class, resource_name): + """Get a specific quota by quota class.""" + + return self.db.quota_class_get(context, quota_class, resource_name) + + def get_default(self, context, resource): + """Get a specific default quota for a resource.""" + + default_quotas = self.db.quota_class_get_default(context) + return default_quotas.get(resource.name, resource.default) + + def get_defaults(self, context, resources): + """Given a list of resources, retrieve the default quotas. + + Use the class quotas named `_DEFAULT_QUOTA_NAME` as default quotas, + if it exists. + + :param context: The request context, for access checks. + :param resources: A dictionary of the registered resources. + """ + + quotas = {} + default_quotas = {} + if CONF.use_default_quota_class: + default_quotas = self.db.quota_class_get_default(context) + for resource in resources.values(): + if resource.name not in default_quotas: + LOG.deprecated(_("Default quota for resource: %(res)s is set " + "by the default quota flag: quota_%(res)s, " + "it is now deprecated. Please use the " + "the default quota class for default " + "quota.") % {'res': resource.name}) + quotas[resource.name] = default_quotas.get(resource.name, + resource.default) + + return quotas + + def get_class_quotas(self, context, resources, quota_class, + defaults=True): + """Given a list of resources, get quotas for the given quota class. + + :param context: The request context, for access checks. + :param resources: A dictionary of the registered resources. + :param quota_class: The name of the quota class to return + quotas for. + :param defaults: If True, the default value will be reported + if there is no specific value for the + resource. + """ + + quotas = {} + default_quotas = {} + class_quotas = self.db.quota_class_get_all_by_name(context, + quota_class) + if defaults: + default_quotas = self.db.quota_class_get_default(context) + for resource in resources.values(): + if resource.name in class_quotas: + quotas[resource.name] = class_quotas[resource.name] + continue + + if defaults: + quotas[resource.name] = default_quotas.get(resource.name, + resource.default) + + return quotas + + def _process_quotas(self, context, resources, project_id, quotas, + quota_class=None, defaults=True, usages=None, + remains=False): + """Get the quotas for the appropriate class. + + If the project ID matches the one in the context, we use the + quota_class from the context, otherwise, we use the provided + quota_class (if any) + """ + + modified_quotas = {} + if project_id == context.project_id: + quota_class = context.quota_class + if quota_class: + class_quotas = self.db.quota_class_get_all_by_name(context, + quota_class) + else: + class_quotas = {} + + default_quotas = self.get_defaults(context, resources) + + for resource in resources.values(): + # Omit default/quota class values + if not defaults and resource.name not in quotas: + continue + class_quota = class_quotas.get(resource.name, + default_quotas[resource.name]) + limit = quotas.get(resource.name, class_quota) + modified_quotas[resource.name] = dict(limit=limit) + + # Include usages if desired. This is optional because one + # internal consumer of this interface wants to access the + # usages directly from inside a transaction. + if usages: + usage = usages.get(resource.name, {}) + modified_quotas[resource.name].update( + in_use=usage.get('in_use', 0), + reserved=usage.get('reserved', 0), + ) + # Initialize remains quotas. + if remains: + modified_quotas[resource.name].update(remains=limit) + + if remains: + all_quotas = self.db.quota_get_all(context, project_id) + for quota in all_quotas: + if quota['resource'] in modified_quotas: + modified_quotas[quota['resource']]['remains'] -= \ + quota['hard_limit'] + + return modified_quotas + + def get_user_quotas(self, context, resources, project_id, user_id, + quota_class=None, defaults=True, + usages=True): + """Get user quotas for given user and project. + + Given a list of resources, retrieve the quotas for the given + user and project. + + :param context: The request context, for access checks. + :param resources: A dictionary of the registered resources. + :param project_id: The ID of the project to return quotas for. + :param user_id: The ID of the user to return quotas for. + :param quota_class: If project_id != context.project_id, the + quota class cannot be determined. This + parameter allows it to be specified. It + will be ignored if project_id == + context.project_id. + :param defaults: If True, the quota class value (or the + default value, if there is no value from the + quota class) will be reported if there is no + specific value for the resource. + :param usages: If True, the current in_use and reserved counts + will also be returned. + """ + user_quotas = self.db.quota_get_all_by_project_and_user( + context, project_id, user_id) + user_usages = None + if usages: + user_usages = self.db.quota_usage_get_all_by_project_and_user( + context, project_id, user_id) + return self._process_quotas(context, resources, project_id, + user_quotas, quota_class, + defaults=defaults, usages=user_usages) + + def get_project_quotas(self, context, resources, project_id, + quota_class=None, defaults=True, + usages=True, remains=False): + """Given a list of resources, get the quotas for the given project. + + :param context: The request context, for access checks. + :param resources: A dictionary of the registered resources. + :param project_id: The ID of the project to return quotas for. + :param quota_class: If project_id != context.project_id, the + quota class cannot be determined. This + parameter allows it to be specified. It + will be ignored if project_id == + context.project_id. + :param defaults: If True, the quota class value (or the + default value, if there is no value from the + quota class) will be reported if there is no + specific value for the resource. + :param usages: If True, the current in_use and reserved counts + will also be returned. + :param remains: If True, the current remains of the project will + will be returned. + """ + project_quotas = self.db.quota_get_all_by_project( + context, project_id) + project_usages = None + if usages: + project_usages = self.db.quota_usage_get_all_by_project( + context, project_id) + return self._process_quotas(context, resources, project_id, + project_quotas, quota_class, + defaults=defaults, usages=project_usages, + remains=remains) + + def get_settable_quotas(self, context, resources, project_id, + user_id=None): + """Get settable quotas for given user and project. + + Given a list of resources, retrieve the range of settable quotas for + the given user or project. + + :param context: The request context, for access checks. + :param resources: A dictionary of the registered resources. + :param project_id: The ID of the project to return quotas for. + :param user_id: The ID of the user to return quotas for. + """ + settable_quotas = {} + project_quotas = self.get_project_quotas(context, resources, + project_id, remains=True) + if user_id: + user_quotas = self.get_user_quotas(context, resources, + project_id, user_id) + setted_quotas = self.db.quota_get_all_by_project_and_user( + context, project_id, user_id) + for key, value in user_quotas.items(): + maximum = project_quotas[key]['remains'] +\ + setted_quotas.get(key, 0) + settable_quotas[key] = dict( + minimum=value['in_use'] + value['reserved'], + maximum=maximum, + ) + else: + for key, value in project_quotas.items(): + minimum = max(int(value['limit'] - value['remains']), + int(value['in_use'] + value['reserved'])) + settable_quotas[key] = dict(minimum=minimum, maximum=-1) + return settable_quotas + + def _get_quotas(self, context, resources, keys, has_sync, project_id=None, + user_id=None): + """Get guotas for resources identified by keys. + + A helper method which retrieves the quotas for the specific + resources identified by keys, and which apply to the current + context. + + :param context: The request context, for access checks. + :param resources: A dictionary of the registered resources. + :param keys: A list of the desired quotas to retrieve. + :param has_sync: If True, indicates that the resource must + have a sync attribute; if False, indicates + that the resource must NOT have a sync + attribute. + :param project_id: Specify the project_id if current context + is admin and admin wants to impact on + common user's tenant. + :param user_id: Specify the user_id if current context + is admin and admin wants to impact on + common user. + """ + + # Filter resources + if has_sync: + sync_filt = lambda x: hasattr(x, 'sync') + else: + sync_filt = lambda x: not hasattr(x, 'sync') + desired = set(keys) + sub_resources = dict((k, v) for k, v in resources.items() + if k in desired and sync_filt(v)) + + # Make sure we accounted for all of them... + if len(keys) != len(sub_resources): + unknown = desired - set(sub_resources.keys()) + raise QuotaResourceUnknown(unknown=sorted(unknown)) + + if user_id: + # Grab and return the quotas (without usages) + quotas = self.get_user_quotas(context, sub_resources, + project_id, user_id, + context.quota_class, usages=False) + else: + # Grab and return the quotas (without usages) + quotas = self.get_project_quotas(context, sub_resources, + project_id, + context.quota_class, + usages=False) + + return dict((k, v['limit']) for k, v in quotas.items()) + + def limit_check(self, context, resources, values, project_id=None, + user_id=None): + """Check simple quota limits. + + For limits--those quotas for which there is no usage + synchronization function--this method checks that a set of + proposed values are permitted by the limit restriction. + + This method will raise a QuotaResourceUnknown exception if a + given resource is unknown or if it is not a simple limit + resource. + + If any of the proposed values is over the defined quota, an + OverQuota exception will be raised with the sorted list of the + resources which are too high. Otherwise, the method returns + nothing. + + :param context: The request context, for access checks. + :param resources: A dictionary of the registered resources. + :param values: A dictionary of the values to check against the + quota. + :param project_id: Specify the project_id if current context + is admin and admin wants to impact on + common user's tenant. + :param user_id: Specify the user_id if current context + is admin and admin wants to impact on + common user. + """ + + # Ensure no value is less than zero + unders = [key for key, val in values.items() if val < 0] + if unders: + raise InvalidQuotaValue(unders=sorted(unders)) + + # If project_id is None, then we use the project_id in context + if project_id is None: + project_id = context.project_id + # If user id is None, then we use the user_id in context + if user_id is None: + user_id = context.user_id + + # Get the applicable quotas + quotas = self._get_quotas(context, resources, values.keys(), + has_sync=False, project_id=project_id) + user_quotas = self._get_quotas(context, resources, values.keys(), + has_sync=False, project_id=project_id, + user_id=user_id) + + # Check the quotas and construct a list of the resources that + # would be put over limit by the desired values + overs = [key for key, val in values.items() + if (quotas[key] >= 0 and quotas[key] < val) or + (user_quotas[key] >= 0 and user_quotas[key] < val)] + if overs: + raise OverQuota(overs=sorted(overs), quotas=quotas, + usages={}) + + def reserve(self, context, resources, deltas, expire=None, + project_id=None, user_id=None): + """Check quotas and reserve resources. + + For counting quotas--those quotas for which there is a usage + synchronization function--this method checks quotas against + current usage and the desired deltas. + + This method will raise a QuotaResourceUnknown exception if a + given resource is unknown or if it does not have a usage + synchronization function. + + If any of the proposed values is over the defined quota, an + OverQuota exception will be raised with the sorted list of the + resources which are too high. Otherwise, the method returns a + list of reservation UUIDs which were created. + + :param context: The request context, for access checks. + :param resources: A dictionary of the registered resources. + :param deltas: A dictionary of the proposed delta changes. + :param expire: An optional parameter specifying an expiration + time for the reservations. If it is a simple + number, it is interpreted as a number of + seconds and added to the current time; if it is + a datetime.timedelta object, it will also be + added to the current time. A datetime.datetime + object will be interpreted as the absolute + expiration time. If None is specified, the + default expiration time set by + --default-reservation-expire will be used (this + value will be treated as a number of seconds). + :param project_id: Specify the project_id if current context + is admin and admin wants to impact on + common user's tenant. + :param user_id: Specify the user_id if current context + is admin and admin wants to impact on + common user. + """ + + # Set up the reservation expiration + if expire is None: + expire = CONF.reservation_expire + if isinstance(expire, (int, long)): + expire = datetime.timedelta(seconds=expire) + if isinstance(expire, datetime.timedelta): + expire = timeutils.utcnow() + expire + if not isinstance(expire, datetime.datetime): + raise InvalidReservationExpiration(expire=expire) + + # If project_id is None, then we use the project_id in context + if project_id is None: + project_id = context.project_id + # If user_id is None, then we use the project_id in context + if user_id is None: + user_id = context.user_id + + # Get the applicable quotas. + # NOTE(Vek): We're not worried about races at this point. + # Yes, the admin may be in the process of reducing + # quotas, but that's a pretty rare thing. + quotas = self._get_quotas(context, resources, deltas.keys(), + has_sync=True, project_id=project_id) + user_quotas = self._get_quotas(context, resources, deltas.keys(), + has_sync=True, project_id=project_id, + user_id=user_id) + + # NOTE(Vek): Most of the work here has to be done in the DB + # API, because we have to do it in a transaction, + # which means access to the session. Since the + # session isn't available outside the DBAPI, we + # have to do the work there. + return self.db.quota_reserve(context, resources, quotas, user_quotas, + deltas, expire, + CONF.until_refresh, CONF.max_age, + project_id=project_id, user_id=user_id) + + def commit(self, context, reservations, project_id=None, user_id=None): + """Commit reservations. + + :param context: The request context, for access checks. + :param reservations: A list of the reservation UUIDs, as + returned by the reserve() method. + :param project_id: Specify the project_id if current context + is admin and admin wants to impact on + common user's tenant. + :param user_id: Specify the user_id if current context + is admin and admin wants to impact on + common user. + """ + # If project_id is None, then we use the project_id in context + if project_id is None: + project_id = context.project_id + # If user_id is None, then we use the user_id in context + if user_id is None: + user_id = context.user_id + + self.db.reservation_commit(context, reservations, + project_id=project_id, user_id=user_id) + + def rollback(self, context, reservations, project_id=None, user_id=None): + """Roll back reservations. + + :param context: The request context, for access checks. + :param reservations: A list of the reservation UUIDs, as + returned by the reserve() method. + :param project_id: Specify the project_id if current context + is admin and admin wants to impact on + common user's tenant. + :param user_id: Specify the user_id if current context + is admin and admin wants to impact on + common user. + """ + # If project_id is None, then we use the project_id in context + if project_id is None: + project_id = context.project_id + # If user_id is None, then we use the user_id in context + if user_id is None: + user_id = context.user_id + + self.db.reservation_rollback(context, reservations, + project_id=project_id, user_id=user_id) + + def usage_reset(self, context, resources): + """Reset the usage records. + + Reset usages for a particular user on a list of resources. + This will force that user's usage records to be refreshed + the next time a reservation is made. + + Note: this does not affect the currently outstanding + reservations the user has; those reservations must be + committed or rolled back (or expired). + + :param context: The request context, for access checks. + :param resources: A list of the resource names for which the + usage must be reset. + """ + + # We need an elevated context for the calls to + # quota_usage_update() + elevated = context.elevated() + + for resource in resources: + try: + # Reset the usage to -1, which will force it to be + # refreshed + self.db.quota_usage_update(elevated, context.project_id, + context.user_id, + resource, in_use=-1) + except QuotaUsageNotFound: + # That means it'll be refreshed anyway + pass + + def destroy_all_by_project_and_user(self, context, project_id, user_id): + """Destroy objects by project and user. + + Destroy all quotas, usages, and reservations associated with a + project and user. + + :param context: The request context, for access checks. + :param project_id: The ID of the project being deleted. + :param user_id: The ID of the user being deleted. + """ + + self.db.quota_destroy_all_by_project_and_user(context, project_id, + user_id) + + def destroy_all_by_project(self, context, project_id): + """Destroy quotas, usages, and reservations for given project. + + :param context: The request context, for access checks. + :param project_id: The ID of the project being deleted. + """ + + self.db.quota_destroy_all_by_project(context, project_id) + + def expire(self, context): + """Expire reservations. + + Explores all currently existing reservations and rolls back + any that have expired. + + :param context: The request context, for access checks. + """ + + self.db.reservation_expire(context) + + +class BaseResource(object): + """Describe a single resource for quota checking.""" + + def __init__(self, name, flag=None): + """Initializes a Resource. + + :param name: The name of the resource, i.e., "volumes". + :param flag: The name of the flag or configuration option + which specifies the default value of the quota + for this resource. + """ + + self.name = name + self.flag = flag + + def quota(self, driver, context, **kwargs): + """Given a driver and context, obtain the quota for this resource. + + :param driver: A quota driver. + :param context: The request context. + :param project_id: The project to obtain the quota value for. + If not provided, it is taken from the + context. If it is given as None, no + project-specific quota will be searched + for. + :param quota_class: The quota class corresponding to the + project, or for which the quota is to be + looked up. If not provided, it is taken + from the context. If it is given as None, + no quota class-specific quota will be + searched for. Note that the quota class + defaults to the value in the context, + which may not correspond to the project if + project_id is not the same as the one in + the context. + """ + + # Get the project ID + project_id = kwargs.get('project_id', context.project_id) + + # Ditto for the quota class + quota_class = kwargs.get('quota_class', context.quota_class) + + # Look up the quota for the project + if project_id: + try: + return driver.get_by_project(context, project_id, self.name) + except ProjectQuotaNotFound: + pass + + # Try for the quota class + if quota_class: + try: + return driver.get_by_class(context, quota_class, self.name) + except QuotaClassNotFound: + pass + + # OK, return the default + return driver.get_default(context, self) + + @property + def default(self): + """Return the default value of the quota.""" + + return CONF[self.flag] if self.flag else -1 + + +class ReservableResource(BaseResource): + """Describe a reservable resource.""" + + def __init__(self, name, sync, flag=None): + """Initializes a ReservableResource. + + Reservable resources are those resources which directly + correspond to objects in the database, i.e., instances, + cores, etc. + + Usage synchronization function must be associated with each + object. This function will be called to determine the current + counts of one or more resources. This association is done in + database backend. See QUOTA_SYNC_FUNCTIONS in db/sqlalchemy/api.py. + + The usage synchronization function will be passed three + arguments: an admin context, the project ID, and an opaque + session object, which should in turn be passed to the + underlying database function. Synchronization functions + should return a dictionary mapping resource names to the + current in_use count for those resources; more than one + resource and resource count may be returned. Note that + synchronization functions may be associated with more than one + ReservableResource. + + :param name: The name of the resource, i.e., "volumes". + :param sync: A dbapi methods name which returns a dictionary + to resynchronize the in_use count for one or more + resources, as described above. + :param flag: The name of the flag or configuration option + which specifies the default value of the quota + for this resource. + """ + + super(ReservableResource, self).__init__(name, flag=flag) + self.sync = sync + + +class AbsoluteResource(BaseResource): + """Describe a non-reservable resource.""" + + pass + + +class CountableResource(AbsoluteResource): + """Countable resource. + + Describe a resource where the counts aren't based solely on the + project ID. + """ + + def __init__(self, name, count, flag=None): + """Initializes a CountableResource. + + Countable resources are those resources which directly + correspond to objects in the database, i.e., volumes, gigabytes, + etc., but for which a count by project ID is inappropriate. A + CountableResource must be constructed with a counting + function, which will be called to determine the current counts + of the resource. + + The counting function will be passed the context, along with + the extra positional and keyword arguments that are passed to + Quota.count(). It should return an integer specifying the + count. + + Note that this counting is not performed in a transaction-safe + manner. This resource class is a temporary measure to provide + required functionality, until a better approach to solving + this problem can be evolved. + + :param name: The name of the resource, i.e., "volumes". + :param count: A callable which returns the count of the + resource. The arguments passed are as described + above. + :param flag: The name of the flag or configuration option + which specifies the default value of the quota + for this resource. + """ + + super(CountableResource, self).__init__(name, flag=flag) + self.count = count + + +class QuotaEngine(object): + """Represent the set of recognized quotas.""" + + def __init__(self, db, quota_driver_class=None): + """Initialize a Quota object.""" + self.db = db + self._resources = {} + self._driver_cls = quota_driver_class + self.__driver = None + + @property + def _driver(self): + if self.__driver: + return self.__driver + if not self._driver_cls: + self._driver_cls = CONF.quota_driver + if isinstance(self._driver_cls, basestring): + self._driver_cls = importutils.import_object(self._driver_cls, + self.db) + self.__driver = self._driver_cls + return self.__driver + + def __contains__(self, resource): + return resource in self._resources + + def register_resource(self, resource): + """Register a resource.""" + + self._resources[resource.name] = resource + + def register_resources(self, resources): + """Register a list of resources.""" + + for resource in resources: + self.register_resource(resource) + + def get_by_project_and_user(self, context, project_id, user_id, resource): + """Get a specific quota by project and user.""" + + return self._driver.get_by_project_and_user(context, project_id, + user_id, resource) + + def get_by_project(self, context, project_id, resource_name): + """Get a specific quota by project.""" + + return self._driver.get_by_project(context, project_id, resource_name) + + def get_by_class(self, context, quota_class, resource_name): + """Get a specific quota by quota class.""" + + return self._driver.get_by_class(context, quota_class, resource_name) + + def get_default(self, context, resource): + """Get a specific default quota for a resource.""" + + return self._driver.get_default(context, resource) + + def get_defaults(self, context): + """Retrieve the default quotas. + + :param context: The request context, for access checks. + """ + + return self._driver.get_defaults(context, self.resources) + + def get_class_quotas(self, context, quota_class, defaults=True): + """Retrieve the quotas for the given quota class. + + :param context: The request context, for access checks. + :param quota_class: The name of the quota class to return + quotas for. + :param defaults: If True, the default value will be reported + if there is no specific value for the + resource. + """ + + return self._driver.get_class_quotas(context, self.resources, + quota_class, defaults=defaults) + + def get_user_quotas(self, context, project_id, user_id, quota_class=None, + defaults=True, usages=True): + """Retrieve the quotas for the given user and project. + + :param context: The request context, for access checks. + :param project_id: The ID of the project to return quotas for. + :param user_id: The ID of the user to return quotas for. + :param quota_class: If project_id != context.project_id, the + quota class cannot be determined. This + parameter allows it to be specified. + :param defaults: If True, the quota class value (or the + default value, if there is no value from the + quota class) will be reported if there is no + specific value for the resource. + :param usages: If True, the current in_use and reserved counts + will also be returned. + """ + + return self._driver.get_user_quotas(context, self._resources, + project_id, user_id, + quota_class=quota_class, + defaults=defaults, + usages=usages) + + def get_project_quotas(self, context, project_id, quota_class=None, + defaults=True, usages=True, remains=False): + """Retrieve the quotas for the given project. + + :param context: The request context, for access checks. + :param project_id: The ID of the project to return quotas for. + :param quota_class: If project_id != context.project_id, the + quota class cannot be determined. This + parameter allows it to be specified. + :param defaults: If True, the quota class value (or the + default value, if there is no value from the + quota class) will be reported if there is no + specific value for the resource. + :param usages: If True, the current in_use and reserved counts + will also be returned. + :param remains: If True, the current remains of the project will + will be returned. + """ + + return self._driver.get_project_quotas(context, self._resources, + project_id, + quota_class=quota_class, + defaults=defaults, + usages=usages, + remains=remains) + + def get_settable_quotas(self, context, project_id, user_id=None): + """Get settable quotas for given user and project. + + Given a list of resources, retrieve the range of settable quotas for + the given user or project. + + :param context: The request context, for access checks. + :param resources: A dictionary of the registered resources. + :param project_id: The ID of the project to return quotas for. + :param user_id: The ID of the user to return quotas for. + """ + + return self._driver.get_settable_quotas(context, self._resources, + project_id, + user_id=user_id) + + def count(self, context, resource, *args, **kwargs): + """Count a resource. + + For countable resources, invokes the count() function and + returns its result. Arguments following the context and + resource are passed directly to the count function declared by + the resource. + + :param context: The request context, for access checks. + :param resource: The name of the resource, as a string. + """ + + # Get the resource + res = self.resources.get(resource) + if not res or not hasattr(res, 'count'): + raise QuotaResourceUnknown(unknown=[resource]) + + return res.count(context, *args, **kwargs) + + def limit_check(self, context, project_id=None, user_id=None, **values): + """Check simple quota limits. + + For limits--those quotas for which there is no usage + synchronization function--this method checks that a set of + proposed values are permitted by the limit restriction. The + values to check are given as keyword arguments, where the key + identifies the specific quota limit to check, and the value is + the proposed value. + + This method will raise a QuotaResourceUnknown exception if a + given resource is unknown or if it is not a simple limit + resource. + + If any of the proposed values is over the defined quota, an + OverQuota exception will be raised with the sorted list of the + resources which are too high. Otherwise, the method returns + nothing. + + :param context: The request context, for access checks. + :param project_id: Specify the project_id if current context + is admin and admin wants to impact on + common user's tenant. + :param user_id: Specify the user_id if current context + is admin and admin wants to impact on + common user. + """ + + return self._driver.limit_check(context, self._resources, values, + project_id=project_id, user_id=user_id) + + def reserve(self, context, expire=None, project_id=None, user_id=None, + **deltas): + """Check quotas and reserve resources. + + For counting quotas--those quotas for which there is a usage + synchronization function--this method checks quotas against + current usage and the desired deltas. The deltas are given as + keyword arguments, and current usage and other reservations + are factored into the quota check. + + This method will raise a QuotaResourceUnknown exception if a + given resource is unknown or if it does not have a usage + synchronization function. + + If any of the proposed values is over the defined quota, an + OverQuota exception will be raised with the sorted list of the + resources which are too high. Otherwise, the method returns a + list of reservation UUIDs which were created. + + :param context: The request context, for access checks. + :param expire: An optional parameter specifying an expiration + time for the reservations. If it is a simple + number, it is interpreted as a number of + seconds and added to the current time; if it is + a datetime.timedelta object, it will also be + added to the current time. A datetime.datetime + object will be interpreted as the absolute + expiration time. If None is specified, the + default expiration time set by + --default-reservation-expire will be used (this + value will be treated as a number of seconds). + :param project_id: Specify the project_id if current context + is admin and admin wants to impact on + common user's tenant. + """ + + reservations = self._driver.reserve(context, self._resources, deltas, + expire=expire, + project_id=project_id, + user_id=user_id) + + LOG.debug(_("Created reservations %s"), reservations) + + return reservations + + def commit(self, context, reservations, project_id=None, user_id=None): + """Commit reservations. + + :param context: The request context, for access checks. + :param reservations: A list of the reservation UUIDs, as + returned by the reserve() method. + :param project_id: Specify the project_id if current context + is admin and admin wants to impact on + common user's tenant. + """ + + try: + self._driver.commit(context, reservations, project_id=project_id, + user_id=user_id) + except Exception: + # NOTE(Vek): Ignoring exceptions here is safe, because the + # usage resynchronization and the reservation expiration + # mechanisms will resolve the issue. The exception is + # logged, however, because this is less than optimal. + LOG.exception(_("Failed to commit reservations %s"), reservations) + return + LOG.debug(_("Committed reservations %s"), reservations) + + def rollback(self, context, reservations, project_id=None, user_id=None): + """Roll back reservations. + + :param context: The request context, for access checks. + :param reservations: A list of the reservation UUIDs, as + returned by the reserve() method. + :param project_id: Specify the project_id if current context + is admin and admin wants to impact on + common user's tenant. + """ + + try: + self._driver.rollback(context, reservations, project_id=project_id, + user_id=user_id) + except Exception: + # NOTE(Vek): Ignoring exceptions here is safe, because the + # usage resynchronization and the reservation expiration + # mechanisms will resolve the issue. The exception is + # logged, however, because this is less than optimal. + LOG.exception(_("Failed to roll back reservations %s"), + reservations) + return + LOG.debug(_("Rolled back reservations %s"), reservations) + + def usage_reset(self, context, resources): + """Reset the usage records. + + Reset usages for a particular user on a list of resources. + This will force that user's usage records to be refreshed + the next time a reservation is made. + + Note: this does not affect the currently outstanding + reservations the user has; those reservations must be + committed or rolled back (or expired). + + :param context: The request context, for access checks. + :param resources: A list of the resource names for which the + usage must be reset. + """ + + self._driver.usage_reset(context, resources) + + def destroy_all_by_project_and_user(self, context, project_id, user_id): + """Destroy all objects for given user and project. + + Destroy all quotas, usages, and reservations associated with a + project and user. + + :param context: The request context, for access checks. + :param project_id: The ID of the project being deleted. + :param user_id: The ID of the user being deleted. + """ + + self._driver.destroy_all_by_project_and_user(context, + project_id, user_id) + + def destroy_all_by_project(self, context, project_id): + """Destroy all quotas, usages, and reservations for given project. + + :param context: The request context, for access checks. + :param project_id: The ID of the project being deleted. + """ + + self._driver.destroy_all_by_project(context, project_id) + + def expire(self, context): + """Expire reservations. + + Explores all currently existing reservations and rolls back + any that have expired. + + :param context: The request context, for access checks. + """ + + self._driver.expire(context) + + @property + def resource_names(self): + return sorted(self.resources.keys()) + + @property + def resources(self): + return self._resources diff --git a/openstack/common/rootwrap/cmd.py b/openstack/common/rootwrap/cmd.py index 500f6c9..473dafb 100755..100644 --- a/openstack/common/rootwrap/cmd.py +++ b/openstack/common/rootwrap/cmd.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # vim: tabstop=4 shiftwidth=4 softtabstop=4 # Copyright (c) 2011 OpenStack Foundation. diff --git a/openstack/common/rpc/__init__.py b/openstack/common/rpc/__init__.py index e39f294..104b059 100644 --- a/openstack/common/rpc/__init__.py +++ b/openstack/common/rpc/__init__.py @@ -56,8 +56,7 @@ rpc_opts = [ help='Seconds to wait before a cast expires (TTL). ' 'Only supported by impl_zmq.'), cfg.ListOpt('allowed_rpc_exception_modules', - default=['openstack.common.exception', - 'nova.exception', + default=['nova.exception', 'cinder.exception', 'exceptions', ], diff --git a/openstack/common/rpc/amqp.py b/openstack/common/rpc/amqp.py index 1afd2ab..38f2515 100644 --- a/openstack/common/rpc/amqp.py +++ b/openstack/common/rpc/amqp.py @@ -300,8 +300,13 @@ def pack_context(msg, context): for args at some point. """ - context_d = dict([('_context_%s' % key, value) - for (key, value) in context.to_dict().iteritems()]) + if isinstance(context, dict): + context_d = dict([('_context_%s' % key, value) + for (key, value) in context.iteritems()]) + else: + context_d = dict([('_context_%s' % key, value) + for (key, value) in context.to_dict().iteritems()]) + msg.update(context_d) diff --git a/openstack/common/rpc/impl_kombu.py b/openstack/common/rpc/impl_kombu.py index 6b1ae93..61ab415 100644 --- a/openstack/common/rpc/impl_kombu.py +++ b/openstack/common/rpc/impl_kombu.py @@ -490,12 +490,8 @@ class Connection(object): # future with this? ssl_params['cert_reqs'] = ssl.CERT_REQUIRED - if not ssl_params: - # Just have the default behavior - return True - else: - # Return the extended behavior - return ssl_params + # Return the extended behavior or just have the default behavior + return ssl_params or True def _connect(self, params): """Connect to rabbit. Re-establish any queues that may have diff --git a/openstack/common/rpc/impl_qpid.py b/openstack/common/rpc/impl_qpid.py index 99c4619..e54beb4 100644 --- a/openstack/common/rpc/impl_qpid.py +++ b/openstack/common/rpc/impl_qpid.py @@ -320,7 +320,7 @@ class DirectPublisher(Publisher): def __init__(self, conf, session, msg_id): """Init a 'direct' publisher.""" super(DirectPublisher, self).__init__(session, msg_id, - {"type": "Direct"}) + {"type": "direct"}) class TopicPublisher(Publisher): diff --git a/openstack/common/rpc/matchmaker.py b/openstack/common/rpc/matchmaker.py index e80ab37..a94f542 100644 --- a/openstack/common/rpc/matchmaker.py +++ b/openstack/common/rpc/matchmaker.py @@ -248,9 +248,7 @@ class DirectBinding(Binding): that it maps directly to a host, thus direct. """ def test(self, key): - if '.' in key: - return True - return False + return '.' in key class TopicBinding(Binding): @@ -262,17 +260,13 @@ class TopicBinding(Binding): matches that of a direct exchange. """ def test(self, key): - if '.' not in key: - return True - return False + return '.' not in key class FanoutBinding(Binding): """Match on fanout keys, where key starts with 'fanout.' string.""" def test(self, key): - if key.startswith('fanout~'): - return True - return False + return key.startswith('fanout~') class StubExchange(Exchange): diff --git a/openstack/common/rpc/matchmaker_ring.py b/openstack/common/rpc/matchmaker_ring.py index 45a095f..6b488ce 100644 --- a/openstack/common/rpc/matchmaker_ring.py +++ b/openstack/common/rpc/matchmaker_ring.py @@ -63,9 +63,7 @@ class RingExchange(mm.Exchange): self.ring0[k] = itertools.cycle(self.ring[k]) def _ring_has(self, key): - if key in self.ring0: - return True - return False + return key in self.ring0 class RoundRobinRingExchange(RingExchange): diff --git a/openstack/common/rpc/securemessage.py b/openstack/common/rpc/securemessage.py new file mode 100644 index 0000000..c5530a6 --- /dev/null +++ b/openstack/common/rpc/securemessage.py @@ -0,0 +1,521 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import base64 +import collections +import os +import struct +import time + +import requests + +from oslo.config import cfg + +from openstack.common.crypto import utils as cryptoutils +from openstack.common import jsonutils +from openstack.common import log as logging + +secure_message_opts = [ + cfg.BoolOpt('enabled', default=True, + help='Whether Secure Messaging (Signing) is enabled,' + ' defaults to enabled'), + cfg.BoolOpt('enforced', default=False, + help='Whether Secure Messaging (Signing) is enforced,' + ' defaults to not enforced'), + cfg.BoolOpt('encrypt', default=False, + help='Whether Secure Messaging (Encryption) is enabled,' + ' defaults to not enabled'), + cfg.StrOpt('secret_keys_file', + help='Path to the file containing the keys, takes precedence' + ' over secret_key'), + cfg.MultiStrOpt('secret_key', + help='A list of keys: (ex: name:<base64 encoded key>),' + ' ignored if secret_keys_file is set'), + cfg.StrOpt('kds_endpoint', + help='KDS endpoint (ex: http://kds.example.com:35357/v3)'), +] +secure_message_group = cfg.OptGroup('secure_messages', + title='Secure Messaging options') + +LOG = logging.getLogger(__name__) + + +class SecureMessageException(Exception): + """Generic Exception for Secure Messages.""" + + msg = "An unknown Secure Message related exception occurred." + + def __init__(self, msg=None): + if msg is None: + msg = self.msg + super(SecureMessageException, self).__init__(msg) + + +class SharedKeyNotFound(SecureMessageException): + """No shared key was found and no other external authentication mechanism + is available. + """ + + msg = "Shared Key for [%s] Not Found. (%s)" + + def __init__(self, name, errmsg): + super(SharedKeyNotFound, self).__init__(self.msg % (name, errmsg)) + + +class InvalidMetadata(SecureMessageException): + """The metadata is invalid.""" + + msg = "Invalid metadata: %s" + + def __init__(self, err): + super(InvalidMetadata, self).__init__(self.msg % err) + + +class InvalidSignature(SecureMessageException): + """Signature validation failed.""" + + msg = "Failed to validate signature (source=%s, destination=%s)" + + def __init__(self, src, dst): + super(InvalidSignature, self).__init__(self.msg % (src, dst)) + + +class UnknownDestinationName(SecureMessageException): + """The Destination name is unknown to us.""" + + msg = "Invalid destination name (%s)" + + def __init__(self, name): + super(UnknownDestinationName, self).__init__(self.msg % name) + + +class InvalidEncryptedTicket(SecureMessageException): + """The Encrypted Ticket could not be successfully handled.""" + + msg = "Invalid Ticket (source=%s, destination=%s)" + + def __init__(self, src, dst): + super(InvalidEncryptedTicket, self).__init__(self.msg % (src, dst)) + + +class InvalidExpiredTicket(SecureMessageException): + """The ticket received is already expired.""" + + msg = "Expired ticket (source=%s, destination=%s)" + + def __init__(self, src, dst): + super(InvalidExpiredTicket, self).__init__(self.msg % (src, dst)) + + +class CommunicationError(SecureMessageException): + """The Communication with the KDS failed.""" + + msg = "Communication Error (target=%s): %s" + + def __init__(self, target, errmsg): + super(CommunicationError, self).__init__(self.msg % (target, errmsg)) + + +class InvalidArgument(SecureMessageException): + """Bad initialization argument.""" + + msg = "Invalid argument: %s" + + def __init__(self, errmsg): + super(InvalidArgument, self).__init__(self.msg % errmsg) + + +Ticket = collections.namedtuple('Ticket', ['skey', 'ekey', 'esek']) + + +class KeyStore(object): + """A storage class for Signing and Encryption Keys. + + This class creates an object that holds Generic Keys like Signing + Keys, Encryption Keys, Encrypted SEK Tickets ... + """ + + def __init__(self): + self._kvps = dict() + + def _get_key_name(self, source, target, ktype): + return (source, target, ktype) + + def _put(self, src, dst, ktype, expiration, data): + name = self._get_key_name(src, dst, ktype) + self._kvps[name] = (expiration, data) + + def _get(self, src, dst, ktype): + name = self._get_key_name(src, dst, ktype) + if name in self._kvps: + expiration, data = self._kvps[name] + if expiration > time.time(): + return data + else: + del self._kvps[name] + + return None + + def clear(self): + """Wipes the store clear of all data.""" + self._kvps.clear() + + def put_ticket(self, source, target, skey, ekey, esek, expiration): + """Puts a sek pair in the cache. + + :param source: Client name + :param target: Target name + :param skey: The Signing Key + :param ekey: The Encription Key + :param esek: The token encrypted with the target key + :param expiration: Expiration time in seconds since Epoch + """ + keys = Ticket(skey, ekey, esek) + self._put(source, target, 'ticket', expiration, keys) + + def get_ticket(self, source, target): + """Returns a Ticket (skey, ekey, esek) namedtuple for the + source/target pair. + """ + return self._get(source, target, 'ticket') + + +_KEY_STORE = KeyStore() + + +class _KDSClient(object): + + USER_AGENT = 'oslo-incubator/rpc' + + def __init__(self, endpoint=None, timeout=None): + """A KDS Client class.""" + + self._endpoint = endpoint + if timeout is not None: + self.timeout = float(timeout) + else: + self.timeout = None + + def _do_get(self, url, request): + req_kwargs = dict() + req_kwargs['headers'] = dict() + req_kwargs['headers']['User-Agent'] = self.USER_AGENT + req_kwargs['headers']['Content-Type'] = 'application/json' + req_kwargs['data'] = jsonutils.dumps({'request': request}) + if self.timeout is not None: + req_kwargs['timeout'] = self.timeout + + try: + resp = requests.get(url, **req_kwargs) + except requests.ConnectionError as e: + err = "Unable to establish connection. %s" % e + raise CommunicationError(url, err) + + return resp + + def _get_reply(self, url, resp): + if resp.text: + try: + body = jsonutils.loads(resp.text) + reply = body['reply'] + except (KeyError, TypeError, ValueError): + msg = "Failed to decode reply: %s" % resp.text + raise CommunicationError(url, msg) + else: + msg = "No reply data was returned." + raise CommunicationError(url, msg) + + return reply + + def _get_ticket(self, request, url=None, redirects=10): + """Send an HTTP request. + + Wraps around 'requests' to handle redirects and common errors. + """ + if url is None: + if not self._endpoint: + raise CommunicationError(url, 'Endpoint not configured') + url = self._endpoint + '/kds/ticket' + + while redirects: + resp = self._do_get(url, request) + if resp.status_code in (301, 302, 305): + # Redirected. Reissue the request to the new location. + url = resp.headers['location'] + redirects -= 1 + continue + elif resp.status_code != 200: + msg = "Request returned failure status: %s (%s)" + err = msg % (resp.status_code, resp.text) + raise CommunicationError(url, err) + + return self._get_reply(url, resp) + + raise CommunicationError(url, "Too many redirections, giving up!") + + def get_ticket(self, source, target, crypto, key): + + # prepare metadata + md = {'requestor': source, + 'target': target, + 'timestamp': time.time(), + 'nonce': struct.unpack('Q', os.urandom(8))[0]} + metadata = base64.b64encode(jsonutils.dumps(md)) + + # sign metadata + signature = crypto.sign(key, metadata) + + # HTTP request + reply = self._get_ticket({'metadata': metadata, + 'signature': signature}) + + # verify reply + signature = crypto.sign(key, (reply['metadata'] + reply['ticket'])) + if signature != reply['signature']: + raise InvalidEncryptedTicket(md['source'], md['destination']) + md = jsonutils.loads(base64.b64decode(reply['metadata'])) + if ((md['source'] != source or + md['destination'] != target or + md['expiration'] < time.time())): + raise InvalidEncryptedTicket(md['source'], md['destination']) + + # return ticket data + tkt = jsonutils.loads(crypto.decrypt(key, reply['ticket'])) + + return tkt, md['expiration'] + + +# we need to keep a global nonce, as this value should never repeat non +# matter how many SecureMessage objects we create +_NONCE = None + + +def _get_nonce(): + """We keep a single counter per instance, as it is so huge we can't + possibly cycle through within 1/100 of a second anyway. + """ + + global _NONCE + # Lazy initialize, for now get a random value, multiply by 2^32 and + # use it as the nonce base. The counter itself will rotate after + # 2^32 increments. + if _NONCE is None: + _NONCE = [struct.unpack('I', os.urandom(4))[0], 0] + + # Increment counter and wrap at 2^32 + _NONCE[1] += 1 + if _NONCE[1] > 0xffffffff: + _NONCE[1] = 0 + + # Return base + counter + return long((_NONCE[0] * 0xffffffff)) + _NONCE[1] + + +class SecureMessage(object): + """A Secure Message object. + + This class creates a signing/encryption facility for RPC messages. + It encapsulates all the necessary crypto primitives to insulate + regular code from the intricacies of message authentication, validation + and optionally encryption. + + :param topic: The topic name of the queue + :param host: The server name, together with the topic it forms a unique + name that is used to source signing keys, and verify + incoming messages. + :param conf: a ConfigOpts object + :param key: (optional) explicitly pass in endpoint private key. + If not provided it will be sourced from the service config + :param key_store: (optional) Storage class for local caching + :param encrypt: (defaults to False) Whether to encrypt messages + :param enctype: (defaults to AES) Cipher to use + :param hashtype: (defaults to SHA256) Hash function to use for signatures + """ + + def __init__(self, topic, host, conf, key=None, key_store=None, + encrypt=None, enctype='AES', hashtype='SHA256'): + + conf.register_group(secure_message_group) + conf.register_opts(secure_message_opts, group='secure_messages') + + self._name = '%s.%s' % (topic, host) + self._key = key + self._conf = conf.secure_messages + self._encrypt = self._conf.encrypt if (encrypt is None) else encrypt + self._crypto = cryptoutils.SymmetricCrypto(enctype, hashtype) + self._hkdf = cryptoutils.HKDF(hashtype) + self._kds = _KDSClient(self._conf.kds_endpoint) + + if self._key is None: + self._key = self._init_key(topic, self._name) + if self._key is None: + err = "Secret Key (or key file) is missing or malformed" + raise SharedKeyNotFound(self._name, err) + + self._key_store = key_store or _KEY_STORE + + def _init_key(self, topic, name): + keys = None + if self._conf.secret_keys_file: + with open(self._conf.secret_keys_file, 'r') as f: + keys = f.readlines() + elif self._conf.secret_key: + keys = self._conf.secret_key + + if keys is None: + return None + + for k in keys: + if k[0] == '#': + continue + if ':' not in k: + break + svc, key = k.split(':', 1) + if svc == topic or svc == name: + return base64.b64decode(key) + + return None + + def _split_key(self, key, size): + sig_key = key[:size] + enc_key = key[size:] + return sig_key, enc_key + + def _decode_esek(self, key, source, target, timestamp, esek): + """This function decrypts the esek buffer passed in and returns a + KeyStore to be used to check and decrypt the received message. + + :param key: The key to use to decrypt the ticket (esek) + :param source: The name of the source service + :param traget: The name of the target service + :param timestamp: The incoming message timestamp + :param esek: a base64 encoded encrypted block containing a JSON string + """ + rkey = None + + try: + s = self._crypto.decrypt(key, esek) + j = jsonutils.loads(s) + + rkey = base64.b64decode(j['key']) + expiration = j['timestamp'] + j['ttl'] + if j['timestamp'] > timestamp or timestamp > expiration: + raise InvalidExpiredTicket(source, target) + + except Exception: + raise InvalidEncryptedTicket(source, target) + + info = '%s,%s,%s' % (source, target, str(j['timestamp'])) + + sek = self._hkdf.expand(rkey, info, len(key) * 2) + + return self._split_key(sek, len(key)) + + def _get_ticket(self, target): + """This function will check if we already have a SEK for the specified + target in the cache, or will go and try to fetch a new SEK from the key + server. + + :param target: The name of the target service + """ + ticket = self._key_store.get_ticket(self._name, target) + + if ticket is not None: + return ticket + + tkt, expiration = self._kds.get_ticket(self._name, target, + self._crypto, self._key) + + self._key_store.put_ticket(self._name, target, + base64.b64decode(tkt['skey']), + base64.b64decode(tkt['ekey']), + tkt['esek'], expiration) + return self._key_store.get_ticket(self._name, target) + + def encode(self, version, target, json_msg): + """This is the main encoding function. + + It takes a target and a message and returns a tuple consisting of a + JSON serialized metadata object, a JSON serialized (and optionally + encrypted) message, and a signature. + + :param version: the current envelope version + :param target: The name of the target service (usually with hostname) + :param json_msg: a serialized json message object + """ + ticket = self._get_ticket(target) + + metadata = jsonutils.dumps({'source': self._name, + 'destination': target, + 'timestamp': time.time(), + 'nonce': _get_nonce(), + 'esek': ticket.esek, + 'encryption': self._encrypt}) + + message = json_msg + if self._encrypt: + message = self._crypto.encrypt(ticket.ekey, message) + + signature = self._crypto.sign(ticket.skey, + version + metadata + message) + + return (metadata, message, signature) + + def decode(self, version, metadata, message, signature): + """This is the main decoding function. + + It takes a version, metadata, message and signature strings and + returns a tuple with a (decrypted) message and metadata or raises + an exception in case of error. + + :param version: the current envelope version + :param metadata: a JSON serialized object with metadata for validation + :param message: a JSON serialized (base64 encoded encrypted) message + :param signature: a base64 encoded signature + """ + md = jsonutils.loads(metadata) + + check_args = ('source', 'destination', 'timestamp', + 'nonce', 'esek', 'encryption') + for arg in check_args: + if arg not in md: + raise InvalidMetadata('Missing metadata "%s"' % arg) + + if md['destination'] != self._name: + # TODO(simo) handle group keys by checking target + raise UnknownDestinationName(md['destination']) + + try: + skey, ekey = self._decode_esek(self._key, + md['source'], md['destination'], + md['timestamp'], md['esek']) + except InvalidExpiredTicket: + raise + except Exception: + raise InvalidMetadata('Failed to decode ESEK for %s/%s' % ( + md['source'], md['destination'])) + + sig = self._crypto.sign(skey, version + metadata + message) + + if sig != signature: + raise InvalidSignature(md['source'], md['destination']) + + if md['encryption'] is True: + msg = self._crypto.decrypt(ekey, message) + else: + msg = message + + return (md, msg) diff --git a/openstack/common/rpc/zmq_receiver.py b/openstack/common/rpc/zmq_receiver.py index e74da22..000a7dd 100755..100644 --- a/openstack/common/rpc/zmq_receiver.py +++ b/openstack/common/rpc/zmq_receiver.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # vim: tabstop=4 shiftwidth=4 softtabstop=4 # Copyright 2011 OpenStack Foundation diff --git a/openstack/common/service.py b/openstack/common/service.py index cb71af2..6da9751 100644 --- a/openstack/common/service.py +++ b/openstack/common/service.py @@ -81,6 +81,15 @@ class Launcher(object): """ self.services.wait() + def restart(self): + """Reload config files and restart service. + + :returns: None + + """ + cfg.CONF.reload_config_files() + self.services.restart() + class SignalExit(SystemExit): def __init__(self, signo, exccode=1): @@ -93,24 +102,31 @@ class ServiceLauncher(Launcher): # Allow the process to be killed again and die from natural causes signal.signal(signal.SIGTERM, signal.SIG_DFL) signal.signal(signal.SIGINT, signal.SIG_DFL) + signal.signal(signal.SIGHUP, signal.SIG_DFL) raise SignalExit(signo) - def wait(self): + def handle_signal(self): signal.signal(signal.SIGTERM, self._handle_signal) signal.signal(signal.SIGINT, self._handle_signal) + signal.signal(signal.SIGHUP, self._handle_signal) + + def _wait_for_exit_or_signal(self): + status = None + signo = 0 LOG.debug(_('Full set of CONF:')) CONF.log_opt_values(LOG, std_logging.DEBUG) - status = None try: super(ServiceLauncher, self).wait() except SignalExit as exc: signame = {signal.SIGTERM: 'SIGTERM', - signal.SIGINT: 'SIGINT'}[exc.signo] + signal.SIGINT: 'SIGINT', + signal.SIGHUP: 'SIGHUP'}[exc.signo] LOG.info(_('Caught %s, exiting'), signame) status = exc.code + signo = exc.signo except SystemExit as exc: status = exc.code finally: @@ -121,7 +137,16 @@ class ServiceLauncher(Launcher): except Exception: # We're shutting down, so it doesn't matter at this point. LOG.exception(_('Exception during rpc cleanup.')) - return status + + return status, signo + + def wait(self): + while True: + self.handle_signal() + status, signo = self._wait_for_exit_or_signal() + if signo != signal.SIGHUP: + return status + self.restart() class ServiceWrapper(object): @@ -139,9 +164,12 @@ class ProcessLauncher(object): self.running = True rfd, self.writepipe = os.pipe() self.readpipe = eventlet.greenio.GreenPipe(rfd, 'r') + self.handle_signal() + def handle_signal(self): signal.signal(signal.SIGTERM, self._handle_signal) signal.signal(signal.SIGINT, self._handle_signal) + signal.signal(signal.SIGHUP, self._handle_signal) def _handle_signal(self, signo, frame): self.sigcaught = signo @@ -150,6 +178,7 @@ class ProcessLauncher(object): # Allow the process to be killed again and die from natural causes signal.signal(signal.SIGTERM, signal.SIG_DFL) signal.signal(signal.SIGINT, signal.SIG_DFL) + signal.signal(signal.SIGHUP, signal.SIG_DFL) def _pipe_watcher(self): # This will block until the write end is closed when the parent @@ -160,16 +189,47 @@ class ProcessLauncher(object): sys.exit(1) - def _child_process(self, service): + def _child_process_handle_signal(self): # Setup child signal handlers differently def _sigterm(*args): signal.signal(signal.SIGTERM, signal.SIG_DFL) raise SignalExit(signal.SIGTERM) + def _sighup(*args): + signal.signal(signal.SIGHUP, signal.SIG_DFL) + raise SignalExit(signal.SIGHUP) + signal.signal(signal.SIGTERM, _sigterm) + signal.signal(signal.SIGHUP, _sighup) # Block SIGINT and let the parent send us a SIGTERM signal.signal(signal.SIGINT, signal.SIG_IGN) + def _child_wait_for_exit_or_signal(self, launcher): + status = None + signo = 0 + + try: + launcher.wait() + except SignalExit as exc: + signame = {signal.SIGTERM: 'SIGTERM', + signal.SIGINT: 'SIGINT', + signal.SIGHUP: 'SIGHUP'}[exc.signo] + LOG.info(_('Caught %s, exiting'), signame) + status = exc.code + signo = exc.signo + except SystemExit as exc: + status = exc.code + except BaseException: + LOG.exception(_('Unhandled exception')) + status = 2 + finally: + launcher.stop() + + return status, signo + + def _child_process(self, service): + self._child_process_handle_signal() + # Reopen the eventlet hub to make sure we don't share an epoll # fd with parent and/or siblings, which would be bad eventlet.hubs.use_hub() @@ -184,7 +244,7 @@ class ProcessLauncher(object): launcher = Launcher() launcher.launch_service(service) - launcher.wait() + return launcher def _start_child(self, wrap): if len(wrap.forktimes) > wrap.workers: @@ -205,21 +265,13 @@ class ProcessLauncher(object): # NOTE(johannes): All exceptions are caught to ensure this # doesn't fallback into the loop spawning children. It would # be bad for a child to spawn more children. - status = 0 - try: - self._child_process(wrap.service) - except SignalExit as exc: - signame = {signal.SIGTERM: 'SIGTERM', - signal.SIGINT: 'SIGINT'}[exc.signo] - LOG.info(_('Caught %s, exiting'), signame) - status = exc.code - except SystemExit as exc: - status = exc.code - except BaseException: - LOG.exception(_('Unhandled exception')) - status = 2 - finally: - wrap.service.stop() + launcher = self._child_process(wrap.service) + while True: + self._child_process_handle_signal() + status, signo = self._child_wait_for_exit_or_signal(launcher) + if signo != signal.SIGHUP: + break + launcher.restart() os._exit(status) @@ -265,12 +317,7 @@ class ProcessLauncher(object): wrap.children.remove(pid) return wrap - def wait(self): - """Loop waiting on children to die and respawning as necessary.""" - - LOG.debug(_('Full set of CONF:')) - CONF.log_opt_values(LOG, std_logging.DEBUG) - + def _respawn_children(self): while self.running: wrap = self._wait_child() if not wrap: @@ -279,14 +326,30 @@ class ProcessLauncher(object): # (see bug #1095346) eventlet.greenthread.sleep(.01) continue - while self.running and len(wrap.children) < wrap.workers: self._start_child(wrap) - if self.sigcaught: - signame = {signal.SIGTERM: 'SIGTERM', - signal.SIGINT: 'SIGINT'}[self.sigcaught] - LOG.info(_('Caught %s, stopping children'), signame) + def wait(self): + """Loop waiting on children to die and respawning as necessary.""" + + LOG.debug(_('Full set of CONF:')) + CONF.log_opt_values(LOG, std_logging.DEBUG) + + while True: + self.handle_signal() + self._respawn_children() + if self.sigcaught: + signame = {signal.SIGTERM: 'SIGTERM', + signal.SIGINT: 'SIGINT', + signal.SIGHUP: 'SIGHUP'}[self.sigcaught] + LOG.info(_('Caught %s, stopping children'), signame) + if self.sigcaught != signal.SIGHUP: + break + + for pid in self.children: + os.kill(pid, signal.SIGHUP) + self.running = True + self.sigcaught = None for pid in self.children: try: @@ -311,6 +374,10 @@ class Service(object): # signal that the service is done shutting itself down: self._done = event.Event() + def reset(self): + # NOTE(Fengqian): docs for Event.reset() recommend against using it + self._done = event.Event() + def start(self): pass @@ -353,6 +420,13 @@ class Services(object): def wait(self): self.tg.wait() + def restart(self): + self.stop() + self.done = event.Event() + for restart_service in self.services: + restart_service.reset() + self.tg.add_thread(self.run_service, restart_service, self.done) + @staticmethod def run_service(service, done): """Service start wrapper. diff --git a/openstack/common/test.py b/openstack/common/test.py new file mode 100644 index 0000000..7f400e5 --- /dev/null +++ b/openstack/common/test.py @@ -0,0 +1,52 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010-2011 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Common utilities used in testing""" + +import os + +import fixtures +import testtools + + +class BaseTestCase(testtools.TestCase): + + def setUp(self): + super(BaseTestCase, self).setUp() + self._set_timeout() + self._fake_output() + self.useFixture(fixtures.FakeLogger('openstack.common')) + + def _set_timeout(self): + test_timeout = os.environ.get('OS_TEST_TIMEOUT', 0) + try: + test_timeout = int(test_timeout) + except ValueError: + # If timeout value is invalid do not set a timeout. + test_timeout = 0 + if test_timeout > 0: + self.useFixture(fixtures.Timeout(test_timeout, gentle=True)) + + def _fake_output(self): + if (os.environ.get('OS_STDOUT_CAPTURE') == 'True' or + os.environ.get('OS_STDOUT_CAPTURE') == '1'): + stdout = self.useFixture(fixtures.StringStream('stdout')).stream + self.useFixture(fixtures.MonkeyPatch('sys.stdout', stdout)) + if (os.environ.get('OS_STDERR_CAPTURE') == 'True' or + os.environ.get('OS_STDERR_CAPTURE') == '1'): + stderr = self.useFixture(fixtures.StringStream('stderr')).stream + self.useFixture(fixtures.MonkeyPatch('sys.stderr', stderr)) diff --git a/openstack/common/timeutils.py b/openstack/common/timeutils.py index bd60489..aa9f708 100644 --- a/openstack/common/timeutils.py +++ b/openstack/common/timeutils.py @@ -49,9 +49,9 @@ def parse_isotime(timestr): try: return iso8601.parse_date(timestr) except iso8601.ParseError as e: - raise ValueError(e.message) + raise ValueError(unicode(e)) except TypeError as e: - raise ValueError(e.message) + raise ValueError(unicode(e)) def strtime(at=None, fmt=PERFECT_TIME_FORMAT): diff --git a/pypi/setup.py b/pypi/setup.py index feb38c8..ea9a5be 100644 --- a/pypi/setup.py +++ b/pypi/setup.py @@ -33,7 +33,7 @@ setuptools.setup( ], keywords='openstack', author='OpenStack', - author_email='openstack@lists.launchpad.net', + author_email='openstack@lists.openstack.org', url='http://www.openstack.org/', license='Apache Software License', zip_safe=True, diff --git a/requirements.txt b/requirements.txt index 2c512f1..656241b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ WebOb==1.2.3 eventlet>=0.12.0 greenlet>=0.3.2 lxml +requests>=1.1,<1.2.3 routes==1.12.3 iso8601>=0.1.4 anyjson>=0.3.3 @@ -17,6 +18,7 @@ qpid-python six netaddr pycrypto>=2.6 +Babel>=0.9.6 -f http://tarballs.openstack.org/oslo.config/oslo.config-1.2.0a3.tar.gz#egg=oslo.config-1.2.0a3 oslo.config>=1.2.0a3 diff --git a/run_tests.sh b/run_tests.sh new file mode 100755 index 0000000..bd850af --- /dev/null +++ b/run_tests.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +# Current scriipt is the simple wrapper on common tools/run_tests_common.sh +# scrip. It pass project specific variables to common script. +# +# Optins list (from tools/run_tests_common.sh). +# Use `./run_tests.sh -h` `./run_tests.sh --help` to get help message +# +# -V, --virtual-env Always use virtualenv. Install automatically if not present +# -N, --no-virtual-env Don't use virtualenv. Run tests in local environment +# -s, --no-site-packages Isolate the virtualenv from the global Python environment +# -r, --recreate-db Recreate the test database (deprecated, as this is now the default). +# -n, --no-recreate-db Don't recreate the test database. +# -f, --force Force a clean re-build of the virtual environment. Useful when dependencies have been added. +# -u, --update Update the virtual environment with any newer package versions +# -p, --pep8 Just run PEP8 and HACKING compliance check +# -P, --no-pep8 Don't run static code checks +# -c, --coverage Generate coverage report +# -d, --debug Run tests with testtools instead of testr. This allows you to use the debugger. +# -h, --help Print this usage message +# --hide-elapsed Don't print the elapsed time for each test along with slow test list +# --virtual-env-path <path> Location of the virtualenv directory. Default: \$(pwd) +# --virtual-env-name <name> Name of the virtualenv directory. Default: .venv +# --tools-path <dir> Location of the tools directory. Default: \$(pwd) +# +# Note: with no options specified, the script will try to run the tests in a +# virtual environment, if no virtualenv is found, the script will ask if +# you would like to create one. If you prefer to run tests NOT in a +# virtual environment, simply pass the -N option. + + +# On Linux, testrepository will inspect /proc/cpuinfo to determine how many +# CPUs are present in the machine, and run one worker per CPU. +# Set workers_count=0 if you want to run one worker per CPU. +# Make our paths available to run_tests_common.sh using `export` statement +# export WORKERS_COUNT=0 + +# there are no possibility to run some oslo tests with concurrency > 1 +# or separately due to dependencies between tests (see bug 1192207) +export WORKERS_COUNT=1 +# option include {PROJECT_NAME}/* directory to coverage report if `-c` or +# `--coverage` uses +export PROJECT_NAME="openstack" +# option exclude "${PROJECT_NAME}/openstack/common/*" from coverage report +# if equals to 1 +export OMIT_OSLO_FROM_COVERAGE=0 +# path to directory with tests +export TESTS_DIR="tests/" +export EGG_INFO_FILE="openstack.common.egg-info/entry_points.txt" + +# run common test script +tools/run_tests_common.sh $* diff --git a/test-requirements.txt b/test-requirements.txt index 3d88f90..f98d81a 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -2,7 +2,7 @@ coverage discover fixtures>=0.3.12 flake8==2.0 -hacking>=0.5.6,<0.7 +hacking>=0.5.6,<0.8 mock mox==0.5.3 mysql-python diff --git a/tests/unit/apiclient/test_auth.py b/tests/unit/apiclient/test_auth.py new file mode 100644 index 0000000..15d9ef9 --- /dev/null +++ b/tests/unit/apiclient/test_auth.py @@ -0,0 +1,181 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import argparse + +import fixtures +import mock +import requests + +from stevedore import extension + +try: + import json +except ImportError: + import simplejson as json + +from openstack.common.apiclient import auth +from openstack.common.apiclient import client +from openstack.common.apiclient import fake_client +from openstack.common import test + + +TEST_REQUEST_BASE = { + 'verify': True, +} + + +def mock_http_request(resp=None): + """Mock an HTTP Request.""" + if not resp: + resp = { + "access": { + "token": { + "expires": "12345", + "id": "FAKE_ID", + "tenant": { + "id": "FAKE_TENANT_ID", + } + }, + "serviceCatalog": [ + { + "type": "compute", + "endpoints": [ + { + "region": "RegionOne", + "adminURL": "http://localhost:8774/v1.1", + "internalURL": "http://localhost:8774/v1.1", + "publicURL": "http://localhost:8774/v1.1/", + }, + ], + }, + ], + }, + } + + auth_response = fake_client.TestResponse({ + "status_code": 200, + "text": json.dumps(resp), + }) + return mock.Mock(return_value=(auth_response)) + + +def requested_headers(cs): + """Return requested passed headers.""" + return { + 'User-Agent': cs.user_agent, + 'Content-Type': 'application/json', + } + + +class BaseFakePlugin(auth.BaseAuthPlugin): + def _do_authenticate(self, http_client): + pass + + def token_and_endpoint(self, endpoint_type, service_type): + pass + + +class GlobalFunctionsTest(test.BaseTestCase): + + def test_load_auth_system_opts(self): + self.useFixture(fixtures.MonkeyPatch( + "os.environ", + {"OS_TENANT_NAME": "fake-project", + "OS_USERNAME": "fake-username"})) + parser = argparse.ArgumentParser() + auth.load_auth_system_opts(parser) + options = parser.parse_args( + ["--os-auth-url=fake-url", "--os_auth_system=fake-system"]) + self.assertTrue(options.os_tenant_name, "fake-project") + self.assertTrue(options.os_username, "fake-username") + self.assertTrue(options.os_auth_url, "fake-url") + self.assertTrue(options.os_auth_system, "fake-system") + + +class MockEntrypoint(object): + def __init__(self, name, plugin): + self.name = name + self.plugin = plugin + + +class AuthPluginTest(test.BaseTestCase): + @mock.patch.object(requests.Session, "request") + @mock.patch.object(extension.ExtensionManager, "map") + def test_auth_system_success(self, mock_mgr_map, mock_request): + """Test that we can authenticate using the auth system.""" + class FakePlugin(BaseFakePlugin): + def authenticate(self, cls): + cls.request( + "POST", "http://auth/tokens", + json={"fake": "me"}, allow_redirects=True) + + mock_mgr_map.side_effect = ( + lambda func: func(MockEntrypoint("fake", FakePlugin))) + + mock_request.side_effect = mock_http_request() + + auth.discover_auth_systems() + plugin = auth.load_plugin("fake") + cs = client.HTTPClient(auth_plugin=plugin) + cs.authenticate() + + headers = requested_headers(cs) + + mock_request.assert_called_with( + "POST", + "http://auth/tokens", + headers=headers, + data='{"fake": "me"}', + allow_redirects=True, + **TEST_REQUEST_BASE) + + @mock.patch.object(extension.ExtensionManager, "map") + def test_discover_auth_system_options(self, mock_mgr_map): + """Test that we can load the auth system options.""" + class FakePlugin(BaseFakePlugin): + @classmethod + def add_opts(cls, parser): + parser.add_argument('--auth_system_opt', + default=False, + action='store_true', + help="Fake option") + + mock_mgr_map.side_effect = ( + lambda func: func(MockEntrypoint("fake", FakePlugin))) + + parser = argparse.ArgumentParser() + auth.discover_auth_systems() + auth.load_auth_system_opts(parser) + opts, _args = parser.parse_known_args(['--auth_system_opt']) + + self.assertTrue(opts.auth_system_opt) + + @mock.patch.object(extension.ExtensionManager, "map") + def test_parse_auth_system_options(self, mock_mgr_map): + """Test that we can parse the auth system options.""" + class FakePlugin(BaseFakePlugin): + opt_names = ["fake_argument"] + + mock_mgr_map.side_effect = ( + lambda func: func(MockEntrypoint("fake", FakePlugin))) + + auth.discover_auth_systems() + plugin = auth.load_plugin("fake") + + plugin.parse_opts([]) + self.assertIn("fake_argument", plugin.opts) diff --git a/tests/unit/apiclient/test_base.py b/tests/unit/apiclient/test_base.py new file mode 100644 index 0000000..86f2941 --- /dev/null +++ b/tests/unit/apiclient/test_base.py @@ -0,0 +1,239 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from openstack.common.apiclient import base +from openstack.common.apiclient import client +from openstack.common.apiclient import exceptions +from openstack.common.apiclient import fake_client +from openstack.common import test + + +class HumanResource(base.Resource): + HUMAN_ID = True + + +class HumanResourceManager(base.ManagerWithFind): + resource_class = HumanResource + + def list(self): + return self._list("/human_resources", "human_resources") + + def get(self, human_resource): + return self._get( + "/human_resources/%s" % base.getid(human_resource), + "human_resource") + + def update(self, human_resource, name): + body = { + "human_resource": { + "name": name, + }, + } + return self._put( + "/human_resources/%s" % base.getid(human_resource), + body, + "human_resource") + + +class CrudResource(base.Resource): + pass + + +class CrudResourceManager(base.CrudManager): + """Manager class for manipulating Identity crud_resources.""" + resource_class = CrudResource + collection_key = 'crud_resources' + key = 'crud_resource' + + def get(self, crud_resource): + return super(CrudResourceManager, self).get( + crud_resource_id=base.getid(crud_resource)) + + +class FakeHTTPClient(fake_client.FakeHTTPClient): + crud_resource_json = {"id": "1", "domain_id": "my-domain"} + + def get_human_resources(self, **kw): + return (200, {}, {'human_resources': [ + {'id': 1, 'name': '256 MB Server'}, + {'id': 2, 'name': '512 MB Server'}, + {'id': 'aa1', 'name': '128 MB Server'} + ]}) + + def get_human_resources_1(self, **kw): + res = self.get_human_resources()[2]['human_resources'][0] + return (200, {}, {'human_resource': res}) + + def put_human_resources_1(self, **kw): + kw = kw["json"]["human_resource"].copy() + kw["id"] = "1" + return (200, {}, {'human_resource': kw}) + + def post_crud_resources(self, **kw): + return (200, {}, {"crud_resource": {"id": "1"}}) + + def get_crud_resources(self, **kw): + crud_resources = [] + if kw.get("domain_id") == self.crud_resource_json["domain_id"]: + crud_resources = [self.crud_resource_json] + else: + crud_resources = [] + return (200, {}, {"crud_resources": crud_resources}) + + def get_crud_resources_1(self, **kw): + return (200, {}, {"crud_resource": self.crud_resource_json}) + + def head_crud_resources_1(self, **kw): + return (204, {}, None) + + def patch_crud_resources_1(self, **kw): + self.crud_resource_json.update(kw) + return (200, {}, {"crud_resource": self.crud_resource_json}) + + def delete_crud_resources_1(self, **kw): + return (202, {}, None) + + +class TestClient(client.BaseClient): + + service_type = "test" + + def __init__(self, http_client, extensions=None): + super(TestClient, self).__init__( + http_client, extensions=extensions) + + self.human_resources = HumanResourceManager(self) + self.crud_resources = CrudResourceManager(self) + + +class ResourceTest(test.BaseTestCase): + + def test_resource_repr(self): + r = base.Resource(None, dict(foo="bar", baz="spam")) + self.assertEqual(repr(r), "<Resource baz=spam, foo=bar>") + + def test_getid(self): + class TmpObject(base.Resource): + id = "4" + self.assertEqual(base.getid(TmpObject(None, {})), "4") + + def test_human_id(self): + r = base.Resource(None, {"name": "1"}) + self.assertEqual(r.human_id, None) + r = HumanResource(None, {"name": "1"}) + self.assertEqual(r.human_id, "1") + + +class BaseManagerTest(test.BaseTestCase): + + def setUp(self): + super(BaseManagerTest, self).setUp() + self.http_client = FakeHTTPClient() + self.tc = TestClient(self.http_client) + + def test_resource_lazy_getattr(self): + f = HumanResource(self.tc.human_resources, {'id': 1}) + self.assertEqual(f.name, '256 MB Server') + self.http_client.assert_called('GET', '/human_resources/1') + + # Missing stuff still fails after a second get + self.assertRaises(AttributeError, getattr, f, 'blahblah') + + def test_eq(self): + # Two resources of the same type with the same id: equal + r1 = base.Resource(None, {'id': 1, 'name': 'hi'}) + r2 = base.Resource(None, {'id': 1, 'name': 'hello'}) + self.assertEqual(r1, r2) + + # Two resources of different types: never equal + r1 = base.Resource(None, {'id': 1}) + r2 = HumanResource(None, {'id': 1}) + self.assertNotEqual(r1, r2) + + # Two resources with no ID: equal if their info is equal + r1 = base.Resource(None, {'name': 'joe', 'age': 12}) + r2 = base.Resource(None, {'name': 'joe', 'age': 12}) + self.assertEqual(r1, r2) + + def test_findall_invalid_attribute(self): + # Make sure findall with an invalid attribute doesn't cause errors. + # The following should not raise an exception. + self.tc.human_resources.findall(vegetable='carrot') + + # However, find() should raise an error + self.assertRaises(exceptions.NotFound, + self.tc.human_resources.find, + vegetable='carrot') + + def test_update(self): + name = "new-name" + human_resource = self.tc.human_resources.update("1", name) + self.assertEqual(human_resource.id, "1") + self.assertEqual(human_resource.name, name) + + +class CrudManagerTest(test.BaseTestCase): + + domain_id = "my-domain" + crud_resource_id = "1" + + def setUp(self): + super(CrudManagerTest, self).setUp() + self.http_client = FakeHTTPClient() + self.tc = TestClient(self.http_client) + + def test_create(self): + crud_resource = self.tc.crud_resources.create() + self.assertEqual(crud_resource.id, self.crud_resource_id) + + def test_list(self, domain=None, user=None): + crud_resources = self.tc.crud_resources.list( + base_url=None, + domain_id=self.domain_id) + self.assertEqual(len(crud_resources), 1) + self.assertEqual(crud_resources[0].id, self.crud_resource_id) + self.assertEqual(crud_resources[0].domain_id, self.domain_id) + crud_resources = self.tc.crud_resources.list( + base_url=None, + domain_id="another-domain", + another_attr=None) + self.assertEqual(len(crud_resources), 0) + + def test_get(self): + crud_resource = self.tc.crud_resources.get(self.crud_resource_id) + self.assertEqual(crud_resource.id, self.crud_resource_id) + fake_client.assert_has_keys( + crud_resource._info, + required=["id", "domain_id"], + optional=["missing-attr"]) + + def test_update(self): + crud_resource = self.tc.crud_resources.update( + crud_resource_id=self.crud_resource_id, + domain_id=self.domain_id) + self.assertEqual(crud_resource.id, self.crud_resource_id) + self.assertEqual(crud_resource.domain_id, self.domain_id) + + def test_delete(self): + resp = self.tc.crud_resources.delete( + crud_resource_id=self.crud_resource_id) + self.assertEqual(resp.status_code, 202) + + def test_head(self): + ret = self.tc.crud_resources.head( + crud_resource_id=self.crud_resource_id) + self.assertTrue(ret) diff --git a/tests/unit/apiclient/test_client.py b/tests/unit/apiclient/test_client.py new file mode 100644 index 0000000..b48ba12 --- /dev/null +++ b/tests/unit/apiclient/test_client.py @@ -0,0 +1,137 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import mock +import requests + +from openstack.common.apiclient import auth +from openstack.common.apiclient import client +from openstack.common.apiclient import exceptions +from openstack.common import test + + +class TestClient(client.BaseClient): + service_type = "test" + + +class FakeAuthPlugin(auth.BaseAuthPlugin): + auth_system = "fake" + attempt = -1 + + def _do_authenticate(self, http_client): + pass + + def token_and_endpoint(self, endpoint_type, service_type): + self.attempt = self.attempt + 1 + return ("token-%s" % self.attempt, "/endpoint-%s" % self.attempt) + + +class ClientTest(test.BaseTestCase): + + def test_client_with_timeout(self): + http_client = client.HTTPClient(None, timeout=2) + self.assertEqual(http_client.timeout, 2) + mock_request = mock.Mock() + mock_request.return_value = requests.Response() + mock_request.return_value.status_code = 200 + with mock.patch("requests.Session.request", mock_request): + http_client.request("GET", "/", json={"1": "2"}) + requests.Session.request.assert_called_with( + "GET", + "/", + timeout=2, + headers=mock.ANY, + verify=mock.ANY, + data=mock.ANY) + + def test_concat_url(self): + self.assertEqual(client.HTTPClient.concat_url("/a", "/b"), "/a/b") + self.assertEqual(client.HTTPClient.concat_url("/a", "b"), "/a/b") + self.assertEqual(client.HTTPClient.concat_url("/a/", "/b"), "/a/b") + + def test_client_request(self): + http_client = client.HTTPClient(FakeAuthPlugin()) + mock_request = mock.Mock() + mock_request.return_value = requests.Response() + mock_request.return_value.status_code = 200 + with mock.patch("requests.Session.request", mock_request): + http_client.client_request( + TestClient(http_client), "GET", "/resource", json={"1": "2"}) + requests.Session.request.assert_called_with( + "GET", + "/endpoint-0/resource", + headers={ + "User-Agent": http_client.user_agent, + "Content-Type": "application/json", + "X-Auth-Token": "token-0" + }, + data='{"1": "2"}', + verify=True) + + def test_client_request_reissue(self): + reject_token = None + + def fake_request(method, url, **kwargs): + if kwargs["headers"]["X-Auth-Token"] == reject_token: + raise exceptions.Unauthorized(method=method, url=url) + return "%s %s" % (method, url) + + http_client = client.HTTPClient(FakeAuthPlugin()) + test_client = TestClient(http_client) + http_client.request = fake_request + + self.assertEqual( + http_client.client_request( + test_client, "GET", "/resource"), + "GET /endpoint-0/resource") + reject_token = "token-0" + self.assertEqual( + http_client.client_request( + test_client, "GET", "/resource"), + "GET /endpoint-1/resource") + + +class FakeClient1(object): + pass + + +class FakeClient21(object): + pass + + +class GetClientClassTestCase(test.BaseTestCase): + version_map = { + "1": "%s.FakeClient1" % __name__, + "2.1": "%s.FakeClient21" % __name__, + } + + def test_get_int(self): + self.assertEqual( + client.BaseClient.get_class("fake", 1, self.version_map), + FakeClient1) + + def test_get_str(self): + self.assertEqual( + client.BaseClient.get_class("fake", "2.1", self.version_map), + FakeClient21) + + def test_unsupported_version(self): + self.assertRaises( + exceptions.UnsupportedVersion, + client.BaseClient.get_class, + "fake", "7", self.version_map) diff --git a/tests/unit/apiclient/test_exceptions.py b/tests/unit/apiclient/test_exceptions.py index 34cae73..39753ec 100644 --- a/tests/unit/apiclient/test_exceptions.py +++ b/tests/unit/apiclient/test_exceptions.py @@ -13,9 +13,8 @@ # License for the specific language governing permissions and limitations # under the License. -from tests import utils - from openstack.common.apiclient import exceptions +from openstack.common import test class FakeResponse(object): @@ -29,7 +28,7 @@ class FakeResponse(object): return self.json_data -class ExceptionsArgsTest(utils.BaseTestCase): +class ExceptionsArgsTest(test.BaseTestCase): def assert_exception(self, ex_cls, method, url, status_code, json_data): ex = exceptions.from_response( @@ -61,7 +60,7 @@ class ExceptionsArgsTest(utils.BaseTestCase): json_data = {"error": {"message": "fake unknown message", "details": "fake unknown details"}} self.assert_exception( - exceptions.HttpClientError, method, url, status_code, json_data) + exceptions.HTTPClientError, method, url, status_code, json_data) status_code = 600 self.assert_exception( exceptions.HttpError, method, url, status_code, json_data) diff --git a/tests/unit/crypto/test_utils.py b/tests/unit/crypto/test_utils.py index 3a39100..a6cb6a2 100644 --- a/tests/unit/crypto/test_utils.py +++ b/tests/unit/crypto/test_utils.py @@ -18,10 +18,10 @@ Unit Tests for crypto utils. """ from openstack.common.crypto import utils as cryptoutils -from tests import utils as test_utils +from openstack.common import test -class CryptoUtilsTestCase(test_utils.BaseTestCase): +class CryptoUtilsTestCase(test.BaseTestCase): # Uses Tests from RFC5869 def _test_HKDF(self, ikm, prk, okm, length, diff --git a/tests/unit/db/sqlalchemy/test_migrate.py b/tests/unit/db/sqlalchemy/test_migrate.py index 6724b5c..3e74a88 100644 --- a/tests/unit/db/sqlalchemy/test_migrate.py +++ b/tests/unit/db/sqlalchemy/test_migrate.py @@ -28,7 +28,7 @@ def uniques(*constraints): Convert a sequence of UniqueConstraint instances into a set of tuples of form (constraint_name, (constraint_columns)) so that - assertEquals() will be able to compare sets of unique constraints + assertEqual() will be able to compare sets of unique constraints """ @@ -70,7 +70,7 @@ class TestSqliteUniqueConstraints(test_base.DbTestCase): sa.UniqueConstraint(table.c.a, table.c.b, name='unique_a_b'), sa.UniqueConstraint(table.c.b, table.c.c, name='unique_b_c'), ) - self.assertEquals(should_be, existing) + self.assertEqual(should_be, existing) def test_add_unique_constraint(self): table = self.reflected_table @@ -82,7 +82,7 @@ class TestSqliteUniqueConstraints(test_base.DbTestCase): sa.UniqueConstraint(table.c.b, table.c.c, name='unique_b_c'), sa.UniqueConstraint(table.c.a, table.c.c, name='unique_a_c'), ) - self.assertEquals(should_be, existing) + self.assertEqual(should_be, existing) def test_drop_unique_constraint(self): table = self.reflected_table @@ -92,4 +92,4 @@ class TestSqliteUniqueConstraints(test_base.DbTestCase): should_be = uniques( sa.UniqueConstraint(table.c.b, table.c.c, name='unique_b_c'), ) - self.assertEquals(should_be, existing) + self.assertEqual(should_be, existing) diff --git a/tests/unit/db/sqlalchemy/test_migration_common.py b/tests/unit/db/sqlalchemy/test_migration_common.py new file mode 100644 index 0000000..bb27212 --- /dev/null +++ b/tests/unit/db/sqlalchemy/test_migration_common.py @@ -0,0 +1,154 @@ +# Copyright 2013 Mirantis Inc. +# All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + + +import contextlib +import os +import tempfile + +from migrate import exceptions as migrate_exception +from migrate.versioning import api as versioning_api +import mock +import sqlalchemy + +from openstack.common.db import exception as db_exception +from openstack.common.db.sqlalchemy import migration +from openstack.common.db.sqlalchemy import session as db_session +from tests.unit.db.sqlalchemy import base + + +class TestMigrationCommon(base.DbTestCase): + def setUp(self): + super(TestMigrationCommon, self).setUp() + + migration._REPOSITORY = None + self.path = tempfile.mkdtemp('test_migration') + self.return_value = '/home/openstack/migrations' + self.init_version = 1 + self.test_version = 123 + + self.patcher_repo = mock.patch.object(migration, 'Repository') + self.repository = self.patcher_repo.start() + self.repository.return_value = self.return_value + + self.mock_api_db = mock.patch.object(versioning_api, 'db_version') + self.mock_api_db_version = self.mock_api_db.start() + self.mock_api_db_version.return_value = self.test_version + + def tearDown(self): + os.rmdir(self.path) + self.mock_api_db.stop() + self.patcher_repo.stop() + super(TestMigrationCommon, self).tearDown() + + def test_find_migrate_repo_path_not_found(self): + self.assertRaises( + db_exception.DbMigrationError, + migration._find_migrate_repo, + "/foo/bar/", + ) + self.assertIsNone(migration._REPOSITORY) + + def test_find_migrate_repo_called_once(self): + my_repository = migration._find_migrate_repo(self.path) + + self.repository.assert_called_once_with(self.path) + self.assertEqual(migration._REPOSITORY, self.return_value) + self.assertEqual(my_repository, self.return_value) + + def test_find_migrate_repo_called_few_times(self): + repository1 = migration._find_migrate_repo(self.path) + repository2 = migration._find_migrate_repo(self.path) + + self.repository.assert_called_once_with(self.path) + self.assertEqual(migration._REPOSITORY, self.return_value) + self.assertEqual(repository1, self.return_value) + self.assertEqual(repository2, self.return_value) + + def test_db_version_control(self): + with contextlib.nested( + mock.patch.object(migration, '_find_migrate_repo'), + mock.patch.object(versioning_api, 'version_control'), + ) as (mock_find_repo, mock_version_control): + mock_find_repo.return_value = self.return_value + + version = migration.db_version_control( + self.path, self.test_version) + + self.assertEqual(version, self.test_version) + mock_version_control.assert_called_once_with( + db_session.get_engine(), self.return_value, self.test_version) + + def test_db_version_return(self): + ret_val = migration.db_version(self.path, self.init_version) + self.assertEqual(ret_val, self.test_version) + + def test_db_version_raise_not_controlled_error_first(self): + with mock.patch.object(migration, 'db_version_control') as mock_ver: + + self.mock_api_db_version.side_effect = [ + migrate_exception.DatabaseNotControlledError('oups'), + self.test_version] + + ret_val = migration.db_version(self.path, self.init_version) + self.assertEqual(ret_val, self.test_version) + mock_ver.assert_called_once_with(self.path, self.init_version) + + def test_db_version_raise_not_controlled_error_no_tables(self): + with mock.patch.object(sqlalchemy, 'MetaData') as mock_meta: + self.mock_api_db_version.side_effect = \ + migrate_exception.DatabaseNotControlledError('oups') + my_meta = mock.MagicMock() + my_meta.tables = {'a': 1, 'b': 2} + mock_meta.return_value = my_meta + + self.assertRaises( + db_exception.DbMigrationError, migration.db_version, + self.path, self.init_version) + + def test_db_sync_wrong_version(self): + self.assertRaises( + db_exception.DbMigrationError, migration.db_sync, self.path, 'foo') + + def test_db_sync_upgrade(self): + init_ver = 55 + with contextlib.nested( + mock.patch.object(migration, '_find_migrate_repo'), + mock.patch.object(versioning_api, 'upgrade') + ) as (mock_find_repo, mock_upgrade): + + mock_find_repo.return_value = self.return_value + self.mock_api_db_version.return_value = self.test_version - 1 + + migration.db_sync(self.path, self.test_version, init_ver) + + mock_upgrade.assert_called_once_with( + db_session.get_engine(), self.return_value, self.test_version) + + def test_db_sync_downgrade(self): + with contextlib.nested( + mock.patch.object(migration, '_find_migrate_repo'), + mock.patch.object(versioning_api, 'downgrade') + ) as (mock_find_repo, mock_downgrade): + + mock_find_repo.return_value = self.return_value + self.mock_api_db_version.return_value = self.test_version + 1 + + migration.db_sync(self.path, self.test_version) + + mock_downgrade.assert_called_once_with( + db_session.get_engine(), self.return_value, self.test_version) diff --git a/tests/unit/db/sqlalchemy/test_migrations.py b/tests/unit/db/sqlalchemy/test_migrations.py index 8428b1c..ee2929c 100644 --- a/tests/unit/db/sqlalchemy/test_migrations.py +++ b/tests/unit/db/sqlalchemy/test_migrations.py @@ -355,10 +355,10 @@ class TestWalkVersions(test_utils.BaseTestCase, WalkVersionsMixin): versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1) upgraded = [mock.call(None, v, with_data=True) for v in versions] - self.assertEquals(self._migrate_up.call_args_list, upgraded) + self.assertEqual(self._migrate_up.call_args_list, upgraded) downgraded = [mock.call(None, v - 1) for v in reversed(versions)] - self.assertEquals(self._migrate_down.call_args_list, downgraded) + self.assertEqual(self._migrate_down.call_args_list, downgraded) @mock.patch.object(WalkVersionsMixin, '_migrate_up') @mock.patch.object(WalkVersionsMixin, '_migrate_down') @@ -376,7 +376,7 @@ class TestWalkVersions(test_utils.BaseTestCase, WalkVersionsMixin): upgraded.extend( [mock.call(self.engine, v) for v in reversed(versions)] ) - self.assertEquals(upgraded, self._migrate_up.call_args_list) + self.assertEqual(upgraded, self._migrate_up.call_args_list) downgraded_1 = [ mock.call(self.engine, v - 1, with_data=True) for v in versions @@ -386,7 +386,7 @@ class TestWalkVersions(test_utils.BaseTestCase, WalkVersionsMixin): downgraded_2.append(mock.call(self.engine, v - 1)) downgraded_2.append(mock.call(self.engine, v - 1)) downgraded = downgraded_1 + downgraded_2 - self.assertEquals(self._migrate_down.call_args_list, downgraded) + self.assertEqual(self._migrate_down.call_args_list, downgraded) @mock.patch.object(WalkVersionsMixin, '_migrate_up') @mock.patch.object(WalkVersionsMixin, '_migrate_down') @@ -402,12 +402,12 @@ class TestWalkVersions(test_utils.BaseTestCase, WalkVersionsMixin): for v in versions: upgraded.append(mock.call(self.engine, v, with_data=True)) upgraded.append(mock.call(self.engine, v)) - self.assertEquals(upgraded, self._migrate_up.call_args_list) + self.assertEqual(upgraded, self._migrate_up.call_args_list) downgraded = [ mock.call(self.engine, v - 1, with_data=True) for v in versions ] - self.assertEquals(self._migrate_down.call_args_list, downgraded) + self.assertEqual(self._migrate_down.call_args_list, downgraded) @mock.patch.object(WalkVersionsMixin, '_migrate_up') @mock.patch.object(WalkVersionsMixin, '_migrate_down') @@ -422,4 +422,4 @@ class TestWalkVersions(test_utils.BaseTestCase, WalkVersionsMixin): upgraded = [ mock.call(self.engine, v, with_data=True) for v in versions ] - self.assertEquals(upgraded, self._migrate_up.call_args_list) + self.assertEqual(upgraded, self._migrate_up.call_args_list) diff --git a/tests/unit/db/sqlalchemy/test_models.py b/tests/unit/db/sqlalchemy/test_models.py index 04905a6..89bf83f 100644 --- a/tests/unit/db/sqlalchemy/test_models.py +++ b/tests/unit/db/sqlalchemy/test_models.py @@ -16,10 +16,10 @@ # under the License. from openstack.common.db.sqlalchemy import models -from tests import utils as test_utils +from openstack.common import test -class ModelBaseTest(test_utils.BaseTestCase): +class ModelBaseTest(test.BaseTestCase): def test_modelbase_has_dict_methods(self): dict_methods = ('__getitem__', @@ -73,7 +73,7 @@ class ModelBaseTest(test_utils.BaseTestCase): self.assertEqual(min_items, found_items) -class TimestampMixinTest(test_utils.BaseTestCase): +class TimestampMixinTest(test.BaseTestCase): def test_timestampmixin_attr(self): diff --git a/tests/unit/db/sqlalchemy/test_sqlalchemy.py b/tests/unit/db/sqlalchemy/test_sqlalchemy.py index 48d6cf7..01141e3 100644 --- a/tests/unit/db/sqlalchemy/test_sqlalchemy.py +++ b/tests/unit/db/sqlalchemy/test_sqlalchemy.py @@ -53,14 +53,14 @@ sql_connection_debug=60 sql_connection_trace=True """)]) self.conf(['--config-file', paths[0]]) - self.assertEquals(self.conf.database.connection, 'x://y.z') - self.assertEquals(self.conf.database.min_pool_size, 10) - self.assertEquals(self.conf.database.max_pool_size, 20) - self.assertEquals(self.conf.database.max_retries, 30) - self.assertEquals(self.conf.database.retry_interval, 40) - self.assertEquals(self.conf.database.max_overflow, 50) - self.assertEquals(self.conf.database.connection_debug, 60) - self.assertEquals(self.conf.database.connection_trace, True) + self.assertEqual(self.conf.database.connection, 'x://y.z') + self.assertEqual(self.conf.database.min_pool_size, 10) + self.assertEqual(self.conf.database.max_pool_size, 20) + self.assertEqual(self.conf.database.max_retries, 30) + self.assertEqual(self.conf.database.retry_interval, 40) + self.assertEqual(self.conf.database.max_overflow, 50) + self.assertEqual(self.conf.database.connection_debug, 60) + self.assertEqual(self.conf.database.connection_trace, True) def test_session_parameters(self): paths = self.create_tempfiles([('test', """[database] @@ -75,15 +75,15 @@ connection_trace=True pool_timeout=7 """)]) self.conf(['--config-file', paths[0]]) - self.assertEquals(self.conf.database.connection, 'x://y.z') - self.assertEquals(self.conf.database.min_pool_size, 10) - self.assertEquals(self.conf.database.max_pool_size, 20) - self.assertEquals(self.conf.database.max_retries, 30) - self.assertEquals(self.conf.database.retry_interval, 40) - self.assertEquals(self.conf.database.max_overflow, 50) - self.assertEquals(self.conf.database.connection_debug, 60) - self.assertEquals(self.conf.database.connection_trace, True) - self.assertEquals(self.conf.database.pool_timeout, 7) + self.assertEqual(self.conf.database.connection, 'x://y.z') + self.assertEqual(self.conf.database.min_pool_size, 10) + self.assertEqual(self.conf.database.max_pool_size, 20) + self.assertEqual(self.conf.database.max_retries, 30) + self.assertEqual(self.conf.database.retry_interval, 40) + self.assertEqual(self.conf.database.max_overflow, 50) + self.assertEqual(self.conf.database.connection_debug, 60) + self.assertEqual(self.conf.database.connection_trace, True) + self.assertEqual(self.conf.database.pool_timeout, 7) def test_dbapi_database_deprecated_parameters(self): paths = self.create_tempfiles([('test', @@ -98,15 +98,14 @@ pool_timeout=7 'sqlalchemy_pool_timeout=5\n' )]) self.conf(['--config-file', paths[0]]) - self.assertEquals(self.conf.database.connection, - 'fake_connection') - self.assertEquals(self.conf.database.idle_timeout, 100) - self.assertEquals(self.conf.database.min_pool_size, 99) - self.assertEquals(self.conf.database.max_pool_size, 199) - self.assertEquals(self.conf.database.max_retries, 22) - self.assertEquals(self.conf.database.retry_interval, 17) - self.assertEquals(self.conf.database.max_overflow, 101) - self.assertEquals(self.conf.database.pool_timeout, 5) + self.assertEqual(self.conf.database.connection, 'fake_connection') + self.assertEqual(self.conf.database.idle_timeout, 100) + self.assertEqual(self.conf.database.min_pool_size, 99) + self.assertEqual(self.conf.database.max_pool_size, 199) + self.assertEqual(self.conf.database.max_retries, 22) + self.assertEqual(self.conf.database.retry_interval, 17) + self.assertEqual(self.conf.database.max_overflow, 101) + self.assertEqual(self.conf.database.pool_timeout, 5) class SessionErrorWrapperTestCase(test_base.DbTestCase): diff --git a/tests/unit/db/sqlalchemy/test_utils.py b/tests/unit/db/sqlalchemy/test_utils.py index 15a1e25..0d8d87c 100644 --- a/tests/unit/db/sqlalchemy/test_utils.py +++ b/tests/unit/db/sqlalchemy/test_utils.py @@ -15,20 +15,37 @@ # License for the specific language governing permissions and limitations # under the License. +import warnings + +from migrate.changeset import UniqueConstraint import sqlalchemy from sqlalchemy.dialects import mysql from sqlalchemy import Boolean, Index, Integer, DateTime, String -from sqlalchemy import MetaData, Table, Column +from sqlalchemy import MetaData, Table, Column, ForeignKey from sqlalchemy.engine import reflection +from sqlalchemy.exc import SAWarning from sqlalchemy.sql import select from sqlalchemy.types import UserDefinedType, NullType from openstack.common.db.sqlalchemy import utils -from openstack.common import exception from tests.unit.db.sqlalchemy import test_migrations from tests import utils as testutils +class TestSanitizeDbUrl(testutils.BaseTestCase): + + def test_url_with_cred(self): + db_url = 'myproto://johndoe:secret@localhost/myschema' + expected = 'myproto://****:****@localhost/myschema' + actual = utils.sanitize_db_url(db_url) + self.assertEqual(expected, actual) + + def test_url_with_no_cred(self): + db_url = 'sqlite:///mysqlitefile' + actual = utils.sanitize_db_url(db_url) + self.assertEqual(db_url, actual) + + class CustomType(UserDefinedType): """Dummy column type for testing unsupported types.""" def get_col_spec(self): @@ -317,7 +334,7 @@ class TestMigrationUtils(test_migrations.BaseMigrationTestCase): Column('deleted', Boolean)) table.create() - self.assertRaises(exception.OpenstackException, + self.assertRaises(utils.ColumnError, utils.change_deleted_column_type_to_id_type, engine, table_name) @@ -358,7 +375,7 @@ class TestMigrationUtils(test_migrations.BaseMigrationTestCase): Column('deleted', Integer)) table.create() - self.assertRaises(exception.OpenstackException, + self.assertRaises(utils.ColumnError, utils.change_deleted_column_type_to_boolean, engine, table_name) @@ -371,3 +388,131 @@ class TestMigrationUtils(test_migrations.BaseMigrationTestCase): # but sqlalchemy will set it to NullType. self.assertTrue(isinstance(table.c.foo.type, NullType)) self.assertTrue(isinstance(table.c.deleted.type, Boolean)) + + def test_utils_drop_unique_constraint(self): + table_name = "__test_tmp_table__" + uc_name = 'uniq_foo' + values = [ + {'id': 1, 'a': 3, 'foo': 10}, + {'id': 2, 'a': 2, 'foo': 20}, + {'id': 3, 'a': 1, 'foo': 30}, + ] + for key, engine in self.engines.items(): + meta = MetaData() + meta.bind = engine + test_table = Table( + table_name, meta, + Column('id', Integer, primary_key=True, nullable=False), + Column('a', Integer), + Column('foo', Integer), + UniqueConstraint('a', name='uniq_a'), + UniqueConstraint('foo', name=uc_name), + ) + test_table.create() + + engine.execute(test_table.insert(), values) + # NOTE(boris-42): This method is generic UC dropper. + utils.drop_unique_constraint(engine, table_name, uc_name, 'foo') + + s = test_table.select().order_by(test_table.c.id) + rows = engine.execute(s).fetchall() + + for i in xrange(0, len(values)): + v = values[i] + self.assertEqual((v['id'], v['a'], v['foo']), rows[i]) + + # NOTE(boris-42): Update data about Table from DB. + meta = MetaData() + meta.bind = engine + test_table = Table(table_name, meta, autoload=True) + constraints = filter( + lambda c: c.name == uc_name, test_table.constraints) + self.assertEqual(len(constraints), 0) + self.assertEqual(len(test_table.constraints), 1) + + test_table.drop() + + def test_util_drop_unique_constraint_with_not_supported_sqlite_type(self): + table_name = "__test_tmp_table__" + uc_name = 'uniq_foo' + values = [ + {'id': 1, 'a': 3, 'foo': 10}, + {'id': 2, 'a': 2, 'foo': 20}, + {'id': 3, 'a': 1, 'foo': 30} + ] + + engine = self.engines['sqlite'] + meta = MetaData(bind=engine) + + test_table = Table( + table_name, meta, + Column('id', Integer, primary_key=True, nullable=False), + Column('a', Integer), + Column('foo', CustomType, default=0), + UniqueConstraint('a', name='uniq_a'), + UniqueConstraint('foo', name=uc_name), + ) + test_table.create() + + engine.execute(test_table.insert(), values) + warnings.simplefilter("ignore", SAWarning) + # NOTE(boris-42): Missing info about column `foo` that has + # unsupported type CustomType. + self.assertRaises(utils.ColumnError, + utils.drop_unique_constraint, + engine, table_name, uc_name, 'foo') + + # NOTE(boris-42): Wrong type of foo instance. it should be + # instance of sqlalchemy.Column. + self.assertRaises(utils.ColumnError, + utils.drop_unique_constraint, + engine, table_name, uc_name, 'foo', foo=Integer()) + + foo = Column('foo', CustomType, default=0) + utils.drop_unique_constraint( + engine, table_name, uc_name, 'foo', foo=foo) + + s = test_table.select().order_by(test_table.c.id) + rows = engine.execute(s).fetchall() + + for i in xrange(0, len(values)): + v = values[i] + self.assertEqual((v['id'], v['a'], v['foo']), rows[i]) + + # NOTE(boris-42): Update data about Table from DB. + meta = MetaData(bind=engine) + test_table = Table(table_name, meta, autoload=True) + constraints = filter( + lambda c: c.name == uc_name, test_table.constraints) + self.assertEqual(len(constraints), 0) + self.assertEqual(len(test_table.constraints), 1) + test_table.drop() + + def test_drop_unique_constraint_in_sqlite_fk_recreate(self): + engine = self.engines['sqlite'] + meta = MetaData() + meta.bind = engine + parent_table = Table( + 'table0', meta, + Column('id', Integer, primary_key=True), + Column('foo', Integer), + ) + parent_table.create() + table_name = 'table1' + table = Table( + table_name, meta, + Column('id', Integer, primary_key=True), + Column('baz', Integer), + Column('bar', Integer, ForeignKey("table0.id")), + UniqueConstraint('baz', name='constr1') + ) + table.create() + utils.drop_unique_constraint(engine, table_name, 'constr1', 'baz') + + insp = reflection.Inspector.from_engine(engine) + f_keys = insp.get_foreign_keys(table_name) + self.assertEqual(len(f_keys), 1) + f_key = f_keys[0] + self.assertEqual(f_key['referred_table'], 'table0') + self.assertEqual(f_key['referred_columns'], ['id']) + self.assertEqual(f_key['constrained_columns'], ['bar']) diff --git a/tests/unit/db/test_api.py b/tests/unit/db/test_api.py index 2a8db3b..dab2198 100644 --- a/tests/unit/db/test_api.py +++ b/tests/unit/db/test_api.py @@ -41,8 +41,8 @@ class DBAPITestCase(test_utils.BaseTestCase): )]) self.conf(['--config-file', paths[0]]) - self.assertEquals(self.conf.database.backend, 'test_123') - self.assertEquals(self.conf.database.use_tpool, True) + self.assertEqual(self.conf.database.backend, 'test_123') + self.assertEqual(self.conf.database.use_tpool, True) def test_dbapi_parameters(self): paths = self.create_tempfiles([('test', @@ -52,8 +52,8 @@ class DBAPITestCase(test_utils.BaseTestCase): )]) self.conf(['--config-file', paths[0]]) - self.assertEquals(self.conf.database.backend, 'test_123') - self.assertEquals(self.conf.database.use_tpool, True) + self.assertEqual(self.conf.database.backend, 'test_123') + self.assertEqual(self.conf.database.use_tpool, True) def test_dbapi_api_class_method_and_tpool_false(self): backend_mapping = {'test_known': 'tests.unit.db.test_api'} diff --git a/tests/unit/deprecated/test_wsgi.py b/tests/unit/deprecated/test_wsgi.py index 72aeae7..7e71837 100644 --- a/tests/unit/deprecated/test_wsgi.py +++ b/tests/unit/deprecated/test_wsgi.py @@ -25,7 +25,6 @@ import six import webob from openstack.common.deprecated import wsgi -from openstack.common import exception from tests import utils @@ -44,7 +43,7 @@ class RequestTest(utils.BaseTestCase): request = wsgi.Request.blank('/tests/123', method='POST') request.headers["Content-Type"] = "text/html" request.body = "asdf<br />" - self.assertRaises(exception.InvalidContentType, + self.assertRaises(wsgi.InvalidContentType, request.get_content_type) def test_content_type_with_charset(self): @@ -311,7 +310,7 @@ class ResponseSerializerTest(utils.BaseTestCase): self.body_serializers[ctype]) def test_get_serializer_unknown_content_type(self): - self.assertRaises(exception.InvalidContentType, + self.assertRaises(wsgi.InvalidContentType, self.serializer.get_body_serializer, 'application/unknown') @@ -335,7 +334,7 @@ class ResponseSerializerTest(utils.BaseTestCase): self.assertEqual(response.status_int, 404) def test_serialize_response_dict_to_unknown_content_type(self): - self.assertRaises(exception.InvalidContentType, + self.assertRaises(wsgi.InvalidContentType, self.serializer.serialize, {}, 'application/unknown') @@ -371,7 +370,7 @@ class RequestDeserializerTest(utils.BaseTestCase): self.body_deserializers['application/xml']) def test_get_deserializer_unknown_content_type(self): - self.assertRaises(exception.InvalidContentType, + self.assertRaises(wsgi.InvalidContentType, self.deserializer.get_body_deserializer, 'application/unknown') @@ -523,7 +522,7 @@ class WSGIServerTest(utils.BaseTestCase): server.start() response = urllib2.urlopen('http://127.0.0.1:%d/' % server.port) - self.assertEquals(greetings, response.read()) + self.assertEqual(greetings, response.read()) server.stop() @@ -541,7 +540,7 @@ class WSGIServerTest(utils.BaseTestCase): server.start() response = urllib2.urlopen('http://127.0.0.1:%d/v1.0/' % server.port) - self.assertEquals(greetings, response.read()) + self.assertEqual(greetings, response.read()) server.stop() @@ -595,7 +594,7 @@ class WSGIServerWithSSLTest(utils.BaseTestCase): server.start() response = urllib2.urlopen('https://127.0.0.1:%d/v1.0/' % server.port) - self.assertEquals(greetings, response.read()) + self.assertEqual(greetings, response.read()) server.stop() @@ -618,6 +617,6 @@ class WSGIServerWithSSLTest(utils.BaseTestCase): server.start() response = urllib2.urlopen('https://[::1]:%d/v1.0/' % server.port) - self.assertEquals(greetings, response.read()) + self.assertEqual(greetings, response.read()) server.stop() diff --git a/tests/unit/fixture/__init__.py b/tests/unit/fixture/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/unit/fixture/__init__.py diff --git a/tests/unit/fixture/test_config.py b/tests/unit/fixture/test_config.py new file mode 100644 index 0000000..3368272 --- /dev/null +++ b/tests/unit/fixture/test_config.py @@ -0,0 +1,45 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 +# +# Copyright 2013 Mirantis, Inc. +# Copyright 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from oslo.config import cfg + +from openstack.common.fixture import config +from tests.utils import BaseTestCase + +conf = cfg.CONF + + +class ConfigTestCase(BaseTestCase): + def setUp(self): + super(ConfigTestCase, self).setUp() + self.config = self.useFixture(config.Config(conf)).config + self.config_fixture = config.Config(conf) + conf.register_opt(cfg.StrOpt( + 'testing_option', default='initial_value')) + + def test_overriden_value(self): + self.assertEqual(conf.get('testing_option'), 'initial_value') + self.config(testing_option='changed_value') + self.assertEqual(conf.get('testing_option'), + self.config_fixture.conf.get('testing_option')) + + def test_cleanup(self): + self.config(testing_option='changed_value') + self.assertEqual(self.config_fixture.conf.get('testing_option'), + 'changed_value') + self.config_fixture.conf.reset() + self.assertEqual(conf.get('testing_option'), 'initial_value') diff --git a/tests/unit/middleware/test_context.py b/tests/unit/middleware/test_context.py index 5bf5ae4..daffa60 100644 --- a/tests/unit/middleware/test_context.py +++ b/tests/unit/middleware/test_context.py @@ -20,10 +20,10 @@ import mock import openstack.common.context from openstack.common.middleware import context -from tests import utils +from openstack.common import test -class ContextMiddlewareTest(utils.BaseTestCase): +class ContextMiddlewareTest(test.BaseTestCase): def test_process_request(self): req = mock.Mock() @@ -59,7 +59,7 @@ class ContextMiddlewareTest(utils.BaseTestCase): import_class.assert_called_with(mock.sentinel.arg) -class FilterFactoryTest(utils.BaseTestCase): +class FilterFactoryTest(test.BaseTestCase): def test_filter_factory(self): global_conf = dict(sentinel=mock.sentinel.global_conf) diff --git a/tests/unit/middleware/test_correlation_id.py b/tests/unit/middleware/test_correlation_id.py index dc83cc7..11939e9 100644 --- a/tests/unit/middleware/test_correlation_id.py +++ b/tests/unit/middleware/test_correlation_id.py @@ -36,7 +36,7 @@ class CorrelationIdMiddlewareTest(utils.BaseTestCase): middleware = correlation_id.CorrelationIdMiddleware(app) middleware(req) - self.assertEquals(req.headers.get("X_CORRELATION_ID"), "fake_uuid") + self.assertEqual(req.headers.get("X_CORRELATION_ID"), "fake_uuid") def test_process_request_should_not_regenerate_correlation_id(self): app = mock.Mock() @@ -46,5 +46,4 @@ class CorrelationIdMiddlewareTest(utils.BaseTestCase): middleware = correlation_id.CorrelationIdMiddleware(app) middleware(req) - self.assertEquals(req.headers.get("X_CORRELATION_ID"), - "correlation_id") + self.assertEqual(req.headers.get("X_CORRELATION_ID"), "correlation_id") diff --git a/tests/unit/middleware/test_sizelimit.py b/tests/unit/middleware/test_sizelimit.py index 7579659..3666b54 100644 --- a/tests/unit/middleware/test_sizelimit.py +++ b/tests/unit/middleware/test_sizelimit.py @@ -32,7 +32,7 @@ class TestLimitingReader(utils.BaseTestCase): for chunk in sizelimit.LimitingReader(data, BYTES): bytes_read += len(chunk) - self.assertEquals(bytes_read, BYTES) + self.assertEqual(bytes_read, BYTES) bytes_read = 0 data = six.StringIO("*" * BYTES) @@ -42,7 +42,7 @@ class TestLimitingReader(utils.BaseTestCase): bytes_read += 1 byte = reader.read(1) - self.assertEquals(bytes_read, BYTES) + self.assertEqual(bytes_read, BYTES) def test_limiting_reader_fails(self): BYTES = 1024 diff --git a/tests/unit/rpc/amqp.py b/tests/unit/rpc/amqp.py index 83713c7..76d6946 100644 --- a/tests/unit/rpc/amqp.py +++ b/tests/unit/rpc/amqp.py @@ -22,6 +22,7 @@ Unit Tests for AMQP-based remote procedure calls import logging from eventlet import greenthread +import mock from oslo.config import cfg from openstack.common import jsonutils @@ -177,3 +178,13 @@ class BaseRpcAMQPTestCase(common.BaseRpcTestCase): conn.close() self.assertTrue(self.exc_raised) + + def test_context_dict_type_check(self): + """Test that context is handled properly depending on the type.""" + fake_context = {'fake': 'context'} + mock_msg = mock.MagicMock() + rpc_amqp.pack_context(mock_msg, fake_context) + + # assert first arg in args was a dict type + args = mock_msg.update.call_args[0] + self.assertIsInstance(args[0], dict) diff --git a/tests/unit/rpc/common.py b/tests/unit/rpc/common.py index 2c343fe..4cc279c 100644 --- a/tests/unit/rpc/common.py +++ b/tests/unit/rpc/common.py @@ -26,7 +26,6 @@ import time import eventlet from oslo.config import cfg -from openstack.common import exception from openstack.common.gettextutils import _ # noqa from openstack.common.rpc import common as rpc_common from openstack.common.rpc import dispatcher as rpc_dispatcher @@ -37,6 +36,13 @@ FLAGS = cfg.CONF LOG = logging.getLogger(__name__) +class ApiError(Exception): + def __init__(self, message='Unknown', code='Unknown'): + self.api_message = message + self.code = code + super(ApiError, self).__init__('%s: %s' % (code, message)) + + class BaseRpcTestCase(test_utils.BaseTestCase): def setUp(self, supports_timeouts=True, topic='test', @@ -424,7 +430,7 @@ class TestReceiver(object): @staticmethod def fail_converted(context, value): """Raises an exception with the value sent in.""" - raise exception.ApiError(message=value, code='500') + raise ApiError(message=value, code='500') @staticmethod def block(context, value): diff --git a/tests/unit/rpc/test_common.py b/tests/unit/rpc/test_common.py index 291f823..f37f4b0 100644 --- a/tests/unit/rpc/test_common.py +++ b/tests/unit/rpc/test_common.py @@ -23,7 +23,6 @@ import sys from oslo.config import cfg import six -from openstack.common import exception from openstack.common import importutils from openstack.common import jsonutils from openstack.common import rpc @@ -35,10 +34,6 @@ FLAGS = cfg.CONF LOG = logging.getLogger(__name__) -def raise_exception(): - raise Exception("test") - - class FakeUserDefinedException(Exception): def __init__(self, *args, **kwargs): super(FakeUserDefinedException, self).__init__(*args) @@ -54,7 +49,7 @@ class RpcCommonTestCase(test_utils.BaseTestCase): } try: - raise_exception() + raise Exception("test") except Exception: failure = rpc_common.serialize_remote_exception(sys.exc_info()) @@ -64,17 +59,14 @@ class RpcCommonTestCase(test_utils.BaseTestCase): self.assertEqual(expected['message'], failure['message']) def test_serialize_remote_custom_exception(self): - def raise_custom_exception(): - raise exception.MalformedRequestBody(reason='test') - expected = { - 'class': 'MalformedRequestBody', - 'module': 'openstack.common.exception', - 'message': str(exception.MalformedRequestBody(reason='test')), + 'class': 'FakeUserDefinedException', + 'module': self.__class__.__module__, + 'message': 'test', } try: - raise_custom_exception() + raise FakeUserDefinedException('test') except Exception: failure = rpc_common.serialize_remote_exception(sys.exc_info()) @@ -88,14 +80,14 @@ class RpcCommonTestCase(test_utils.BaseTestCase): # module, when being re-serialized, so that through any amount of cell # hops up, it can pop out with the right type expected = { - 'class': 'OpenstackException', - 'module': 'openstack.common.exception', - 'message': exception.OpenstackException.msg_fmt, + 'class': 'FakeUserDefinedException', + 'module': self.__class__.__module__, + 'message': 'foobar', } def raise_remote_exception(): try: - raise exception.OpenstackException() + raise FakeUserDefinedException('foobar') except Exception as e: ex_type = type(e) message = str(e) @@ -132,26 +124,6 @@ class RpcCommonTestCase(test_utils.BaseTestCase): self.assertTrue('raise NotImplementedError' in six.text_type(after_exc)) - def test_deserialize_remote_custom_exception(self): - failure = { - 'class': 'OpenstackException', - 'module': 'openstack.common.exception', - 'message': exception.OpenstackException.msg_fmt, - 'tb': ['raise OpenstackException'], - } - serialized = jsonutils.dumps(failure) - - after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized) - self.assertTrue(isinstance(after_exc, exception.OpenstackException)) - self.assertTrue('An unknown' in six.text_type(after_exc)) - #assure the traceback was added - self.assertTrue('raise OpenstackException' in six.text_type(after_exc)) - self.assertEqual('OpenstackException_Remote', - after_exc.__class__.__name__) - self.assertEqual('openstack.common.exception_Remote', - after_exc.__class__.__module__) - self.assertTrue(isinstance(after_exc, exception.OpenstackException)) - def test_deserialize_remote_exception_bad_module(self): failure = { 'class': 'popen2', @@ -169,6 +141,7 @@ class RpcCommonTestCase(test_utils.BaseTestCase): self.config(allowed_rpc_exception_modules=[self.__class__.__module__]) failure = { 'class': 'FakeUserDefinedException', + 'message': 'foobar', 'module': self.__class__.__module__, 'tb': ['raise FakeUserDefinedException'], } @@ -176,6 +149,7 @@ class RpcCommonTestCase(test_utils.BaseTestCase): after_exc = rpc_common.deserialize_remote_exception(FLAGS, serialized) self.assertTrue(isinstance(after_exc, FakeUserDefinedException)) + self.assertTrue('foobar' in six.text_type(after_exc)) #assure the traceback was added self.assertTrue('raise FakeUserDefinedException' in six.text_type(after_exc)) @@ -327,8 +301,8 @@ class RpcCommonTestCase(test_utils.BaseTestCase): def test_safe_log_sanitizes_globals(self): def logger_method(msg, data): - self.assertEquals('<SANITIZED>', data['_context_auth_token']) - self.assertEquals('<SANITIZED>', data['auth_token']) + self.assertEqual('<SANITIZED>', data['_context_auth_token']) + self.assertEqual('<SANITIZED>', data['auth_token']) data = {'_context_auth_token': 'banana', 'auth_token': 'cheese'} @@ -336,7 +310,7 @@ class RpcCommonTestCase(test_utils.BaseTestCase): def test_safe_log_sanitizes_set_admin_password(self): def logger_method(msg, data): - self.assertEquals('<SANITIZED>', data['args']['new_pass']) + self.assertEqual('<SANITIZED>', data['args']['new_pass']) data = {'_context_auth_token': 'banana', 'auth_token': 'cheese', @@ -346,7 +320,7 @@ class RpcCommonTestCase(test_utils.BaseTestCase): def test_safe_log_sanitizes_run_instance(self): def logger_method(msg, data): - self.assertEquals('<SANITIZED>', data['args']['admin_password']) + self.assertEqual('<SANITIZED>', data['args']['admin_password']) data = {'_context_auth_token': 'banana', 'auth_token': 'cheese', @@ -356,8 +330,8 @@ class RpcCommonTestCase(test_utils.BaseTestCase): def test_safe_log_sanitizes_any_password_in_context(self): def logger_method(msg, data): - self.assertEquals('<SANITIZED>', data['_context_password']) - self.assertEquals('<SANITIZED>', data['password']) + self.assertEqual('<SANITIZED>', data['_context_password']) + self.assertEqual('<SANITIZED>', data['password']) data = {'_context_auth_token': 'banana', 'auth_token': 'cheese', @@ -369,7 +343,7 @@ class RpcCommonTestCase(test_utils.BaseTestCase): def test_safe_log_sanitizes_cells_route_message(self): def logger_method(msg, data): vals = data['args']['message']['args']['method_info'] - self.assertEquals('<SANITIZED>', vals['method_kwargs']['password']) + self.assertEqual('<SANITIZED>', vals['method_kwargs']['password']) meth_info = {'method_args': ['aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee'], 'method': 'set_admin_password', diff --git a/tests/unit/rpc/test_kombu.py b/tests/unit/rpc/test_kombu.py index cbe948d..695e701 100644 --- a/tests/unit/rpc/test_kombu.py +++ b/tests/unit/rpc/test_kombu.py @@ -33,10 +33,10 @@ from oslo.config import cfg import six import time -from openstack.common import exception from openstack.common.rpc import amqp as rpc_amqp from openstack.common.rpc import common as rpc_common from tests.unit.rpc import amqp +from tests.unit.rpc import common from tests import utils try: @@ -596,7 +596,8 @@ class RpcKombuTestCase(amqp.BaseRpcAMQPTestCase): """ value = "This is the exception message" # The use of ApiError is an arbitrary choice here ... - self.assertRaises(exception.ApiError, + self.config(allowed_rpc_exception_modules=[common.__name__]) + self.assertRaises(common.ApiError, self.rpc.call, FLAGS, self.context, @@ -609,10 +610,10 @@ class RpcKombuTestCase(amqp.BaseRpcAMQPTestCase): {"method": "fail_converted", "args": {"value": value}}) self.fail("should have thrown Exception") - except exception.ApiError as exc: + except common.ApiError as exc: self.assertTrue(value in six.text_type(exc)) #Traceback should be included in exception message - self.assertTrue('exception.ApiError' in six.text_type(exc)) + self.assertTrue('ApiError' in six.text_type(exc)) def test_create_worker(self): meth = 'declare_topic_consumer' diff --git a/tests/unit/rpc/test_qpid.py b/tests/unit/rpc/test_qpid.py index a4bce1c..910e292 100644 --- a/tests/unit/rpc/test_qpid.py +++ b/tests/unit/rpc/test_qpid.py @@ -438,9 +438,9 @@ class RpcQpidTestCase(utils.BaseTestCase): {"method": "test_method", "args": {}}) if multi: - self.assertEquals(list(res), ["foo", "bar", "baz"]) + self.assertEqual(list(res), ["foo", "bar", "baz"]) else: - self.assertEquals(res, "foo") + self.assertEqual(res, "foo") finally: impl_qpid.cleanup() self.uuid4 = uuid.uuid4() @@ -488,7 +488,7 @@ class RpcQpidTestCase(utils.BaseTestCase): else: res = method(FLAGS, ctx, "impl_qpid_test", {"method": "test_method", "args": {}}, timeout) - self.assertEquals(res, "foo") + self.assertEqual(res, "foo") finally: impl_qpid.cleanup() self.uuid4 = uuid.uuid4() diff --git a/tests/unit/rpc/test_securemessage.py b/tests/unit/rpc/test_securemessage.py new file mode 100644 index 0000000..8c07df1 --- /dev/null +++ b/tests/unit/rpc/test_securemessage.py @@ -0,0 +1,134 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +Unit Tests for rpc 'securemessage' functions. +""" + +import logging + +from oslo.config import cfg + +from openstack.common import jsonutils +from openstack.common.rpc import common as rpc_common +from openstack.common.rpc import securemessage as rpc_secmsg +from tests import utils as test_utils + + +CONF = cfg.CONF +LOG = logging.getLogger(__name__) + + +class RpcCryptoTestCase(test_utils.BaseTestCase): + + def test_KeyStore(self): + store = rpc_secmsg.KeyStore() + + # check empty cache returns noting + keys = store.get_ticket('foo', 'bar') + self.assertIsNone(keys) + + ticket = rpc_secmsg.Ticket('skey', 'ekey', 'esek') + + #add entry in the cache + store.put_ticket('foo', 'bar', 'skey', 'ekey', 'esek', 2000000000) + + #chck it returns the object + keys = store.get_ticket('foo', 'bar') + self.assertEqual(keys, ticket) + + #check inverted source/target returns nothing + keys = store.get_ticket('bar', 'foo') + self.assertIsNone(keys) + + #add expired entry in the cache + store.put_ticket('foo', 'bar', 'skey', 'ekey', 'skey', 1000000000) + + #check expired entries are not returned + keys = store.get_ticket('foo', 'bar') + self.assertIsNone(keys) + + def _test_secure_message(self, data, encrypt): + msg = {'message': 'body'} + + # Use a fresh store for each test + store = rpc_secmsg.KeyStore() + + send = rpc_secmsg.SecureMessage(data['source'][0], data['source'][1], + CONF, data['send_key'], + store, encrypt, enctype=data['cipher'], + hashtype=data['hash']) + recv = rpc_secmsg.SecureMessage(data['target'][0], data['target'][1], + CONF, data['recv_key'], + store, encrypt, enctype=data['cipher'], + hashtype=data['hash']) + + source = '%s.%s' % data['source'] + target = '%s.%s' % data['target'] + # Adds test keys in cache, we do it twice, once for client side use, + # then for server side use as we run both in the same process + store.put_ticket(source, target, + data['skey'], data['ekey'], data['esek'], 2000000000) + + pkt = send.encode(rpc_common._RPC_ENVELOPE_VERSION, + target, jsonutils.dumps(msg)) + + out = recv.decode(rpc_common._RPC_ENVELOPE_VERSION, + pkt[0], pkt[1], pkt[2]) + rmsg = jsonutils.loads(out[1]) + + self.assertEqual(len(msg), + len(set(msg.items()) & set(rmsg.items()))) + + def test_secure_message_sha256_aes(self): + foo_to_bar_sha256_aes = { + 'source': ('foo', 'host.example.com'), + 'target': ('bar', 'host.example.com'), + 'send_key': '\x0b' * 16, + 'recv_key': '\x0b' * 16, + 'hash': 'SHA256', + 'cipher': 'AES', + 'skey': "\xaf\xab\x81\x14'\xdd\x1ck\xd1\xb4[\x84MZ\xf5\r", + 'ekey': '\x98\x06\x1bW\x1e\xc1z\xdd\xe2\xb1h\xa5\xb7;\x14\n', + 'esek': ('IehVCF684xJVN0sHc/zngsCAZWQkKSueK4I+ycRhxDGYsqYaAw+nECnZ' + 'mgA3R+DM8halM5TEwwI/uuPqExu8p+fW4CqSMh8oEtLGGqrx85GromaH' + '/YVqK1GpIfUSIQSZrXhAzITN9MeYfeLhD0w2ENUG6AyAk3D56W6l9zJw' + 'ZsI=') + } + # Test signing only first + self._test_secure_message(foo_to_bar_sha256_aes, False) + # Test encryption too + self._test_secure_message(foo_to_bar_sha256_aes, True) + + def test_secure_message_md5_des(self): + foo_to_baz_md5_des = { + 'source': ('foo', 'host.example.com'), + 'target': ('bar', 'host.example.com'), + 'send_key': '????????', + 'recv_key': '????????', + 'hash': 'MD5', + 'cipher': 'DES', + 'skey': 'N<\xeb\x98\x9f$\xa9\xa8', + 'ekey': '\x8c\xd2\x02\x89\xbb6\xd0\xdd', + 'esek': ('CyVMteHe5LiYWFcRnodPv4t8UJ14QztJCC0p/olib9vq50/wua0LY6sk' + 'WWe0GGcvEdzaoZAuH6eBh00CdAVT2LqlK0nBE3Szj93jmVIJxMM+ydxZ' + '2VCvEZohhKeenMiI') + } + # Test signing only first + self._test_secure_message(foo_to_baz_md5_des, False) + # Test encryption too + self._test_secure_message(foo_to_baz_md5_des, True) + + #TODO(simo): test fetching key from file diff --git a/tests/unit/rpc/test_zmq.py b/tests/unit/rpc/test_zmq.py index 55e3eb1..888dbb0 100644 --- a/tests/unit/rpc/test_zmq.py +++ b/tests/unit/rpc/test_zmq.py @@ -106,9 +106,7 @@ class _RpcZmqBaseTestCase(common.BaseRpcTestCase): for char in badchars: self.topic_nested = char.join(('hello', 'world')) try: - # TODO(ewindisch): Determine which exception is raised. - # pending bug #1121348 - self.assertRaises(Exception, self._test_cast, + self.assertRaises(AssertionError, self._test_cast, common.TestReceiver.echo, 42, {"value": 42}, fanout=False) finally: diff --git a/tests/unit/scheduler/test_weights.py b/tests/unit/scheduler/test_weights.py index 21d3f3e..fcb25c6 100644 --- a/tests/unit/scheduler/test_weights.py +++ b/tests/unit/scheduler/test_weights.py @@ -17,11 +17,11 @@ Tests For Scheduler weights. """ from openstack.common.scheduler import base_weight +from openstack.common import test from tests.unit import fakes -from tests import utils -class TestWeightHandler(utils.BaseTestCase): +class TestWeightHandler(test.BaseTestCase): def test_get_all_classes(self): namespace = "openstack.common.tests.fakes.weights" handler = base_weight.BaseWeightHandler( diff --git a/tests/unit/test_authutils.py b/tests/unit/test_authutils.py index 3596df9..82433ad 100644 --- a/tests/unit/test_authutils.py +++ b/tests/unit/test_authutils.py @@ -16,10 +16,10 @@ # under the License. from openstack.common import authutils -from tests import utils +from openstack.common import test -class AuthUtilsTest(utils.BaseTestCase): +class AuthUtilsTest(test.BaseTestCase): def test_auth_str_equal(self): self.assertTrue(authutils.auth_str_equal('abc123', 'abc123')) diff --git a/tests/unit/test_cfgfilter.py b/tests/unit/test_cfgfilter.py index 58b4e2d..1270758 100644 --- a/tests/unit/test_cfgfilter.py +++ b/tests/unit/test_cfgfilter.py @@ -30,11 +30,11 @@ class ConfigFilterTestCase(utils.BaseTestCase): def test_register_opt_default(self): self.fconf.register_opt(cfg.StrOpt('foo', default='bar')) - self.assertEquals(self.fconf.foo, 'bar') - self.assertEquals(self.fconf['foo'], 'bar') + self.assertEqual(self.fconf.foo, 'bar') + self.assertEqual(self.fconf['foo'], 'bar') self.assertTrue('foo' in self.fconf) - self.assertEquals(list(self.fconf), ['foo']) - self.assertEquals(len(self.fconf), 1) + self.assertEqual(list(self.fconf), ['foo']) + self.assertEqual(len(self.fconf), 1) def test_register_opt_none_default(self): self.fconf.register_opt(cfg.StrOpt('foo')) @@ -42,21 +42,21 @@ class ConfigFilterTestCase(utils.BaseTestCase): self.assertTrue(self.fconf.foo is None) self.assertTrue(self.fconf['foo'] is None) self.assertTrue('foo' in self.fconf) - self.assertEquals(list(self.fconf), ['foo']) - self.assertEquals(len(self.fconf), 1) + self.assertEqual(list(self.fconf), ['foo']) + self.assertEqual(len(self.fconf), 1) def test_register_grouped_opt_default(self): self.fconf.register_opt(cfg.StrOpt('foo', default='bar'), group='blaa') - self.assertEquals(self.fconf.blaa.foo, 'bar') - self.assertEquals(self.fconf['blaa']['foo'], 'bar') + self.assertEqual(self.fconf.blaa.foo, 'bar') + self.assertEqual(self.fconf['blaa']['foo'], 'bar') self.assertTrue('blaa' in self.fconf) self.assertTrue('foo' in self.fconf.blaa) - self.assertEquals(list(self.fconf), ['blaa']) - self.assertEquals(list(self.fconf.blaa), ['foo']) - self.assertEquals(len(self.fconf), 1) - self.assertEquals(len(self.fconf.blaa), 1) + self.assertEqual(list(self.fconf), ['blaa']) + self.assertEqual(list(self.fconf.blaa), ['foo']) + self.assertEqual(len(self.fconf), 1) + self.assertEqual(len(self.fconf.blaa), 1) def test_register_grouped_opt_none_default(self): self.fconf.register_opt(cfg.StrOpt('foo'), group='blaa') @@ -65,10 +65,10 @@ class ConfigFilterTestCase(utils.BaseTestCase): self.assertTrue(self.fconf['blaa']['foo'] is None) self.assertTrue('blaa' in self.fconf) self.assertTrue('foo' in self.fconf.blaa) - self.assertEquals(list(self.fconf), ['blaa']) - self.assertEquals(list(self.fconf.blaa), ['foo']) - self.assertEquals(len(self.fconf), 1) - self.assertEquals(len(self.fconf.blaa), 1) + self.assertEqual(list(self.fconf), ['blaa']) + self.assertEqual(list(self.fconf.blaa), ['foo']) + self.assertEqual(len(self.fconf), 1) + self.assertEqual(len(self.fconf.blaa), 1) def test_register_group(self): group = cfg.OptGroup('blaa') @@ -79,24 +79,24 @@ class ConfigFilterTestCase(utils.BaseTestCase): self.assertTrue(self.fconf['blaa']['foo'] is None) self.assertTrue('blaa' in self.fconf) self.assertTrue('foo' in self.fconf.blaa) - self.assertEquals(list(self.fconf), ['blaa']) - self.assertEquals(list(self.fconf.blaa), ['foo']) - self.assertEquals(len(self.fconf), 1) - self.assertEquals(len(self.fconf.blaa), 1) + self.assertEqual(list(self.fconf), ['blaa']) + self.assertEqual(list(self.fconf.blaa), ['foo']) + self.assertEqual(len(self.fconf), 1) + self.assertEqual(len(self.fconf.blaa), 1) def test_unknown_opt(self): self.assertFalse('foo' in self.fconf) - self.assertEquals(len(self.fconf), 0) + self.assertEqual(len(self.fconf), 0) self.assertRaises(cfg.NoSuchOptError, getattr, self.fconf, 'foo') def test_blocked_opt(self): self.conf.register_opt(cfg.StrOpt('foo')) self.assertTrue('foo' in self.conf) - self.assertEquals(len(self.conf), 1) + self.assertEqual(len(self.conf), 1) self.assertTrue(self.conf.foo is None) self.assertFalse('foo' in self.fconf) - self.assertEquals(len(self.fconf), 0) + self.assertEqual(len(self.fconf), 0) self.assertRaises(cfg.NoSuchOptError, getattr, self.fconf, 'foo') def test_import_opt(self): diff --git a/tests/unit/test_cliutils.py b/tests/unit/test_cliutils.py index 13f954c..fada65f 100644 --- a/tests/unit/test_cliutils.py +++ b/tests/unit/test_cliutils.py @@ -15,10 +15,10 @@ # under the License. from openstack.common import cliutils -from tests import utils +from openstack.common import test -class ValidateArgsTest(utils.BaseTestCase): +class ValidateArgsTest(test.BaseTestCase): def test_lambda_no_args(self): cliutils.validate_args(lambda: None) diff --git a/tests/unit/test_compat.py b/tests/unit/test_compat.py new file mode 100644 index 0000000..f71e130 --- /dev/null +++ b/tests/unit/test_compat.py @@ -0,0 +1,53 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 +# +# Copyright 2013 Canonical Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# + +from openstack.common.py3kcompat import urlutils +from tests import utils + + +class CompatTestCase(utils.BaseTestCase): + def test_urlencode(self): + fake = 'fake' + result = urlutils.urlencode({'Fake': fake}) + self.assertEqual(result, 'Fake=fake') + + def test_urljoin(self): + root_url = "http://yahoo.com/" + url2 = "faq.html" + result = urlutils.urljoin(root_url, url2) + self.assertEqual(result, "http://yahoo.com/faq.html") + + def test_urlquote(self): + url = "/~fake" + result = urlutils.quote(url) + self.assertEqual(result, '/%7Efake') + + def test_urlparse(self): + url = 'http://www.yahoo.com' + result = urlutils.urlparse(url) + self.assertEqual(result.scheme, 'http') + + def test_urlsplit(self): + url = 'http://www.yahoo.com' + result = urlutils.urlsplit(url) + self.assertEqual(result.scheme, 'http') + + def test_urlunsplit(self): + url = "http://www.yahoo.com" + result = urlutils.urlunsplit(urlutils.urlsplit(url)) + self.assertEqual(result, 'http://www.yahoo.com') diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index 2f9a3de..ae79b31 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -16,10 +16,10 @@ # under the License. from openstack.common import context -from tests import utils +from openstack.common import test -class ContextTest(utils.BaseTestCase): +class ContextTest(test.BaseTestCase): def test_context(self): ctx = context.RequestContext() diff --git a/tests/unit/test_exception.py b/tests/unit/test_exception.py deleted file mode 100644 index 407e68d..0000000 --- a/tests/unit/test_exception.py +++ /dev/null @@ -1,99 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2011 OpenStack Foundation. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -from openstack.common import exception -from tests import utils - - -def good_function(): - return "Is Bueno!" - - -def bad_function_error(): - raise exception.Error() - - -def bad_function_exception(): - raise Exception() - - -class WrapExceptionTest(utils.BaseTestCase): - - def test_wrap_exception_good_return(self): - wrapped = exception.wrap_exception - self.assertEquals(good_function(), wrapped(good_function)()) - - def test_wrap_exception_throws_error(self): - wrapped = exception.wrap_exception - self.assertRaises(exception.Error, wrapped(bad_function_error)) - - def test_wrap_exception_throws_exception(self): - wrapped = exception.wrap_exception - self.assertRaises(Exception, wrapped(bad_function_exception)) - - -class ApiErrorTest(utils.BaseTestCase): - - def test_without_code(self): - err = exception.ApiError('fake error') - self.assertEqual(str(err), 'Unknown: fake error') - self.assertEqual(err.code, 'Unknown') - self.assertEqual(err.api_message, 'fake error') - - def test_with_code(self): - err = exception.ApiError('fake error', 'blah code') - self.assertEqual(str(err), 'blah code: fake error') - self.assertEqual(err.code, 'blah code') - self.assertEqual(err.api_message, 'fake error') - - -class BadStoreUriTest(utils.BaseTestCase): - - def test(self): - uri = 'http:///etc/passwd' - reason = 'Permission DENIED!' - err = exception.BadStoreUri(uri, reason) - self.assertTrue(uri in str(err)) - self.assertTrue(reason in str(err)) - - -class UnknownSchemeTest(utils.BaseTestCase): - - def test(self): - scheme = 'http' - err = exception.UnknownScheme(scheme) - self.assertTrue(scheme in str(err)) - - -class OpenstackExceptionTest(utils.BaseTestCase): - class TestException(exception.OpenstackException): - msg_fmt = '%(test)s' - - def test_format_error_string(self): - test_message = 'Know Your Meme' - err = self.TestException(test=test_message) - self.assertEqual(err._error_string, test_message) - - def test_error_formating_error_string(self): - self.stubs.Set(exception, '_FATAL_EXCEPTION_FORMAT_ERRORS', False) - err = self.TestException(lol='U mad brah') - self.assertEqual(err._error_string, self.TestException.msg_fmt) - - def test_str(self): - message = 'Y u no fail' - err = self.TestException(test=message) - self.assertEqual(str(err), message) diff --git a/tests/unit/test_fileutils.py b/tests/unit/test_fileutils.py index 4214e83..139d478 100644 --- a/tests/unit/test_fileutils.py +++ b/tests/unit/test_fileutils.py @@ -25,10 +25,10 @@ import mock import mox from openstack.common import fileutils -from tests import utils +from openstack.common import test -class EnsureTree(utils.BaseTestCase): +class EnsureTree(test.BaseTestCase): def test_ensure_tree(self): tmpdir = tempfile.mkdtemp() try: @@ -41,7 +41,7 @@ class EnsureTree(utils.BaseTestCase): shutil.rmtree(tmpdir) -class TestCachedFile(utils.BaseTestCase): +class TestCachedFile(test.BaseTestCase): def setUp(self): super(TestCachedFile, self).setUp() @@ -86,7 +86,7 @@ class TestCachedFile(utils.BaseTestCase): self.assertTrue(fresh) -class DeleteIfExists(utils.BaseTestCase): +class DeleteIfExists(test.BaseTestCase): def test_file_present(self): tmpfile = tempfile.mktemp() @@ -113,7 +113,7 @@ class DeleteIfExists(utils.BaseTestCase): self.assertRaises(OSError, fileutils.delete_if_exists, tmpfile) -class RemovePathOnError(utils.BaseTestCase): +class RemovePathOnError(test.BaseTestCase): def test_error(self): tmpfile = tempfile.mktemp() open(tmpfile, 'w') @@ -134,7 +134,7 @@ class RemovePathOnError(utils.BaseTestCase): os.unlink(tmpfile) -class UtilsTestCase(utils.BaseTestCase): +class UtilsTestCase(test.BaseTestCase): def test_file_open(self): dst_fd, dst_path = tempfile.mkstemp() try: @@ -142,6 +142,6 @@ class UtilsTestCase(utils.BaseTestCase): with open(dst_path, 'w') as f: f.write('hello') with fileutils.file_open(dst_path, 'r') as fp: - self.assertEquals(fp.read(), 'hello') + self.assertEqual(fp.read(), 'hello') finally: os.unlink(dst_path) diff --git a/tests/unit/test_funcutils.py b/tests/unit/test_funcutils.py index 439d825..1cf665b 100644 --- a/tests/unit/test_funcutils.py +++ b/tests/unit/test_funcutils.py @@ -20,11 +20,10 @@ import functools from openstack.common import funcutils +from openstack.common import test -from tests import utils - -class FuncutilsTestCase(utils.BaseTestCase): +class FuncutilsTestCase(test.BaseTestCase): def _test_func(self, instance, red=None, blue=None): pass diff --git a/tests/unit/test_gettext.py b/tests/unit/test_gettext.py index e21297f..ddd8b84 100644 --- a/tests/unit/test_gettext.py +++ b/tests/unit/test_gettext.py @@ -1,8 +1,8 @@ # vim: tabstop=4 shiftwidth=4 softtabstop=4 # Copyright 2012 Red Hat, Inc. -# All Rights Reserved. # Copyright 2013 IBM Corp. +# All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain @@ -16,6 +16,7 @@ # License for the specific language governing permissions and limitations # under the License. +from babel import localedata import copy import gettext import logging.handlers @@ -32,9 +33,53 @@ LOG = logging.getLogger(__name__) class GettextTest(utils.BaseTestCase): + def setUp(self): + super(GettextTest, self).setUp() + # remember so we can reset to it later + self._USE_LAZY = gettextutils.USE_LAZY + + def tearDown(self): + # reset to value before test + gettextutils.USE_LAZY = self._USE_LAZY + super(GettextTest, self).tearDown() + + def test_enable_lazy(self): + gettextutils.USE_LAZY = False + + gettextutils.enable_lazy() + # assert now enabled + self.assertTrue(gettextutils.USE_LAZY) + + def test_underscore_non_lazy(self): + # set lazy off + gettextutils.USE_LAZY = False + + self.mox.StubOutWithMock(gettextutils._t, 'ugettext') + gettextutils._t.ugettext('blah').AndReturn('translated blah') + self.mox.ReplayAll() + + result = gettextutils._('blah') + self.assertEqual('translated blah', result) + + def test_underscore_lazy(self): + # set lazy off + gettextutils.USE_LAZY = False + + gettextutils.enable_lazy() + result = gettextutils._('blah') + self.assertIsInstance(result, gettextutils.Message) + def test_gettext_does_not_blow_up(self): LOG.info(gettextutils._('test')) + def test_gettextutils_install(self): + gettextutils.install('blaa') + self.assertTrue(isinstance(_('A String'), unicode)) # noqa + + gettextutils.install('blaa', lazy=True) + self.assertTrue(isinstance(_('A Message'), # noqa + gettextutils.Message)) + def test_gettext_install_looks_up_localedir(self): with mock.patch('os.environ.get') as environ_get: with mock.patch('gettext.install') as gettext_install: @@ -47,13 +92,83 @@ class GettextTest(utils.BaseTestCase): localedir='/foo/bar', unicode=True) + def test_get_localized_message(self): + non_message = 'Non-translatable Message' + en_message = 'A message in the default locale' + es_translation = 'A message in Spanish' + zh_translation = 'A message in Chinese' + message = gettextutils.Message(en_message, 'test_domain') + + # In the Message class the translation ultimately occurs when the + # message is turned into a string, and that is what we mock here + def _mock_translation_and_unicode(self): + if self.locale == 'es': + return es_translation + if self.locale == 'zh': + return zh_translation + return self.data + + self.stubs.Set(gettextutils.Message, + '__unicode__', _mock_translation_and_unicode) + + self.assertEqual(es_translation, + gettextutils.get_localized_message(message, 'es')) + self.assertEqual(zh_translation, + gettextutils.get_localized_message(message, 'zh')) + self.assertEqual(en_message, + gettextutils.get_localized_message(message, 'en')) + self.assertEqual(en_message, + gettextutils.get_localized_message(message, 'XX')) + self.assertEqual(en_message, + gettextutils.get_localized_message(message, None)) + self.assertEqual(non_message, + gettextutils.get_localized_message(non_message, 'A')) + + def test_get_available_languages(self): + # All the available languages for which locale data is available + def _mock_locale_identifiers(): + return ['zh', 'es', 'nl', 'fr'] + + self.stubs.Set(localedata, + 'list' if hasattr(localedata, 'list') + else 'locale_identifiers', + _mock_locale_identifiers) + + # Only the languages available for a specific translation domain + def _mock_gettext_find(domain, localedir=None, languages=[], all=0): + if domain == 'test_domain': + return 'translation-file' if any(x in ['zh', 'es'] + for x in languages) else None + return None + self.stubs.Set(gettext, 'find', _mock_gettext_find) + + domain_languages = gettextutils.get_available_languages('test_domain') + # en_US should always be available no matter the domain + # en_US should also always be the first element since order matters + # finally only the domain languages should be included after en_US + self.assertTrue('en_US', domain_languages) + self.assertEqual(3, len(domain_languages)) + self.assertEqual('en_US', domain_languages[0]) + self.assertTrue('zh' in domain_languages) + self.assertTrue('es' in domain_languages) + + # Clear languages to test an unknown domain + gettextutils._AVAILABLE_LANGUAGES = [] + unknown_domain_languages = gettextutils.get_available_languages('huh') + self.assertEqual(1, len(unknown_domain_languages)) + self.assertTrue('en_US' in unknown_domain_languages) + class MessageTestCase(utils.BaseTestCase): """Unit tests for locale Message class.""" def setUp(self): super(MessageTestCase, self).setUp() - self._lazy_gettext = gettextutils.get_lazy_gettext('oslo') + + def _message_with_domain(msg): + return gettextutils.Message(msg, 'oslo') + + self._lazy_gettext = _message_with_domain def tearDown(self): # need to clean up stubs early since they interfere @@ -127,6 +242,19 @@ class MessageTestCase(utils.BaseTestCase): self.assertEqual(result, msgid % params) + def test_regex_find_named_parameters_no_space(self): + msgid = ("Request: %(method)s http://%(server)s:" + "%(port)s%(url)s with headers %(headers)s") + params = {'method': 'POST', + 'server': 'test1', + 'port': 1234, + 'url': 'test2', + 'headers': {'h1': 'val1'}} + + result = self._lazy_gettext(msgid) % params + + self.assertEqual(result, msgid % params) + def test_regex_dict_is_parameter(self): msgid = ("Test that we can inject a dictionary %s") params = {'description': 'test1', @@ -395,7 +523,11 @@ class LocaleHandlerTestCase(utils.BaseTestCase): def setUp(self): super(LocaleHandlerTestCase, self).setUp() - self._lazy_gettext = gettextutils.get_lazy_gettext('oslo') + + def _message_with_domain(msg): + return gettextutils.Message(msg, 'oslo') + + self._lazy_gettext = _message_with_domain self.buffer_handler = logging.handlers.BufferingHandler(40) self.locale_handler = gettextutils.LocaleHandler( 'zh_CN', self.buffer_handler) diff --git a/tests/unit/test_importutils.py b/tests/unit/test_importutils.py index 372bb6d..d716929 100644 --- a/tests/unit/test_importutils.py +++ b/tests/unit/test_importutils.py @@ -19,10 +19,10 @@ import datetime import sys from openstack.common import importutils -from tests import utils +from openstack.common import test -class ImportUtilsTest(utils.BaseTestCase): +class ImportUtilsTest(test.BaseTestCase): # NOTE(jkoelker) There has GOT to be a way to test this. But mocking # __import__ is the devil. Right now we just make diff --git a/tests/unit/test_jsonutils.py b/tests/unit/test_jsonutils.py index 5dc23f7..0d22e7e 100644 --- a/tests/unit/test_jsonutils.py +++ b/tests/unit/test_jsonutils.py @@ -22,10 +22,10 @@ import netaddr import six from openstack.common import jsonutils -from tests import utils +from openstack.common import test -class JSONUtilsTestCase(utils.BaseTestCase): +class JSONUtilsTestCase(test.BaseTestCase): def test_dumps(self): self.assertEqual(jsonutils.dumps({'a': 'b'}), '{"a": "b"}') @@ -38,37 +38,37 @@ class JSONUtilsTestCase(utils.BaseTestCase): self.assertEqual(jsonutils.load(x), {'a': 'b'}) -class ToPrimitiveTestCase(utils.BaseTestCase): +class ToPrimitiveTestCase(test.BaseTestCase): def test_list(self): - self.assertEquals(jsonutils.to_primitive([1, 2, 3]), [1, 2, 3]) + self.assertEqual(jsonutils.to_primitive([1, 2, 3]), [1, 2, 3]) def test_empty_list(self): - self.assertEquals(jsonutils.to_primitive([]), []) + self.assertEqual(jsonutils.to_primitive([]), []) def test_tuple(self): - self.assertEquals(jsonutils.to_primitive((1, 2, 3)), [1, 2, 3]) + self.assertEqual(jsonutils.to_primitive((1, 2, 3)), [1, 2, 3]) def test_dict(self): - self.assertEquals(jsonutils.to_primitive(dict(a=1, b=2, c=3)), - dict(a=1, b=2, c=3)) + self.assertEqual(jsonutils.to_primitive(dict(a=1, b=2, c=3)), + dict(a=1, b=2, c=3)) def test_empty_dict(self): - self.assertEquals(jsonutils.to_primitive({}), {}) + self.assertEqual(jsonutils.to_primitive({}), {}) def test_datetime(self): x = datetime.datetime(1920, 2, 3, 4, 5, 6, 7) - self.assertEquals(jsonutils.to_primitive(x), - '1920-02-03T04:05:06.000007') + self.assertEqual(jsonutils.to_primitive(x), + '1920-02-03T04:05:06.000007') def test_datetime_preserve(self): x = datetime.datetime(1920, 2, 3, 4, 5, 6, 7) - self.assertEquals(jsonutils.to_primitive(x, convert_datetime=False), x) + self.assertEqual(jsonutils.to_primitive(x, convert_datetime=False), x) def test_DateTime(self): x = xmlrpclib.DateTime() x.decode("19710203T04:05:06") - self.assertEquals(jsonutils.to_primitive(x), - '1971-02-03T04:05:06.000000') + self.assertEqual(jsonutils.to_primitive(x), + '1971-02-03T04:05:06.000000') def test_iter(self): class IterClass(object): @@ -86,7 +86,7 @@ class ToPrimitiveTestCase(utils.BaseTestCase): return self.data[self.index - 1] x = IterClass() - self.assertEquals(jsonutils.to_primitive(x), [1, 2, 3, 4, 5]) + self.assertEqual(jsonutils.to_primitive(x), [1, 2, 3, 4, 5]) def test_iteritems(self): class IterItemsClass(object): @@ -99,7 +99,7 @@ class ToPrimitiveTestCase(utils.BaseTestCase): x = IterItemsClass() p = jsonutils.to_primitive(x) - self.assertEquals(p, {'a': 1, 'b': 2, 'c': 3}) + self.assertEqual(p, {'a': 1, 'b': 2, 'c': 3}) def test_iteritems_with_cycle(self): class IterItemsClass(object): @@ -127,24 +127,24 @@ class ToPrimitiveTestCase(utils.BaseTestCase): self.b = 1 x = MysteryClass() - self.assertEquals(jsonutils.to_primitive(x, convert_instances=True), - dict(b=1)) + self.assertEqual(jsonutils.to_primitive(x, convert_instances=True), + dict(b=1)) - self.assertEquals(jsonutils.to_primitive(x), x) + self.assertEqual(jsonutils.to_primitive(x), x) def test_typeerror(self): x = bytearray # Class, not instance - self.assertEquals(jsonutils.to_primitive(x), u"<type 'bytearray'>") + self.assertEqual(jsonutils.to_primitive(x), u"<type 'bytearray'>") def test_nasties(self): def foo(): pass x = [datetime, foo, dir] ret = jsonutils.to_primitive(x) - self.assertEquals(len(ret), 3) + self.assertEqual(len(ret), 3) self.assertTrue(ret[0].startswith(u"<module 'datetime' from ")) self.assertTrue(ret[1].startswith('<function foo at 0x')) - self.assertEquals(ret[2], '<built-in function dir>') + self.assertEqual(ret[2], '<built-in function dir>') def test_depth(self): class LevelsGenerator(object): @@ -164,15 +164,15 @@ class ToPrimitiveTestCase(utils.BaseTestCase): json_l4 = {0: {0: {0: {0: '?'}}}} ret = jsonutils.to_primitive(l4_obj, max_depth=2) - self.assertEquals(ret, json_l2) + self.assertEqual(ret, json_l2) ret = jsonutils.to_primitive(l4_obj, max_depth=3) - self.assertEquals(ret, json_l3) + self.assertEqual(ret, json_l3) ret = jsonutils.to_primitive(l4_obj, max_depth=4) - self.assertEquals(ret, json_l4) + self.assertEqual(ret, json_l4) def test_ipaddr(self): thing = {'ip_addr': netaddr.IPAddress('1.2.3.4')} ret = jsonutils.to_primitive(thing) - self.assertEquals({'ip_addr': '1.2.3.4'}, ret) + self.assertEqual({'ip_addr': '1.2.3.4'}, ret) diff --git a/tests/unit/test_local.py b/tests/unit/test_local.py index 37e5798..a8c9ab6 100644 --- a/tests/unit/test_local.py +++ b/tests/unit/test_local.py @@ -15,10 +15,10 @@ # License for the specific language governing permissions and limitations # under the License. -import eventlet +import threading from openstack.common import local -from tests import utils +from openstack.common import test class Dict(dict): @@ -26,11 +26,20 @@ class Dict(dict): pass -class LocalStoreTestCase(utils.BaseTestCase): +class LocalStoreTestCase(test.BaseTestCase): v1 = Dict(a='1') v2 = Dict(a='2') v3 = Dict(a='3') + def setUp(self): + super(LocalStoreTestCase, self).setUp() + # NOTE(mrodden): we need to make sure that local store + # gets imported in the current python context we are + # testing in (eventlet vs normal python threading) so + # we test the correct type of local store for the current + # threading model + reload(local) + def test_thread_unique_storage(self): """Make sure local store holds thread specific values.""" expected_set = [] @@ -44,8 +53,13 @@ class LocalStoreTestCase(utils.BaseTestCase): local.store.a = self.v3 expected_set.append(getattr(local.store, 'a')) - eventlet.spawn(do_something).wait() - eventlet.spawn(do_something2).wait() + t1 = threading.Thread(target=do_something) + t2 = threading.Thread(target=do_something2) + t1.start() + t2.start() + t1.join() + t2.join() + expected_set.append(getattr(local.store, 'a')) self.assertTrue(self.v1 in expected_set) diff --git a/tests/unit/test_lockutils.py b/tests/unit/test_lockutils.py index b5783b8..670d289 100644 --- a/tests/unit/test_lockutils.py +++ b/tests/unit/test_lockutils.py @@ -73,10 +73,10 @@ class LockTestCase(utils.BaseTestCase): """Bar""" pass - self.assertEquals(foo.__doc__, 'Bar', "Wrapped function's docstring " - "got lost") - self.assertEquals(foo.__name__, 'foo', "Wrapped function's name " - "got mangled") + self.assertEqual(foo.__doc__, 'Bar', "Wrapped function's docstring " + "got lost") + self.assertEqual(foo.__name__, 'foo', "Wrapped function's name " + "got mangled") def test_lock_internally(self): """We can lock across multiple green threads.""" @@ -97,13 +97,13 @@ class LockTestCase(utils.BaseTestCase): for thread in threads: thread.wait() - self.assertEquals(len(seen_threads), 100) + self.assertEqual(len(seen_threads), 100) # Looking at the seen threads, split it into chunks of 10, and verify # that the last 9 match the first in each chunk. for i in range(10): for j in range(9): - self.assertEquals(seen_threads[i * 10], - seen_threads[i * 10 + 1 + j]) + self.assertEqual(seen_threads[i * 10], + seen_threads[i * 10 + 1 + j]) self.assertEqual(saved_sem_num, len(lockutils._semaphores), "Semaphore leak detected") diff --git a/tests/unit/test_log.py b/tests/unit/test_log.py index 8f40f75..39a0cce 100644 --- a/tests/unit/test_log.py +++ b/tests/unit/test_log.py @@ -108,13 +108,13 @@ class LazyLoggerTestCase(CommonLoggerTestsMixIn, test_utils.BaseTestCase): class LogHandlerTestCase(test_utils.BaseTestCase): def test_log_path_logdir(self): self.config(log_dir='/some/path', log_file=None) - self.assertEquals(log._get_log_file_path(binary='foo-bar'), - '/some/path/foo-bar.log') + self.assertEqual(log._get_log_file_path(binary='foo-bar'), + '/some/path/foo-bar.log') def test_log_path_logfile(self): self.config(log_file='/some/path/foo-bar.log') - self.assertEquals(log._get_log_file_path(binary='foo-bar'), - '/some/path/foo-bar.log') + self.assertEqual(log._get_log_file_path(binary='foo-bar'), + '/some/path/foo-bar.log') def test_log_path_none(self): self.config(log_dir=None, log_file=None) @@ -123,8 +123,8 @@ class LogHandlerTestCase(test_utils.BaseTestCase): def test_log_path_logfile_overrides_logdir(self): self.config(log_dir='/some/other/path', log_file='/some/path/foo-bar.log') - self.assertEquals(log._get_log_file_path(binary='foo-bar'), - '/some/path/foo-bar.log') + self.assertEqual(log._get_log_file_path(binary='foo-bar'), + '/some/path/foo-bar.log') class PublishErrorsHandlerTestCase(test_utils.BaseTestCase): @@ -354,7 +354,7 @@ class SetDefaultsTestCase(test_utils.BaseTestCase): def test_default_to_none(self): log.set_defaults(logging_context_format_string=None) self.conf([]) - self.assertEquals(self.conf.logging_context_format_string, None) + self.assertEqual(self.conf.logging_context_format_string, None) def test_change_default(self): my_default = '%(asctime)s %(levelname)s %(name)s [%(request_id)s '\ @@ -362,7 +362,7 @@ class SetDefaultsTestCase(test_utils.BaseTestCase): '%(message)s' log.set_defaults(logging_context_format_string=my_default) self.conf([]) - self.assertEquals(self.conf.logging_context_format_string, my_default) + self.assertEqual(self.conf.logging_context_format_string, my_default) class LogConfigOptsTestCase(test_utils.BaseTestCase): @@ -379,8 +379,8 @@ class LogConfigOptsTestCase(test_utils.BaseTestCase): def test_debug_verbose(self): CONF(['--debug', '--verbose']) - self.assertEquals(CONF.debug, True) - self.assertEquals(CONF.verbose, True) + self.assertEqual(CONF.debug, True) + self.assertEqual(CONF.verbose, True) def test_logging_opts(self): CONF([]) @@ -390,29 +390,29 @@ class LogConfigOptsTestCase(test_utils.BaseTestCase): self.assertTrue(CONF.log_dir is None) self.assertTrue(CONF.log_format is None) - self.assertEquals(CONF.log_date_format, log._DEFAULT_LOG_DATE_FORMAT) + self.assertEqual(CONF.log_date_format, log._DEFAULT_LOG_DATE_FORMAT) - self.assertEquals(CONF.use_syslog, False) + self.assertEqual(CONF.use_syslog, False) def test_log_file(self): log_file = '/some/path/foo-bar.log' CONF(['--log-file', log_file]) - self.assertEquals(CONF.log_file, log_file) + self.assertEqual(CONF.log_file, log_file) def test_logfile_deprecated(self): logfile = '/some/other/path/foo-bar.log' CONF(['--logfile', logfile]) - self.assertEquals(CONF.log_file, logfile) + self.assertEqual(CONF.log_file, logfile) def test_log_dir(self): log_dir = '/some/path/' CONF(['--log-dir', log_dir]) - self.assertEquals(CONF.log_dir, log_dir) + self.assertEqual(CONF.log_dir, log_dir) def test_logdir_deprecated(self): logdir = '/some/other/path/' CONF(['--logdir', logdir]) - self.assertEquals(CONF.log_dir, logdir) + self.assertEqual(CONF.log_dir, logdir) def test_log_format_overrides_formatter(self): CONF(['--log-format', '[Any format]']) diff --git a/tests/unit/test_loopingcall.py b/tests/unit/test_loopingcall.py index 89cf336..9746b88 100644 --- a/tests/unit/test_loopingcall.py +++ b/tests/unit/test_loopingcall.py @@ -20,11 +20,11 @@ from eventlet import greenthread import mox from openstack.common import loopingcall +from openstack.common import test from openstack.common import timeutils -from tests import utils -class LoopingCallTestCase(utils.BaseTestCase): +class LoopingCallTestCase(test.BaseTestCase): def setUp(self): super(LoopingCallTestCase, self).setUp() diff --git a/tests/unit/test_memorycache.py b/tests/unit/test_memorycache.py index 48a36b5..0d176da 100644 --- a/tests/unit/test_memorycache.py +++ b/tests/unit/test_memorycache.py @@ -18,11 +18,11 @@ import datetime from openstack.common import memorycache +from openstack.common import test from openstack.common import timeutils -from tests import utils -class MemorycacheTest(utils.BaseTestCase): +class MemorycacheTest(test.BaseTestCase): def setUp(self): self.client = memorycache.get_client() super(MemorycacheTest, self).setUp() diff --git a/tests/unit/test_network_utils.py b/tests/unit/test_network_utils.py index 4ac0222..a4b9042 100644 --- a/tests/unit/test_network_utils.py +++ b/tests/unit/test_network_utils.py @@ -16,10 +16,10 @@ # under the License. from openstack.common import network_utils -from tests import utils +from openstack.common import test -class NetworkUtilsTest(utils.BaseTestCase): +class NetworkUtilsTest(test.BaseTestCase): def test_parse_host_port(self): self.assertEqual(('server01', 80), diff --git a/tests/unit/test_pastedeploy.py b/tests/unit/test_pastedeploy.py index 3dd02d7..f8a2c57 100644 --- a/tests/unit/test_pastedeploy.py +++ b/tests/unit/test_pastedeploy.py @@ -20,7 +20,7 @@ import tempfile import fixtures from openstack.common import pastedeploy -from tests import utils +from openstack.common import test class App(object): @@ -43,7 +43,7 @@ class Filter(object): self.data = data -class PasteTestCase(utils.BaseTestCase): +class PasteTestCase(test.BaseTestCase): def setUp(self): super(PasteTestCase, self).setUp() @@ -67,7 +67,7 @@ openstack.app_factory = tests.unit.test_pastedeploy:App """) app = pastedeploy.paste_deploy_app(paste_conf, 'myfoo', data) - self.assertEquals(app.data, data) + self.assertEqual(app.data, data) def test_app_factory_with_local_conf(self): data = 'test_app_factory_with_local_conf' @@ -80,8 +80,8 @@ foo = bar """) app = pastedeploy.paste_deploy_app(paste_conf, 'myfoo', data) - self.assertEquals(app.data, data) - self.assertEquals(app.foo, 'bar') + self.assertEqual(app.data, data) + self.assertEqual(app.foo, 'bar') def test_filter_factory(self): data = 'test_filter_factory' @@ -100,5 +100,5 @@ openstack.app_factory = tests.unit.test_pastedeploy:App """) app = pastedeploy.paste_deploy_app(paste_conf, 'myfoo', data) - self.assertEquals(app.data, data) - self.assertEquals(app.app.data, data) + self.assertEqual(app.data, data) + self.assertEqual(app.app.data, data) diff --git a/tests/unit/test_periodic.py b/tests/unit/test_periodic.py index d663f8b..500e8f4 100644 --- a/tests/unit/test_periodic.py +++ b/tests/unit/test_periodic.py @@ -28,6 +28,10 @@ from tests import utils from testtools import matchers +class AnException(Exception): + pass + + class AService(periodic_task.PeriodicTasks): def __init__(self): @@ -40,7 +44,7 @@ class AService(periodic_task.PeriodicTasks): @periodic_task.periodic_task def crashit(self, context): self.called['urg'] = True - raise Exception('urg') + raise AnException('urg') @periodic_task.periodic_task(spacing=10) def doit_with_kwargs_odd(self, context): @@ -66,7 +70,7 @@ class PeriodicTasksTestCase(utils.BaseTestCase): def test_raises(self): serv = AService() - self.assertRaises(Exception, + self.assertRaises(AnException, serv.run_periodic_tasks, None, raise_on_error=True) diff --git a/tests/unit/test_policy.py b/tests/unit/test_policy.py index b7d38a3..143f56c 100644 --- a/tests/unit/test_policy.py +++ b/tests/unit/test_policy.py @@ -170,6 +170,49 @@ class EnforcerTest(PolicyBaseTestCase): creds = {'roles': ''} self.assertEqual(self.enforcer.enforce(action, {}, creds), True) + def test_enforcer_with_default_rule(self): + rules_json = """{ + "deny_stack_user": "not role:stack_user", + "cloudwatch:PutMetricData": "" + }""" + rules = policy.Rules.load_json(rules_json) + default_rule = policy.TrueCheck() + enforcer = policy.Enforcer(default_rule=default_rule) + enforcer.set_rules(rules) + action = "cloudwatch:PutMetricData" + creds = {'roles': ''} + self.assertEqual(enforcer.enforce(action, {}, creds), True) + + def test_enforcer_force_reload_true(self): + self.enforcer.set_rules({'test': 'test'}) + self.enforcer.load_rules(force_reload=True) + self.assertNotIn({'test': 'test'}, self.enforcer.rules) + self.assertIn('default', self.enforcer.rules) + self.assertIn('admin', self.enforcer.rules) + + def test_enforcer_force_reload_false(self): + self.enforcer.set_rules({'test': 'test'}) + self.enforcer.load_rules(force_reload=False) + self.assertIn('test', self.enforcer.rules) + self.assertNotIn('default', self.enforcer.rules) + self.assertNotIn('admin', self.enforcer.rules) + + def test_enforcer_overwrite_rules(self): + self.enforcer.set_rules({'test': 'test'}) + self.enforcer.set_rules({'test': 'test1'}, overwrite=True) + self.assertEqual(self.enforcer.rules, {'test': 'test1'}) + + def test_enforcer_update_rules(self): + self.enforcer.set_rules({'test': 'test'}) + self.enforcer.set_rules({'test1': 'test1'}, overwrite=False) + self.assertEqual(self.enforcer.rules, {'test': 'test', + 'test1': 'test1'}) + + def test_get_policy_path_raises_exc(self): + enforcer = policy.Enforcer(policy_file='raise_error.json') + self.assertRaises(cfg.ConfigFilesNotFoundError, + enforcer._get_policy_path) + class FakeCheck(policy.BaseCheck): def __init__(self, result=None): @@ -187,24 +230,15 @@ class FakeCheck(policy.BaseCheck): class CheckFunctionTestCase(PolicyBaseTestCase): def test_check_explicit(self): - self.enforcer.load_rules() - self.enforcer.rules = None rule = FakeCheck() result = self.enforcer.enforce(rule, "target", "creds") - self.assertEqual(result, ("target", "creds", self.enforcer)) - self.assertEqual(self.enforcer.rules, None) def test_check_no_rules(self): - self.enforcer.load_rules() - self.enforcer.rules = None result = self.enforcer.enforce('rule', "target", "creds") - self.assertEqual(result, False) - self.assertEqual(self.enforcer.rules, None) def test_check_missing_rule(self): - self.enforcer.rules = {} result = self.enforcer.enforce('rule', 'target', 'creds') self.assertEqual(result, False) @@ -298,6 +332,48 @@ class NotCheckTestCase(utils.BaseTestCase): rule.assert_called_once_with('target', 'cred', None) +class AndCheckTestCase(utils.BaseTestCase): + def test_init(self): + check = policy.AndCheck(['rule1', 'rule2']) + + self.assertEqual(check.rules, ['rule1', 'rule2']) + + def test_add_check(self): + check = policy.AndCheck(['rule1', 'rule2']) + check.add_check('rule3') + + self.assertEqual(check.rules, ['rule1', 'rule2', 'rule3']) + + def test_str(self): + check = policy.AndCheck(['rule1', 'rule2']) + + self.assertEqual(str(check), '(rule1 and rule2)') + + def test_call_all_false(self): + rules = [mock.Mock(return_value=False), mock.Mock(return_value=False)] + check = policy.AndCheck(rules) + + self.assertEqual(check('target', 'cred', None), False) + rules[0].assert_called_once_with('target', 'cred', None) + self.assertFalse(rules[1].called) + + def test_call_first_true(self): + rules = [mock.Mock(return_value=True), mock.Mock(return_value=False)] + check = policy.AndCheck(rules) + + self.assertFalse(check('target', 'cred', None)) + rules[0].assert_called_once_with('target', 'cred', None) + rules[1].assert_called_once_with('target', 'cred', None) + + def test_call_second_true(self): + rules = [mock.Mock(return_value=False), mock.Mock(return_value=True)] + check = policy.AndCheck(rules) + + self.assertFalse(check('target', 'cred', None)) + rules[0].assert_called_once_with('target', 'cred', None) + self.assertFalse(rules[1].called) + + class OrCheckTestCase(utils.BaseTestCase): def test_init(self): check = policy.OrCheck(['rule1', 'rule2']) @@ -320,15 +396,15 @@ class OrCheckTestCase(utils.BaseTestCase): check = policy.OrCheck(rules) self.assertEqual(check('target', 'cred', None), False) - rules[0].assert_called_once_with('target', 'cred') - rules[1].assert_called_once_with('target', 'cred') + rules[0].assert_called_once_with('target', 'cred', None) + rules[1].assert_called_once_with('target', 'cred', None) def test_call_first_true(self): rules = [mock.Mock(return_value=True), mock.Mock(return_value=False)] check = policy.OrCheck(rules) self.assertEqual(check('target', 'cred', None), True) - rules[0].assert_called_once_with('target', 'cred') + rules[0].assert_called_once_with('target', 'cred', None) self.assertFalse(rules[1].called) def test_call_second_true(self): @@ -336,8 +412,8 @@ class OrCheckTestCase(utils.BaseTestCase): check = policy.OrCheck(rules) self.assertEqual(check('target', 'cred', None), True) - rules[0].assert_called_once_with('target', 'cred') - rules[1].assert_called_once_with('target', 'cred') + rules[0].assert_called_once_with('target', 'cred', None) + rules[1].assert_called_once_with('target', 'cred', None) class ParseCheckTestCase(utils.BaseTestCase): diff --git a/tests/unit/test_processutils.py b/tests/unit/test_processutils.py index 7c6e11c..0f9289c 100644 --- a/tests/unit/test_processutils.py +++ b/tests/unit/test_processutils.py @@ -24,10 +24,10 @@ import tempfile import six from openstack.common import processutils -from tests import utils +from openstack.common import test -class UtilsTest(utils.BaseTestCase): +class UtilsTest(test.BaseTestCase): # NOTE(jkoelker) Moar tests from nova need to be ported. But they # need to be mock'd out. Currently they requre actually # running code. @@ -37,27 +37,27 @@ class UtilsTest(utils.BaseTestCase): hozer=True) -class ProcessExecutionErrorTest(utils.BaseTestCase): +class ProcessExecutionErrorTest(test.BaseTestCase): def test_defaults(self): err = processutils.ProcessExecutionError() - self.assertTrue('None\n' in err.message) - self.assertTrue('code: -\n' in err.message) + self.assertTrue('None\n' in unicode(err)) + self.assertTrue('code: -\n' in unicode(err)) def test_with_description(self): description = 'The Narwhal Bacons at Midnight' err = processutils.ProcessExecutionError(description=description) - self.assertTrue(description in err.message) + self.assertTrue(description in unicode(err)) def test_with_exit_code(self): exit_code = 0 err = processutils.ProcessExecutionError(exit_code=exit_code) - self.assertTrue(str(exit_code) in err.message) + self.assertTrue(str(exit_code) in unicode(err)) def test_with_cmd(self): cmd = 'telinit' err = processutils.ProcessExecutionError(cmd=cmd) - self.assertTrue(cmd in err.message) + self.assertTrue(cmd in unicode(err)) def test_with_stdout(self): stdout = """ @@ -80,13 +80,13 @@ class ProcessExecutionErrorTest(utils.BaseTestCase): the Wielder of Wonder, with world's renown. """.strip() err = processutils.ProcessExecutionError(stdout=stdout) - print(err.message) - self.assertTrue('people-kings' in err.message) + print(unicode(err)) + self.assertTrue('people-kings' in unicode(err)) def test_with_stderr(self): stderr = 'Cottonian library' err = processutils.ProcessExecutionError(stderr=stderr) - self.assertTrue(stderr in str(err.message)) + self.assertTrue(stderr in unicode(err)) def test_retry_on_failure(self): fd, tmpfilename = tempfile.mkstemp() @@ -127,8 +127,7 @@ exit 1 'always get passed ' 'correctly') runs = int(runs.strip()) - self.assertEquals(runs, 10, - 'Ran %d times instead of 10.' % (runs,)) + self.assertEqual(runs, 10, 'Ran %d times instead of 10.' % (runs,)) finally: os.unlink(tmpfilename) os.unlink(tmpfilename2) @@ -181,7 +180,7 @@ def fake_execute_raises(*cmd, **kwargs): 'command']) -class TryCmdTestCase(utils.BaseTestCase): +class TryCmdTestCase(test.BaseTestCase): def test_keep_warnings(self): self.useFixture(fixtures.MonkeyPatch( 'openstack.common.processutils.execute', fake_execute)) @@ -231,7 +230,7 @@ class FakeSshConnection(object): six.StringIO('stderr')) -class SshExecuteTestCase(utils.BaseTestCase): +class SshExecuteTestCase(test.BaseTestCase): def test_invalid_addl_env(self): self.assertRaises(processutils.InvalidArgumentError, processutils.ssh_execute, diff --git a/tests/unit/test_quota.py b/tests/unit/test_quota.py new file mode 100644 index 0000000..8c26369 --- /dev/null +++ b/tests/unit/test_quota.py @@ -0,0 +1,441 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import datetime +import mock +from openstack.common import quota +from oslo.config import cfg +from tests import utils + +CONF = cfg.CONF + + +class FakeContext(object): + project_id = 'p1' + user_id = 'u1' + quota_class = 'QuotaClass_' + + def elevated(self): + return self + + +class ExceptionTestCase(utils.BaseTestCase): + + def _get_raised_exception(self, exception, *args, **kwargs): + try: + raise exception(*args, **kwargs) + except Exception as e: + return e + + def test_quota_exception_format(self): + + class TestException(quota.QuotaException): + msg_fmt = "Test format %(string)s" + + e = self._get_raised_exception(TestException) + self.assertEqual(e.message, e.msg_fmt) + + e = self._get_raised_exception(TestException, number=42) + self.assertEqual(e.message, e.msg_fmt) + + e = self._get_raised_exception(TestException, string="test") + self.assertEqual(e.message, e.msg_fmt % {"string": "test"}) + + +class DbQuotaDriverTestCase(utils.BaseTestCase): + + def setUp(self): + self.sample_resources = {'r1': quota.BaseResource('r1'), + 'r2': quota.BaseResource('r2')} + + dbapi = mock.Mock() + dbapi.quota_usage_get_all_by_project_and_user = mock.Mock( + return_value={'project_id': 'p1', 'user_id': 'u1', + 'r1': {'reserved': 1, 'in_use': 2}, + 'r2': {'reserved': 2, 'in_use': 3}}) + dbapi.quota_get_all_by_project_and_user = mock.Mock( + return_value={'project_id': 'p1', 'user_id': 'u1', + 'r1': 5, 'r2': 6}) + dbapi.quota_get = mock.Mock(return_value='quota_get') + dbapi.quota_reserve = mock.Mock(return_value='quota_reserve') + dbapi.quota_class_get = mock.Mock(return_value='quota_class_get') + dbapi.quota_class_reserve = mock.Mock( + return_value='quota_class_reserve') + dbapi.quota_class_get_default = mock.Mock( + return_value={'r1': 1, 'r2': 2}) + dbapi.quota_class_get_all_by_name = mock.Mock(return_value={'r1': 9}) + dbapi.quota_get_all_by_project = mock.Mock( + return_value=dict([('r%d' % i, i) for i in range(3)])) + dbapi.quota_get_all = mock.Mock( + return_value=[{'resource': 'r1', 'hard_limit': 3}, + {'resource': 'r2', 'hard_limit': 4}]) + dbapi.quota_usage_get_all_by_project = mock.Mock( + return_value=dict([('r%d' % i, {'in_use': i, 'reserved': i + 1}) + for i in range(3)])) + self.driver = quota.DbQuotaDriver(dbapi) + self.ctxt = FakeContext() + return super(DbQuotaDriverTestCase, self).setUp() + + def test_get_by_project(self): + args = ['p1', 'resource'] + self.assertEqual('quota_get', + self.driver.get_by_project(self.ctxt, *args)) + self.driver.db.quota_get.assert_called_once_with(self.ctxt, *args) + + def test_get_by_project_and_user(self): + args = ['p1', 'u1', 'resource'] + self.assertEqual('quota_get', + self.driver.get_by_project_and_user(self.ctxt, *args)) + self.driver.db.quota_get.assert_called_once_with(self.ctxt, *args) + + def test_get_by_class(self): + args = ['class', 'resource'] + self.assertEqual('quota_class_get', + self.driver.get_by_class(self.ctxt, *args)) + self.driver.db.quota_class_get.assert_called_once_with(self.ctxt, + *args) + + def test_get_defaults(self): + defaults = self.driver.get_defaults(self.ctxt, self.sample_resources) + self.assertEqual(defaults, {'r1': 1, 'r2': 2}) + self.sample_resources.pop('r1') + defaults = self.driver.get_defaults(self.ctxt, self.sample_resources) + self.assertEqual(defaults, {'r2': 2}) + + def test_get_class_quotas(self): + quotas = self.driver.get_class_quotas(self.ctxt, + self.sample_resources, + 'ClassName') + self.assertEqual(quotas, {'r1': 9, 'r2': 2}) + + def test_get_user_quotas(self): + actual = self.driver.get_user_quotas( + self.ctxt, self.sample_resources.copy(), 'p1', 'u1') + expected = {'r1': {'in_use': 2, 'limit': 5, 'reserved': 1}, + 'r2': {'in_use': 3, 'limit': 6, 'reserved': 2}} + self.assertEqual(actual, expected) + + def test_get_settable_quotas(self): + actual = self.driver.get_settable_quotas(self.ctxt, + self.sample_resources, 'p1') + expected = {'r1': {'maximum': -1, 'minimum': 3}, + 'r2': {'maximum': -1, 'minimum': 5}} + self.assertEqual(actual, expected) + + def test_get_settable_quotas_with_user_id(self): + actual = self.driver.get_settable_quotas( + self.ctxt, self.sample_resources, 'p1', user_id='u1') + expected = {'r1': {'maximum': 3, 'minimum': 3}, + 'r2': {'maximum': 4, 'minimum': 5}} + self.assertEqual(actual, expected) + + def test_get_project_quotas(self): + self.ctxt.quota_class = 'ClassName' + expected = {'r1': {'limit': 1, 'in_use': 1, 'reserved': 2}, + 'r2': {'limit': 2, 'in_use': 2, 'reserved': 3}} + quotas = self.driver.get_project_quotas(self.ctxt, + self.sample_resources, 'p1') + self.assertEqual(quotas, expected) + + def test_get_project_quotas_project_id_differs(self): + self.ctxt.project_id = 'p2' + expected = {'r1': {'limit': 1, 'in_use': 1, 'reserved': 2}, + 'r2': {'limit': 2, 'in_use': 2, 'reserved': 3}} + quotas = self.driver.get_project_quotas(self.ctxt, + self.sample_resources, 'p1') + self.assertEqual(quotas, expected) + + def test_get_project_quotas_omit_default_quota_class(self): + self.sample_resources['r3'] = quota.BaseResource('r3') + quotas = self.driver.get_project_quotas( + self.ctxt, self.sample_resources, 'p1', defaults=False) + expected = {'r1': {'limit': 1, 'in_use': 1, 'reserved': 2}, + 'r2': {'limit': 2, 'in_use': 2, 'reserved': 3}} + self.assertEqual(quotas, expected) + + def test_limit_check_invalid_quota_value(self): + self.assertRaises(quota.InvalidQuotaValue, + self.driver.limit_check, self.ctxt, [], {'r1': -1}) + + def test_limit_check_quota_resource_unknown(self): + self.assertRaises(quota.QuotaResourceUnknown, + self.driver.limit_check, + self.ctxt, + {'r1': quota.ReservableResource('r1', 'r1')}, + {'r1': 42}) + + def test_limit_check_over_quota(self): + self.assertRaises(quota.OverQuota, + self.driver.limit_check, + self.ctxt, + {'r1': quota.BaseResource('r1')}, + {'r1': 2}) + + def test_limit_check(self): + self.assertIsNone(self.driver.limit_check( + self.ctxt, {'r1': quota.BaseResource('r1')}, {'r1': 1})) + + def test_quota_reserve(self): + now = datetime.datetime.utcnow() + + class FakeTimeutils(object): + @staticmethod + def utcnow(): + return now + + self.stubs.Set(quota, "timeutils", FakeTimeutils) + + expected = [self.ctxt, self.sample_resources, {}, {}, {}, None, + CONF.until_refresh, CONF.max_age] + + # expire as None + self.assertEqual('quota_reserve', self.driver.reserve( + self.ctxt, self.sample_resources, {}, None, 'p1')) + expected[5] = now + datetime.timedelta( + seconds=CONF.reservation_expire) + self.driver.db.quota_reserve.assert_called_once_with(*expected, + project_id='p1', + user_id='u1') + self.driver.db.reset_mock() + # expire as seconds + self.assertEqual('quota_reserve', self.driver.reserve( + self.ctxt, self.sample_resources, {}, 42, 'p1')) + expected[5] = now + datetime.timedelta(seconds=42) + self.driver.db.quota_reserve.assert_called_once_with(*expected, + project_id='p1', + user_id='u1') + self.driver.db.reset_mock() + # expire as absolute + expected[5] = now + datetime.timedelta(hours=1) + self.assertEqual('quota_reserve', self.driver.reserve( + self.ctxt, self.sample_resources, {}, + now + datetime.timedelta(hours=1), 'p1')) + self.driver.db.quota_reserve.assert_called_once_with(*expected, + project_id='p1', + user_id='u1') + self.driver.db.reset_mock() + # InvalidReservationExpiration + self.assertRaises(quota.InvalidReservationExpiration, + self.driver.reserve, self.ctxt, + self.sample_resources, {}, (), 'p1') + self.driver.db.reset_mock() + # project_id is None + self.assertEqual('quota_reserve', self.driver.reserve( + self.ctxt, self.sample_resources, {}, + now + datetime.timedelta(hours=1))) + self.driver.db.quota_reserve.assert_called_once_with(*expected, + project_id='p1', + user_id='u1') + + def test_commit(self): + self.assertIsNone(self.driver.commit(self.ctxt, 'reservations', + project_id='p1')) + self.driver.db.reservation_commit.assert_called_once_with( + self.ctxt, 'reservations', project_id='p1', user_id='u1') + + def test_commit_project_id_none(self): + self.assertIsNone(self.driver.commit(self.ctxt, 'reservations')) + self.driver.db.reservation_commit.assert_called_once_with( + self.ctxt, 'reservations', project_id='p1', user_id='u1') + + def test_rollback(self): + self.assertIsNone(self.driver.rollback(self.ctxt, 'reservations', + project_id='p1')) + self.driver.db.reservation_rollback.assert_called_once_with( + self.ctxt, 'reservations', project_id='p1', user_id='u1') + + def test_rollback_project_id_none(self): + self.assertIsNone(self.driver.rollback(self.ctxt, 'reservations')) + self.driver.db.reservation_rollback.assert_called_once_with( + self.ctxt, 'reservations', project_id='p1', user_id='u1') + + def test_usage_reset(self): + resource = self.sample_resources['r1'] + self.assertIsNone(self.driver.usage_reset(self.ctxt, [resource])) + self.driver.db.quota_usage_update.assert_called_once_with( + self.ctxt, 'p1', 'u1', resource, in_use=-1) + + def test_usage_reset_quota_usage_not_found(self): + resource = self.sample_resources['r1'] + self.driver.db.quota_usage_update = mock.Mock( + side_effect=quota.QuotaUsageNotFound) + self.assertIsNone(self.driver.usage_reset(self.ctxt, [resource])) + self.driver.db.quota_usage_update.assert_called_once_with( + self.ctxt, 'p1', 'u1', resource, in_use=-1) + + def test_destroy_all_by_project_and_user(self): + self.assertIsNone(self.driver.destroy_all_by_project_and_user( + self.ctxt, 'p1', 'u1')) + method = self.driver.db.quota_destroy_all_by_project_and_user + method.assert_called_once_with(self.ctxt, 'p1', 'u1') + + def test_destroy_all_by_project(self): + self.assertIsNone(self.driver.destroy_all_by_project(self.ctxt, 'p1')) + self.driver.db.quota_destroy_all_by_project.assert_called_once_with( + self.ctxt, 'p1') + + def test_expire(self): + self.assertIsNone(self.driver.expire(self.ctxt)) + self.driver.db.reservation_expire.assert_called_once_with(self.ctxt) + + +class BaseResourceTestCase(utils.BaseTestCase): + + def setUp(self): + self.ctxt = FakeContext() + self.dbapi = mock.Mock() + self.dbapi.quota_get = mock.Mock(return_value='quota_get') + self.dbapi.quota_class_get = mock.Mock( + return_value='quota_class_get') + self.dbapi.quota_class_get_default = mock.Mock( + return_value={'r1': 1}) + self.driver = quota.DbQuotaDriver(self.dbapi) + super(BaseResourceTestCase, self).setUp() + + def test_quota(self): + resource = quota.BaseResource('r1') + self.assertEqual('quota_get', resource.quota(self.driver, self.ctxt)) + + def test_quota_no_project_id(self): + self.ctxt.project_id = None + resource = quota.BaseResource('r1') + self.assertEqual('quota_class_get', + resource.quota(self.driver, self.ctxt)) + + def test_quota_project_quota_not_found(self): + self.dbapi.quota_get = mock.Mock( + side_effect=quota.ProjectQuotaNotFound()) + resource = quota.BaseResource('r1') + self.assertEqual('quota_class_get', + resource.quota(self.driver, self.ctxt)) + + def test_quota_quota_class_not_found(self): + self.dbapi.quota_get = mock.Mock( + side_effect=quota.ProjectQuotaNotFound(project_id='p1')) + self.dbapi.quota_class_get = mock.Mock( + side_effect=quota.QuotaClassNotFound(class_name='ClassName')) + resource = quota.BaseResource('r1') + self.assertEqual(1, resource.quota(self.driver, self.ctxt)) + + +class CountableResourceTestCase(utils.BaseTestCase): + + def test_init(self): + resource = quota.CountableResource('r1', 42) + self.assertEqual('r1', resource.name) + self.assertEqual(42, resource.count) + + +class QuotaEngineTestCase(utils.BaseTestCase): + + def setUp(self): + self.ctxt = FakeContext() + self.dbapi = mock.Mock() + self.quota_driver = mock.Mock() + self.engine = quota.QuotaEngine(self.dbapi, self.quota_driver) + self.r1 = quota.BaseResource('r1') + self.r2 = quota.BaseResource('r2') + self.engine.register_resources([self.r1, self.r2]) + super(QuotaEngineTestCase, self).setUp() + + def assertProxyMethod(self, method, *args, **kwargs): + if 'retval' in kwargs: + retval = kwargs.pop('retval') + else: + retval = method + setattr(self.quota_driver, method, mock.Mock(return_value=method)) + actual = getattr(self.engine, method)(self.ctxt, *args, **kwargs) + getattr(self.quota_driver, method).assert_called_once_with(self.ctxt, + *args, + **kwargs) + self.assertEqual(actual, retval) + + def assertMethod(self, method, args, kwargs, called_args, + called_kwargs, retval): + setattr(self.quota_driver, method, mock.Mock(return_value=method)) + actual = getattr(self.engine, method)(self.ctxt, *args, **kwargs) + getattr(self.quota_driver, method).assert_called_once_with( + self.ctxt, *called_args, **called_kwargs) + self.assertEqual(actual, retval) + + def test_proxy_methods(self): + self.assertProxyMethod('get_by_project', 'p1', 'resname') + self.assertProxyMethod('get_by_project_and_user', 'p1', 'u1', 'res') + self.assertProxyMethod('get_by_class', 'quota_class', 'resname') + self.assertProxyMethod('get_default', 'resource') + self.assertProxyMethod('expire', retval=None) + self.assertProxyMethod('usage_reset', 'resources', retval=None) + self.assertProxyMethod('destroy_all_by_project', 'p1', retval=None) + self.assertProxyMethod('destroy_all_by_project_and_user', 'p1', + 'u1', retval=None) + self.assertProxyMethod('commit', 'reservations', project_id='p1', + user_id='u1', retval=None) + self.assertProxyMethod('rollback', 'reservations', project_id='p1', + user_id=None, retval=None) + + self.assertMethod('get_settable_quotas', ['p1'], {'user_id': 'u1'}, + [self.engine.resources, 'p1'], {'user_id': 'u1'}, + 'get_settable_quotas') + self.assertMethod('get_defaults', [], {}, + [self.engine.resources], {}, 'get_defaults') + self.assertMethod('get_project_quotas', ['p1', 'quotaclass'], + {'defaults': 'defaults', 'usages': 'usages'}, + [self.engine.resources, 'p1'], + {'quota_class': 'quotaclass', 'defaults': 'defaults', + 'usages': 'usages', 'remains': False}, + 'get_project_quotas') + self.assertMethod('reserve', [], + {'expire': 'expire', 'project_id': 'p1', + 'user_id': 'u1', 'deltas': 'd1'}, + [self.engine.resources, {'deltas': 'd1'}], + {'expire': 'expire', + 'project_id': 'p1', 'user_id': 'u1'}, 'reserve') + self.assertMethod('get_class_quotas', + ['quota_class'], {'defaults': 'defaults'}, + [self.engine.resources, 'quota_class'], + {'defaults': 'defaults'}, 'get_class_quotas') + self.assertMethod('get_user_quotas', ['project_id', 'user_id'], + {'quota_class': 'qc', 'defaults': 'de', + 'usages': 'us'}, + [self.engine.resources, 'project_id', 'user_id'], + {'quota_class': 'qc', 'defaults': 'de', + 'usages': 'us'}, + 'get_user_quotas') + self.assertMethod('limit_check', + [], {'project_id': 'p1', 'user_id': 'u1', + 'val1': 'val1'}, + [self.engine.resources, {'val1': 'val1'}], + {'project_id': 'p1', 'user_id': 'u1'}, 'limit_check') + + def test_resource_names(self): + self.assertEqual(['r1', 'r2'], self.engine.resource_names) + + def test_contains(self): + self.assertTrue(self.r1.name in self.engine) + self.assertTrue(self.r2.name in self.engine) + self.assertFalse('r3' in self.engine) + + def test_count(self): + count = mock.Mock(return_value=42) + r = quota.CountableResource('r1', count) + self.engine.register_resource(r) + actual = self.engine.count(self.ctxt, 'r1') + self.assertEqual(42, actual) + count.assert_called_once_with(self.ctxt) + self.assertRaises(quota.QuotaResourceUnknown, + self.engine.count, self.ctxt, 'r2') + + def test_init(self): + engine = quota.QuotaEngine(self.dbapi) + self.assertIsInstance(engine._driver, quota.DbQuotaDriver) diff --git a/tests/unit/test_service.py b/tests/unit/test_service.py index 4f742ce..c7d18f6 100644 --- a/tests/unit/test_service.py +++ b/tests/unit/test_service.py @@ -67,7 +67,41 @@ class ServiceWithTimer(service.Service): self.timer_fired = self.timer_fired + 1 -class ServiceLauncherTest(utils.BaseTestCase): +class ServiceTestBase(utils.BaseTestCase): + """A base class for ServiceLauncherTest and ServiceRestartTest.""" + + def _wait(self, cond, timeout): + start = time.time() + while not cond(): + if time.time() - start > timeout: + break + time.sleep(.1) + + def setUp(self): + super(ServiceTestBase, self).setUp() + # FIXME(markmc): Ugly hack to workaround bug #1073732 + CONF.unregister_opts(notifier_api.notifier_opts) + # NOTE(markmc): ConfigOpts.log_opt_values() uses CONF.config-file + CONF(args=[], default_config_files=[]) + self.addCleanup(CONF.reset) + self.addCleanup(CONF.register_opts, notifier_api.notifier_opts) + self.addCleanup(self._reap_pid) + + def _reap_pid(self): + if self.pid: + # Make sure all processes are stopped + os.kill(self.pid, signal.SIGTERM) + + # Make sure we reap our test process + self._reap_test() + + def _reap_test(self): + pid, status = os.waitpid(self.pid, 0) + self.pid = None + return status + + +class ServiceLauncherTest(ServiceTestBase): """Originally from nova/tests/integrated/test_multiprocess_api.py.""" def _spawn(self): @@ -111,38 +145,6 @@ class ServiceLauncherTest(utils.BaseTestCase): self.assertEqual(len(workers), self.workers) return workers - def _wait(self, cond, timeout): - start = time.time() - while True: - if cond(): - break - if time.time() - start > timeout: - break - time.sleep(.1) - - def setUp(self): - super(ServiceLauncherTest, self).setUp() - # FIXME(markmc): Ugly hack to workaround bug #1073732 - CONF.unregister_opts(notifier_api.notifier_opts) - # NOTE(markmc): ConfigOpts.log_opt_values() uses CONF.config-file - CONF(args=[], default_config_files=[]) - self.addCleanup(CONF.reset) - self.addCleanup(CONF.register_opts, notifier_api.notifier_opts) - self.addCleanup(self._reap_pid) - - def _reap_pid(self): - if self.pid: - # Make sure all processes are stopped - os.kill(self.pid, signal.SIGTERM) - - # Make sure we reap our test process - self._reap_test() - - def _reap_test(self): - pid, status = os.waitpid(self.pid, 0) - self.pid = None - return status - def _get_workers(self): f = os.popen('ps ax -o pid,ppid,command') # Skip ps header @@ -195,6 +197,89 @@ class ServiceLauncherTest(utils.BaseTestCase): self.assertTrue(os.WIFEXITED(status)) self.assertEqual(os.WEXITSTATUS(status), 0) + def test_child_signal_sighup(self): + start_workers = self._spawn() + + os.kill(start_workers[0], signal.SIGHUP) + # Wait at most 5 seconds to respawn a worker + cond = lambda: start_workers == self._get_workers() + timeout = 5 + self._wait(cond, timeout) + + # Make sure worker pids match + end_workers = self._get_workers() + LOG.info('workers: %r' % end_workers) + self.assertEqual(start_workers, end_workers) + + def test_parent_signal_sighup(self): + start_workers = self._spawn() + + os.kill(self.pid, signal.SIGHUP) + # Wait at most 5 seconds to respawn a worker + cond = lambda: start_workers == self._get_workers() + timeout = 5 + self._wait(cond, timeout) + + # Make sure worker pids match + end_workers = self._get_workers() + LOG.info('workers: %r' % end_workers) + self.assertEqual(start_workers, end_workers) + + +class ServiceRestartTest(ServiceTestBase): + + def _check_process_alive(self): + f = os.popen('ps ax -o pid,stat,cmd') + f.readline() + pid_stat = [tuple(p for p in line.strip().split()[:2]) + for line in f.readlines()] + for p, stat in pid_stat: + if int(p) == self.pid: + return stat not in ['Z', 'T', 'Z+'] + return False + + def _spawn_service(self): + pid = os.fork() + status = 0 + if pid == 0: + try: + serv = ServiceWithTimer() + launcher = service.ServiceLauncher() + launcher.launch_service(serv) + launcher.wait() + except SystemExit as exc: + status = exc.code + os._exit(status) + self.pid = pid + + def test_service_restart(self): + self._spawn_service() + + cond = self._check_process_alive + timeout = 5 + self._wait(cond, timeout) + + ret = self._check_process_alive() + self.assertTrue(ret) + + os.kill(self.pid, signal.SIGHUP) + self._wait(cond, timeout) + + ret_restart = self._check_process_alive() + self.assertTrue(ret_restart) + + def test_terminate_sigterm(self): + self._spawn_service() + cond = self._check_process_alive + timeout = 5 + self._wait(cond, timeout) + + os.kill(self.pid, signal.SIGTERM) + + status = self._reap_test() + self.assertTrue(os.WIFEXITED(status)) + self.assertEqual(os.WEXITSTATUS(status), 0) + class _Service(service.Service): def __init__(self): diff --git a/tests/unit/test_sslutils.py b/tests/unit/test_sslutils.py index 4c0646e..2095d3c 100644 --- a/tests/unit/test_sslutils.py +++ b/tests/unit/test_sslutils.py @@ -17,39 +17,38 @@ import ssl from openstack.common import sslutils -from tests import utils +from openstack.common import test -class SSLUtilsTest(utils.BaseTestCase): +class SSLUtilsTest(test.BaseTestCase): def test_valid_versions(self): - self.assertEquals(sslutils.validate_ssl_version("SSLv3"), - ssl.PROTOCOL_SSLv3) - self.assertEquals(sslutils.validate_ssl_version("SSLv23"), - ssl.PROTOCOL_SSLv23) - self.assertEquals(sslutils.validate_ssl_version("TLSv1"), - ssl.PROTOCOL_TLSv1) + self.assertEqual(sslutils.validate_ssl_version("SSLv3"), + ssl.PROTOCOL_SSLv3) + self.assertEqual(sslutils.validate_ssl_version("SSLv23"), + ssl.PROTOCOL_SSLv23) + self.assertEqual(sslutils.validate_ssl_version("TLSv1"), + ssl.PROTOCOL_TLSv1) try: protocol = ssl.PROTOCOL_SSLv2 except AttributeError: pass else: - self.assertEquals(sslutils.validate_ssl_version("SSLv2"), - protocol) + self.assertEqual(sslutils.validate_ssl_version("SSLv2"), protocol) def test_lowercase_valid_versions(self): - self.assertEquals(sslutils.validate_ssl_version("sslv3"), - ssl.PROTOCOL_SSLv3) - self.assertEquals(sslutils.validate_ssl_version("sslv23"), - ssl.PROTOCOL_SSLv23) - self.assertEquals(sslutils.validate_ssl_version("tlsv1"), - ssl.PROTOCOL_TLSv1) + self.assertEqual(sslutils.validate_ssl_version("sslv3"), + ssl.PROTOCOL_SSLv3) + self.assertEqual(sslutils.validate_ssl_version("sslv23"), + ssl.PROTOCOL_SSLv23) + self.assertEqual(sslutils.validate_ssl_version("tlsv1"), + ssl.PROTOCOL_TLSv1) try: protocol = ssl.PROTOCOL_SSLv2 except AttributeError: pass else: - self.assertEquals(sslutils.validate_ssl_version("sslv2"), - protocol) + self.assertEqual(sslutils.validate_ssl_version("sslv2"), + protocol) def test_invalid_version(self): self.assertRaises(RuntimeError, diff --git a/tests/unit/test_strutils.py b/tests/unit/test_strutils.py index a8d8462..22aafd9 100644 --- a/tests/unit/test_strutils.py +++ b/tests/unit/test_strutils.py @@ -20,10 +20,10 @@ import mock import six from openstack.common import strutils -from tests import utils +from openstack.common import test -class StrUtilsTest(utils.BaseTestCase): +class StrUtilsTest(test.BaseTestCase): def test_bool_bool_from_string(self): self.assertTrue(strutils.bool_from_string(True)) @@ -188,11 +188,11 @@ class StrUtilsTest(utils.BaseTestCase): } for (in_value, expected_value) in working_examples.items(): b_value = strutils.to_bytes(in_value) - self.assertEquals(expected_value, b_value) + self.assertEqual(expected_value, b_value) if in_value: in_value = "-" + in_value b_value = strutils.to_bytes(in_value) - self.assertEquals(expected_value * -1, b_value) + self.assertEqual(expected_value * -1, b_value) breaking_examples = [ 'junk1KB', '1023BBBB', ] diff --git a/tests/unit/test_threadgroup.py b/tests/unit/test_threadgroup.py index 5af6653..2273800 100644 --- a/tests/unit/test_threadgroup.py +++ b/tests/unit/test_threadgroup.py @@ -20,13 +20,13 @@ Unit Tests for thread groups """ from openstack.common import log as logging +from openstack.common import test from openstack.common import threadgroup -from tests import utils LOG = logging.getLogger(__name__) -class ThreadGroupTestCase(utils.BaseTestCase): +class ThreadGroupTestCase(test.BaseTestCase): """Test cases for thread group.""" def setUp(self): super(ThreadGroupTestCase, self).setUp() diff --git a/tests/unit/test_timeutils.py b/tests/unit/test_timeutils.py index bfab278..792c0aa 100644 --- a/tests/unit/test_timeutils.py +++ b/tests/unit/test_timeutils.py @@ -21,11 +21,11 @@ import datetime import iso8601 import mock +from openstack.common import test from openstack.common import timeutils -from tests import utils -class TimeUtilsTest(utils.BaseTestCase): +class TimeUtilsTest(test.BaseTestCase): def setUp(self): super(TimeUtilsTest, self).setUp() @@ -181,16 +181,16 @@ class TimeUtilsTest(utils.BaseTestCase): self.assertTrue(timeutils.is_soon(expires, 0)) -class TestIso8601Time(utils.BaseTestCase): +class TestIso8601Time(test.BaseTestCase): def _instaneous(self, timestamp, yr, mon, day, hr, min, sec, micro): - self.assertEquals(timestamp.year, yr) - self.assertEquals(timestamp.month, mon) - self.assertEquals(timestamp.day, day) - self.assertEquals(timestamp.hour, hr) - self.assertEquals(timestamp.minute, min) - self.assertEquals(timestamp.second, sec) - self.assertEquals(timestamp.microsecond, micro) + self.assertEqual(timestamp.year, yr) + self.assertEqual(timestamp.month, mon) + self.assertEqual(timestamp.day, day) + self.assertEqual(timestamp.hour, hr) + self.assertEqual(timestamp.minute, min) + self.assertEqual(timestamp.second, sec) + self.assertEqual(timestamp.microsecond, micro) def _do_test(self, str, yr, mon, day, hr, min, sec, micro, shift): DAY_SECONDS = 24 * 60 * 60 @@ -246,26 +246,26 @@ class TestIso8601Time(utils.BaseTestCase): def test_zulu_roundtrip(self): str = '2012-02-14T20:53:07Z' zulu = timeutils.parse_isotime(str) - self.assertEquals(zulu.tzinfo, iso8601.iso8601.UTC) - self.assertEquals(timeutils.isotime(zulu), str) + self.assertEqual(zulu.tzinfo, iso8601.iso8601.UTC) + self.assertEqual(timeutils.isotime(zulu), str) def test_east_roundtrip(self): str = '2012-02-14T20:53:07-07:00' east = timeutils.parse_isotime(str) - self.assertEquals(east.tzinfo.tzname(None), '-07:00') - self.assertEquals(timeutils.isotime(east), str) + self.assertEqual(east.tzinfo.tzname(None), '-07:00') + self.assertEqual(timeutils.isotime(east), str) def test_west_roundtrip(self): str = '2012-02-14T20:53:07+11:30' west = timeutils.parse_isotime(str) - self.assertEquals(west.tzinfo.tzname(None), '+11:30') - self.assertEquals(timeutils.isotime(west), str) + self.assertEqual(west.tzinfo.tzname(None), '+11:30') + self.assertEqual(timeutils.isotime(west), str) def test_now_roundtrip(self): str = timeutils.isotime() now = timeutils.parse_isotime(str) - self.assertEquals(now.tzinfo, iso8601.iso8601.UTC) - self.assertEquals(timeutils.isotime(now), str) + self.assertEqual(now.tzinfo, iso8601.iso8601.UTC) + self.assertEqual(timeutils.isotime(now), str) def test_zulu_normalize(self): str = '2012-02-14T20:53:07Z' diff --git a/tests/unit/test_uuidutils.py b/tests/unit/test_uuidutils.py index e9348e2..4af5a8e 100644 --- a/tests/unit/test_uuidutils.py +++ b/tests/unit/test_uuidutils.py @@ -17,11 +17,11 @@ import uuid +from openstack.common import test from openstack.common import uuidutils -from tests import utils -class UUIDUtilsTest(utils.BaseTestCase): +class UUIDUtilsTest(test.BaseTestCase): def test_generate_uuid(self): uuid_string = uuidutils.generate_uuid() diff --git a/tests/unit/test_xmlutils.py b/tests/unit/test_xmlutils.py index 5d2bd05..a5bd29f 100644 --- a/tests/unit/test_xmlutils.py +++ b/tests/unit/test_xmlutils.py @@ -16,11 +16,11 @@ from xml.dom import minidom +from openstack.common import test from openstack.common import xmlutils -from tests import utils -class XMLUtilsTestCase(utils.BaseTestCase): +class XMLUtilsTestCase(test.BaseTestCase): def test_safe_parse_xml(self): normal_body = (""" @@ -55,7 +55,7 @@ class XMLUtilsTestCase(utils.BaseTestCase): killer_body()) -class SafeParserTestCase(utils.BaseTestCase): +class SafeParserTestCase(test.BaseTestCase): def test_external_dtd(self): xml_string = ("""<?xml version="1.0" encoding="utf-8"?> <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" diff --git a/tests/utils.py b/tests/utils.py index e93c278..b0770a9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -24,7 +24,6 @@ import fixtures from oslo.config import cfg import testtools -from openstack.common import exception from openstack.common.fixture import moxstubout @@ -38,8 +37,24 @@ class BaseTestCase(testtools.TestCase): self.conf = conf self.addCleanup(self.conf.reset) self.useFixture(fixtures.FakeLogger('openstack.common')) - self.useFixture(fixtures.Timeout(30, True)) - self.stubs.Set(exception, '_FATAL_EXCEPTION_FORMAT_ERRORS', True) + + test_timeout = os.environ.get('OS_TEST_TIMEOUT', 0) + try: + test_timeout = int(test_timeout) + except ValueError: + # If timeout value is invalid do not set a timeout. + test_timeout = 0 + if test_timeout > 0: + self.useFixture(fixtures.Timeout(test_timeout, gentle=True)) + if (os.environ.get('OS_STDOUT_CAPTURE') == 'True' or + os.environ.get('OS_STDOUT_CAPTURE') == '1'): + stdout = self.useFixture(fixtures.StringStream('stdout')).stream + self.useFixture(fixtures.MonkeyPatch('sys.stdout', stdout)) + if (os.environ.get('OS_STDERR_CAPTURE') == 'True' or + os.environ.get('OS_STDERR_CAPTURE') == '1'): + stderr = self.useFixture(fixtures.StringStream('stderr')).stream + self.useFixture(fixtures.MonkeyPatch('sys.stderr', stderr)) + self.useFixture(fixtures.NestedTempfile()) self.tempdirs = [] diff --git a/tools/colorizer.py b/tools/colorizer.py new file mode 100755 index 0000000..13364ba --- /dev/null +++ b/tools/colorizer.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (c) 2013, Nebula, Inc. +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# Colorizer Code is borrowed from Twisted: +# Copyright (c) 2001-2010 Twisted Matrix Laboratories. +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +"""Display a subunit stream through a colorized unittest test runner.""" + +import heapq +import subunit +import sys +import unittest + +import testtools + + +class _AnsiColorizer(object): + """Colorizer allows callers to write text in a particular color. + + A colorizer is an object that loosely wraps around a stream, allowing + callers to write text to the stream in a particular color. + + Colorizer classes must implement C{supported()} and C{write(text, color)}. + """ + _colors = dict(black=30, red=31, green=32, yellow=33, + blue=34, magenta=35, cyan=36, white=37) + + def __init__(self, stream): + self.stream = stream + + def supported(cls, stream=sys.stdout): + """Check is the current platform supports coloring terminal output. + + A class method that returns True if the current platform supports + coloring terminal output using this method. Returns False otherwise. + """ + if not stream.isatty(): + return False # auto color only on TTYs + try: + import curses + except ImportError: + return False + else: + try: + try: + return curses.tigetnum("colors") > 2 + except curses.error: + curses.setupterm() + return curses.tigetnum("colors") > 2 + except Exception: + # guess false in case of error + return False + supported = classmethod(supported) + + def write(self, text, color): + """Write the given text to the stream in the given color. + + @param text: Text to be written to the stream. + + @param color: A string label for a color. e.g. 'red', 'white'. + """ + color = self._colors[color] + self.stream.write('\x1b[%s;1m%s\x1b[0m' % (color, text)) + + +class _Win32Colorizer(object): + """See _AnsiColorizer docstring.""" + def __init__(self, stream): + import win32console + red, green, blue, bold = (win32console.FOREGROUND_RED, + win32console.FOREGROUND_GREEN, + win32console.FOREGROUND_BLUE, + win32console.FOREGROUND_INTENSITY) + self.stream = stream + self.screenBuffer = win32console.GetStdHandle( + win32console.STD_OUT_HANDLE) + self._colors = { + 'normal': red | green | blue, + 'red': red | bold, + 'green': green | bold, + 'blue': blue | bold, + 'yellow': red | green | bold, + 'magenta': red | blue | bold, + 'cyan': green | blue | bold, + 'white': red | green | blue | bold, + } + + def supported(cls, stream=sys.stdout): + try: + import win32console + screenBuffer = win32console.GetStdHandle( + win32console.STD_OUT_HANDLE) + except ImportError: + return False + import pywintypes + try: + screenBuffer.SetConsoleTextAttribute( + win32console.FOREGROUND_RED | + win32console.FOREGROUND_GREEN | + win32console.FOREGROUND_BLUE) + except pywintypes.error: + return False + else: + return True + supported = classmethod(supported) + + def write(self, text, color): + color = self._colors[color] + self.screenBuffer.SetConsoleTextAttribute(color) + self.stream.write(text) + self.screenBuffer.SetConsoleTextAttribute(self._colors['normal']) + + +class _NullColorizer(object): + """See _AnsiColorizer docstring.""" + def __init__(self, stream): + self.stream = stream + + def supported(cls, stream=sys.stdout): + return True + supported = classmethod(supported) + + def write(self, text, color): + self.stream.write(text) + + +def get_elapsed_time_color(elapsed_time): + if elapsed_time > 1.0: + return 'red' + elif elapsed_time > 0.25: + return 'yellow' + else: + return 'green' + + +class OpenStackTestResult(testtools.TestResult): + def __init__(self, stream, descriptions, verbosity): + super(OpenStackTestResult, self).__init__() + self.stream = stream + self.showAll = verbosity > 1 + self.num_slow_tests = 10 + self.slow_tests = [] # this is a fixed-sized heap + self.colorizer = None + # NOTE(vish): reset stdout for the terminal check + stdout = sys.stdout + sys.stdout = sys.__stdout__ + for colorizer in [_Win32Colorizer, _AnsiColorizer, _NullColorizer]: + if colorizer.supported(): + self.colorizer = colorizer(self.stream) + break + sys.stdout = stdout + self.start_time = None + self.last_time = {} + self.results = {} + self.last_written = None + + def _writeElapsedTime(self, elapsed): + color = get_elapsed_time_color(elapsed) + self.colorizer.write(" %.2f" % elapsed, color) + + def _addResult(self, test, *args): + try: + name = test.id() + except AttributeError: + name = 'Unknown.unknown' + test_class, test_name = name.rsplit('.', 1) + + elapsed = (self._now() - self.start_time).total_seconds() + item = (elapsed, test_class, test_name) + if len(self.slow_tests) >= self.num_slow_tests: + heapq.heappushpop(self.slow_tests, item) + else: + heapq.heappush(self.slow_tests, item) + + self.results.setdefault(test_class, []) + self.results[test_class].append((test_name, elapsed) + args) + self.last_time[test_class] = self._now() + self.writeTests() + + def _writeResult(self, test_name, elapsed, long_result, color, + short_result, success): + if self.showAll: + self.stream.write(' %s' % str(test_name).ljust(66)) + self.colorizer.write(long_result, color) + if success: + self._writeElapsedTime(elapsed) + self.stream.writeln() + else: + self.colorizer.write(short_result, color) + + def addSuccess(self, test): + super(OpenStackTestResult, self).addSuccess(test) + self._addResult(test, 'OK', 'green', '.', True) + + def addFailure(self, test, err): + if test.id() == 'process-returncode': + return + super(OpenStackTestResult, self).addFailure(test, err) + self._addResult(test, 'FAIL', 'red', 'F', False) + + def addError(self, test, err): + super(OpenStackTestResult, self).addFailure(test, err) + self._addResult(test, 'ERROR', 'red', 'E', False) + + def addSkip(self, test, reason=None, details=None): + super(OpenStackTestResult, self).addSkip(test, reason, details) + self._addResult(test, 'SKIP', 'blue', 'S', True) + + def startTest(self, test): + self.start_time = self._now() + super(OpenStackTestResult, self).startTest(test) + + def writeTestCase(self, cls): + if not self.results.get(cls): + return + if cls != self.last_written: + self.colorizer.write(cls, 'white') + self.stream.writeln() + for result in self.results[cls]: + self._writeResult(*result) + del self.results[cls] + self.stream.flush() + self.last_written = cls + + def writeTests(self): + time = self.last_time.get(self.last_written, self._now()) + if not self.last_written or (self._now() - time).total_seconds() > 2.0: + diff = 3.0 + while diff > 2.0: + classes = self.results.keys() + oldest = min(classes, key=lambda x: self.last_time[x]) + diff = (self._now() - self.last_time[oldest]).total_seconds() + self.writeTestCase(oldest) + else: + self.writeTestCase(self.last_written) + + def done(self): + self.stopTestRun() + + def stopTestRun(self): + for cls in list(self.results.iterkeys()): + self.writeTestCase(cls) + self.stream.writeln() + self.writeSlowTests() + + def writeSlowTests(self): + # Pare out 'fast' tests + slow_tests = [item for item in self.slow_tests + if get_elapsed_time_color(item[0]) != 'green'] + if slow_tests: + slow_total_time = sum(item[0] for item in slow_tests) + slow = ("Slowest %i tests took %.2f secs:" + % (len(slow_tests), slow_total_time)) + self.colorizer.write(slow, 'yellow') + self.stream.writeln() + last_cls = None + # sort by name + for elapsed, cls, name in sorted(slow_tests, + key=lambda x: x[1] + x[2]): + if cls != last_cls: + self.colorizer.write(cls, 'white') + self.stream.writeln() + last_cls = cls + self.stream.write(' %s' % str(name).ljust(68)) + self._writeElapsedTime(elapsed) + self.stream.writeln() + + def printErrors(self): + if self.showAll: + self.stream.writeln() + self.printErrorList('ERROR', self.errors) + self.printErrorList('FAIL', self.failures) + + def printErrorList(self, flavor, errors): + for test, err in errors: + self.colorizer.write("=" * 70, 'red') + self.stream.writeln() + self.colorizer.write(flavor, 'red') + self.stream.writeln(": %s" % test.id()) + self.colorizer.write("-" * 70, 'red') + self.stream.writeln() + self.stream.writeln("%s" % err) + + +test = subunit.ProtocolTestCase(sys.stdin, passthrough=None) + +if sys.version_info[0:2] <= (2, 6): + runner = unittest.TextTestRunner(verbosity=2) +else: + runner = unittest.TextTestRunner(verbosity=2, + resultclass=OpenStackTestResult) + +if runner.run(test).wasSuccessful(): + exit_code = 0 +else: + exit_code = 1 +sys.exit(exit_code) diff --git a/tools/config/generate_sample.sh b/tools/config/generate_sample.sh index 26f02dd..f5e5a67 100755 --- a/tools/config/generate_sample.sh +++ b/tools/config/generate_sample.sh @@ -64,6 +64,9 @@ FILES=$(find $BASEDIR/$PACKAGENAME -type f -name "*.py" ! -path "*/tests/*" \ export EVENTLET_NO_GREENDNS=yes +OS_VARS=$(set | sed -n '/^OS_/s/=[^=]*$//gp' | xargs) +[ "$OS_VARS" ] && eval "unset \$OS_VARS" + MODULEPATH=openstack.common.config.generator OUTPUTFILE=$OUTPUTDIR/$PACKAGENAME.conf.sample python -m $MODULEPATH $FILES > $OUTPUTFILE diff --git a/tools/install_venv.py b/tools/install_venv.py new file mode 100644 index 0000000..1d4e7e0 --- /dev/null +++ b/tools/install_venv.py @@ -0,0 +1,74 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Copyright 2010 OpenStack Foundation +# Copyright 2013 IBM Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import sys + +import install_venv_common as install_venv # noqa + + +def print_help(venv, root): + help = """ + Openstack development environment setup is complete. + + Openstack development uses virtualenv to track and manage Python + dependencies while in development and testing. + + To activate the Openstack virtualenv for the extent of your current shell + session you can run: + + $ source %s/bin/activate + + Or, if you prefer, you can run commands in the virtualenv on a case by case + basis by running: + + $ %s/tools/with_venv.sh <your command> + + Also, make test will automatically use the virtualenv. + """ + print(help % (venv, root)) + + +def main(argv): + root = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + + if os.environ.get('tools_path'): + root = os.environ['tools_path'] + venv = os.path.join(root, '.venv') + if os.environ.get('venv'): + venv = os.environ['venv'] + + pip_requires = os.path.join(root, 'requirements.txt') + test_requires = os.path.join(root, 'test-requirements.txt') + py_version = "python%s.%s" % (sys.version_info[0], sys.version_info[1]) + project = 'Openstack' + install = install_venv.InstallVenv(root, venv, pip_requires, test_requires, + py_version, project) + options = install.parse_args(argv) + install.check_python_version() + install.check_dependencies() + install.create_virtualenv(no_site_packages=options.no_site_packages) + install.install_dependencies() + install.post_process() + print_help(venv, root) + +if __name__ == '__main__': + main(sys.argv) diff --git a/tools/install_venv_common.py b/tools/install_venv_common.py index f428c1e..0999e2c 100644 --- a/tools/install_venv_common.py +++ b/tools/install_venv_common.py @@ -114,9 +114,10 @@ class InstallVenv(object): print('Installing dependencies with pip (this can take a while)...') # First things first, make sure our venv has the latest pip and - # setuptools. - self.pip_install('pip>=1.3') + # setuptools and pbr + self.pip_install('pip>=1.4') self.pip_install('setuptools') + self.pip_install('pbr') self.pip_install('-r', self.requirements) self.pip_install('-r', self.test_requirements) @@ -201,12 +202,13 @@ class Fedora(Distro): RHEL: https://bugzilla.redhat.com/958868 """ - # Install "patch" program if it's not there - if not self.check_pkg('patch'): - self.die("Please install 'patch'.") + if os.path.exists('contrib/redhat-eventlet.patch'): + # Install "patch" program if it's not there + if not self.check_pkg('patch'): + self.die("Please install 'patch'.") - # Apply the eventlet patch - self.apply_patch(os.path.join(self.venv, 'lib', self.py_version, - 'site-packages', - 'eventlet/green/subprocess.py'), - 'contrib/redhat-eventlet.patch') + # Apply the eventlet patch + self.apply_patch(os.path.join(self.venv, 'lib', self.py_version, + 'site-packages', + 'eventlet/green/subprocess.py'), + 'contrib/redhat-eventlet.patch') diff --git a/tools/run_tests_common.sh b/tools/run_tests_common.sh new file mode 100755 index 0000000..1cc4e6f --- /dev/null +++ b/tools/run_tests_common.sh @@ -0,0 +1,253 @@ +#!/bin/bash + +set -eu + +function usage { + echo "Usage: $0 [OPTION]..." + echo "Run project's test suite(s)" + echo "" + echo " -V, --virtual-env Always use virtualenv. Install automatically if not present" + echo " -N, --no-virtual-env Don't use virtualenv. Run tests in local environment" + echo " -s, --no-site-packages Isolate the virtualenv from the global Python environment" + echo " -r, --recreate-db Recreate the test database (deprecated, as this is now the default)." + echo " -n, --no-recreate-db Don't recreate the test database." + echo " -f, --force Force a clean re-build of the virtual environment. Useful when dependencies have been added." + echo " -u, --update Update the virtual environment with any newer package versions" + echo " -p, --pep8 Just run PEP8 and HACKING compliance check" + echo " -P, --no-pep8 Don't run static code checks" + echo " -c, --coverage Generate coverage report" + echo " -d, --debug Run tests with testtools instead of testr. This allows you to use the debugger." + echo " -h, --help Print this usage message" + echo " --hide-elapsed Don't print the elapsed time for each test along with slow test list" + echo " --virtual-env-path <path> Location of the virtualenv directory" + echo " Default: \$(pwd)" + echo " --virtual-env-name <name> Name of the virtualenv directory" + echo " Default: .venv" + echo " --tools-path <dir> Location of the tools directory" + echo " Default: \$(pwd)" + echo "" + echo "Note: with no options specified, the script will try to run the tests in a virtual environment," + echo " If no virtualenv is found, the script will ask if you would like to create one. If you " + echo " prefer to run tests NOT in a virtual environment, simply pass the -N option." + exit +} + +function process_options { + i=1 + while [ $i -le $# ]; do + case "${!i}" in + -h|--help) usage;; + -V|--virtual-env) ALWAYS_VENV=1; NEVER_VENV=0;; + -N|--no-virtual-env) ALWAYS_VENV=0; NEVER_VENV=1;; + -s|--no-site-packages) NO_SITE_PACKAGES=1;; + -r|--recreate-db) RECREATE_DB=1;; + -n|--no-recreate-db) RECREATE_DB=0;; + -f|--force) FORCE=1;; + -u|--update) UPDATE=1;; + -p|--pep8) JUST_PEP8=1;; + -P|--no-pep8) NO_PEP8=1;; + -c|--coverage) COVERAGE=1;; + -d|--debug) DEBUG=1;; + --virtual-env-path) + (( i++ )) + VENV_PATH=${!i} + ;; + --virtual-env-name) + (( i++ )) + VENV_DIR=${!i} + ;; + --tools-path) + (( i++ )) + TOOLS_PATH=${!i} + ;; + -*) TESTOPTS="$TESTOPTS ${!i}";; + *) TESTRARGS="$TESTRARGS ${!i}" + esac + (( i++ )) + done +} + + +TOOLS_PATH=${TOOLS_PATH:-${PWD}} +VENV_PATH=${VENV_PATH:-${PWD}} +VENV_DIR=${VENV_NAME:-.venv} +WITH_VENV=${TOOLS_PATH}/tools/with_venv.sh + +ALWAYS_VENV=0 +NEVER_VENV=0 +FORCE=0 +NO_SITE_PACKAGES=1 +INSTALLVENVOPTS= +TESTRARGS= +TESTOPTS= +WRAPPER="" +JUST_PEP8=0 +NO_PEP8=0 +COVERAGE=0 +DEBUG=0 +RECREATE_DB=1 +UPDATE=0 + +LANG=en_US.UTF-8 +LANGUAGE=en_US:en +LC_ALL=C + +process_options $@ +# Make our paths available to other scripts we call +export VENV_PATH +export TOOLS_PATH +export VENV_DIR +export VENV_NAME +export WITH_VENV +export VENV=${VENV_PATH}/${VENV_DIR} + +function init_testr { + if [ ! -d .testrepository ]; then + ${WRAPPER} testr init + fi +} + +function run_tests { + # Cleanup *pyc + ${WRAPPER} find . -type f -name "*.pyc" -delete + + if [ ${DEBUG} -eq 1 ]; then + if [ "${TESTOPTS}" = "" ] && [ "${TESTRARGS}" = "" ]; then + # Default to running all tests if specific test is not + # provided. + TESTRARGS="discover ./${TESTS_DIR}" + fi + ${WRAPPER} python -m testtools.run ${TESTOPTS} ${TESTRARGS} + + # Short circuit because all of the testr and coverage stuff + # below does not make sense when running testtools.run for + # debugging purposes. + return $? + fi + + if [ ${COVERAGE} -eq 1 ]; then + TESTRTESTS="${TESTRTESTS} --coverage" + else + TESTRTESTS="${TESTRTESTS}" + fi + + # Just run the test suites in current environment + set +e + TESTRARGS=`echo "${TESTRARGS}" | sed -e's/^\s*\(.*\)\s*$/\1/'` + + if [ ${WORKERS_COUNT} -ne 0 ]; then + TESTRTESTS="${TESTRTESTS} --testr-args='--concurrency=${WORKERS_COUNT} --subunit ${TESTOPTS} ${TESTRARGS}'" + else + TESTRTESTS="${TESTRTESTS} --testr-args='--subunit ${TESTOPTS} ${TESTRARGS}'" + fi + + if [ setup.cfg -nt ${EGG_INFO_FILE} ]; then + ${WRAPPER} python setup.py egg_info + fi + + echo "Running \`${WRAPPER} ${TESTRTESTS}\`" + if ${WRAPPER} which subunit-2to1 2>&1 > /dev/null; then + # subunit-2to1 is present, testr subunit stream should be in version 2 + # format. Convert to version one before colorizing. + bash -c "${WRAPPER} ${TESTRTESTS} | ${WRAPPER} subunit-2to1 | ${WRAPPER} ${TOOLS_PATH}/tools/colorizer.py" + else + bash -c "${WRAPPER} ${TESTRTESTS} | ${WRAPPER} ${TOOLS_PATH}/tools/colorizer.py" + fi + RESULT=$? + set -e + + copy_subunit_log + + if [ $COVERAGE -eq 1 ]; then + echo "Generating coverage report in covhtml/" + ${WRAPPER} coverage combine + # Don't compute coverage for common code, which is tested elsewhere + # if we are not in `oslo-incubator` project + if [ ${OMIT_OSLO_FROM_COVERAGE} -eq 0 ]; then + OMIT_OSLO="" + else + OMIT_OSLO="--omit='${PROJECT_NAME}/openstack/common/*'" + fi + ${WRAPPER} coverage html --include='${PROJECT_NAME}/*' ${OMIT_OSLO} -d covhtml -i + fi + + return ${RESULT} +} + +function copy_subunit_log { + LOGNAME=`cat .testrepository/next-stream` + LOGNAME=$((${LOGNAME} - 1)) + LOGNAME=".testrepository/${LOGNAME}" + cp ${LOGNAME} subunit.log +} + +function run_pep8 { + echo "Running flake8 ..." + bash -c "${WRAPPER} flake8" +} + + +TESTRTESTS="python setup.py testr" + +if [ ${NO_SITE_PACKAGES} -eq 1 ]; then + INSTALLVENVOPTS="--no-site-packages" +fi + +if [ ${NEVER_VENV} -eq 0 ]; then + # Remove the virtual environment if -f or --force used + if [ ${FORCE} -eq 1 ]; then + echo "Cleaning virtualenv..." + rm -rf ${VENV} + fi + + # Update the virtual environment if -u or --update used + if [ ${UPDATE} -eq 1 ]; then + echo "Updating virtualenv..." + python ${TOOLS_PATH}/tools/install_venv.py ${INSTALLVENVOPTS} + fi + + if [ -e ${VENV} ]; then + WRAPPER="${WITH_VENV}" + else + if [ ${ALWAYS_VENV} -eq 1 ]; then + # Automatically install the virtualenv + python ${TOOLS_PATH}/tools/install_venv.py ${INSTALLVENVOPTS} + WRAPPER="${WITH_VENV}" + else + echo -e "No virtual environment found...create one? (Y/n) \c" + read USE_VENV + if [ "x${USE_VENV}" = "xY" -o "x${USE_VENV}" = "x" -o "x${USE_VENV}" = "xy" ]; then + # Install the virtualenv and run the test suite in it + python ${TOOLS_PATH}/tools/install_venv.py ${INSTALLVENVOPTS} + WRAPPER=${WITH_VENV} + fi + fi + fi +fi + +# Delete old coverage data from previous runs +if [ ${COVERAGE} -eq 1 ]; then + ${WRAPPER} coverage erase +fi + +if [ ${JUST_PEP8} -eq 1 ]; then + run_pep8 + exit +fi + +if [ ${RECREATE_DB} -eq 1 ]; then + rm -f tests.sqlite +fi + +init_testr +run_tests + +# NOTE(sirp): we only want to run pep8 when we're running the full-test suite, +# not when we're running tests individually. To handle this, we need to +# distinguish between options (testropts), which begin with a '-', and +# arguments (testrargs). +if [ -z "${TESTRARGS}" ]; then + if [ ${NO_PEP8} -eq 0 ]; then + run_pep8 + fi +fi diff --git a/tools/with_venv.sh b/tools/with_venv.sh new file mode 100755 index 0000000..7303990 --- /dev/null +++ b/tools/with_venv.sh @@ -0,0 +1,7 @@ +#!/bin/bash +TOOLS_PATH=${TOOLS_PATH:-$(dirname $0)} +VENV_PATH=${VENV_PATH:-${TOOLS_PATH}} +VENV_DIR=${VENV_NAME:-/../.venv} +TOOLS=${TOOLS_PATH} +VENV=${VENV:-${VENV_PATH}/${VENV_DIR}} +source ${VENV}/bin/activate && "$@" |