summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore3
-rw-r--r--MAINTAINERS4
-rw-r--r--README.rst2
-rwxr-xr-xopenstack/common/config/generator.py2
-rw-r--r--openstack/common/db/sqlalchemy/models.py3
-rw-r--r--openstack/common/db/sqlalchemy/session.py80
-rw-r--r--openstack/common/deprecated/__init__.py0
-rw-r--r--openstack/common/deprecated/wsgi.py (renamed from openstack/common/wsgi.py)101
-rw-r--r--openstack/common/fileutils.py75
-rw-r--r--openstack/common/gettextutils.py176
-rw-r--r--openstack/common/importutils.py2
-rw-r--r--openstack/common/lockutils.py23
-rw-r--r--openstack/common/memorycache.py3
-rw-r--r--openstack/common/middleware/base.py62
-rw-r--r--openstack/common/middleware/context.py4
-rw-r--r--openstack/common/middleware/correlation_id.py4
-rw-r--r--openstack/common/middleware/debug.py60
-rw-r--r--openstack/common/middleware/sizelimit.py5
-rw-r--r--openstack/common/notifier/api.py2
-rw-r--r--openstack/common/notifier/log_notifier.py4
-rw-r--r--openstack/common/notifier/no_op_notifier.py2
-rw-r--r--openstack/common/notifier/rpc_notifier.py2
-rw-r--r--openstack/common/notifier/rpc_notifier2.py2
-rw-r--r--openstack/common/plugin/callbackplugin.py2
-rw-r--r--openstack/common/plugin/pluginmanager.py2
-rw-r--r--openstack/common/policy.py207
-rw-r--r--openstack/common/processutils.py2
-rw-r--r--openstack/common/rpc/amqp.py14
-rw-r--r--openstack/common/rpc/common.py6
-rw-r--r--openstack/common/rpc/impl_fake.py4
-rw-r--r--openstack/common/rpc/impl_kombu.py61
-rw-r--r--openstack/common/rpc/impl_qpid.py111
-rw-r--r--openstack/common/rpc/impl_zmq.py44
-rw-r--r--openstack/common/rpc/serializer.py4
-rw-r--r--openstack/common/rpc/service.py3
-rw-r--r--openstack/common/scheduler/base_filter.py2
-rw-r--r--openstack/common/scheduler/base_handler.py3
-rw-r--r--openstack/common/scheduler/filters/capabilities_filter.py5
-rw-r--r--openstack/common/scheduler/filters/json_filter.py2
-rw-r--r--openstack/common/service.py2
-rw-r--r--openstack/common/strutils.py39
-rw-r--r--openstack/common/threadgroup.py6
-rw-r--r--openstack/common/timeutils.py11
-rw-r--r--requirements.txt (renamed from tools/pip-requires)0
-rw-r--r--test-requirements.txt (renamed from tools/test-requires)0
-rw-r--r--tests/unit/db/sqlalchemy/test_sqlalchemy.py20
-rw-r--r--tests/unit/deprecated/__init__.py0
-rw-r--r--tests/unit/deprecated/test_wsgi.py (renamed from tests/unit/test_wsgi.py)5
-rw-r--r--tests/unit/extension_stubs.py2
-rw-r--r--tests/unit/middleware/test_correlation_id.py5
-rw-r--r--tests/unit/middleware/test_sizelimit.py10
-rw-r--r--tests/unit/plugin/test_callback_plugin.py2
-rw-r--r--tests/unit/rpc/amqp.py2
-rw-r--r--tests/unit/rpc/test_kombu.py19
-rw-r--r--tests/unit/rpc/test_qpid.py86
-rw-r--r--tests/unit/rpc/test_service.py4
-rw-r--r--tests/unit/scheduler/fake_hosts.py10
-rw-r--r--tests/unit/scheduler/test_base_filter.py3
-rw-r--r--tests/unit/scheduler/test_host_filters.py2
-rw-r--r--tests/unit/test_excutils.py4
-rw-r--r--tests/unit/test_fileutils.py111
-rw-r--r--tests/unit/test_gettext.py372
-rw-r--r--tests/unit/test_jsonutils.py5
-rw-r--r--tests/unit/test_lockutils.py6
-rw-r--r--tests/unit/test_log.py9
-rw-r--r--tests/unit/test_loopingcall.py2
-rw-r--r--tests/unit/test_notifier.py7
-rw-r--r--tests/unit/test_periodic.py2
-rw-r--r--tests/unit/test_plugin.py4
-rw-r--r--tests/unit/test_policy.py163
-rw-r--r--tests/unit/test_processutils.py13
-rw-r--r--tests/unit/test_rootwrap.py6
-rw-r--r--tests/unit/test_service.py2
-rw-r--r--tests/unit/test_strutils.py32
-rw-r--r--tests/unit/test_threadgroup.py2
-rw-r--r--tests/utils.py4
-rw-r--r--tests/var/policy.json4
-rw-r--r--tox.ini8
78 files changed, 1643 insertions, 434 deletions
diff --git a/.gitignore b/.gitignore
index 383605d..8bef1b6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,4 @@
+*~
*.swp
*.pyc
*.log
@@ -17,4 +18,4 @@ AUTHORS
ChangeLog
openstack/versioninfo
*.egg
-openstack/common/db/*.sqlite \ No newline at end of file
+openstack/common/db/*.sqlite
diff --git a/MAINTAINERS b/MAINTAINERS
index e73a25d..11d0670 100644
--- a/MAINTAINERS
+++ b/MAINTAINERS
@@ -111,8 +111,8 @@ F: excutils.py
== fileutils ==
-M:
-S: Orphan
+M: Zhongyue Luo <zhongyue.nah@intel.com>
+S: Maintained
F: fileutils.py
== fixtures ==
diff --git a/README.rst b/README.rst
index 33b88ab..c115054 100644
--- a/README.rst
+++ b/README.rst
@@ -27,5 +27,5 @@ To run tests in virtualenvs (preferred):
To run tests in the current environment:
- sudo pip install -r tools/pip-requires
+ sudo pip install -r requirements.txt
nosetests
diff --git a/openstack/common/config/generator.py b/openstack/common/config/generator.py
index 4251c87..3f66f74 100755
--- a/openstack/common/config/generator.py
+++ b/openstack/common/config/generator.py
@@ -207,7 +207,7 @@ def _print_opt(opt):
opt_type = None
try:
opt_type = OPTION_REGEX.search(str(type(opt))).group(0)
- except (ValueError, AttributeError), err:
+ except (ValueError, AttributeError) as err:
sys.stderr.write("%s\n" % str(err))
sys.exit(1)
opt_help += ' (' + OPT_TYPES[opt_type] + ')'
diff --git a/openstack/common/db/sqlalchemy/models.py b/openstack/common/db/sqlalchemy/models.py
index 6638c83..f61bb39 100644
--- a/openstack/common/db/sqlalchemy/models.py
+++ b/openstack/common/db/sqlalchemy/models.py
@@ -81,7 +81,8 @@ class ModelBase(object):
def iteritems(self):
"""Make the model object behave like a dict.
- Includes attributes from joins."""
+ Includes attributes from joins.
+ """
local = dict(self)
joined = dict([(k, v) for k, v in self.__dict__.iteritems()
if not k[0] == '_'])
diff --git a/openstack/common/db/sqlalchemy/session.py b/openstack/common/db/sqlalchemy/session.py
index b96123a..4394846 100644
--- a/openstack/common/db/sqlalchemy/session.py
+++ b/openstack/common/db/sqlalchemy/session.py
@@ -256,8 +256,8 @@ from sqlalchemy.pool import NullPool, StaticPool
from sqlalchemy.sql.expression import literal_column
from openstack.common.db import exception
-from openstack.common import log as logging
from openstack.common.gettextutils import _
+from openstack.common import log as logging
from openstack.common import timeutils
DEFAULT = 'DEFAULT'
@@ -281,6 +281,11 @@ database_opts = [
deprecated_name='sql_connection',
deprecated_group=DEFAULT,
secret=True),
+ cfg.StrOpt('slave_connection',
+ default='',
+ help='The SQLAlchemy connection string used to connect to the '
+ 'slave database',
+ secret=True),
cfg.IntOpt('idle_timeout',
default=3600,
deprecated_name='sql_idle_timeout',
@@ -334,6 +339,8 @@ LOG = logging.getLogger(__name__)
_ENGINE = None
_MAKER = None
+_SLAVE_ENGINE = None
+_SLAVE_MAKER = None
def set_defaults(sql_connection, sqlite_db):
@@ -346,6 +353,7 @@ def set_defaults(sql_connection, sqlite_db):
def cleanup():
global _ENGINE, _MAKER
+ global _SLAVE_ENGINE, _SLAVE_MAKER
if _MAKER:
_MAKER.close_all()
@@ -353,6 +361,12 @@ def cleanup():
if _ENGINE:
_ENGINE.dispose()
_ENGINE = None
+ if _SLAVE_MAKER:
+ _SLAVE_MAKER.close_all()
+ _SLAVE_MAKER = None
+ if _SLAVE_ENGINE:
+ _SLAVE_ENGINE.dispose()
+ _SLAVE_ENGINE = None
class SqliteForeignKeysListener(PoolListener):
@@ -368,15 +382,25 @@ class SqliteForeignKeysListener(PoolListener):
def get_session(autocommit=True, expire_on_commit=False,
- sqlite_fk=False):
+ sqlite_fk=False, slave_session=False):
"""Return a SQLAlchemy session."""
global _MAKER
+ global _SLAVE_MAKER
+ maker = _MAKER
- if _MAKER is None:
- engine = get_engine(sqlite_fk=sqlite_fk)
- _MAKER = get_maker(engine, autocommit, expire_on_commit)
+ if slave_session:
+ maker = _SLAVE_MAKER
+
+ if maker is None:
+ engine = get_engine(sqlite_fk=sqlite_fk, slave_engine=slave_session)
+ maker = get_maker(engine, autocommit, expire_on_commit)
+
+ if slave_session:
+ _SLAVE_MAKER = maker
+ else:
+ _MAKER = maker
- session = _MAKER()
+ session = maker()
return session
@@ -412,7 +436,7 @@ def _raise_if_duplicate_entry_error(integrity_error, engine_name):
"""
def get_columns_from_uniq_cons_or_name(columns):
- # note(vsergeyev): UniqueConstraint name convention: "uniq_t$c1$c2"
+ # note(vsergeyev): UniqueConstraint name convention: "uniq_t0c10c2"
# where `t` it is table name and columns `c1`, `c2`
# are in UniqueConstraint.
uniqbase = "uniq_"
@@ -420,7 +444,7 @@ def _raise_if_duplicate_entry_error(integrity_error, engine_name):
if engine_name == "postgresql":
return [columns[columns.index("_") + 1:columns.rindex("_")]]
return [columns]
- return columns[len(uniqbase):].split("$")[1:]
+ return columns[len(uniqbase):].split("0")[1:]
if engine_name not in ["mysql", "sqlite", "postgresql"]:
return
@@ -491,13 +515,26 @@ def _wrap_db_error(f):
return _wrap
-def get_engine(sqlite_fk=False):
+def get_engine(sqlite_fk=False, slave_engine=False):
"""Return a SQLAlchemy engine."""
global _ENGINE
- if _ENGINE is None:
- _ENGINE = create_engine(CONF.database.connection,
- sqlite_fk=sqlite_fk)
- return _ENGINE
+ global _SLAVE_ENGINE
+ engine = _ENGINE
+ db_uri = CONF.database.connection
+
+ if slave_engine:
+ engine = _SLAVE_ENGINE
+ db_uri = CONF.database.slave_connection
+
+ if engine is None:
+ engine = create_engine(db_uri,
+ sqlite_fk=sqlite_fk)
+ if slave_engine:
+ _SLAVE_ENGINE = engine
+ else:
+ _ENGINE = engine
+
+ return engine
def _synchronous_switch_listener(dbapi_conn, connection_rec):
@@ -555,6 +592,11 @@ def _is_db_connection_error(args):
def create_engine(sql_connection, sqlite_fk=False):
"""Return a new SQLAlchemy engine."""
+ # NOTE(geekinutah): At this point we could be connecting to the normal
+ # db handle or the slave db handle. Things like
+ # _wrap_db_error aren't going to work well if their
+ # backends don't match. Let's check.
+ _assert_matching_drivers()
connection_dict = sqlalchemy.engine.url.make_url(sql_connection)
engine_args = {
@@ -696,3 +738,15 @@ def _patch_mysqldb_with_stacktrace_comments():
old_mysql_do_query(self, qq)
setattr(MySQLdb.cursors.BaseCursor, '_do_query', _do_query)
+
+
+def _assert_matching_drivers():
+ """Make sure slave handle and normal handle have the same driver."""
+ # NOTE(geekinutah): There's no use case for writing to one backend and
+ # reading from another. Who knows what the future holds?
+ if CONF.database.slave_connection == '':
+ return
+
+ normal = sqlalchemy.engine.url.make_url(CONF.database.connection)
+ slave = sqlalchemy.engine.url.make_url(CONF.database.slave_connection)
+ assert normal.drivername == slave.drivername
diff --git a/openstack/common/deprecated/__init__.py b/openstack/common/deprecated/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/openstack/common/deprecated/__init__.py
diff --git a/openstack/common/wsgi.py b/openstack/common/deprecated/wsgi.py
index 80d4b9b..f9be97e 100644
--- a/openstack/common/wsgi.py
+++ b/openstack/common/deprecated/wsgi.py
@@ -17,15 +17,12 @@
"""Utility methods for working with WSGI servers."""
-from __future__ import print_function
-
import eventlet
eventlet.patcher.monkey_patch(all=False, socket=True)
import datetime
import errno
import socket
-import sys
import time
import eventlet.wsgi
@@ -164,88 +161,6 @@ class Service(service.Service):
log=logging.WritableLogger(logger))
-class Middleware(object):
- """
- Base WSGI middleware wrapper. These classes require an application to be
- initialized that will be called next. By default the middleware will
- simply call its wrapped app, or you can override __call__ to customize its
- behavior.
- """
-
- @classmethod
- def factory(cls, global_conf, **local_conf):
- """
- Factory method for paste.deploy
- """
-
- def filter(app):
- return cls(app)
-
- return filter
-
- def __init__(self, application):
- self.application = application
-
- def process_request(self, req):
- """
- Called on each request.
-
- If this returns None, the next application down the stack will be
- executed. If it returns a response then that response will be returned
- and execution will stop here.
- """
- return None
-
- def process_response(self, response):
- """Do whatever you'd like to the response."""
- return response
-
- @webob.dec.wsgify
- def __call__(self, req):
- response = self.process_request(req)
- if response:
- return response
- response = req.get_response(self.application)
- return self.process_response(response)
-
-
-class Debug(Middleware):
- """
- Helper class that can be inserted into any WSGI application chain
- to get information about the request and response.
- """
-
- @webob.dec.wsgify
- def __call__(self, req):
- print(("*" * 40) + " REQUEST ENVIRON")
- for key, value in req.environ.items():
- print(key, "=", value)
- print()
- resp = req.get_response(self.application)
-
- print(("*" * 40) + " RESPONSE HEADERS")
- for (key, value) in resp.headers.iteritems():
- print(key, "=", value)
- print()
-
- resp.app_iter = self.print_generator(resp.app_iter)
-
- return resp
-
- @staticmethod
- def print_generator(app_iter):
- """
- Iterator that prints the contents of a wrapper string iterator
- when iterated.
- """
- print(("*" * 40) + " BODY")
- for part in app_iter:
- sys.stdout.write(part)
- sys.stdout.flush()
- yield part
- print()
-
-
class Router(object):
"""
@@ -448,7 +363,7 @@ class ActionDispatcher(object):
class DictSerializer(ActionDispatcher):
- """Default request body serialization"""
+ """Default request body serialization."""
def serialize(self, data, action='default'):
return self.dispatch(data, action=action)
@@ -458,7 +373,7 @@ class DictSerializer(ActionDispatcher):
class JSONDictSerializer(DictSerializer):
- """Default JSON request body serialization"""
+ """Default JSON request body serialization."""
def default(self, data):
def sanitizer(obj):
@@ -570,7 +485,7 @@ class XMLDictSerializer(DictSerializer):
class ResponseHeadersSerializer(ActionDispatcher):
- """Default response headers serialization"""
+ """Default response headers serialization."""
def serialize(self, response, data, action):
self.dispatch(response, data, action=action)
@@ -580,7 +495,7 @@ class ResponseHeadersSerializer(ActionDispatcher):
class ResponseSerializer(object):
- """Encode the necessary pieces into a response object"""
+ """Encode the necessary pieces into a response object."""
def __init__(self, body_serializers=None, headers_serializer=None):
self.body_serializers = {
@@ -722,7 +637,7 @@ class RequestDeserializer(object):
class TextDeserializer(ActionDispatcher):
- """Default request body deserialization"""
+ """Default request body deserialization."""
def deserialize(self, datastring, action='default'):
return self.dispatch(datastring, action=action)
@@ -787,20 +702,20 @@ class XMLDeserializer(TextDeserializer):
return result
def find_first_child_named(self, parent, name):
- """Search a nodes children for the first child with a given name"""
+ """Search a nodes children for the first child with a given name."""
for node in parent.childNodes:
if node.nodeName == name:
return node
return None
def find_children_named(self, parent, name):
- """Return all of a nodes children who have the given name"""
+ """Return all of a nodes children who have the given name."""
for node in parent.childNodes:
if node.nodeName == name:
yield node
def extract_text(self, node):
- """Get the text field contained by the given node"""
+ """Get the text field contained by the given node."""
if len(node.childNodes) == 1:
child = node.childNodes[0]
if child.nodeType == child.TEXT_NODE:
diff --git a/openstack/common/fileutils.py b/openstack/common/fileutils.py
index b988ad0..9f8807f 100644
--- a/openstack/common/fileutils.py
+++ b/openstack/common/fileutils.py
@@ -16,9 +16,18 @@
# under the License.
+import contextlib
import errno
import os
+from openstack.common import excutils
+from openstack.common.gettextutils import _
+from openstack.common import log as logging
+
+LOG = logging.getLogger(__name__)
+
+_FILE_CACHE = {}
+
def ensure_tree(path):
"""Create a directory (and any ancestor directories required)
@@ -33,3 +42,69 @@ def ensure_tree(path):
raise
else:
raise
+
+
+def read_cached_file(filename, force_reload=False):
+ """Read from a file if it has been modified.
+
+ :param force_reload: Whether to reload the file.
+ :returns: A tuple with a boolean specifying if the data is fresh
+ or not.
+ """
+ global _FILE_CACHE
+
+ if force_reload and filename in _FILE_CACHE:
+ del _FILE_CACHE[filename]
+
+ reloaded = False
+ mtime = os.path.getmtime(filename)
+ cache_info = _FILE_CACHE.setdefault(filename, {})
+
+ if not cache_info or mtime > cache_info.get('mtime', 0):
+ LOG.debug(_("Reloading cached file %s") % filename)
+ with open(filename) as fap:
+ cache_info['data'] = fap.read()
+ cache_info['mtime'] = mtime
+ reloaded = True
+ return (reloaded, cache_info['data'])
+
+
+def delete_if_exists(path):
+ """Delete a file, but ignore file not found error.
+
+ :param path: File to delete
+ """
+
+ try:
+ os.unlink(path)
+ except OSError as e:
+ if e.errno == errno.ENOENT:
+ return
+ else:
+ raise
+
+
+@contextlib.contextmanager
+def remove_path_on_error(path):
+ """Protect code that wants to operate on PATH atomically.
+ Any exception will cause PATH to be removed.
+
+ :param path: File to work with
+ """
+ try:
+ yield
+ except Exception:
+ with excutils.save_and_reraise_exception():
+ delete_if_exists(path)
+
+
+def file_open(*args, **kwargs):
+ """Open file
+
+ see built-in file() documentation for more details
+
+ Note: The reason this is kept in a separate module is to easily
+ be able to provide a stub module that doesn't alter system
+ state at all (for unit tests)
+ """
+ return file(*args, **kwargs)
diff --git a/openstack/common/gettextutils.py b/openstack/common/gettextutils.py
index e816f14..d6b5f10 100644
--- a/openstack/common/gettextutils.py
+++ b/openstack/common/gettextutils.py
@@ -2,6 +2,7 @@
# Copyright 2012 Red Hat, Inc.
# All Rights Reserved.
+# Copyright 2013 IBM Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
@@ -23,8 +24,11 @@ Usual usage in an openstack.common module:
from openstack.common.gettextutils import _
"""
+import copy
import gettext
+import logging.handlers
import os
+import UserString
_localedir = os.environ.get('oslo'.upper() + '_LOCALEDIR')
_t = gettext.translation('oslo', localedir=_localedir, fallback=True)
@@ -48,3 +52,175 @@ def install(domain):
gettext.install(domain,
localedir=os.environ.get(domain.upper() + '_LOCALEDIR'),
unicode=True)
+
+
+"""
+Lazy gettext functionality.
+
+The following is an attempt to introduce a deferred way
+to do translations on messages in OpenStack. We attempt to
+override the standard _() function and % (format string) operation
+to build Message objects that can later be translated when we have
+more information. Also included is an example LogHandler that
+translates Messages to an associated locale, effectively allowing
+many logs, each with their own locale.
+"""
+
+
+def get_lazy_gettext(domain):
+ """Assemble and return a lazy gettext function for a given domain.
+
+ Factory method for a project/module to get a lazy gettext function
+ for its own translation domain (i.e. nova, glance, cinder, etc.)
+ """
+
+ def _lazy_gettext(msg):
+ """
+ Create and return a Message object encapsulating a string
+ so that we can translate it later when needed.
+ """
+ return Message(msg, domain)
+
+ return _lazy_gettext
+
+
+class Message(UserString.UserString, object):
+ """Class used to encapsulate translatable messages."""
+ def __init__(self, msg, domain):
+ # _msg is the gettext msgid and should never change
+ self._msg = msg
+ self._left_extra_msg = ''
+ self._right_extra_msg = ''
+ self.params = None
+ self.locale = None
+ self.domain = domain
+
+ @property
+ def data(self):
+ # NOTE(mrodden): this should always resolve to a unicode string
+ # that best represents the state of the message currently
+
+ localedir = os.environ.get(self.domain.upper() + '_LOCALEDIR')
+ if self.locale:
+ lang = gettext.translation(self.domain,
+ localedir=localedir,
+ languages=[self.locale],
+ fallback=True)
+ else:
+ # use system locale for translations
+ lang = gettext.translation(self.domain,
+ localedir=localedir,
+ fallback=True)
+
+ full_msg = (self._left_extra_msg +
+ lang.ugettext(self._msg) +
+ self._right_extra_msg)
+
+ if self.params is not None:
+ full_msg = full_msg % self.params
+
+ return unicode(full_msg)
+
+ def _save_parameters(self, other):
+ # we check for None later to see if
+ # we actually have parameters to inject,
+ # so encapsulate if our parameter is actually None
+ if other is None:
+ self.params = (other, )
+ else:
+ self.params = copy.deepcopy(other)
+
+ return self
+
+ # overrides to be more string-like
+ def __unicode__(self):
+ return self.data
+
+ def __str__(self):
+ return self.data.encode('utf-8')
+
+ def __getstate__(self):
+ to_copy = ['_msg', '_right_extra_msg', '_left_extra_msg',
+ 'domain', 'params', 'locale']
+ new_dict = self.__dict__.fromkeys(to_copy)
+ for attr in to_copy:
+ new_dict[attr] = copy.deepcopy(self.__dict__[attr])
+
+ return new_dict
+
+ def __setstate__(self, state):
+ for (k, v) in state.items():
+ setattr(self, k, v)
+
+ # operator overloads
+ def __add__(self, other):
+ copied = copy.deepcopy(self)
+ copied._right_extra_msg += other.__str__()
+ return copied
+
+ def __radd__(self, other):
+ copied = copy.deepcopy(self)
+ copied._left_extra_msg += other.__str__()
+ return copied
+
+ def __mod__(self, other):
+ # do a format string to catch and raise
+ # any possible KeyErrors from missing parameters
+ self.data % other
+ copied = copy.deepcopy(self)
+ return copied._save_parameters(other)
+
+ def __mul__(self, other):
+ return self.data * other
+
+ def __rmul__(self, other):
+ return other * self.data
+
+ def __getitem__(self, key):
+ return self.data[key]
+
+ def __getslice__(self, start, end):
+ return self.data.__getslice__(start, end)
+
+ def __getattribute__(self, name):
+ # NOTE(mrodden): handle lossy operations that we can't deal with yet
+ # These override the UserString implementation, since UserString
+ # uses our __class__ attribute to try and build a new message
+ # after running the inner data string through the operation.
+ # At that point, we have lost the gettext message id and can just
+ # safely resolve to a string instead.
+ ops = ['capitalize', 'center', 'decode', 'encode',
+ 'expandtabs', 'ljust', 'lstrip', 'replace', 'rjust', 'rstrip',
+ 'strip', 'swapcase', 'title', 'translate', 'upper', 'zfill']
+ if name in ops:
+ return getattr(self.data, name)
+ else:
+ return UserString.UserString.__getattribute__(self, name)
+
+
+class LocaleHandler(logging.Handler):
+ """Handler that can have a locale associated to translate Messages.
+
+ A quick example of how to utilize the Message class above.
+ LocaleHandler takes a locale and a target logging.Handler object
+ to forward LogRecord objects to after translating the internal Message.
+ """
+
+ def __init__(self, locale, target):
+ """
+ Initialize a LocaleHandler
+
+ :param locale: locale to use for translating messages
+ :param target: logging.Handler object to forward
+ LogRecord objects to after translation
+ """
+ logging.Handler.__init__(self)
+ self.locale = locale
+ self.target = target
+
+ def emit(self, record):
+ if isinstance(record.msg, Message):
+ # set the locale and resolve to a string
+ record.msg.locale = self.locale
+
+ self.target.emit(record)
diff --git a/openstack/common/importutils.py b/openstack/common/importutils.py
index 3bd277f..dbee325 100644
--- a/openstack/common/importutils.py
+++ b/openstack/common/importutils.py
@@ -24,7 +24,7 @@ import traceback
def import_class(import_str):
- """Returns a class from a string including module and class"""
+ """Returns a class from a string including module and class."""
mod_str, _sep, class_str = import_str.rpartition('.')
try:
__import__(mod_str)
diff --git a/openstack/common/lockutils.py b/openstack/common/lockutils.py
index 79d1905..27525ce 100644
--- a/openstack/common/lockutils.py
+++ b/openstack/common/lockutils.py
@@ -158,17 +158,18 @@ def synchronized(name, lock_file_prefix, external=False, lock_path=None):
This way only one of either foo or bar can be executing at a time.
- The lock_file_prefix argument is used to provide lock files on disk with a
- meaningful prefix. The prefix should end with a hyphen ('-') if specified.
-
- The external keyword argument denotes whether this lock should work across
- multiple processes. This means that if two different workers both run a
- a method decorated with @synchronized('mylock', external=True), only one
- of them will execute at a time.
-
- The lock_path keyword argument is used to specify a special location for
- external lock files to live. If nothing is set, then CONF.lock_path is
- used as a default.
+ :param lock_file_prefix: The lock_file_prefix argument is used to provide
+ lock files on disk with a meaningful prefix. The prefix should end with a
+ hyphen ('-') if specified.
+
+ :param external: The external keyword argument denotes whether this lock
+ should work across multiple processes. This means that if two different
+ workers both run a a method decorated with @synchronized('mylock',
+ external=True), only one of them will execute at a time.
+
+ :param lock_path: The lock_path keyword argument is used to specify a
+ special location for external lock files to live. If nothing is set, then
+ CONF.lock_path is used as a default.
"""
def wrap(f):
diff --git a/openstack/common/memorycache.py b/openstack/common/memorycache.py
index 23847e6..f60143a 100644
--- a/openstack/common/memorycache.py
+++ b/openstack/common/memorycache.py
@@ -57,7 +57,8 @@ class Client(object):
def get(self, key):
"""Retrieves the value for a key or None.
- this expunges expired keys during each get"""
+ This expunges expired keys during each get.
+ """
now = timeutils.utcnow_ts()
for k in self.cache.keys():
diff --git a/openstack/common/middleware/base.py b/openstack/common/middleware/base.py
new file mode 100644
index 0000000..624a391
--- /dev/null
+++ b/openstack/common/middleware/base.py
@@ -0,0 +1,62 @@
+# Copyright 2011 OpenStack Foundation.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""Base class(es) for WSGI Middleware."""
+
+import webob.dec
+
+
+class Middleware(object):
+ """
+ Base WSGI middleware wrapper. These classes require an application to be
+ initialized that will be called next. By default the middleware will
+ simply call its wrapped app, or you can override __call__ to customize its
+ behavior.
+ """
+
+ @classmethod
+ def factory(cls, global_conf, **local_conf):
+ """
+ Factory method for paste.deploy
+ """
+
+ def filter(app):
+ return cls(app)
+
+ return filter
+
+ def __init__(self, application):
+ self.application = application
+
+ def process_request(self, req):
+ """
+ Called on each request.
+
+ If this returns None, the next application down the stack will be
+ executed. If it returns a response then that response will be returned
+ and execution will stop here.
+ """
+ return None
+
+ def process_response(self, response):
+ """Do whatever you'd like to the response."""
+ return response
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ response = self.process_request(req)
+ if response:
+ return response
+ response = req.get_response(self.application)
+ return self.process_response(response)
diff --git a/openstack/common/middleware/context.py b/openstack/common/middleware/context.py
index ac94190..2636e8e 100644
--- a/openstack/common/middleware/context.py
+++ b/openstack/common/middleware/context.py
@@ -21,10 +21,10 @@ Middleware that attaches a context to the WSGI request
from openstack.common import context
from openstack.common import importutils
-from openstack.common import wsgi
+from openstack.common.middleware import base
-class ContextMiddleware(wsgi.Middleware):
+class ContextMiddleware(base.Middleware):
def __init__(self, app, options):
self.options = options
super(ContextMiddleware, self).__init__(app)
diff --git a/openstack/common/middleware/correlation_id.py b/openstack/common/middleware/correlation_id.py
index a3efe34..bffa0d7 100644
--- a/openstack/common/middleware/correlation_id.py
+++ b/openstack/common/middleware/correlation_id.py
@@ -17,11 +17,11 @@
"""Middleware that attaches a correlation id to WSGI request"""
+from openstack.common.middleware import base
from openstack.common import uuidutils
-from openstack.common import wsgi
-class CorrelationIdMiddleware(wsgi.Middleware):
+class CorrelationIdMiddleware(base.Middleware):
def process_request(self, req):
correlation_id = (req.headers.get("X_CORRELATION_ID") or
diff --git a/openstack/common/middleware/debug.py b/openstack/common/middleware/debug.py
new file mode 100644
index 0000000..b92af11
--- /dev/null
+++ b/openstack/common/middleware/debug.py
@@ -0,0 +1,60 @@
+# Copyright 2011 OpenStack Foundation.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""Debug middleware"""
+
+from __future__ import print_function
+
+import sys
+
+import webob.dec
+
+from openstack.common.middleware import base
+
+
+class Debug(base.Middleware):
+ """
+ Helper class that can be inserted into any WSGI application chain
+ to get information about the request and response.
+ """
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ print(("*" * 40) + " REQUEST ENVIRON")
+ for key, value in req.environ.items():
+ print(key, "=", value)
+ print()
+ resp = req.get_response(self.application)
+
+ print(("*" * 40) + " RESPONSE HEADERS")
+ for (key, value) in resp.headers.iteritems():
+ print(key, "=", value)
+ print()
+
+ resp.app_iter = self.print_generator(resp.app_iter)
+
+ return resp
+
+ @staticmethod
+ def print_generator(app_iter):
+ """
+ Iterator that prints the contents of a wrapper string iterator
+ when iterated.
+ """
+ print(("*" * 40) + " BODY")
+ for part in app_iter:
+ sys.stdout.write(part)
+ sys.stdout.flush()
+ yield part
+ print()
diff --git a/openstack/common/middleware/sizelimit.py b/openstack/common/middleware/sizelimit.py
index 45de527..1128b8a 100644
--- a/openstack/common/middleware/sizelimit.py
+++ b/openstack/common/middleware/sizelimit.py
@@ -22,8 +22,9 @@ from oslo.config import cfg
import webob.dec
import webob.exc
+from openstack.common.deprecated import wsgi
from openstack.common.gettextutils import _
-from openstack.common import wsgi
+from openstack.common.middleware import base
#default request size is 112k
@@ -66,7 +67,7 @@ class LimitingReader(object):
return result
-class RequestBodySizeLimiter(wsgi.Middleware):
+class RequestBodySizeLimiter(base.Middleware):
"""Limit the size of incoming requests."""
def __init__(self, *args, **kwargs):
diff --git a/openstack/common/notifier/api.py b/openstack/common/notifier/api.py
index 13ac394..7c4dbd1 100644
--- a/openstack/common/notifier/api.py
+++ b/openstack/common/notifier/api.py
@@ -56,7 +56,7 @@ class BadPriorityException(Exception):
def notify_decorator(name, fn):
- """ decorator for notify which is used from utils.monkey_patch()
+ """Decorator for notify which is used from utils.monkey_patch().
:param name: name of the function
:param function: - object of the function
diff --git a/openstack/common/notifier/log_notifier.py b/openstack/common/notifier/log_notifier.py
index aa3bc0a..d3ef0ae 100644
--- a/openstack/common/notifier/log_notifier.py
+++ b/openstack/common/notifier/log_notifier.py
@@ -24,7 +24,9 @@ CONF = cfg.CONF
def notify(_context, message):
"""Notifies the recipient of the desired event given the model.
- Log notifications using openstack's default logging system"""
+
+ Log notifications using openstack's default logging system.
+ """
priority = message.get('priority',
CONF.default_notification_level)
diff --git a/openstack/common/notifier/no_op_notifier.py b/openstack/common/notifier/no_op_notifier.py
index bc7a56c..13d946e 100644
--- a/openstack/common/notifier/no_op_notifier.py
+++ b/openstack/common/notifier/no_op_notifier.py
@@ -15,5 +15,5 @@
def notify(_context, message):
- """Notifies the recipient of the desired event given the model"""
+ """Notifies the recipient of the desired event given the model."""
pass
diff --git a/openstack/common/notifier/rpc_notifier.py b/openstack/common/notifier/rpc_notifier.py
index 52677fe..17bbc9a 100644
--- a/openstack/common/notifier/rpc_notifier.py
+++ b/openstack/common/notifier/rpc_notifier.py
@@ -31,7 +31,7 @@ CONF.register_opt(notification_topic_opt)
def notify(context, message):
- """Sends a notification via RPC"""
+ """Sends a notification via RPC."""
if not context:
context = req_context.get_admin_context()
priority = message.get('priority',
diff --git a/openstack/common/notifier/rpc_notifier2.py b/openstack/common/notifier/rpc_notifier2.py
index 6ccc9c5..38fe33b 100644
--- a/openstack/common/notifier/rpc_notifier2.py
+++ b/openstack/common/notifier/rpc_notifier2.py
@@ -37,7 +37,7 @@ CONF.register_opt(notification_topic_opt, opt_group)
def notify(context, message):
- """Sends a notification via RPC"""
+ """Sends a notification via RPC."""
if not context:
context = req_context.get_admin_context()
priority = message.get('priority',
diff --git a/openstack/common/plugin/callbackplugin.py b/openstack/common/plugin/callbackplugin.py
index fead44c..2de7fb0 100644
--- a/openstack/common/plugin/callbackplugin.py
+++ b/openstack/common/plugin/callbackplugin.py
@@ -58,7 +58,7 @@ class _CallbackNotifier(object):
class CallbackPlugin(plugin.Plugin):
- """ Plugin with a simple callback interface.
+ """Plugin with a simple callback interface.
This class is provided as a convenience for producing a simple
plugin that only watches a couple of events. For example, here's
diff --git a/openstack/common/plugin/pluginmanager.py b/openstack/common/plugin/pluginmanager.py
index 56064dd..3962447 100644
--- a/openstack/common/plugin/pluginmanager.py
+++ b/openstack/common/plugin/pluginmanager.py
@@ -41,7 +41,7 @@ class PluginManager(object):
"""
def __init__(self, project_name, service_name):
- """ Construct Plugin Manager; load and initialize plugins.
+ """Construct Plugin Manager; load and initialize plugins.
project_name (e.g. 'nova' or 'glance') is used
to construct the entry point that identifies plugins.
diff --git a/openstack/common/policy.py b/openstack/common/policy.py
index cd6dcfc..f3e62ba 100644
--- a/openstack/common/policy.py
+++ b/openstack/common/policy.py
@@ -59,22 +59,40 @@ as it allows particular rules to be explicitly disabled.
import abc
import re
import urllib
+import urllib2
+from oslo.config import cfg
import six
-import urllib2
+from openstack.common import fileutils
from openstack.common.gettextutils import _
from openstack.common import jsonutils
from openstack.common import log as logging
+policy_opts = [
+ cfg.StrOpt('policy_file',
+ default='policy.json',
+ help=_('JSON file containing policy')),
+ cfg.StrOpt('policy_default_rule',
+ default='default',
+ help=_('Rule enforced when requested rule is not found')),
+]
-LOG = logging.getLogger(__name__)
+CONF = cfg.CONF
+CONF.register_opts(policy_opts)
+LOG = logging.getLogger(__name__)
-_rules = None
_checks = {}
+class PolicyNotAuthorized(Exception):
+
+ def __init__(self, rule):
+ msg = _("Policy doesn't allow %s to be performed.") % rule
+ super(PolicyNotAuthorized, self).__init__(msg)
+
+
class Rules(dict):
"""
A store for rules. Handles the default_rule setting directly.
@@ -124,65 +142,146 @@ class Rules(dict):
return jsonutils.dumps(out_rules, indent=4)
-# Really have to figure out a way to deprecate this
-def set_rules(rules):
- """Set the rules in use for policy checks."""
+class Enforcer(object):
+ """
+ Responsible for loading and enforcing rules
- global _rules
+ :param policy_file: Custom policy file to use, if none is
+ specified, `CONF.policy_file` will be
+ used.
+ :param rules: Default dictionary / Rules to use. It will be
+ considered just in the first instantiation. If
+ `load_rules(True)`, `clear()` or `set_rules(True)`
+ is called this will be overwritten.
+ :param default_rule: Default rule to use, CONF.default_rule will
+ be used if none is specified.
+ """
- _rules = rules
+ def __init__(self, policy_file=None, rules=None, default_rule=None):
+ self.rules = Rules(rules)
+ self.default_rule = default_rule or CONF.policy_default_rule
+ self.policy_path = None
+ self.policy_file = policy_file or CONF.policy_file
-# Ditto
-def reset():
- """Clear the rules used for policy checks."""
+ def set_rules(self, rules, overwrite=True):
+ """
+ Create a new Rules object based on the provided dict of rules
- global _rules
+ :param rules: New rules to use. It should be an instance of dict.
+ :param overwrite: Whether to overwrite current rules or update them
+ with the new rules.
+ """
- _rules = None
+ if not isinstance(rules, dict):
+ raise TypeError(_("Rules must be an instance of dict or Rules, "
+ "got %s instead") % type(rules))
+ if overwrite:
+ self.rules = Rules(rules)
+ else:
+ self.update(rules)
-def check(rule, target, creds, exc=None, *args, **kwargs):
- """
- Checks authorization of a rule against the target and credentials.
+ def clear(self):
+ """
+ Clears Enforcer rules, policy's cache
+ and policy's path.
+ """
+ self.set_rules({})
+ self.policy_path = None
- :param rule: The rule to evaluate.
- :param target: As much information about the object being operated
- on as possible, as a dictionary.
- :param creds: As much information about the user performing the
- action as possible, as a dictionary.
- :param exc: Class of the exception to raise if the check fails.
- Any remaining arguments passed to check() (both
- positional and keyword arguments) will be passed to
- the exception class. If exc is not provided, returns
- False.
+ def load_rules(self, force_reload=False):
+ """
+ Loads policy_path's rules. Policy file is cached
+ and will be reloaded if modified.
- :return: Returns False if the policy does not allow the action and
- exc is not provided; otherwise, returns a value that
- evaluates to True. Note: for rules using the "case"
- expression, this True value will be the specified string
- from the expression.
- """
+ :param force_reload: Whether to overwrite current rules.
+ """
- # Allow the rule to be a Check tree
- if isinstance(rule, BaseCheck):
- result = rule(target, creds)
- elif not _rules:
- # No rules to reference means we're going to fail closed
- result = False
- else:
- try:
- # Evaluate the rule
- result = _rules[rule](target, creds)
- except KeyError:
- # If the rule doesn't exist, fail closed
+ if not self.policy_path:
+ self.policy_path = self._get_policy_path()
+
+ reloaded, data = fileutils.read_cached_file(self.policy_path,
+ force_reload=force_reload)
+
+ if reloaded:
+ rules = Rules.load_json(data, self.default_rule)
+ self.set_rules(rules)
+ LOG.debug(_("Rules successfully reloaded"))
+
+ def _get_policy_path(self):
+ """
+ Locate the policy json data file
+
+ :param policy_file: Custom policy file to locate.
+
+ :returns: The policy path
+
+ :raises: ConfigFilesNotFoundError if the file couldn't
+ be located.
+ """
+ policy_file = CONF.find_file(self.policy_file)
+
+ if policy_file:
+ return policy_file
+
+ raise cfg.ConfigFilesNotFoundError(path=CONF.policy_file)
+
+ def enforce(self, rule, target, creds, do_raise=False,
+ exc=None, *args, **kwargs):
+ """
+ Checks authorization of a rule against the target and credentials.
+
+ :param rule: A string or BaseCheck instance specifying the rule
+ to evaluate.
+ :param target: As much information about the object being operated
+ on as possible, as a dictionary.
+ :param creds: As much information about the user performing the
+ action as possible, as a dictionary.
+ :param do_raise: Whether to raise an exception or not if check
+ fails.
+ :param exc: Class of the exception to raise if the check fails.
+ Any remaining arguments passed to check() (both
+ positional and keyword arguments) will be passed to
+ the exception class. If not specified, PolicyNotAuthorized
+ will be used.
+
+ :return: Returns False if the policy does not allow the action and
+ exc is not provided; otherwise, returns a value that
+ evaluates to True. Note: for rules using the "case"
+ expression, this True value will be the specified string
+ from the expression.
+ """
+
+ # NOTE(flaper87): Not logging target or creds to avoid
+ # potential security issues.
+ LOG.debug(_("Rule %s will be now enforced") % rule)
+
+ self.load_rules()
+
+ # Allow the rule to be a Check tree
+ if isinstance(rule, BaseCheck):
+ result = rule(target, creds, self)
+ elif not self.rules:
+ # No rules to reference means we're going to fail closed
result = False
+ else:
+ try:
+ # Evaluate the rule
+ result = self.rules[rule](target, creds, self)
+ except KeyError:
+ LOG.debug(_("Rule [%s] doesn't exist") % rule)
+ # If the rule doesn't exist, fail closed
+ result = False
+
+ # If it is False, raise the exception if requested
+ if do_raise and not result:
+ if exc:
+ raise exc(*args, **kwargs)
- # If it is False, raise the exception if requested
- if exc and result is False:
- raise exc(*args, **kwargs)
+ raise PolicyNotAuthorized(rule)
- return result
+ return result
class BaseCheck(object):
@@ -392,7 +491,7 @@ def _parse_check(rule):
try:
kind, match = rule.split(':', 1)
except Exception:
- LOG.exception(_("Failed to understand rule %(rule)s") % locals())
+ LOG.exception(_("Failed to understand rule %s") % rule)
# If the rule is invalid, we'll fail closed
return FalseCheck()
@@ -723,13 +822,13 @@ def register(name, func=None):
@register("rule")
class RuleCheck(Check):
- def __call__(self, target, creds):
+ def __call__(self, target, creds, enforcer):
"""
Recursively checks credentials based on the defined rules.
"""
try:
- return _rules[self.match](target, creds)
+ return enforcer.rules[self.match](target, creds, enforcer)
except KeyError:
# We don't have any matching rule; fail closed
return False
@@ -737,7 +836,7 @@ class RuleCheck(Check):
@register("role")
class RoleCheck(Check):
- def __call__(self, target, creds):
+ def __call__(self, target, creds, enforcer):
"""Check that there is a matching role in the cred dict."""
return self.match.lower() in [x.lower() for x in creds['roles']]
@@ -745,7 +844,7 @@ class RoleCheck(Check):
@register('http')
class HttpCheck(Check):
- def __call__(self, target, creds):
+ def __call__(self, target, creds, enforcer):
"""
Check http: rules by calling to a remote server.
@@ -763,7 +862,7 @@ class HttpCheck(Check):
@register(None)
class GenericCheck(Check):
- def __call__(self, target, creds):
+ def __call__(self, target, creds, enforcer):
"""
Check an individual match.
diff --git a/openstack/common/processutils.py b/openstack/common/processutils.py
index 1aa1335..02cfada 100644
--- a/openstack/common/processutils.py
+++ b/openstack/common/processutils.py
@@ -203,7 +203,7 @@ def trycmd(*args, **kwargs):
try:
out, err = execute(*args, **kwargs)
failed = False
- except ProcessExecutionError, exn:
+ except ProcessExecutionError as exn:
out, err = '', str(exn)
failed = True
diff --git a/openstack/common/rpc/amqp.py b/openstack/common/rpc/amqp.py
index 946501b..f5b7cab 100644
--- a/openstack/common/rpc/amqp.py
+++ b/openstack/common/rpc/amqp.py
@@ -114,7 +114,7 @@ class ConnectionContext(rpc_common.Connection):
"""
def __init__(self, conf, connection_pool, pooled=True, server_params=None):
- """Create a new connection, or get one from the pool"""
+ """Create a new connection, or get one from the pool."""
self.connection = None
self.conf = conf
self.connection_pool = connection_pool
@@ -127,7 +127,7 @@ class ConnectionContext(rpc_common.Connection):
self.pooled = pooled
def __enter__(self):
- """When with ConnectionContext() is used, return self"""
+ """When with ConnectionContext() is used, return self."""
return self
def _done(self):
@@ -175,7 +175,7 @@ class ConnectionContext(rpc_common.Connection):
self.connection.consume_in_thread()
def __getattr__(self, key):
- """Proxy all other calls to the Connection instance"""
+ """Proxy all other calls to the Connection instance."""
if self.connection:
return getattr(self.connection, key)
else:
@@ -183,7 +183,7 @@ class ConnectionContext(rpc_common.Connection):
class ReplyProxy(ConnectionContext):
- """ Connection class for RPC replies / callbacks """
+ """Connection class for RPC replies / callbacks."""
def __init__(self, conf, connection_pool):
self._call_waiters = {}
self._num_call_waiters = 0
@@ -252,7 +252,7 @@ def msg_reply(conf, msg_id, reply_q, connection_pool, reply=None,
class RpcContext(rpc_common.CommonRpcContext):
- """Context that supports replying to a rpc.call"""
+ """Context that supports replying to a rpc.call."""
def __init__(self, **kwargs):
self.msg_id = kwargs.pop('msg_id', None)
self.reply_q = kwargs.pop('reply_q', None)
@@ -491,7 +491,7 @@ class MulticallProxyWaiter(object):
return result
def __iter__(self):
- """Return a result until we get a reply with an 'ending" flag"""
+ """Return a result until we get a reply with an 'ending' flag."""
if self._done:
raise StopIteration
while True:
@@ -567,7 +567,7 @@ class MulticallWaiter(object):
def create_connection(conf, new, connection_pool):
- """Create a connection"""
+ """Create a connection."""
return ConnectionContext(conf, connection_pool, pooled=not new)
diff --git a/openstack/common/rpc/common.py b/openstack/common/rpc/common.py
index 5a7e525..28dcacd 100644
--- a/openstack/common/rpc/common.py
+++ b/openstack/common/rpc/common.py
@@ -417,7 +417,8 @@ class ClientException(Exception):
"""This encapsulates some actual exception that is expected to be
hit by an RPC proxy object. Merely instantiating it records the
current exception information, which will be passed back to the
- RPC client without exceptional logging."""
+ RPC client without exceptional logging.
+ """
def __init__(self):
self._exc_info = sys.exc_info()
@@ -438,7 +439,8 @@ def client_exceptions(*exceptions):
of expected exceptions that the RPC layer should not consider fatal,
and not log as if they were generated in a real error scenario. Note
that this will cause listed exceptions to be wrapped in a
- ClientException, which is used internally by the RPC layer."""
+ ClientException, which is used internally by the RPC layer.
+ """
def outer(func):
def inner(*args, **kwargs):
return catch_client_exception(exceptions, func, *args, **kwargs)
diff --git a/openstack/common/rpc/impl_fake.py b/openstack/common/rpc/impl_fake.py
index 815570d..7719697 100644
--- a/openstack/common/rpc/impl_fake.py
+++ b/openstack/common/rpc/impl_fake.py
@@ -122,7 +122,7 @@ class Connection(object):
def create_connection(conf, new=True):
- """Create a connection"""
+ """Create a connection."""
return Connection()
@@ -179,7 +179,7 @@ def cleanup():
def fanout_cast(conf, context, topic, msg):
- """Cast to all consumers of a topic"""
+ """Cast to all consumers of a topic."""
check_serialize(msg)
method = msg.get('method')
if not method:
diff --git a/openstack/common/rpc/impl_kombu.py b/openstack/common/rpc/impl_kombu.py
index 0648e4b..c062d9a 100644
--- a/openstack/common/rpc/impl_kombu.py
+++ b/openstack/common/rpc/impl_kombu.py
@@ -132,7 +132,7 @@ class ConsumerBase(object):
self.reconnect(channel)
def reconnect(self, channel):
- """Re-declare the queue after a rabbit reconnect"""
+ """Re-declare the queue after a rabbit reconnect."""
self.channel = channel
self.kwargs['channel'] = channel
self.queue = kombu.entity.Queue(**self.kwargs)
@@ -173,7 +173,7 @@ class ConsumerBase(object):
self.queue.consume(*args, callback=_callback, **options)
def cancel(self):
- """Cancel the consuming from the queue, if it has started"""
+ """Cancel the consuming from the queue, if it has started."""
try:
self.queue.cancel(self.tag)
except KeyError as e:
@@ -184,7 +184,7 @@ class ConsumerBase(object):
class DirectConsumer(ConsumerBase):
- """Queue/consumer class for 'direct'"""
+ """Queue/consumer class for 'direct'."""
def __init__(self, conf, channel, msg_id, callback, tag, **kwargs):
"""Init a 'direct' queue.
@@ -216,7 +216,7 @@ class DirectConsumer(ConsumerBase):
class TopicConsumer(ConsumerBase):
- """Consumer class for 'topic'"""
+ """Consumer class for 'topic'."""
def __init__(self, conf, channel, topic, callback, tag, name=None,
exchange_name=None, **kwargs):
@@ -253,7 +253,7 @@ class TopicConsumer(ConsumerBase):
class FanoutConsumer(ConsumerBase):
- """Consumer class for 'fanout'"""
+ """Consumer class for 'fanout'."""
def __init__(self, conf, channel, topic, callback, tag, **kwargs):
"""Init a 'fanout' queue.
@@ -286,7 +286,7 @@ class FanoutConsumer(ConsumerBase):
class Publisher(object):
- """Base Publisher class"""
+ """Base Publisher class."""
def __init__(self, channel, exchange_name, routing_key, **kwargs):
"""Init the Publisher class with the exchange_name, routing_key,
@@ -298,7 +298,7 @@ class Publisher(object):
self.reconnect(channel)
def reconnect(self, channel):
- """Re-establish the Producer after a rabbit reconnection"""
+ """Re-establish the Producer after a rabbit reconnection."""
self.exchange = kombu.entity.Exchange(name=self.exchange_name,
**self.kwargs)
self.producer = kombu.messaging.Producer(exchange=self.exchange,
@@ -306,7 +306,7 @@ class Publisher(object):
routing_key=self.routing_key)
def send(self, msg, timeout=None):
- """Send a message"""
+ """Send a message."""
if timeout:
#
# AMQP TTL is in milliseconds when set in the header.
@@ -317,7 +317,7 @@ class Publisher(object):
class DirectPublisher(Publisher):
- """Publisher class for 'direct'"""
+ """Publisher class for 'direct'."""
def __init__(self, conf, channel, msg_id, **kwargs):
"""init a 'direct' publisher.
@@ -333,7 +333,7 @@ class DirectPublisher(Publisher):
class TopicPublisher(Publisher):
- """Publisher class for 'topic'"""
+ """Publisher class for 'topic'."""
def __init__(self, conf, channel, topic, **kwargs):
"""init a 'topic' publisher.
@@ -352,7 +352,7 @@ class TopicPublisher(Publisher):
class FanoutPublisher(Publisher):
- """Publisher class for 'fanout'"""
+ """Publisher class for 'fanout'."""
def __init__(self, conf, channel, topic, **kwargs):
"""init a 'fanout' publisher.
@@ -367,7 +367,7 @@ class FanoutPublisher(Publisher):
class NotifyPublisher(TopicPublisher):
- """Publisher class for 'notify'"""
+ """Publisher class for 'notify'."""
def __init__(self, conf, channel, topic, **kwargs):
self.durable = kwargs.pop('durable', conf.rabbit_durable_queues)
@@ -447,8 +447,9 @@ class Connection(object):
self.reconnect()
def _fetch_ssl_params(self):
- """Handles fetching what ssl params
- should be used for the connection (if any)"""
+ """Handles fetching what ssl params should be used for the connection
+ (if any).
+ """
ssl_params = dict()
# http://docs.python.org/library/ssl.html - ssl.wrap_socket
@@ -578,18 +579,18 @@ class Connection(object):
self.reconnect()
def get_channel(self):
- """Convenience call for bin/clear_rabbit_queues"""
+ """Convenience call for bin/clear_rabbit_queues."""
return self.channel
def close(self):
- """Close/release this connection"""
+ """Close/release this connection."""
self.cancel_consumer_thread()
self.wait_on_proxy_callbacks()
self.connection.release()
self.connection = None
def reset(self):
- """Reset a connection so it can be used again"""
+ """Reset a connection so it can be used again."""
self.cancel_consumer_thread()
self.wait_on_proxy_callbacks()
self.channel.close()
@@ -618,7 +619,7 @@ class Connection(object):
return self.ensure(_connect_error, _declare_consumer)
def iterconsume(self, limit=None, timeout=None):
- """Return an iterator that will consume from all queues/consumers"""
+ """Return an iterator that will consume from all queues/consumers."""
info = {'do_consume': True}
@@ -648,7 +649,7 @@ class Connection(object):
yield self.ensure(_error_callback, _consume)
def cancel_consumer_thread(self):
- """Cancel a consumer thread"""
+ """Cancel a consumer thread."""
if self.consumer_thread is not None:
self.consumer_thread.kill()
try:
@@ -663,7 +664,7 @@ class Connection(object):
proxy_cb.wait()
def publisher_send(self, cls, topic, msg, timeout=None, **kwargs):
- """Send to a publisher based on the publisher class"""
+ """Send to a publisher based on the publisher class."""
def _error_callback(exc):
log_info = {'topic': topic, 'err_str': str(exc)}
@@ -693,27 +694,27 @@ class Connection(object):
topic, callback)
def declare_fanout_consumer(self, topic, callback):
- """Create a 'fanout' consumer"""
+ """Create a 'fanout' consumer."""
self.declare_consumer(FanoutConsumer, topic, callback)
def direct_send(self, msg_id, msg):
- """Send a 'direct' message"""
+ """Send a 'direct' message."""
self.publisher_send(DirectPublisher, msg_id, msg)
def topic_send(self, topic, msg, timeout=None):
- """Send a 'topic' message"""
+ """Send a 'topic' message."""
self.publisher_send(TopicPublisher, topic, msg, timeout)
def fanout_send(self, topic, msg):
- """Send a 'fanout' message"""
+ """Send a 'fanout' message."""
self.publisher_send(FanoutPublisher, topic, msg)
def notify_send(self, topic, msg, **kwargs):
- """Send a notify message on a topic"""
+ """Send a notify message on a topic."""
self.publisher_send(NotifyPublisher, topic, msg, None, **kwargs)
def consume(self, limit=None):
- """Consume from all queues/consumers"""
+ """Consume from all queues/consumers."""
it = self.iterconsume(limit=limit)
while True:
try:
@@ -722,7 +723,7 @@ class Connection(object):
return
def consume_in_thread(self):
- """Consumer from all queues/consumers in a greenthread"""
+ """Consumer from all queues/consumers in a greenthread."""
def _consumer_thread():
try:
self.consume()
@@ -733,7 +734,7 @@ class Connection(object):
return self.consumer_thread
def create_consumer(self, topic, proxy, fanout=False):
- """Create a consumer that calls a method in a proxy object"""
+ """Create a consumer that calls a method in a proxy object."""
proxy_cb = rpc_amqp.ProxyCallback(
self.conf, proxy,
rpc_amqp.get_connection_pool(self.conf, Connection))
@@ -745,7 +746,7 @@ class Connection(object):
self.declare_topic_consumer(topic, proxy_cb)
def create_worker(self, topic, proxy, pool_name):
- """Create a worker that calls a method in a proxy object"""
+ """Create a worker that calls a method in a proxy object."""
proxy_cb = rpc_amqp.ProxyCallback(
self.conf, proxy,
rpc_amqp.get_connection_pool(self.conf, Connection))
@@ -778,7 +779,7 @@ class Connection(object):
def create_connection(conf, new=True):
- """Create a connection"""
+ """Create a connection."""
return rpc_amqp.create_connection(
conf, new,
rpc_amqp.get_connection_pool(conf, Connection))
diff --git a/openstack/common/rpc/impl_qpid.py b/openstack/common/rpc/impl_qpid.py
index a03ebb2..7352517 100644
--- a/openstack/common/rpc/impl_qpid.py
+++ b/openstack/common/rpc/impl_qpid.py
@@ -31,6 +31,7 @@ from openstack.common import log as logging
from openstack.common.rpc import amqp as rpc_amqp
from openstack.common.rpc import common as rpc_common
+qpid_codec = importutils.try_import("qpid.codec010")
qpid_messaging = importutils.try_import("qpid.messaging")
qpid_exceptions = importutils.try_import("qpid.messaging.exceptions")
@@ -69,6 +70,8 @@ qpid_opts = [
cfg.CONF.register_opts(qpid_opts)
+JSON_CONTENT_TYPE = 'application/json; charset=utf8'
+
class ConsumerBase(object):
"""Consumer base class."""
@@ -118,15 +121,32 @@ class ConsumerBase(object):
self.reconnect(session)
def reconnect(self, session):
- """Re-declare the receiver after a qpid reconnect"""
+ """Re-declare the receiver after a qpid reconnect."""
self.session = session
self.receiver = session.receiver(self.address)
self.receiver.capacity = 1
+ def _unpack_json_msg(self, msg):
+ """Load the JSON data in msg if msg.content_type indicates that it
+ is necessary. Put the loaded data back into msg.content and
+ update msg.content_type appropriately.
+
+ A Qpid Message containing a dict will have a content_type of
+ 'amqp/map', whereas one containing a string that needs to be converted
+ back from JSON will have a content_type of JSON_CONTENT_TYPE.
+
+ :param msg: a Qpid Message object
+ :returns: None
+ """
+ if msg.content_type == JSON_CONTENT_TYPE:
+ msg.content = jsonutils.loads(msg.content)
+ msg.content_type = 'amqp/map'
+
def consume(self):
- """Fetch the message and pass it to the callback object"""
+ """Fetch the message and pass it to the callback object."""
message = self.receiver.fetch()
try:
+ self._unpack_json_msg(message)
msg = rpc_common.deserialize_msg(message.content)
self.callback(msg)
except Exception:
@@ -139,7 +159,7 @@ class ConsumerBase(object):
class DirectConsumer(ConsumerBase):
- """Queue/consumer class for 'direct'"""
+ """Queue/consumer class for 'direct'."""
def __init__(self, conf, session, msg_id, callback):
"""Init a 'direct' queue.
@@ -157,7 +177,7 @@ class DirectConsumer(ConsumerBase):
class TopicConsumer(ConsumerBase):
- """Consumer class for 'topic'"""
+ """Consumer class for 'topic'."""
def __init__(self, conf, session, topic, callback, name=None,
exchange_name=None):
@@ -177,7 +197,7 @@ class TopicConsumer(ConsumerBase):
class FanoutConsumer(ConsumerBase):
- """Consumer class for 'fanout'"""
+ """Consumer class for 'fanout'."""
def __init__(self, conf, session, topic, callback):
"""Init a 'fanout' queue.
@@ -196,7 +216,7 @@ class FanoutConsumer(ConsumerBase):
class Publisher(object):
- """Base Publisher class"""
+ """Base Publisher class."""
def __init__(self, session, node_name, node_opts=None):
"""Init the Publisher class with the exchange_name, routing_key,
@@ -225,16 +245,43 @@ class Publisher(object):
self.reconnect(session)
def reconnect(self, session):
- """Re-establish the Sender after a reconnection"""
+ """Re-establish the Sender after a reconnection."""
self.sender = session.sender(self.address)
+ def _pack_json_msg(self, msg):
+ """Qpid cannot serialize dicts containing strings longer than 65535
+ characters. This function dumps the message content to a JSON
+ string, which Qpid is able to handle.
+
+ :param msg: May be either a Qpid Message object or a bare dict.
+ :returns: A Qpid Message with its content field JSON encoded.
+ """
+ try:
+ msg.content = jsonutils.dumps(msg.content)
+ except AttributeError:
+ # Need to have a Qpid message so we can set the content_type.
+ msg = qpid_messaging.Message(jsonutils.dumps(msg))
+ msg.content_type = JSON_CONTENT_TYPE
+ return msg
+
def send(self, msg):
- """Send a message"""
+ """Send a message."""
+ try:
+ # Check if Qpid can encode the message
+ check_msg = msg
+ if not hasattr(check_msg, 'content_type'):
+ check_msg = qpid_messaging.Message(msg)
+ content_type = check_msg.content_type
+ enc, dec = qpid_messaging.message.get_codec(content_type)
+ enc(check_msg.content)
+ except qpid_codec.CodecException:
+ # This means the message couldn't be serialized as a dict.
+ msg = self._pack_json_msg(msg)
self.sender.send(msg)
class DirectPublisher(Publisher):
- """Publisher class for 'direct'"""
+ """Publisher class for 'direct'."""
def __init__(self, conf, session, msg_id):
"""Init a 'direct' publisher."""
super(DirectPublisher, self).__init__(session, msg_id,
@@ -242,7 +289,7 @@ class DirectPublisher(Publisher):
class TopicPublisher(Publisher):
- """Publisher class for 'topic'"""
+ """Publisher class for 'topic'."""
def __init__(self, conf, session, topic):
"""init a 'topic' publisher.
"""
@@ -252,7 +299,7 @@ class TopicPublisher(Publisher):
class FanoutPublisher(Publisher):
- """Publisher class for 'fanout'"""
+ """Publisher class for 'fanout'."""
def __init__(self, conf, session, topic):
"""init a 'fanout' publisher.
"""
@@ -262,7 +309,7 @@ class FanoutPublisher(Publisher):
class NotifyPublisher(Publisher):
- """Publisher class for notifications"""
+ """Publisher class for notifications."""
def __init__(self, conf, session, topic):
"""init a 'topic' publisher.
"""
@@ -330,7 +377,7 @@ class Connection(object):
return self.consumers[str(receiver)]
def reconnect(self):
- """Handles reconnecting and re-establishing sessions and queues"""
+ """Handles reconnecting and re-establishing sessions and queues."""
attempt = 0
delay = 1
while True:
@@ -381,14 +428,20 @@ class Connection(object):
self.reconnect()
def close(self):
- """Close/release this connection"""
+ """Close/release this connection."""
self.cancel_consumer_thread()
self.wait_on_proxy_callbacks()
- self.connection.close()
+ try:
+ self.connection.close()
+ except Exception:
+ # NOTE(dripton) Logging exceptions that happen during cleanup just
+ # causes confusion; there's really nothing useful we can do with
+ # them.
+ pass
self.connection = None
def reset(self):
- """Reset a connection so it can be used again"""
+ """Reset a connection so it can be used again."""
self.cancel_consumer_thread()
self.wait_on_proxy_callbacks()
self.session.close()
@@ -412,7 +465,7 @@ class Connection(object):
return self.ensure(_connect_error, _declare_consumer)
def iterconsume(self, limit=None, timeout=None):
- """Return an iterator that will consume from all queues/consumers"""
+ """Return an iterator that will consume from all queues/consumers."""
def _error_callback(exc):
if isinstance(exc, qpid_exceptions.Empty):
@@ -436,7 +489,7 @@ class Connection(object):
yield self.ensure(_error_callback, _consume)
def cancel_consumer_thread(self):
- """Cancel a consumer thread"""
+ """Cancel a consumer thread."""
if self.consumer_thread is not None:
self.consumer_thread.kill()
try:
@@ -451,7 +504,7 @@ class Connection(object):
proxy_cb.wait()
def publisher_send(self, cls, topic, msg):
- """Send to a publisher based on the publisher class"""
+ """Send to a publisher based on the publisher class."""
def _connect_error(exc):
log_info = {'topic': topic, 'err_str': str(exc)}
@@ -481,15 +534,15 @@ class Connection(object):
topic, callback)
def declare_fanout_consumer(self, topic, callback):
- """Create a 'fanout' consumer"""
+ """Create a 'fanout' consumer."""
self.declare_consumer(FanoutConsumer, topic, callback)
def direct_send(self, msg_id, msg):
- """Send a 'direct' message"""
+ """Send a 'direct' message."""
self.publisher_send(DirectPublisher, msg_id, msg)
def topic_send(self, topic, msg, timeout=None):
- """Send a 'topic' message"""
+ """Send a 'topic' message."""
#
# We want to create a message with attributes, e.g. a TTL. We
# don't really need to keep 'msg' in its JSON format any longer
@@ -504,15 +557,15 @@ class Connection(object):
self.publisher_send(TopicPublisher, topic, qpid_message)
def fanout_send(self, topic, msg):
- """Send a 'fanout' message"""
+ """Send a 'fanout' message."""
self.publisher_send(FanoutPublisher, topic, msg)
def notify_send(self, topic, msg, **kwargs):
- """Send a notify message on a topic"""
+ """Send a notify message on a topic."""
self.publisher_send(NotifyPublisher, topic, msg)
def consume(self, limit=None):
- """Consume from all queues/consumers"""
+ """Consume from all queues/consumers."""
it = self.iterconsume(limit=limit)
while True:
try:
@@ -521,7 +574,7 @@ class Connection(object):
return
def consume_in_thread(self):
- """Consumer from all queues/consumers in a greenthread"""
+ """Consumer from all queues/consumers in a greenthread."""
def _consumer_thread():
try:
self.consume()
@@ -532,7 +585,7 @@ class Connection(object):
return self.consumer_thread
def create_consumer(self, topic, proxy, fanout=False):
- """Create a consumer that calls a method in a proxy object"""
+ """Create a consumer that calls a method in a proxy object."""
proxy_cb = rpc_amqp.ProxyCallback(
self.conf, proxy,
rpc_amqp.get_connection_pool(self.conf, Connection))
@@ -548,7 +601,7 @@ class Connection(object):
return consumer
def create_worker(self, topic, proxy, pool_name):
- """Create a worker that calls a method in a proxy object"""
+ """Create a worker that calls a method in a proxy object."""
proxy_cb = rpc_amqp.ProxyCallback(
self.conf, proxy,
rpc_amqp.get_connection_pool(self.conf, Connection))
@@ -591,7 +644,7 @@ class Connection(object):
def create_connection(conf, new=True):
- """Create a connection"""
+ """Create a connection."""
return rpc_amqp.create_connection(
conf, new,
rpc_amqp.get_connection_pool(conf, Connection))
diff --git a/openstack/common/rpc/impl_zmq.py b/openstack/common/rpc/impl_zmq.py
index add3973..07b7b41 100644
--- a/openstack/common/rpc/impl_zmq.py
+++ b/openstack/common/rpc/impl_zmq.py
@@ -30,7 +30,6 @@ from openstack.common import excutils
from openstack.common.gettextutils import _
from openstack.common import importutils
from openstack.common import jsonutils
-from openstack.common import processutils as utils
from openstack.common.rpc import common as rpc_common
zmq = importutils.try_import('eventlet.green.zmq')
@@ -199,15 +198,15 @@ class ZmqSocket(object):
LOG.error("ZeroMQ socket could not be closed.")
self.sock = None
- def recv(self):
+ def recv(self, **kwargs):
if not self.can_recv:
raise RPCException(_("You cannot recv on this socket."))
- return self.sock.recv_multipart()
+ return self.sock.recv_multipart(**kwargs)
- def send(self, data):
+ def send(self, data, **kwargs):
if not self.can_send:
raise RPCException(_("You cannot send on this socket."))
- self.sock.send_multipart(data)
+ self.sock.send_multipart(data, **kwargs)
class ZmqClient(object):
@@ -446,11 +445,8 @@ class ZmqProxy(ZmqBaseReactor):
def consume(self, sock):
ipc_dir = CONF.rpc_zmq_ipc_dir
- #TODO(ewindisch): use zero-copy (i.e. references, not copying)
- data = sock.recv()
- topic = data[1]
-
- LOG.debug(_("CONSUMER GOT %s"), ' '.join(map(pformat, data)))
+ data = sock.recv(copy=False)
+ topic = data[1].bytes
if topic.startswith('fanout~'):
sock_type = zmq.PUB
@@ -492,9 +488,7 @@ class ZmqProxy(ZmqBaseReactor):
while(True):
data = self.topic_proxy[topic].get()
- out_sock.send(data)
- LOG.debug(_("ROUTER RELAY-OUT SUCCEEDED %(data)s") %
- {'data': data})
+ out_sock.send(data, copy=False)
wait_sock_creation = eventlet.event.Event()
eventlet.spawn(publisher, wait_sock_creation)
@@ -507,37 +501,35 @@ class ZmqProxy(ZmqBaseReactor):
try:
self.topic_proxy[topic].put_nowait(data)
- LOG.debug(_("ROUTER RELAY-OUT QUEUED %(data)s") %
- {'data': data})
except eventlet.queue.Full:
LOG.error(_("Local per-topic backlog buffer full for topic "
"%(topic)s. Dropping message.") % {'topic': topic})
def consume_in_thread(self):
- """Runs the ZmqProxy service"""
+ """Runs the ZmqProxy service."""
ipc_dir = CONF.rpc_zmq_ipc_dir
consume_in = "tcp://%s:%s" % \
(CONF.rpc_zmq_bind_address,
CONF.rpc_zmq_port)
consumption_proxy = InternalContext(None)
- if not os.path.isdir(ipc_dir):
- try:
- utils.execute('mkdir', '-p', ipc_dir, run_as_root=True)
- utils.execute('chown', "%s:%s" % (os.getuid(), os.getgid()),
- ipc_dir, run_as_root=True)
- utils.execute('chmod', '750', ipc_dir, run_as_root=True)
- except utils.ProcessExecutionError:
+ try:
+ os.makedirs(ipc_dir)
+ except os.error:
+ if not os.path.isdir(ipc_dir):
with excutils.save_and_reraise_exception():
- LOG.error(_("Could not create IPC directory %s") %
- (ipc_dir, ))
-
+ LOG.error(_("Required IPC directory does not exist at"
+ " %s") % (ipc_dir, ))
try:
self.register(consumption_proxy,
consume_in,
zmq.PULL,
out_bind=True)
except zmq.ZMQError:
+ if os.access(ipc_dir, os.X_OK):
+ with excutils.save_and_reraise_exception():
+ LOG.error(_("Permission denied to IPC directory at"
+ " %s") % (ipc_dir, ))
with excutils.save_and_reraise_exception():
LOG.error(_("Could not create ZeroMQ receiver daemon. "
"Socket may already be in use."))
diff --git a/openstack/common/rpc/serializer.py b/openstack/common/rpc/serializer.py
index 0a2c9c4..76c6831 100644
--- a/openstack/common/rpc/serializer.py
+++ b/openstack/common/rpc/serializer.py
@@ -18,7 +18,7 @@ import abc
class Serializer(object):
- """Generic (de-)serialization definition base class"""
+ """Generic (de-)serialization definition base class."""
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
@@ -43,7 +43,7 @@ class Serializer(object):
class NoOpSerializer(Serializer):
- """A serializer that does nothing"""
+ """A serializer that does nothing."""
def serialize_entity(self, context, entity):
return entity
diff --git a/openstack/common/rpc/service.py b/openstack/common/rpc/service.py
index 6b56ebb..3f51d0b 100644
--- a/openstack/common/rpc/service.py
+++ b/openstack/common/rpc/service.py
@@ -30,7 +30,8 @@ LOG = logging.getLogger(__name__)
class Service(service.Service):
"""Service object for binaries running on hosts.
- A service enables rpc by listening to queues based on topic and host."""
+ A service enables rpc by listening to queues based on topic and host.
+ """
def __init__(self, host, topic, manager=None):
super(Service, self).__init__()
self.host = host
diff --git a/openstack/common/scheduler/base_filter.py b/openstack/common/scheduler/base_filter.py
index 5f2fc9c..44b26ae 100644
--- a/openstack/common/scheduler/base_filter.py
+++ b/openstack/common/scheduler/base_filter.py
@@ -41,7 +41,7 @@ class BaseFilter(object):
class BaseFilterHandler(base_handler.BaseHandler):
- """ Base class to handle loading filter classes.
+ """Base class to handle loading filter classes.
This class should be subclassed where one needs to use filters.
"""
diff --git a/openstack/common/scheduler/base_handler.py b/openstack/common/scheduler/base_handler.py
index 147d90d..1808d2c 100644
--- a/openstack/common/scheduler/base_handler.py
+++ b/openstack/common/scheduler/base_handler.py
@@ -24,8 +24,7 @@ from stevedore import extension
class BaseHandler(object):
- """ Base class to handle loading filter and weight classes.
- """
+ """Base class to handle loading filter and weight classes."""
def __init__(self, modifier_class_type, modifier_namespace):
self.namespace = modifier_namespace
self.modifier_class_type = modifier_class_type
diff --git a/openstack/common/scheduler/filters/capabilities_filter.py b/openstack/common/scheduler/filters/capabilities_filter.py
index df69955..89e2bdb 100644
--- a/openstack/common/scheduler/filters/capabilities_filter.py
+++ b/openstack/common/scheduler/filters/capabilities_filter.py
@@ -25,8 +25,9 @@ class CapabilitiesFilter(filters.BaseHostFilter):
"""HostFilter to work with resource (instance & volume) type records."""
def _satisfies_extra_specs(self, capabilities, resource_type):
- """Check that the capabilities provided by the services
- satisfy the extra specs associated with the instance type"""
+ """Check that the capabilities provided by the services satisfy
+ the extra specs associated with the instance type.
+ """
extra_specs = resource_type.get('extra_specs', [])
if not extra_specs:
return True
diff --git a/openstack/common/scheduler/filters/json_filter.py b/openstack/common/scheduler/filters/json_filter.py
index 7035947..bc4b4fd 100644
--- a/openstack/common/scheduler/filters/json_filter.py
+++ b/openstack/common/scheduler/filters/json_filter.py
@@ -51,7 +51,7 @@ class JsonFilter(filters.BaseHostFilter):
return self._op_compare(args, operator.gt)
def _in(self, args):
- """First term is in set of remaining terms"""
+ """First term is in set of remaining terms."""
return self._op_compare(args, operator.contains)
def _less_than_equal(self, args):
diff --git a/openstack/common/service.py b/openstack/common/service.py
index eb46164..55e23ed 100644
--- a/openstack/common/service.py
+++ b/openstack/common/service.py
@@ -271,7 +271,7 @@ class ProcessLauncher(object):
return wrap
def wait(self):
- """Loop waiting on children to die and respawning as necessary"""
+ """Loop waiting on children to die and respawning as necessary."""
LOG.debug(_('Full set of CONF:'))
CONF.log_opt_values(LOG, std_logging.DEBUG)
diff --git a/openstack/common/strutils.py b/openstack/common/strutils.py
index 6d227c6..8a5367b 100644
--- a/openstack/common/strutils.py
+++ b/openstack/common/strutils.py
@@ -24,6 +24,17 @@ import sys
from openstack.common.gettextutils import _
+# Used for looking up extensions of text
+# to their 'multiplied' byte amount
+BYTE_MULTIPLIERS = {
+ '': 1,
+ 't': 1024 ** 4,
+ 'g': 1024 ** 3,
+ 'm': 1024 ** 2,
+ 'k': 1024,
+}
+
+
TRUE_STRINGS = ('1', 't', 'true', 'on', 'y', 'yes')
FALSE_STRINGS = ('0', 'f', 'false', 'off', 'n', 'no')
@@ -148,3 +159,31 @@ def safe_encode(text, incoming=None,
return text.encode(encoding, errors)
return text
+
+
+def to_bytes(text, default=0):
+ """Try to turn a string into a number of bytes. Looks at the last
+ characters of the text to determine what conversion is needed to
+ turn the input text into a byte number.
+
+ Supports: B/b, K/k, M/m, G/g, T/t (or the same with b/B on the end)
+
+ """
+ # Take off everything not number 'like' (which should leave
+ # only the byte 'identifier' left)
+ mult_key_org = text.lstrip('-1234567890')
+ mult_key = mult_key_org.lower()
+ mult_key_len = len(mult_key)
+ if mult_key.endswith("b"):
+ mult_key = mult_key[0:-1]
+ try:
+ multiplier = BYTE_MULTIPLIERS[mult_key]
+ if mult_key_len:
+ # Empty cases shouldn't cause text[0:-0]
+ text = text[0:-mult_key_len]
+ return int(text) * multiplier
+ except KeyError:
+ msg = _('Unknown byte multiplier: %s') % mult_key_org
+ raise TypeError(msg)
+ except ValueError:
+ return default
diff --git a/openstack/common/threadgroup.py b/openstack/common/threadgroup.py
index 6cafbaf..877059c 100644
--- a/openstack/common/threadgroup.py
+++ b/openstack/common/threadgroup.py
@@ -26,7 +26,7 @@ LOG = logging.getLogger(__name__)
def _thread_done(gt, *args, **kwargs):
- """ Callback function to be passed to GreenThread.link() when we spawn()
+ """Callback function to be passed to GreenThread.link() when we spawn()
Calls the :class:`ThreadGroup` to notify if.
"""
@@ -34,7 +34,7 @@ def _thread_done(gt, *args, **kwargs):
class Thread(object):
- """ Wrapper around a greenthread, that holds a reference to the
+ """Wrapper around a greenthread, that holds a reference to the
:class:`ThreadGroup`. The Thread will notify the :class:`ThreadGroup` when
it has done so it can be removed from the threads list.
"""
@@ -50,7 +50,7 @@ class Thread(object):
class ThreadGroup(object):
- """ The point of the ThreadGroup classis to:
+ """The point of the ThreadGroup classis to:
* keep track of timers and greenthreads (making it easier to stop them
when need be).
diff --git a/openstack/common/timeutils.py b/openstack/common/timeutils.py
index 6094365..008e9c8 100644
--- a/openstack/common/timeutils.py
+++ b/openstack/common/timeutils.py
@@ -32,7 +32,7 @@ PERFECT_TIME_FORMAT = _ISO8601_TIME_FORMAT_SUBSECOND
def isotime(at=None, subsecond=False):
- """Stringify time in ISO 8601 format"""
+ """Stringify time in ISO 8601 format."""
if not at:
at = utcnow()
st = at.strftime(_ISO8601_TIME_FORMAT
@@ -44,7 +44,7 @@ def isotime(at=None, subsecond=False):
def parse_isotime(timestr):
- """Parse time from ISO 8601 format"""
+ """Parse time from ISO 8601 format."""
try:
return iso8601.parse_date(timestr)
except iso8601.ParseError as e:
@@ -66,7 +66,7 @@ def parse_strtime(timestr, fmt=PERFECT_TIME_FORMAT):
def normalize_time(timestamp):
- """Normalize time in arbitrary timezone to UTC naive object"""
+ """Normalize time in arbitrary timezone to UTC naive object."""
offset = timestamp.utcoffset()
if offset is None:
return timestamp
@@ -103,7 +103,7 @@ def utcnow():
def iso8601_from_timestamp(timestamp):
- """Returns a iso8601 formated date from timestamp"""
+ """Returns a iso8601 formated date from timestamp."""
return isotime(datetime.datetime.utcfromtimestamp(timestamp))
@@ -141,7 +141,8 @@ def clear_time_override():
def marshall_now(now=None):
"""Make an rpc-safe datetime with microseconds.
- Note: tzinfo is stripped, but not required for relative times."""
+ Note: tzinfo is stripped, but not required for relative times.
+ """
if not now:
now = utcnow()
return dict(day=now.day, month=now.month, year=now.year, hour=now.hour,
diff --git a/tools/pip-requires b/requirements.txt
index 067af58..067af58 100644
--- a/tools/pip-requires
+++ b/requirements.txt
diff --git a/tools/test-requires b/test-requirements.txt
index 62c0eea..62c0eea 100644
--- a/tools/test-requires
+++ b/test-requirements.txt
diff --git a/tests/unit/db/sqlalchemy/test_sqlalchemy.py b/tests/unit/db/sqlalchemy/test_sqlalchemy.py
index b18825a..e548a3b 100644
--- a/tests/unit/db/sqlalchemy/test_sqlalchemy.py
+++ b/tests/unit/db/sqlalchemy/test_sqlalchemy.py
@@ -182,3 +182,23 @@ class RegexpFilterTestCase(test_utils.BaseTestCase):
def test_regexp_filter_unicode_nomatch(self):
self._test_regexp_filter(u'♦', [])
+
+
+class SlaveBackendTestCase(test_utils.BaseTestCase):
+
+ def test_slave_engine_nomatch(self):
+ default = session.CONF.database.connection
+ session.CONF.database.slave_connection = default
+
+ e = session.get_engine()
+ slave_e = session.get_engine(slave_engine=True)
+ self.assertNotEqual(slave_e, e)
+
+ def test_no_slave_engine_match(self):
+ slave_e = session.get_engine()
+ e = session.get_engine()
+ self.assertEqual(slave_e, e)
+
+ def test_slave_backend_nomatch(self):
+ session.CONF.database.slave_connection = "mysql:///localhost"
+ self.assertRaises(AssertionError, session._assert_matching_drivers)
diff --git a/tests/unit/deprecated/__init__.py b/tests/unit/deprecated/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/unit/deprecated/__init__.py
diff --git a/tests/unit/test_wsgi.py b/tests/unit/deprecated/test_wsgi.py
index a3a4d32..72aeae7 100644
--- a/tests/unit/test_wsgi.py
+++ b/tests/unit/deprecated/test_wsgi.py
@@ -24,12 +24,13 @@ import routes
import six
import webob
+from openstack.common.deprecated import wsgi
from openstack.common import exception
-from openstack.common import wsgi
+
from tests import utils
TEST_VAR_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__),
- '..', 'var'))
+ '..', '..', 'var'))
class RequestTest(utils.BaseTestCase):
diff --git a/tests/unit/extension_stubs.py b/tests/unit/extension_stubs.py
index 03b9702..4e44c9c 100644
--- a/tests/unit/extension_stubs.py
+++ b/tests/unit/extension_stubs.py
@@ -14,7 +14,7 @@
# License for the specific language governing permissions and limitations
# under the License.
-from openstack.common import wsgi
+from openstack.common.deprecated import wsgi
class StubExtension(object):
diff --git a/tests/unit/middleware/test_correlation_id.py b/tests/unit/middleware/test_correlation_id.py
index 070c23e..dc83cc7 100644
--- a/tests/unit/middleware/test_correlation_id.py
+++ b/tests/unit/middleware/test_correlation_id.py
@@ -28,14 +28,13 @@ class CorrelationIdMiddlewareTest(utils.BaseTestCase):
app = mock.Mock()
req = mock.Mock()
req.headers = {}
- original_method = uuidutils.generate_uuid
+
mock_generate_uuid = mock.Mock()
mock_generate_uuid.return_value = "fake_uuid"
- uuidutils.generate_uuid = mock_generate_uuid
+ self.stubs.Set(uuidutils, 'generate_uuid', mock_generate_uuid)
middleware = correlation_id.CorrelationIdMiddleware(app)
middleware(req)
- uuidutils.generate_uuid = original_method
self.assertEquals(req.headers.get("X_CORRELATION_ID"), "fake_uuid")
diff --git a/tests/unit/middleware/test_sizelimit.py b/tests/unit/middleware/test_sizelimit.py
index 500f29d..22098d0 100644
--- a/tests/unit/middleware/test_sizelimit.py
+++ b/tests/unit/middleware/test_sizelimit.py
@@ -13,7 +13,7 @@
# under the License.
from oslo.config import cfg
-import StringIO
+from six import StringIO
import webob
from openstack.common.middleware import sizelimit
@@ -28,14 +28,14 @@ class TestLimitingReader(utils.BaseTestCase):
def test_limiting_reader(self):
BYTES = 1024
bytes_read = 0
- data = StringIO.StringIO("*" * BYTES)
+ data = StringIO("*" * BYTES)
for chunk in sizelimit.LimitingReader(data, BYTES):
bytes_read += len(chunk)
self.assertEquals(bytes_read, BYTES)
bytes_read = 0
- data = StringIO.StringIO("*" * BYTES)
+ data = StringIO("*" * BYTES)
reader = sizelimit.LimitingReader(data, BYTES)
byte = reader.read(1)
while len(byte) != 0:
@@ -49,7 +49,7 @@ class TestLimitingReader(utils.BaseTestCase):
def _consume_all_iter():
bytes_read = 0
- data = StringIO.StringIO("*" * BYTES)
+ data = StringIO("*" * BYTES)
for chunk in sizelimit.LimitingReader(data, BYTES - 1):
bytes_read += len(chunk)
@@ -58,7 +58,7 @@ class TestLimitingReader(utils.BaseTestCase):
def _consume_all_read():
bytes_read = 0
- data = StringIO.StringIO("*" * BYTES)
+ data = StringIO("*" * BYTES)
reader = sizelimit.LimitingReader(data, BYTES - 1)
byte = reader.read(1)
while len(byte) != 0:
diff --git a/tests/unit/plugin/test_callback_plugin.py b/tests/unit/plugin/test_callback_plugin.py
index cbe2601..3f3fd63 100644
--- a/tests/unit/plugin/test_callback_plugin.py
+++ b/tests/unit/plugin/test_callback_plugin.py
@@ -47,7 +47,7 @@ class TestCBP(callbackplugin.CallbackPlugin):
class CallbackTestCase(test_utils.BaseTestCase):
- """Tests for the callback plugin convenience class"""
+ """Tests for the callback plugin convenience class."""
def test_callback_plugin_subclass(self):
diff --git a/tests/unit/rpc/amqp.py b/tests/unit/rpc/amqp.py
index 69d647a..432dd35 100644
--- a/tests/unit/rpc/amqp.py
+++ b/tests/unit/rpc/amqp.py
@@ -223,7 +223,7 @@ class BaseRpcAMQPTestCase(common.BaseRpcTestCase):
self.config(amqp_rpc_single_reply_queue=False)
def test_duplicate_message_check(self):
- """Test sending *not-dict* to a topic exchange/queue"""
+ """Test sending *not-dict* to a topic exchange/queue."""
conn = self.rpc.create_connection(FLAGS)
message = {'args': 'topic test message', '_unique_id': 'aaaabbbbcccc'}
diff --git a/tests/unit/rpc/test_kombu.py b/tests/unit/rpc/test_kombu.py
index ebc29ea..159fefb 100644
--- a/tests/unit/rpc/test_kombu.py
+++ b/tests/unit/rpc/test_kombu.py
@@ -95,7 +95,7 @@ class RpcKombuTestCase(amqp.BaseRpcAMQPTestCase):
self.assertEqual(conn1, conn2)
def test_topic_send_receive(self):
- """Test sending to a topic exchange/queue"""
+ """Test sending to a topic exchange/queue."""
conn = self.rpc.create_connection(FLAGS)
message = 'topic test message'
@@ -114,7 +114,8 @@ class RpcKombuTestCase(amqp.BaseRpcAMQPTestCase):
def test_message_ttl_on_timeout(self):
"""Test message ttl being set by request timeout. The message
- should die on the vine and never arrive."""
+ should die on the vine and never arrive.
+ """
conn = self.rpc.create_connection(FLAGS)
message = 'topic test message'
@@ -131,7 +132,7 @@ class RpcKombuTestCase(amqp.BaseRpcAMQPTestCase):
conn.close()
def test_topic_send_receive_exchange_name(self):
- """Test sending to a topic exchange/queue with an exchange name"""
+ """Test sending to a topic exchange/queue with an exchange name."""
conn = self.rpc.create_connection(FLAGS)
message = 'topic test message'
@@ -150,7 +151,7 @@ class RpcKombuTestCase(amqp.BaseRpcAMQPTestCase):
self.assertEqual(self.received_message, message)
def test_topic_multiple_queues(self):
- """Test sending to a topic exchange with multiple queues"""
+ """Test sending to a topic exchange with multiple queues."""
conn = self.rpc.create_connection(FLAGS)
message = 'topic test message'
@@ -232,7 +233,7 @@ class RpcKombuTestCase(amqp.BaseRpcAMQPTestCase):
self.assertEqual(self.received_message_2, message)
def test_direct_send_receive(self):
- """Test sending to a direct exchange/queue"""
+ """Test sending to a direct exchange/queue."""
conn = self.rpc.create_connection(FLAGS)
message = 'direct test message'
@@ -249,7 +250,7 @@ class RpcKombuTestCase(amqp.BaseRpcAMQPTestCase):
self.assertEqual(self.received_message, message)
def test_cast_interface_uses_default_options(self):
- """Test kombu rpc.cast"""
+ """Test kombu rpc.cast."""
ctxt = rpc_common.CommonRpcContext(user='fake_user',
project='fake_project')
@@ -275,7 +276,7 @@ class RpcKombuTestCase(amqp.BaseRpcAMQPTestCase):
impl_kombu.cast(FLAGS, ctxt, 'fake_topic', {'msg': 'fake'})
def test_cast_to_server_uses_server_params(self):
- """Test kombu rpc.cast"""
+ """Test kombu rpc.cast."""
ctxt = rpc_common.CommonRpcContext(user='fake_user',
project='fake_project')
@@ -308,7 +309,7 @@ class RpcKombuTestCase(amqp.BaseRpcAMQPTestCase):
'fake_topic', {'msg': 'fake'})
def test_fanout_send_receive(self):
- """Test sending to a fanout exchange and consuming from 2 queues"""
+ """Test sending to a fanout exchange and consuming from 2 queues."""
self.skipTest("kombu memory transport seems buggy with "
"fanout queues as this test passes when "
@@ -365,7 +366,7 @@ class RpcKombuTestCase(amqp.BaseRpcAMQPTestCase):
self.assertTrue(isinstance(result, self.rpc.DirectConsumer))
def test_declare_consumer_ioerrors_will_reconnect(self):
- """Test that an IOError exception causes a reconnection"""
+ """Test that an IOError exception causes a reconnection."""
info = _raise_exc_stub(self.stubs, 2, self.rpc.DirectConsumer,
'__init__', 'Socket closed', exc_class=IOError)
diff --git a/tests/unit/rpc/test_qpid.py b/tests/unit/rpc/test_qpid.py
index 82cc119..02e8e20 100644
--- a/tests/unit/rpc/test_qpid.py
+++ b/tests/unit/rpc/test_qpid.py
@@ -28,12 +28,14 @@ import mox
from oslo.config import cfg
from openstack.common import context
+from openstack.common import jsonutils
from openstack.common.rpc import amqp as rpc_amqp
from openstack.common.rpc import common as rpc_common
from tests import utils
try:
import qpid
+
from openstack.common.rpc import impl_qpid
except ImportError:
qpid = None
@@ -69,7 +71,6 @@ class RpcQpidTestCase(utils.BaseTestCase):
self.mock_session = None
self.mock_sender = None
self.mock_receiver = None
- self.mox = mox.Mox()
self.orig_connection = qpid.messaging.Connection
self.orig_session = qpid.messaging.Session
@@ -437,7 +438,6 @@ class RpcQpidTestCase(utils.BaseTestCase):
if expect_failure:
self.mock_session.next_receiver(timeout=mox.IsA(int)).AndRaise(
qpid.messaging.exceptions.Empty())
- self.mock_receiver.fetch()
self.mock_session.close()
self.mock_connection.session().AndReturn(self.mock_session)
else:
@@ -493,6 +493,88 @@ class RpcQpidTestCase(utils.BaseTestCase):
def test_multicall(self):
self._test_call(multi=True)
+ def _test_publisher(self, message=True):
+ """Test that messages containing long strings are correctly serialized
+ in a way that Qpid can handle.
+
+ :param message: The publisher may be passed either a Qpid Message
+ object or a bare dict. This parameter controls which of those the test
+ will send.
+ """
+ self.sent_msg = None
+
+ def send_stub(msg):
+ self.sent_msg = msg
+
+ # Qpid cannot serialize a dict containing a string > 65535 chars.
+ raw_msg = {'test': 'a' * 65536}
+ if message:
+ base_msg = qpid.messaging.Message(raw_msg)
+ else:
+ base_msg = raw_msg
+ expected_msg = qpid.messaging.Message(jsonutils.dumps(raw_msg))
+ expected_msg.content_type = impl_qpid.JSON_CONTENT_TYPE
+ mock_session = self.mox.CreateMock(self.orig_session)
+ mock_sender = self.mox.CreateMock(self.orig_sender)
+ mock_session.sender(mox.IgnoreArg()).AndReturn(mock_sender)
+ self.stubs.Set(mock_sender, 'send', send_stub)
+ self.mox.ReplayAll()
+
+ publisher = impl_qpid.Publisher(mock_session, 'test_node')
+ publisher.send(base_msg)
+
+ self.assertEqual(self.sent_msg.content, expected_msg.content)
+ self.assertEqual(self.sent_msg.content_type, expected_msg.content_type)
+
+ def test_publisher_long_message(self):
+ self._test_publisher(message=True)
+
+ def test_publisher_long_dict(self):
+ self._test_publisher(message=False)
+
+ def _test_consumer_long_message(self, json=True):
+ """Verify that the Qpid implementation correctly deserializes
+ message content.
+
+ :param json: For compatibility, this code needs to support both
+ messages that are and are not JSON encoded. This param
+ specifies which is being tested.
+ """
+ def fake_callback(msg):
+ self.received_msg = msg
+
+ # The longest string Qpid can handle itself
+ chars = 65535
+ if json:
+ # The first length that requires JSON encoding
+ chars = 65536
+ raw_msg = {'test': 'a' * chars}
+ if json:
+ fake_message = qpid.messaging.Message(jsonutils.dumps(raw_msg))
+ fake_message.content_type = impl_qpid.JSON_CONTENT_TYPE
+ else:
+ fake_message = qpid.messaging.Message(raw_msg)
+ mock_session = self.mox.CreateMock(self.orig_session)
+ mock_receiver = self.mox.CreateMock(self.orig_receiver)
+ mock_session.receiver(mox.IgnoreArg()).AndReturn(mock_receiver)
+ mock_receiver.fetch().AndReturn(fake_message)
+ mock_session.acknowledge(mox.IgnoreArg())
+ self.mox.ReplayAll()
+
+ consumer = impl_qpid.DirectConsumer(None,
+ mock_session,
+ 'bogus_msg_id',
+ fake_callback)
+ consumer.consume()
+
+ self.assertEqual(self.received_msg, raw_msg)
+
+ def test_consumer_long_message(self):
+ self._test_consumer_long_message(json=True)
+
+ def test_consumer_long_message_no_json(self):
+ self._test_consumer_long_message(json=False)
+
#
#from nova.tests.rpc import common
diff --git a/tests/unit/rpc/test_service.py b/tests/unit/rpc/test_service.py
index 9293d3e..e9f8313 100644
--- a/tests/unit/rpc/test_service.py
+++ b/tests/unit/rpc/test_service.py
@@ -21,7 +21,7 @@ from tests import utils
class FakeService(service.Service):
- """Fake manager for tests"""
+ """Fake manager for tests."""
def __init__(self, host, topic):
super(FakeService, self).__init__(host, topic, None)
self.method_result = 'manager'
@@ -43,7 +43,7 @@ class FakeHookService(FakeService):
class RpcServiceManagerTestCase(utils.BaseTestCase):
- """Test cases for Services"""
+ """Test cases for Services."""
def setUp(self):
super(RpcServiceManagerTestCase, self).setUp()
self.config(fake_rabbit=True)
diff --git a/tests/unit/scheduler/fake_hosts.py b/tests/unit/scheduler/fake_hosts.py
index 248eb3d..b02aca4 100644
--- a/tests/unit/scheduler/fake_hosts.py
+++ b/tests/unit/scheduler/fake_hosts.py
@@ -18,10 +18,12 @@ Fakes For filters tests.
class FakeHostManager(object):
- """host1: free_ram_mb=1024-512-512=0, free_disk_gb=1024-512-512=0
- host2: free_ram_mb=2048-512=1536 free_disk_gb=2048-512=1536
- host3: free_ram_mb=4096-1024=3072 free_disk_gb=4096-1024=3072
- host4: free_ram_mb=8192 free_disk_gb=8192"""
+ """
+ host1: free_ram_mb=1024-512-512=0, free_disk_gb=1024-512-512=0
+ host2: free_ram_mb=2048-512=1536 free_disk_gb=2048-512=1536
+ host3: free_ram_mb=4096-1024=3072 free_disk_gb=4096-1024=3072
+ host4: free_ram_mb=8192 free_disk_gb=8192
+ """
def __init__(self):
self.service_states = {
diff --git a/tests/unit/scheduler/test_base_filter.py b/tests/unit/scheduler/test_base_filter.py
index 65809ea..d66d84c 100644
--- a/tests/unit/scheduler/test_base_filter.py
+++ b/tests/unit/scheduler/test_base_filter.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import mox
-
from openstack.common.scheduler import base_filter
from tests import utils
@@ -25,7 +23,6 @@ class TestBaseFilter(utils.BaseTestCase):
def setUp(self):
super(TestBaseFilter, self).setUp()
self.filter = base_filter.BaseFilter()
- self.mox = mox.Mox()
def test_filter_one_is_called(self):
filters = [1, 2, 3, 4]
diff --git a/tests/unit/scheduler/test_host_filters.py b/tests/unit/scheduler/test_host_filters.py
index c36021c..b452955 100644
--- a/tests/unit/scheduler/test_host_filters.py
+++ b/tests/unit/scheduler/test_host_filters.py
@@ -425,7 +425,7 @@ class HostFiltersTestCase(utils.BaseTestCase):
self.assertFalse(filt_cls.host_passes(host, filter_properties))
def test_json_filter_happy_day(self):
- """Test json filter more thoroughly"""
+ """Test json filter more thoroughly."""
filt_cls = self.class_map['JsonFilter']()
raw = ['and',
'$capabilities.enabled',
diff --git a/tests/unit/test_excutils.py b/tests/unit/test_excutils.py
index 9d9050e..8c8137a 100644
--- a/tests/unit/test_excutils.py
+++ b/tests/unit/test_excutils.py
@@ -26,7 +26,7 @@ class SaveAndReraiseTest(utils.BaseTestCase):
try:
try:
raise Exception(msg)
- except:
+ except Exception:
with excutils.save_and_reraise_exception():
pass
except Exception as _e:
@@ -40,7 +40,7 @@ class SaveAndReraiseTest(utils.BaseTestCase):
try:
try:
raise Exception('dropped')
- except:
+ except Exception:
with excutils.save_and_reraise_exception():
raise Exception(msg)
except Exception as _e:
diff --git a/tests/unit/test_fileutils.py b/tests/unit/test_fileutils.py
index c0bf0ac..4214e83 100644
--- a/tests/unit/test_fileutils.py
+++ b/tests/unit/test_fileutils.py
@@ -15,10 +15,15 @@
# License for the specific language governing permissions and limitations
# under the License.
+import __builtin__
+import errno
import os
import shutil
import tempfile
+import mock
+import mox
+
from openstack.common import fileutils
from tests import utils
@@ -34,3 +39,109 @@ class EnsureTree(utils.BaseTestCase):
finally:
if os.path.exists(tmpdir):
shutil.rmtree(tmpdir)
+
+
+class TestCachedFile(utils.BaseTestCase):
+
+ def setUp(self):
+ super(TestCachedFile, self).setUp()
+ self.mox = mox.Mox()
+ self.addCleanup(self.mox.UnsetStubs)
+
+ def test_read_cached_file(self):
+ self.mox.StubOutWithMock(os.path, "getmtime")
+ os.path.getmtime(mox.IgnoreArg()).AndReturn(1)
+ self.mox.ReplayAll()
+
+ fileutils._FILE_CACHE = {
+ '/this/is/a/fake': {"data": 1123, "mtime": 1}
+ }
+ fresh, data = fileutils.read_cached_file("/this/is/a/fake")
+ fdata = fileutils._FILE_CACHE['/this/is/a/fake']["data"]
+ self.assertEqual(fdata, data)
+
+ def test_read_modified_cached_file(self):
+ self.mox.StubOutWithMock(os.path, "getmtime")
+ self.mox.StubOutWithMock(__builtin__, 'open')
+ os.path.getmtime(mox.IgnoreArg()).AndReturn(2)
+
+ fake_contents = "lorem ipsum"
+ fake_file = self.mox.CreateMockAnything()
+ fake_file.read().AndReturn(fake_contents)
+ fake_context_manager = self.mox.CreateMockAnything()
+ fake_context_manager.__enter__().AndReturn(fake_file)
+ fake_context_manager.__exit__(mox.IgnoreArg(),
+ mox.IgnoreArg(),
+ mox.IgnoreArg())
+
+ __builtin__.open(mox.IgnoreArg()).AndReturn(fake_context_manager)
+
+ self.mox.ReplayAll()
+ fileutils._FILE_CACHE = {
+ '/this/is/a/fake': {"data": 1123, "mtime": 1}
+ }
+
+ fresh, data = fileutils.read_cached_file("/this/is/a/fake")
+ self.assertEqual(data, fake_contents)
+ self.assertTrue(fresh)
+
+
+class DeleteIfExists(utils.BaseTestCase):
+ def test_file_present(self):
+ tmpfile = tempfile.mktemp()
+
+ open(tmpfile, 'w')
+ fileutils.delete_if_exists(tmpfile)
+ self.assertFalse(os.path.exists(tmpfile))
+
+ def test_file_absent(self):
+ tmpfile = tempfile.mktemp()
+
+ fileutils.delete_if_exists(tmpfile)
+ self.assertFalse(os.path.exists(tmpfile))
+
+ @mock.patch('os.unlink')
+ def test_file_error(self, osunlink):
+ tmpfile = tempfile.mktemp()
+
+ open(tmpfile, 'w')
+
+ error = OSError()
+ error.errno = errno.EINVAL
+ osunlink.side_effect = error
+
+ self.assertRaises(OSError, fileutils.delete_if_exists, tmpfile)
+
+
+class RemovePathOnError(utils.BaseTestCase):
+ def test_error(self):
+ tmpfile = tempfile.mktemp()
+ open(tmpfile, 'w')
+
+ try:
+ with fileutils.remove_path_on_error(tmpfile):
+ raise Exception
+ except Exception:
+ self.assertFalse(os.path.exists(tmpfile))
+
+ def test_no_error(self):
+ tmpfile = tempfile.mktemp()
+ open(tmpfile, 'w')
+
+ with fileutils.remove_path_on_error(tmpfile):
+ pass
+ self.assertTrue(os.path.exists(tmpfile))
+ os.unlink(tmpfile)
+
+
+class UtilsTestCase(utils.BaseTestCase):
+ def test_file_open(self):
+ dst_fd, dst_path = tempfile.mkstemp()
+ try:
+ os.close(dst_fd)
+ with open(dst_path, 'w') as f:
+ f.write('hello')
+ with fileutils.file_open(dst_path, 'r') as fp:
+ self.assertEquals(fp.read(), 'hello')
+ finally:
+ os.unlink(dst_path)
diff --git a/tests/unit/test_gettext.py b/tests/unit/test_gettext.py
index 3a86782..cd139ee 100644
--- a/tests/unit/test_gettext.py
+++ b/tests/unit/test_gettext.py
@@ -2,6 +2,7 @@
# Copyright 2012 Red Hat, Inc.
# All Rights Reserved.
+# Copyright 2013 IBM Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
@@ -15,7 +16,10 @@
# License for the specific language governing permissions and limitations
# under the License.
-import logging
+import copy
+import gettext
+import logging.handlers
+import os
import mock
@@ -42,3 +46,369 @@ class GettextTest(utils.BaseTestCase):
gettext_install.assert_called_once_with('blaa',
localedir='/foo/bar',
unicode=True)
+
+
+class MessageTestCase(utils.BaseTestCase):
+ """Unit tests for locale Message class."""
+
+ def setUp(self):
+ super(MessageTestCase, self).setUp()
+ self._lazy_gettext = gettextutils.get_lazy_gettext('oslo')
+
+ def tearDown(self):
+ # need to clean up stubs early since they interfere
+ # with super class clean up operations
+ self.mox.UnsetStubs()
+ super(MessageTestCase, self).tearDown()
+
+ def test_message_equal_to_string(self):
+ msgid = "Some msgid string"
+ result = self._lazy_gettext(msgid)
+
+ self.assertEqual(result, msgid)
+
+ def test_message_not_equal(self):
+ msgid = "Some msgid string"
+ result = self._lazy_gettext(msgid)
+
+ self.assertNotEqual(result, "Other string %s" % msgid)
+
+ def test_message_equal_with_param(self):
+ msgid = "Some string with params: %s"
+ params = (0, )
+
+ message = msgid % params
+
+ result = self._lazy_gettext(msgid) % params
+
+ self.assertEqual(result, message)
+
+ result_str = '%s' % result
+ self.assertEqual(result_str, message)
+
+ def test_message_injects_nonetype(self):
+ msgid = "Some string with param: %s"
+ params = None
+
+ message = msgid % params
+
+ result = self._lazy_gettext(msgid) % params
+
+ self.assertEqual(result, message)
+
+ result_str = '%s' % result
+ self.assertIn('None', result_str)
+ self.assertEqual(result_str, message)
+
+ def test_message_iterate(self):
+ msgid = "Some string with params: %s"
+ params = 'blah'
+
+ message = msgid % params
+
+ result = self._lazy_gettext(msgid) % params
+
+ # compare using iterators
+ for (c1, c2) in zip(result, message):
+ self.assertEqual(c1, c2)
+
+ def test_message_equal_with_dec_param(self):
+ """Verify we can inject numbers into Messages."""
+ msgid = "Some string with params: %d"
+ params = [0, 1, 10, 24124]
+
+ messages = []
+ results = []
+ for param in params:
+ messages.append(msgid % param)
+ results.append(self._lazy_gettext(msgid) % param)
+
+ for message, result in zip(messages, results):
+ self.assertEqual(type(result), gettextutils.Message)
+ self.assertEqual(result, message)
+
+ # simulate writing out as string
+ result_str = '%s' % result
+ self.assertEqual(result_str, message)
+
+ def test_message_equal_with_extra_params(self):
+ msgid = "Some string with params: %(param1)s %(param2)s"
+ params = {'param1': 'test',
+ 'param2': 'test2',
+ 'param3': 'notinstring'}
+
+ result = self._lazy_gettext(msgid) % params
+
+ self.assertEqual(result, msgid % params)
+
+ def test_message_object_param_copied(self):
+ """Verify that injected parameters get copied."""
+ some_obj = SomeObject()
+ some_obj.tag = 'stub_object'
+ msgid = "Found object: %(some_obj)s"
+
+ result = self._lazy_gettext(msgid) % {'some_obj': some_obj}
+
+ old_some_obj = copy.copy(some_obj)
+ some_obj.tag = 'switched_tag'
+
+ self.assertEqual(result, msgid % {'some_obj': old_some_obj})
+
+ def test_interpolation_with_missing_param(self):
+ msgid = ("Some string with params: %(param1)s %(param2)s"
+ " and a missing one %(missing)s")
+ params = {'param1': 'test',
+ 'param2': 'test2'}
+
+ test_me = lambda: self._lazy_gettext(msgid) % params
+
+ self.assertRaises(KeyError, test_me)
+
+ def test_operator_add(self):
+ msgid = "Some msgid string"
+ result = self._lazy_gettext(msgid)
+
+ additional = " with more added"
+ expected = msgid + additional
+ result = result + additional
+
+ self.assertEqual(type(result), gettextutils.Message)
+ self.assertEqual(result, expected)
+
+ def test_operator_radd(self):
+ msgid = "Some msgid string"
+ result = self._lazy_gettext(msgid)
+
+ additional = " with more added"
+ expected = additional + msgid
+ result = additional + result
+
+ self.assertEqual(type(result), gettextutils.Message)
+ self.assertEqual(result, expected)
+
+ def test_get_index(self):
+ msgid = "Some msgid string"
+ result = self._lazy_gettext(msgid)
+
+ expected = 'm'
+ result = result[2]
+
+ self.assertEqual(result, expected)
+
+ def test_getitem_string(self):
+ """Verify using string indexes on Message does not work."""
+ msgid = "Some msgid string"
+ result = self._lazy_gettext(msgid)
+
+ test_me = lambda: result['blah']
+
+ self.assertRaises(TypeError, test_me)
+
+ def test_contains(self):
+ msgid = "Some msgid string"
+ result = self._lazy_gettext(msgid)
+
+ self.assertIn('msgid', result)
+ self.assertNotIn('blah', result)
+
+ def test_locale_set_does_translation(self):
+ msgid = "Some msgid string"
+ result = self._lazy_gettext(msgid)
+ result.domain = 'test_domain'
+ result.locale = 'test_locale'
+ os.environ['TEST_DOMAIN_LOCALEDIR'] = '/tmp/blah'
+
+ self.mox.StubOutWithMock(gettext, 'translation')
+ fake_lang = self.mox.CreateMock(gettext.GNUTranslations)
+
+ gettext.translation('test_domain',
+ languages=['test_locale'],
+ fallback=True,
+ localedir='/tmp/blah').AndReturn(fake_lang)
+ fake_lang.ugettext(msgid).AndReturn(msgid)
+
+ self.mox.ReplayAll()
+ result = result.data
+ os.environ.pop('TEST_DOMAIN_LOCALEDIR')
+ self.assertEqual(msgid, result)
+
+ def _get_testmsg_inner_params(self):
+ return {'params': {'test1': 'blah1',
+ 'test2': 'blah2',
+ 'test3': SomeObject()},
+ 'domain': 'test_domain',
+ 'locale': 'en_US',
+ '_left_extra_msg': 'Extra. ',
+ '_right_extra_msg': '. More Extra.'}
+
+ def _get_full_test_message(self):
+ msgid = "Some msgid string: %(test1)s %(test2)s %(test3)s"
+ message = self._lazy_gettext(msgid)
+ attrs = self._get_testmsg_inner_params()
+ for (k, v) in attrs.items():
+ setattr(message, k, v)
+
+ return copy.deepcopy(message)
+
+ def test_message_copyable(self):
+ message = self._get_full_test_message()
+ copied_msg = copy.copy(message)
+
+ self.assertIsNot(message, copied_msg)
+
+ for k in self._get_testmsg_inner_params():
+ self.assertEqual(getattr(message, k),
+ getattr(copied_msg, k))
+
+ self.assertEqual(message, copied_msg)
+
+ message._msg = 'Some other msgid string'
+
+ self.assertNotEqual(message, copied_msg)
+
+ def test_message_copy_deepcopied(self):
+ message = self._get_full_test_message()
+ inner_obj = SomeObject()
+ message.params['test3'] = inner_obj
+
+ copied_msg = copy.copy(message)
+
+ self.assertIsNot(message, copied_msg)
+
+ inner_obj.tag = 'different'
+ self.assertNotEqual(message, copied_msg)
+
+ def test_add_returns_copy(self):
+ msgid = "Some msgid string: %(test1)s %(test2)s"
+ message = self._lazy_gettext(msgid)
+ m1 = '10 ' + message + ' 10'
+ m2 = '20 ' + message + ' 20'
+
+ self.assertIsNot(message, m1)
+ self.assertIsNot(message, m2)
+ self.assertIsNot(m1, m2)
+ self.assertEqual(m1, '10 %s 10' % msgid)
+ self.assertEqual(m2, '20 %s 20' % msgid)
+
+ def test_mod_returns_copy(self):
+ msgid = "Some msgid string: %(test1)s %(test2)s"
+ message = self._lazy_gettext(msgid)
+ m1 = message % {'test1': 'foo', 'test2': 'bar'}
+ m2 = message % {'test1': 'foo2', 'test2': 'bar2'}
+
+ self.assertIsNot(message, m1)
+ self.assertIsNot(message, m2)
+ self.assertIsNot(m1, m2)
+ self.assertEqual(m1, msgid % {'test1': 'foo', 'test2': 'bar'})
+ self.assertEqual(m2, msgid % {'test1': 'foo2', 'test2': 'bar2'})
+
+ def test_comparator_operators(self):
+ """Verify Message comparison is equivalent to string comparision."""
+ m1 = self._get_full_test_message()
+ m2 = copy.deepcopy(m1)
+ m3 = "1" + m1
+
+ # m1 and m2 are equal
+ self.assertEqual(m1 >= m2, str(m1) >= str(m2))
+ self.assertEqual(m1 <= m2, str(m1) <= str(m2))
+ self.assertEqual(m2 >= m1, str(m2) >= str(m1))
+ self.assertEqual(m2 <= m1, str(m2) <= str(m1))
+
+ # m1 is greater than m3
+ self.assertEqual(m1 >= m3, str(m1) >= str(m3))
+ self.assertEqual(m1 > m3, str(m1) > str(m3))
+
+ # m3 is not greater than m1
+ self.assertEqual(m3 >= m1, str(m3) >= str(m1))
+ self.assertEqual(m3 > m1, str(m3) > str(m1))
+
+ # m3 is less than m1
+ self.assertEqual(m3 <= m1, str(m3) <= str(m1))
+ self.assertEqual(m3 < m1, str(m3) < str(m1))
+
+ # m3 is not less than m1
+ self.assertEqual(m1 <= m3, str(m1) <= str(m3))
+ self.assertEqual(m1 < m3, str(m1) < str(m3))
+
+ def test_mul_operator(self):
+ message = self._get_full_test_message()
+ message_str = str(message)
+
+ self.assertEqual(message * 10, message_str * 10)
+ self.assertEqual(message * 20, message_str * 20)
+ self.assertEqual(10 * message, 10 * message_str)
+ self.assertEqual(20 * message, 20 * message_str)
+
+ def test_to_unicode(self):
+ message = self._get_full_test_message()
+ message_str = unicode(message)
+
+ self.assertEqual(message, message_str)
+ self.assertTrue(isinstance(message_str, unicode))
+
+
+class LocaleHandlerTestCase(utils.BaseTestCase):
+
+ def setUp(self):
+ super(LocaleHandlerTestCase, self).setUp()
+ self._lazy_gettext = gettextutils.get_lazy_gettext('oslo')
+ self.buffer_handler = logging.handlers.BufferingHandler(40)
+ self.locale_handler = gettextutils.LocaleHandler(
+ 'zh_CN', self.buffer_handler)
+ self.logger = logging.getLogger('localehander_logger')
+ self.logger.propogate = False
+ self.logger.setLevel(logging.DEBUG)
+ self.logger.addHandler(self.locale_handler)
+
+ def test_emit_message(self):
+ msgid = 'Some logrecord message.'
+ message = self._lazy_gettext(msgid)
+ self.emit_called = False
+
+ def emit(record):
+ self.assertEqual(record.msg.locale, 'zh_CN')
+ self.assertEqual(record.msg, msgid)
+ self.assertTrue(isinstance(record.msg,
+ gettextutils.Message))
+ self.emit_called = True
+ self.stubs.Set(self.buffer_handler, 'emit', emit)
+
+ self.logger.info(message)
+
+ self.assertTrue(self.emit_called)
+
+ def test_emit_nonmessage(self):
+ msgid = 'Some logrecord message.'
+ self.emit_called = False
+
+ def emit(record):
+ self.assertEqual(record.msg, msgid)
+ self.assertFalse(isinstance(record.msg,
+ gettextutils.Message))
+ self.emit_called = True
+ self.stubs.Set(self.buffer_handler, 'emit', emit)
+
+ self.logger.info(msgid)
+
+ self.assertTrue(self.emit_called)
+
+
+class SomeObject(object):
+
+ def __init__(self, tag='default'):
+ self.tag = tag
+
+ def __str__(self):
+ return self.tag
+
+ def __getstate__(self):
+ return self.__dict__
+
+ def __setstate__(self, state):
+ for (k, v) in state.items():
+ setattr(self, k, v)
+
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return self.tag == other.tag
+ return False
diff --git a/tests/unit/test_jsonutils.py b/tests/unit/test_jsonutils.py
index 35a9487..758455b 100644
--- a/tests/unit/test_jsonutils.py
+++ b/tests/unit/test_jsonutils.py
@@ -16,9 +16,10 @@
# under the License.
import datetime
-import StringIO
import xmlrpclib
+from six import StringIO
+
from openstack.common import jsonutils
from tests import utils
@@ -32,7 +33,7 @@ class JSONUtilsTestCase(utils.BaseTestCase):
self.assertEqual(jsonutils.loads('{"a": "b"}'), {'a': 'b'})
def test_load(self):
- x = StringIO.StringIO('{"a": "b"}')
+ x = StringIO('{"a": "b"}')
self.assertEqual(jsonutils.load(x), {'a': 'b'})
diff --git a/tests/unit/test_lockutils.py b/tests/unit/test_lockutils.py
index c37b030..84afa2d 100644
--- a/tests/unit/test_lockutils.py
+++ b/tests/unit/test_lockutils.py
@@ -75,7 +75,7 @@ class LockTestCase(utils.BaseTestCase):
"got mangled")
def test_synchronized_internally(self):
- """We can lock across multiple green threads"""
+ """We can lock across multiple green threads."""
saved_sem_num = len(lockutils._semaphores)
seen_threads = list()
@@ -105,7 +105,7 @@ class LockTestCase(utils.BaseTestCase):
"Semaphore leak detected")
def test_nested_external_works(self):
- """We can nest external syncs"""
+ """We can nest external syncs."""
tempdir = tempfile.mkdtemp()
try:
self.config(lock_path=tempdir)
@@ -126,7 +126,7 @@ class LockTestCase(utils.BaseTestCase):
shutil.rmtree(tempdir)
def _do_test_synchronized_externally(self):
- """We can lock across multiple processes"""
+ """We can lock across multiple processes."""
@lockutils.synchronized('external', 'test-', external=True)
def lock_files(handles_dir):
diff --git a/tests/unit/test_log.py b/tests/unit/test_log.py
index 301e3a8..a65801b 100644
--- a/tests/unit/test_log.py
+++ b/tests/unit/test_log.py
@@ -1,11 +1,11 @@
import cStringIO
import logging
import os
-import StringIO
import sys
import tempfile
from oslo.config import cfg
+from six import StringIO
from openstack.common import context
from openstack.common import jsonutils
@@ -238,7 +238,7 @@ class ContextFormatterTestCase(test_utils.BaseTestCase):
class ExceptionLoggingTestCase(test_utils.BaseTestCase):
- """Test that Exceptions are logged"""
+ """Test that Exceptions are logged."""
def test_excepthook_logs_exception(self):
product_name = 'somename'
@@ -267,7 +267,8 @@ class ExceptionLoggingTestCase(test_utils.BaseTestCase):
class FancyRecordTestCase(test_utils.BaseTestCase):
"""Test how we handle fancy record keys that are not in the
- base python logging"""
+ base python logging.
+ """
def setUp(self):
super(FancyRecordTestCase, self).setUp()
@@ -350,7 +351,7 @@ class SetDefaultsTestCase(test_utils.BaseTestCase):
class LogConfigOptsTestCase(test_utils.BaseTestCase):
def test_print_help(self):
- f = StringIO.StringIO()
+ f = StringIO()
CONF([])
CONF.print_help(file=f)
self.assertTrue('debug' in f.getvalue())
diff --git a/tests/unit/test_loopingcall.py b/tests/unit/test_loopingcall.py
index f7d21b3..89cf336 100644
--- a/tests/unit/test_loopingcall.py
+++ b/tests/unit/test_loopingcall.py
@@ -58,7 +58,7 @@ class LoopingCallTestCase(utils.BaseTestCase):
self.assertFalse(timer.start(interval=0.5).wait())
def test_interval_adjustment(self):
- """Ensure the interval is adjusted to account for task duration"""
+ """Ensure the interval is adjusted to account for task duration."""
self.num_runs = 3
now = datetime.datetime.utcnow()
diff --git a/tests/unit/test_notifier.py b/tests/unit/test_notifier.py
index 90d811f..6c3b886 100644
--- a/tests/unit/test_notifier.py
+++ b/tests/unit/test_notifier.py
@@ -29,7 +29,7 @@ ctxt2 = context.get_admin_context()
class NotifierTestCase(test_utils.BaseTestCase):
- """Test case for notifications"""
+ """Test case for notifications."""
def setUp(self):
super(NotifierTestCase, self).setUp()
notification_driver = [
@@ -54,7 +54,8 @@ class NotifierTestCase(test_utils.BaseTestCase):
def test_verify_message_format(self):
"""A test to ensure changing the message format is prohibitively
- annoying"""
+ annoying.
+ """
def message_assert(context, message):
fields = [('publisher_id', 'publisher_id'),
@@ -208,7 +209,7 @@ class SimpleNotifier(object):
class MultiNotifierTestCase(test_utils.BaseTestCase):
- """Test case for notifications"""
+ """Test case for notifications."""
def setUp(self):
super(MultiNotifierTestCase, self).setUp()
diff --git a/tests/unit/test_periodic.py b/tests/unit/test_periodic.py
index 1fb1574..d663f8b 100644
--- a/tests/unit/test_periodic.py
+++ b/tests/unit/test_periodic.py
@@ -48,7 +48,7 @@ class AService(periodic_task.PeriodicTasks):
class PeriodicTasksTestCase(utils.BaseTestCase):
- """Test cases for PeriodicTasks"""
+ """Test cases for PeriodicTasks."""
def test_is_called(self):
serv = AService()
diff --git a/tests/unit/test_plugin.py b/tests/unit/test_plugin.py
index 8ee405a..fd653d7 100644
--- a/tests/unit/test_plugin.py
+++ b/tests/unit/test_plugin.py
@@ -37,7 +37,7 @@ class ManagerTestCase(utils.BaseTestCase):
class NotifyTestCase(utils.BaseTestCase):
- """Test case for the plugin notification interface"""
+ """Test case for the plugin notification interface."""
def test_add_notifier(self):
notifier1 = SimpleNotifier()
@@ -99,7 +99,7 @@ class MockExtManager():
class APITestCase(utils.BaseTestCase):
- """Test case for the plugin api extension interface"""
+ """Test case for the plugin api extension interface."""
def test_add_extension(self):
def mock_load(_s):
return TestPluginClass()
diff --git a/tests/unit/test_policy.py b/tests/unit/test_policy.py
index 7e56796..33b291a 100644
--- a/tests/unit/test_policy.py
+++ b/tests/unit/test_policy.py
@@ -17,16 +17,25 @@
"""Test of Policy Engine"""
-import StringIO
+import os
import urllib
+import urllib2
import mock
-import urllib2
+from oslo.config import cfg
+from six import StringIO
from openstack.common import jsonutils
from openstack.common import policy
from tests import utils
+CONF = cfg.CONF
+
+TEST_VAR_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__),
+ '..', 'var'))
+
+ENFORCER = policy.Enforcer()
+
class TestException(Exception):
def __init__(self, *args, **kwargs):
@@ -35,6 +44,7 @@ class TestException(Exception):
class RulesTestCase(utils.BaseTestCase):
+
def test_init_basic(self):
rules = policy.Rules()
@@ -104,24 +114,50 @@ class RulesTestCase(utils.BaseTestCase):
self.assertEqual(str(rules), exemplar)
-class PolicySetAndResetTestCase(utils.BaseTestCase):
+class PolicyBaseTestCase(utils.BaseTestCase):
def setUp(self):
- super(PolicySetAndResetTestCase, self).setUp()
+ super(PolicyBaseTestCase, self).setUp()
+ CONF(args=['--config-dir', TEST_VAR_DIR])
+ self.enforcer = ENFORCER
+
+ def tearDown(self):
+ super(PolicyBaseTestCase, self).tearDown()
# Make sure the policy rules are reset for remaining tests
- self.addCleanup(setattr, policy, '_rules', None)
+ self.enforcer.clear()
+
+
+class EnforcerTest(PolicyBaseTestCase):
+
+ def test_load_file(self):
+ self.enforcer.load_rules(True)
+ self.assertIsNotNone(self.enforcer.rules)
+ self.assertIn('default', self.enforcer.rules)
+ self.assertIn('admin', self.enforcer.rules)
+
+ def test_reload(self):
+ self.enforcer.set_rules({'test': 'test'})
+ self.enforcer.load_rules(force_reload=True)
+
+ self.assertNotEquals(self.enforcer.rules, {'test': 'test'})
+ self.assertIn('default', self.enforcer.rules)
def test_set_rules(self):
# Make sure the rules are set properly
- policy._rules = None
- policy.set_rules('spam')
- self.assertEqual(policy._rules, 'spam')
+ self.enforcer.rules = None
+ self.enforcer.set_rules({'test': 1})
+ self.assertEqual(self.enforcer.rules, {'test': 1})
+
+ def test_set_rules_type(self):
+ self.assertRaises(TypeError,
+ self.enforcer.set_rules,
+ 'dummy')
- def test_reset(self):
+ def test_clear(self):
# Make sure the rules are reset
- policy._rules = 'spam'
- policy.reset()
- self.assertEqual(policy._rules, None)
+ self.enforcer.rules = 'spam'
+ self.enforcer.clear()
+ self.assertEqual(self.enforcer.rules, {})
class FakeCheck(policy.BaseCheck):
@@ -129,59 +165,58 @@ class FakeCheck(policy.BaseCheck):
self.result = result
def __str__(self):
- return self.result
+ return str(self.result)
- def __call__(self, target, creds):
+ def __call__(self, target, creds, enforcer):
if self.result is not None:
return self.result
- return (target, creds)
+ return (target, creds, enforcer)
-class CheckFunctionTestCase(utils.BaseTestCase):
-
- def setUp(self):
- super(CheckFunctionTestCase, self).setUp()
- # Make sure the policy rules are reset for remaining tests
- self.addCleanup(setattr, policy, '_rules', None)
+class CheckFunctionTestCase(PolicyBaseTestCase):
def test_check_explicit(self):
- policy._rules = None
+ self.enforcer.load_rules()
+ self.enforcer.rules = None
rule = FakeCheck()
- result = policy.check(rule, "target", "creds")
+ result = self.enforcer.enforce(rule, "target", "creds")
- self.assertEqual(result, ("target", "creds"))
- self.assertEqual(policy._rules, None)
+ self.assertEqual(result, ("target", "creds", self.enforcer))
+ self.assertEqual(self.enforcer.rules, None)
def test_check_no_rules(self):
- policy._rules = None
- result = policy.check('rule', "target", "creds")
+ self.enforcer.load_rules()
+ self.enforcer.rules = None
+ result = self.enforcer.enforce('rule', "target", "creds")
self.assertEqual(result, False)
- self.assertEqual(policy._rules, None)
+ self.assertEqual(self.enforcer.rules, None)
def test_check_missing_rule(self):
- policy._rules = {}
- result = policy.check('rule', 'target', 'creds')
+ self.enforcer.rules = {}
+ result = self.enforcer.enforce('rule', 'target', 'creds')
self.assertEqual(result, False)
def test_check_with_rule(self):
- policy._rules = dict(default=FakeCheck())
- result = policy.check("default", "target", "creds")
+ self.enforcer.load_rules()
+ self.enforcer.rules = dict(default=FakeCheck())
+ result = self.enforcer.enforce("default", "target", "creds")
- self.assertEqual(result, ("target", "creds"))
+ self.assertEqual(result, ("target", "creds", self.enforcer))
def test_check_raises(self):
- policy._rules = None
+ self.enforcer.rules = None
try:
- policy.check('rule', 'target', 'creds', TestException,
- "arg1", "arg2", kw1="kwarg1", kw2="kwarg2")
+ self.enforcer.enforce('rule', 'target', 'creds',
+ True, TestException, "arg1",
+ "arg2", kw1="kwarg1", kw2="kwarg2")
except TestException as exc:
self.assertEqual(exc.args, ("arg1", "arg2"))
self.assertEqual(exc.kwargs, dict(kw1="kwarg1", kw2="kwarg2"))
else:
- self.fail("policy.check() failed to raise requested exception")
+ self.fail("enforcer.enforce() failed to raise requested exception")
class FalseCheckTestCase(utils.BaseTestCase):
@@ -209,7 +244,7 @@ class TrueCheckTestCase(utils.BaseTestCase):
class CheckForTest(policy.Check):
- def __call__(self, target, creds):
+ def __call__(self, target, creds, enforcer):
pass
@@ -693,42 +728,48 @@ class CheckRegisterTestCase(utils.BaseTestCase):
class RuleCheckTestCase(utils.BaseTestCase):
- @mock.patch.object(policy, '_rules', {})
+ @mock.patch.object(ENFORCER, 'rules', {})
def test_rule_missing(self):
check = policy.RuleCheck('rule', 'spam')
- self.assertEqual(check('target', 'creds'), False)
+ self.assertEqual(check('target', 'creds', ENFORCER), False)
- @mock.patch.object(policy, '_rules',
+ @mock.patch.object(ENFORCER, 'rules',
dict(spam=mock.Mock(return_value=False)))
def test_rule_false(self):
+ enforcer = ENFORCER
+
check = policy.RuleCheck('rule', 'spam')
- self.assertEqual(check('target', 'creds'), False)
- policy._rules['spam'].assert_called_once_with('target', 'creds')
+ self.assertEqual(check('target', 'creds', enforcer), False)
+ enforcer.rules['spam'].assert_called_once_with('target', 'creds',
+ enforcer)
- @mock.patch.object(policy, '_rules',
+ @mock.patch.object(ENFORCER, 'rules',
dict(spam=mock.Mock(return_value=True)))
def test_rule_true(self):
+ enforcer = ENFORCER
check = policy.RuleCheck('rule', 'spam')
- self.assertEqual(check('target', 'creds'), True)
- policy._rules['spam'].assert_called_once_with('target', 'creds')
+ self.assertEqual(check('target', 'creds', enforcer), True)
+ enforcer.rules['spam'].assert_called_once_with('target', 'creds',
+ enforcer)
-class RoleCheckTestCase(utils.BaseTestCase):
+class RoleCheckTestCase(PolicyBaseTestCase):
def test_accept(self):
check = policy.RoleCheck('role', 'sPaM')
- self.assertEqual(check('target', dict(roles=['SpAm'])), True)
+ self.assertEqual(check('target', dict(roles=['SpAm']),
+ self.enforcer), True)
def test_reject(self):
check = policy.RoleCheck('role', 'spam')
- self.assertEqual(check('target', dict(roles=[])), False)
+ self.assertEqual(check('target', dict(roles=[]), self.enforcer), False)
-class HttpCheckTestCase(utils.BaseTestCase):
+class HttpCheckTestCase(PolicyBaseTestCase):
def decode_post_data(self, post_data):
result = {}
for item in post_data.split('&'):
@@ -738,12 +779,13 @@ class HttpCheckTestCase(utils.BaseTestCase):
return result
@mock.patch.object(urllib2, 'urlopen',
- return_value=StringIO.StringIO('True'))
+ return_value=StringIO('True'))
def test_accept(self, mock_urlopen):
check = policy.HttpCheck('http', '//example.com/%(name)s')
self.assertEqual(check(dict(name='target', spam='spammer'),
- dict(user='user', roles=['a', 'b', 'c'])),
+ dict(user='user', roles=['a', 'b', 'c']),
+ self.enforcer),
True)
self.assertEqual(mock_urlopen.call_count, 1)
@@ -756,12 +798,13 @@ class HttpCheckTestCase(utils.BaseTestCase):
))
@mock.patch.object(urllib2, 'urlopen',
- return_value=StringIO.StringIO('other'))
+ return_value=StringIO('other'))
def test_reject(self, mock_urlopen):
check = policy.HttpCheck('http', '//example.com/%(name)s')
self.assertEqual(check(dict(name='target', spam='spammer'),
- dict(user='user', roles=['a', 'b', 'c'])),
+ dict(user='user', roles=['a', 'b', 'c']),
+ self.enforcer),
False)
self.assertEqual(mock_urlopen.call_count, 1)
@@ -774,18 +817,22 @@ class HttpCheckTestCase(utils.BaseTestCase):
))
-class GenericCheckTestCase(utils.BaseTestCase):
+class GenericCheckTestCase(PolicyBaseTestCase):
def test_no_cred(self):
check = policy.GenericCheck('name', '%(name)s')
- self.assertEqual(check(dict(name='spam'), {}), False)
+ self.assertEqual(check(dict(name='spam'), {}, self.enforcer), False)
def test_cred_mismatch(self):
check = policy.GenericCheck('name', '%(name)s')
- self.assertEqual(check(dict(name='spam'), dict(name='ham')), False)
+ self.assertEqual(check(dict(name='spam'),
+ dict(name='ham'),
+ self.enforcer), False)
def test_accept(self):
check = policy.GenericCheck('name', '%(name)s')
- self.assertEqual(check(dict(name='spam'), dict(name='spam')), True)
+ self.assertEqual(check(dict(name='spam'),
+ dict(name='spam'),
+ self.enforcer), True)
diff --git a/tests/unit/test_processutils.py b/tests/unit/test_processutils.py
index e00a66e..4466d71 100644
--- a/tests/unit/test_processutils.py
+++ b/tests/unit/test_processutils.py
@@ -19,9 +19,10 @@ from __future__ import print_function
import fixtures
import os
-import StringIO
import tempfile
+from six import StringIO
+
from openstack.common import processutils
from tests import utils
@@ -113,7 +114,7 @@ echo $runs > "$1"
exit 1
''')
fp.close()
- os.chmod(tmpfilename, 0755)
+ os.chmod(tmpfilename, 0o755)
self.assertRaises(processutils.ProcessExecutionError,
processutils.execute,
tmpfilename, tmpfilename2, attempts=10,
@@ -158,7 +159,7 @@ echo foo > "$1"
grep foo
""")
fp.close()
- os.chmod(tmpfilename, 0755)
+ os.chmod(tmpfilename, 0o755)
processutils.execute(tmpfilename,
tmpfilename2,
process_input='foo',
@@ -213,7 +214,7 @@ class FakeSshChannel(object):
return self.rc
-class FakeSshStream(StringIO.StringIO):
+class FakeSshStream(StringIO):
def setup_channel(self, rc):
self.channel = FakeSshChannel(rc)
@@ -225,9 +226,9 @@ class FakeSshConnection(object):
def exec_command(self, cmd):
stdout = FakeSshStream('stdout')
stdout.setup_channel(self.rc)
- return (StringIO.StringIO(),
+ return (StringIO(),
stdout,
- StringIO.StringIO('stderr'))
+ StringIO('stderr'))
class SshExecuteTestCase(utils.BaseTestCase):
diff --git a/tests/unit/test_rootwrap.py b/tests/unit/test_rootwrap.py
index 5a5d9ca..25b2051 100644
--- a/tests/unit/test_rootwrap.py
+++ b/tests/unit/test_rootwrap.py
@@ -113,7 +113,7 @@ class RootwrapTestCase(utils.BaseTestCase):
p.wait()
def test_KillFilter_no_raise(self):
- """Makes sure ValueError from bug 926412 is gone"""
+ """Makes sure ValueError from bug 926412 is gone."""
f = filters.KillFilter("root", "")
# Providing anything other than kill should be False
usercmd = ['notkill', 999999]
@@ -123,7 +123,7 @@ class RootwrapTestCase(utils.BaseTestCase):
self.assertFalse(f.match(usercmd))
def test_KillFilter_deleted_exe(self):
- """Makes sure deleted exe's are killed correctly"""
+ """Makes sure deleted exe's are killed correctly."""
# See bug #967931.
def fake_readlink(blah):
return '/bin/commandddddd (deleted)'
@@ -135,7 +135,7 @@ class RootwrapTestCase(utils.BaseTestCase):
self.assertTrue(f.match(usercmd))
def test_KillFilter_upgraded_exe(self):
- """Makes sure upgraded exe's are killed correctly"""
+ """Makes sure upgraded exe's are killed correctly."""
# See bug #1179793.
def fake_readlink(blah):
return '/bin/commandddddd\0\05190bfb2 (deleted)'
diff --git a/tests/unit/test_service.py b/tests/unit/test_service.py
index b7ba4f7..4a2827e 100644
--- a/tests/unit/test_service.py
+++ b/tests/unit/test_service.py
@@ -44,7 +44,7 @@ class ExtendedService(service.Service):
class ServiceManagerTestCase(utils.BaseTestCase):
- """Test cases for Services"""
+ """Test cases for Services."""
def test_override_manager_method(self):
serv = ExtendedService()
serv.start()
diff --git a/tests/unit/test_strutils.py b/tests/unit/test_strutils.py
index bad50c8..42160a6 100644
--- a/tests/unit/test_strutils.py
+++ b/tests/unit/test_strutils.py
@@ -166,3 +166,35 @@ class StrUtilsTest(utils.BaseTestCase):
# Forcing incoming to ascii so it falls back to utf-8
self.assertEqual('ni\xc3\xb1o', safe_encode('ni\xc3\xb1o',
incoming='ascii'))
+
+ def test_string_conversions(self):
+ working_examples = {
+ '1024KB': 1048576,
+ '1024TB': 1125899906842624,
+ '1024K': 1048576,
+ '1024T': 1125899906842624,
+ '1TB': 1099511627776,
+ '1T': 1099511627776,
+ '1KB': 1024,
+ '1K': 1024,
+ '1B': 1,
+ '1': 1,
+ '1MB': 1048576,
+ '7MB': 7340032,
+ '0MB': 0,
+ '0KB': 0,
+ '0TB': 0,
+ '': 0,
+ }
+ for (in_value, expected_value) in working_examples.items():
+ b_value = strutils.to_bytes(in_value)
+ self.assertEquals(expected_value, b_value)
+ if in_value:
+ in_value = "-" + in_value
+ b_value = strutils.to_bytes(in_value)
+ self.assertEquals(expected_value * -1, b_value)
+ breaking_examples = [
+ 'junk1KB', '1023BBBB',
+ ]
+ for v in breaking_examples:
+ self.assertRaises(TypeError, strutils.to_bytes, v)
diff --git a/tests/unit/test_threadgroup.py b/tests/unit/test_threadgroup.py
index f627215..5af6653 100644
--- a/tests/unit/test_threadgroup.py
+++ b/tests/unit/test_threadgroup.py
@@ -27,7 +27,7 @@ LOG = logging.getLogger(__name__)
class ThreadGroupTestCase(utils.BaseTestCase):
- """Test cases for thread group"""
+ """Test cases for thread group."""
def setUp(self):
super(ThreadGroupTestCase, self).setUp()
self.tg = threadgroup.ThreadGroup()
diff --git a/tests/utils.py b/tests/utils.py
index 4682428..7d0cc85 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -34,7 +34,9 @@ class BaseTestCase(testtools.TestCase):
def setUp(self):
super(BaseTestCase, self).setUp()
- self.stubs = self.useFixture(moxstubout.MoxStubout()).stubs
+ moxfixture = self.useFixture(moxstubout.MoxStubout())
+ self.mox = moxfixture.mox
+ self.stubs = moxfixture.stubs
self.addCleanup(CONF.reset)
self.useFixture(fixtures.FakeLogger('openstack.common'))
self.useFixture(fixtures.Timeout(30, True))
diff --git a/tests/var/policy.json b/tests/var/policy.json
new file mode 100644
index 0000000..73730ae
--- /dev/null
+++ b/tests/var/policy.json
@@ -0,0 +1,4 @@
+{
+ "default": "rule:admin",
+ "admin": "is_admin:True"
+}
diff --git a/tox.ini b/tox.ini
index 6b9f5bc..1570db5 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,5 +1,5 @@
[tox]
-envlist = py26,py27,pep8,pylint
+envlist = py26,py27,py33,pep8,pylint
[testenv]
setenv = VIRTUAL_ENV={envdir}
@@ -9,15 +9,15 @@ setenv = VIRTUAL_ENV={envdir}
NOSE_OPENSTACK_YELLOW=0.025
NOSE_OPENSTACK_SHOW_ELAPSED=1
NOSE_OPENSTACK_STDOUT=1
-deps = -r{toxinidir}/tools/pip-requires
- -r{toxinidir}/tools/test-requires
+deps = -r{toxinidir}/requirements.txt
+ -r{toxinidir}/test-requirements.txt
commands =
python tools/patch_tox_venv.py
nosetests --with-doctest --exclude-dir=tests/testmods {posargs}
[flake8]
show-source = True
-ignore = H201,H202,H302,H304,H306,H401,H402,H403,H404
+ignore = H202,H302,H304,H404
exclude = .venv,.tox,dist,doc,*.egg,.update-venv
[testenv:pep8]