summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore2
-rw-r--r--MANIFEST.in2
-rwxr-xr-xbin/keystone-all2
-rw-r--r--doc/source/apache-httpd.rst12
-rw-r--r--doc/source/developing.rst38
-rw-r--r--etc/keystone.conf.sample3
-rw-r--r--keystone/assignment/backends/ldap.py27
-rw-r--r--keystone/assignment/core.py22
-rw-r--r--keystone/auth/controllers.py4
-rw-r--r--keystone/auth/core.py72
-rw-r--r--keystone/auth/plugins/password.py2
-rw-r--r--keystone/auth/plugins/token.py2
-rw-r--r--keystone/catalog/backends/sql.py77
-rw-r--r--keystone/catalog/backends/templated.py2
-rw-r--r--keystone/catalog/core.py2
-rw-r--r--keystone/clean.py7
-rw-r--r--keystone/cli.py54
-rw-r--r--keystone/common/cms.py2
-rw-r--r--keystone/common/config.py100
-rw-r--r--keystone/common/controller.py7
-rw-r--r--keystone/common/environment/__init__.py2
-rw-r--r--keystone/common/environment/eventlet_server.py5
-rw-r--r--keystone/common/ldap/core.py20
-rw-r--r--keystone/common/ldap/fakeldap.py4
-rw-r--r--keystone/common/openssl.py39
-rw-r--r--keystone/common/sql/core.py7
-rw-r--r--keystone/common/sql/legacy.py2
-rw-r--r--keystone/common/sql/migrate_repo/versions/031_drop_credential_indexes.py40
-rw-r--r--keystone/common/sql/migrate_repo/versions/032_username_length.py31
-rw-r--r--keystone/common/sql/migration.py42
-rw-r--r--keystone/common/sql/nova.py2
-rw-r--r--keystone/common/utils.py15
-rw-r--r--keystone/common/wsgi.py19
-rw-r--r--keystone/contrib/access/core.py3
-rw-r--r--keystone/contrib/example/__init__.py (renamed from tests/tmp/.gitkeep)0
-rw-r--r--keystone/contrib/example/migrate_repo/__init__.py0
-rw-r--r--keystone/contrib/example/migrate_repo/migrate.cfg25
-rw-r--r--keystone/contrib/example/migrate_repo/versions/001_example_table.py45
-rw-r--r--keystone/contrib/example/migrate_repo/versions/__init__.py0
-rw-r--r--keystone/contrib/stats/core.py2
-rw-r--r--keystone/contrib/user_crud/core.py4
-rw-r--r--keystone/controllers.py2
-rw-r--r--keystone/credential/core.py2
-rw-r--r--keystone/exception.py2
-rw-r--r--keystone/identity/backends/ldap.py35
-rw-r--r--keystone/identity/backends/sql.py2
-rw-r--r--keystone/identity/controllers.py19
-rw-r--r--keystone/identity/core.py2
-rw-r--r--keystone/middleware/core.py3
-rw-r--r--keystone/middleware/s3_token.py2
-rw-r--r--keystone/openstack/common/context.py83
-rw-r--r--keystone/openstack/common/crypto/utils.py10
-rw-r--r--keystone/openstack/common/eventlet_backdoor.py146
-rw-r--r--keystone/openstack/common/excutils.py99
-rw-r--r--keystone/openstack/common/gettextutils.py263
-rw-r--r--keystone/openstack/common/importutils.py7
-rw-r--r--keystone/openstack/common/jsonutils.py67
-rw-r--r--keystone/openstack/common/local.py47
-rw-r--r--keystone/openstack/common/log.py559
-rw-r--r--keystone/openstack/common/loopingcall.py147
-rw-r--r--keystone/openstack/common/network_utils.py81
-rw-r--r--keystone/openstack/common/notifier/__init__.py14
-rw-r--r--keystone/openstack/common/notifier/api.py173
-rw-r--r--keystone/openstack/common/notifier/log_notifier.py37
-rw-r--r--keystone/openstack/common/notifier/no_op_notifier.py19
-rw-r--r--keystone/openstack/common/notifier/rpc_notifier.py46
-rw-r--r--keystone/openstack/common/notifier/rpc_notifier2.py52
-rw-r--r--keystone/openstack/common/notifier/test_notifier.py22
-rw-r--r--keystone/openstack/common/rpc/__init__.py307
-rw-r--r--keystone/openstack/common/rpc/amqp.py615
-rw-r--r--keystone/openstack/common/rpc/common.py509
-rw-r--r--keystone/openstack/common/rpc/dispatcher.py178
-rw-r--r--keystone/openstack/common/rpc/impl_fake.py195
-rw-r--r--keystone/openstack/common/rpc/impl_kombu.py861
-rw-r--r--keystone/openstack/common/rpc/impl_qpid.py739
-rw-r--r--keystone/openstack/common/rpc/impl_zmq.py817
-rw-r--r--keystone/openstack/common/rpc/matchmaker.py324
-rw-r--r--keystone/openstack/common/rpc/matchmaker_redis.py145
-rw-r--r--keystone/openstack/common/rpc/matchmaker_ring.py108
-rw-r--r--keystone/openstack/common/rpc/proxy.py226
-rw-r--r--keystone/openstack/common/rpc/securemessage.py521
-rw-r--r--keystone/openstack/common/rpc/serializer.py52
-rw-r--r--keystone/openstack/common/rpc/service.py78
-rwxr-xr-xkeystone/openstack/common/rpc/zmq_receiver.py41
-rw-r--r--keystone/openstack/common/service.py450
-rw-r--r--keystone/openstack/common/sslutils.py100
-rw-r--r--keystone/openstack/common/threadgroup.py121
-rw-r--r--keystone/openstack/common/timeutils.py32
-rw-r--r--keystone/openstack/common/uuidutils.py39
-rw-r--r--keystone/policy/backends/rules.py2
-rw-r--r--keystone/service.py29
-rw-r--r--keystone/tests/__init__.py0
-rw-r--r--keystone/tests/_ldap_livetest.py (renamed from tests/_ldap_livetest.py)13
-rw-r--r--keystone/tests/_ldap_tls_livetest.py (renamed from tests/_ldap_tls_livetest.py)2
-rw-r--r--keystone/tests/_sql_livetest.py (renamed from tests/_sql_livetest.py)0
-rw-r--r--keystone/tests/_test_import_auth_token.py (renamed from tests/_test_import_auth_token.py)0
-rw-r--r--keystone/tests/auth_plugin_external_disabled.conf (renamed from tests/auth_plugin_external_disabled.conf)0
-rw-r--r--keystone/tests/auth_plugin_external_domain.conf (renamed from tests/auth_plugin_external_domain.conf)0
-rw-r--r--keystone/tests/backend_db2.conf (renamed from tests/backend_db2.conf)0
-rw-r--r--keystone/tests/backend_ldap.conf (renamed from tests/backend_ldap.conf)0
-rw-r--r--keystone/tests/backend_ldap_sql.conf (renamed from tests/backend_ldap_sql.conf)1
-rw-r--r--keystone/tests/backend_liveldap.conf (renamed from tests/backend_liveldap.conf)0
-rw-r--r--keystone/tests/backend_mysql.conf (renamed from tests/backend_mysql.conf)0
-rw-r--r--keystone/tests/backend_pam.conf (renamed from tests/backend_pam.conf)0
-rw-r--r--keystone/tests/backend_postgresql.conf (renamed from tests/backend_postgresql.conf)0
-rw-r--r--keystone/tests/backend_sql.conf (renamed from tests/backend_sql.conf)0
-rw-r--r--keystone/tests/backend_sql_disk.conf (renamed from tests/backend_sql_disk.conf)0
-rw-r--r--keystone/tests/backend_tls_liveldap.conf (renamed from tests/backend_tls_liveldap.conf)0
-rw-r--r--keystone/tests/core.py (renamed from keystone/test.py)9
-rw-r--r--keystone/tests/default_catalog.templates (renamed from tests/default_catalog.templates)0
-rw-r--r--keystone/tests/default_fixtures.py (renamed from tests/default_fixtures.py)0
-rw-r--r--keystone/tests/legacy_d5.mysql (renamed from tests/legacy_d5.mysql)0
-rw-r--r--keystone/tests/legacy_d5.sqlite (renamed from tests/legacy_d5.sqlite)0
-rw-r--r--keystone/tests/legacy_diablo.mysql (renamed from tests/legacy_diablo.mysql)0
-rw-r--r--keystone/tests/legacy_diablo.sqlite (renamed from tests/legacy_diablo.sqlite)0
-rw-r--r--keystone/tests/legacy_essex.mysql (renamed from tests/legacy_essex.mysql)0
-rw-r--r--keystone/tests/legacy_essex.sqlite (renamed from tests/legacy_essex.sqlite)0
-rw-r--r--keystone/tests/test_auth.py (renamed from tests/test_auth.py)5
-rw-r--r--keystone/tests/test_auth_plugin.conf (renamed from tests/test_auth_plugin.conf)0
-rw-r--r--keystone/tests/test_auth_plugin.py (renamed from tests/test_auth_plugin.py)2
-rw-r--r--keystone/tests/test_backend.py (renamed from tests/test_backend.py)18
-rw-r--r--keystone/tests/test_backend_kvs.py (renamed from tests/test_backend_kvs.py)3
-rw-r--r--keystone/tests/test_backend_ldap.py (renamed from tests/test_backend_ldap.py)6
-rw-r--r--keystone/tests/test_backend_memcache.py (renamed from tests/test_backend_memcache.py)2
-rw-r--r--keystone/tests/test_backend_pam.py (renamed from tests/test_backend_pam.py)2
-rw-r--r--keystone/tests/test_backend_sql.py (renamed from tests/test_backend_sql.py)5
-rw-r--r--keystone/tests/test_backend_templated.py (renamed from tests/test_backend_templated.py)2
-rw-r--r--keystone/tests/test_catalog.py (renamed from tests/test_catalog.py)0
-rw-r--r--keystone/tests/test_cert_setup.py (renamed from tests/test_cert_setup.py)2
-rw-r--r--keystone/tests/test_config.py (renamed from tests/test_config.py)2
-rw-r--r--keystone/tests/test_content_types.py (renamed from tests/test_content_types.py)2
-rw-r--r--keystone/tests/test_contrib_s3_core.py (renamed from tests/test_contrib_s3_core.py)2
-rw-r--r--keystone/tests/test_contrib_stats_core.py (renamed from tests/test_contrib_stats_core.py)2
-rw-r--r--keystone/tests/test_drivers.py (renamed from tests/test_drivers.py)0
-rw-r--r--keystone/tests/test_exception.py (renamed from tests/test_exception.py)2
-rw-r--r--keystone/tests/test_import_legacy.py (renamed from tests/test_import_legacy.py)2
-rw-r--r--keystone/tests/test_injection.py (renamed from tests/test_injection.py)0
-rw-r--r--keystone/tests/test_ipv6.py (renamed from tests/test_ipv6.py)2
-rw-r--r--keystone/tests/test_keystoneclient.py (renamed from tests/test_keystoneclient.py)49
-rw-r--r--keystone/tests/test_keystoneclient_sql.py (renamed from tests/test_keystoneclient_sql.py)3
-rw-r--r--keystone/tests/test_middleware.py (renamed from tests/test_middleware.py)2
-rw-r--r--keystone/tests/test_no_admin_token_auth.py (renamed from tests/test_no_admin_token_auth.py)2
-rw-r--r--keystone/tests/test_overrides.conf (renamed from tests/test_overrides.conf)6
-rw-r--r--keystone/tests/test_pki_token_provider.conf (renamed from tests/test_pki_token_provider.conf)0
-rw-r--r--keystone/tests/test_policy.py (renamed from tests/test_policy.py)2
-rw-r--r--keystone/tests/test_s3_token_middleware.py (renamed from tests/test_s3_token_middleware.py)0
-rw-r--r--keystone/tests/test_serializer.py (renamed from tests/test_serializer.py)2
-rw-r--r--keystone/tests/test_singular_plural.py (renamed from tests/test_singular_plural.py)0
-rw-r--r--keystone/tests/test_sizelimit.py (renamed from tests/test_sizelimit.py)2
-rw-r--r--keystone/tests/test_sql_core.py (renamed from tests/test_sql_core.py)2
-rw-r--r--keystone/tests/test_sql_migrate_extensions.py47
-rw-r--r--keystone/tests/test_sql_upgrade.py (renamed from tests/test_sql_upgrade.py)182
-rw-r--r--keystone/tests/test_ssl.py (renamed from tests/test_ssl.py)2
-rw-r--r--keystone/tests/test_token_bind.py (renamed from tests/test_token_bind.py)2
-rw-r--r--keystone/tests/test_token_provider.py (renamed from tests/test_token_provider.py)2
-rw-r--r--keystone/tests/test_url_middleware.py (renamed from tests/test_url_middleware.py)2
-rw-r--r--keystone/tests/test_utils.py (renamed from tests/test_utils.py)2
-rw-r--r--keystone/tests/test_uuid_token_provider.conf (renamed from tests/test_uuid_token_provider.conf)0
-rw-r--r--keystone/tests/test_v3.py (renamed from tests/test_v3.py)3
-rw-r--r--keystone/tests/test_v3_auth.py (renamed from tests/test_v3_auth.py)63
-rw-r--r--keystone/tests/test_v3_catalog.py (renamed from tests/test_v3_catalog.py)0
-rw-r--r--keystone/tests/test_v3_credential.py (renamed from tests/test_v3_credential.py)0
-rw-r--r--keystone/tests/test_v3_identity.py (renamed from tests/test_v3_identity.py)0
-rw-r--r--keystone/tests/test_v3_policy.py (renamed from tests/test_v3_policy.py)0
-rw-r--r--keystone/tests/test_v3_protection.py (renamed from tests/test_v3_protection.py)0
-rw-r--r--keystone/tests/test_versions.py (renamed from tests/test_versions.py)2
-rw-r--r--keystone/tests/test_wsgi.py (renamed from tests/test_wsgi.py)67
-rw-r--r--keystone/tests/tmp/.gitkeep0
-rw-r--r--keystone/token/backends/kvs.py2
-rw-r--r--keystone/token/backends/memcache.py2
-rw-r--r--keystone/token/backends/sql.py9
-rw-r--r--keystone/token/controllers.py10
-rw-r--r--keystone/token/core.py2
-rw-r--r--keystone/token/provider.py2
-rw-r--r--keystone/token/providers/pki.py2
-rw-r--r--keystone/trust/backends/sql.py8
-rw-r--r--keystone/trust/controllers.py2
-rw-r--r--keystone/trust/core.py2
-rw-r--r--requirements.txt2
-rwxr-xr-xrun_tests.sh9
-rw-r--r--setup.cfg2
-rw-r--r--test-requirements.txt3
182 files changed, 9207 insertions, 556 deletions
diff --git a/.gitignore b/.gitignore
index d4915b0b..1297ba42 100644
--- a/.gitignore
+++ b/.gitignore
@@ -24,6 +24,6 @@ build/
dist/
etc/keystone.conf
etc/logging.conf
-tests/tmp/
+keystone/tests/tmp/
.project
.pydevproject
diff --git a/MANIFEST.in b/MANIFEST.in
index 2373ea28..9c59a76b 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -16,7 +16,7 @@ include etc/*
include httpd/*
graft bin
graft doc
-graft tests
+graft keystone/tests
graft tools
graft examples
recursive-include keystone *.json *.xml *.cfg *.pem README *.po *.pot *.sql
diff --git a/bin/keystone-all b/bin/keystone-all
index 53b50199..bb755606 100755
--- a/bin/keystone-all
+++ b/bin/keystone-all
@@ -80,7 +80,7 @@ if __name__ == '__main__':
version=pbr.version.VersionInfo('keystone').version_string(),
default_config_files=config_files)
- config.setup_logging(CONF)
+ config.setup_logging(CONF, product_name='keystone')
# Log the options used when starting if we're in debug mode...
if CONF.debug:
diff --git a/doc/source/apache-httpd.rst b/doc/source/apache-httpd.rst
index 41437780..5bc0dbe8 100644
--- a/doc/source/apache-httpd.rst
+++ b/doc/source/apache-httpd.rst
@@ -87,7 +87,17 @@ Putting it somewhere else requires you set up your SELinux policy accordingly.
Keystone Configuration
----------------------
-Make sure you use the ``SQL`` driver for ``tokens``, otherwise the tokens will not be shared between the processes of the Apache HTTPD server. To do that, in ``/etc/keystone/keystone.conf`` make sure you have set::
+Make sure you use either the ``SQL`` or the ``memcached`` driver for ``tokens``, otherwise the tokens will not be shared between the processes of the Apache HTTPD server.
+
+For ``SQL,`` in ``/etc/keystone/keystone.conf`` make sure you have set::
[token]
driver = keystone.token.backends.sql.Token
+
+For ``memcache,`` in ``/etc/keystone/keystone.conf`` make sure you have set::
+
+ [token]
+ driver = keystone.token.backends.memcache.Token
+
+In both cases, all servers that are storing tokens need a shared backend. This means either that both point
+to the same database server, or both point to a common memcached instance.
diff --git a/doc/source/developing.rst b/doc/source/developing.rst
index c14ef7ab..7029e1c8 100644
--- a/doc/source/developing.rst
+++ b/doc/source/developing.rst
@@ -71,6 +71,36 @@ place::
.. _`python-keystoneclient`: https://github.com/openstack/python-keystoneclient
+Database Schema Migrations
+--------------------------
+
+Keystone uses SQLAlchemy-migrate
+_`SQLAlchemy-migrate`:http://code.google.com/p/sqlalchemy-migrate/ to migrate the SQL database
+between revisions. For core components, the migrations are kept in a central
+repository under keystone/common/sql/migrate_repo.
+
+Extensions should be created as directories under `keystone/contrib`. An
+extension that requires sql migrations should not change the common repository,
+but should instead have its own repository. This repository must be in the
+extension's directory in `keystone/contrib/<extension>/migrate_repo.` In
+addition it needs a subdirectory named `versions`. For example, if the
+extension name is `my_extension` then the directory structure would be
+`keystone/contrib/my_extension/migrate_repo/versions/`. For the migration
+o work, both the migrate_repo and versions subdirectories must have empty
+__init__.py files. SQLAlchemy-migrate will look for a configuration file in
+the migrate_repo named migrate.cfg. This conforms to a Key/value ini file
+format. A sample config file with the minimal set of values is::
+
+ [db_settings]
+ repository_id=my_extension
+ version_table=migrate_version
+ required_dbs=[]
+
+The directory `keystone/contrib/example` contains a sample extension migration.
+
+Migrations for extension must be explicitly run. To run a migration for a specific
+extension, run `keystone-manage --extension <name> db_sync`.
+
Initial Sample Data
-------------------
@@ -103,8 +133,8 @@ Test Structure
--------------
``./run_test.sh`` uses its python cohort (``run_tests.py``) to iterate
-through the ``tests`` directory, using Nosetest to collect the tests and
-invoke them using an OpenStack custom test running that displays the tests
+through the ``keystone/tests`` directory, using Nosetest to collect the tests
+and invoke them using an OpenStack custom test running that displays the tests
as well as the time taken to run those tests.
Not all of the tests in the tests directory are strictly unit tests. Keystone
@@ -193,9 +223,9 @@ and set environment variables ``KEYSTONE_IDENTITY_BACKEND=ldap`` and
``KEYSTONE_CLEAR_LDAP=yes`` in your ``localrc`` file.
The unit tests can be run against a live server with
-``tests/_ldap_livetest.py``. The default password is ``test`` but if you have
+``keystone/tests/_ldap_livetest.py``. The default password is ``test`` but if you have
installed devstack with a different LDAP password, modify the file
-``tests/backend_liveldap.conf`` to reflect your password.
+``keystone/tests/backend_liveldap.conf`` to reflect your password.
Building the Documentation
diff --git a/etc/keystone.conf.sample b/etc/keystone.conf.sample
index a49a9a5e..90efe5f6 100644
--- a/etc/keystone.conf.sample
+++ b/etc/keystone.conf.sample
@@ -100,6 +100,9 @@
# exist to order to maintain support for your v2 clients.
# default_domain_id = default
+# Maximum supported length for user passwords; decrease to improve performance.
+# max_password_length = 4096
+
[credential]
# driver = keystone.credential.backends.sql.Credential
diff --git a/keystone/assignment/backends/ldap.py b/keystone/assignment/backends/ldap.py
index 9b273e40..45ce6432 100644
--- a/keystone/assignment/backends/ldap.py
+++ b/keystone/assignment/backends/ldap.py
@@ -23,11 +23,11 @@ from keystone import assignment
from keystone import clean
from keystone.common import dependency
from keystone.common import ldap as common_ldap
-from keystone.common import logging
from keystone.common import models
from keystone import config
from keystone import exception
from keystone.identity.backends import ldap as ldap_identity
+from keystone.openstack.common import log as logging
CONF = config.CONF
@@ -263,28 +263,19 @@ class ProjectApi(common_ldap.EnabledEmuMixIn, common_ldap.BaseLdap):
DEFAULT_OBJECTCLASS = 'groupOfNames'
DEFAULT_ID_ATTR = 'cn'
DEFAULT_MEMBER_ATTRIBUTE = 'member'
- DEFAULT_ATTRIBUTE_IGNORE = []
NotFound = exception.ProjectNotFound
notfound_arg = 'project_id' # NOTE(yorik-sar): while options_name = tenant
options_name = 'tenant'
- attribute_mapping = {'name': 'ou',
- 'description': 'description',
- 'tenantId': 'cn',
- 'enabled': 'enabled',
- 'domain_id': 'domain_id'}
+ attribute_options_names = {'name': 'name',
+ 'description': 'desc',
+ 'enabled': 'enabled',
+ 'domain_id': 'domain_id'}
model = models.Project
def __init__(self, conf):
super(ProjectApi, self).__init__(conf)
- self.attribute_mapping['name'] = conf.ldap.tenant_name_attribute
- self.attribute_mapping['description'] = conf.ldap.tenant_desc_attribute
- self.attribute_mapping['enabled'] = conf.ldap.tenant_enabled_attribute
- self.attribute_mapping['domain_id'] = (
- conf.ldap.tenant_domain_id_attribute)
self.member_attribute = (getattr(conf.ldap, 'tenant_member_attribute')
or self.DEFAULT_MEMBER_ATTRIBUTE)
- self.attribute_ignore = (getattr(conf.ldap, 'tenant_attribute_ignore')
- or self.DEFAULT_ATTRIBUTE_IGNORE)
def create(self, values):
self.affirm_unique(values)
@@ -381,21 +372,15 @@ class RoleApi(common_ldap.BaseLdap):
DEFAULT_STRUCTURAL_CLASSES = []
DEFAULT_OBJECTCLASS = 'organizationalRole'
DEFAULT_MEMBER_ATTRIBUTE = 'roleOccupant'
- DEFAULT_ATTRIBUTE_IGNORE = []
NotFound = exception.RoleNotFound
options_name = 'role'
- attribute_mapping = {'name': 'ou',
- #'serviceId': 'service_id',
- }
+ attribute_options_names = {'name': 'name'}
model = models.Role
def __init__(self, conf):
super(RoleApi, self).__init__(conf)
- self.attribute_mapping['name'] = conf.ldap.role_name_attribute
self.member_attribute = (getattr(conf.ldap, 'role_member_attribute')
or self.DEFAULT_MEMBER_ATTRIBUTE)
- self.attribute_ignore = (getattr(conf.ldap, 'role_attribute_ignore')
- or self.DEFAULT_ATTRIBUTE_IGNORE)
def get(self, id, filter=None):
model = super(RoleApi, self).get(id, filter)
diff --git a/keystone/assignment/core.py b/keystone/assignment/core.py
index 64edb3fa..d78d3485 100644
--- a/keystone/assignment/core.py
+++ b/keystone/assignment/core.py
@@ -17,10 +17,10 @@
"""Main entry point into the assignment service."""
from keystone.common import dependency
-from keystone.common import logging
from keystone.common import manager
from keystone import config
from keystone import exception
+from keystone.openstack.common import log as logging
CONF = config.CONF
@@ -178,9 +178,23 @@ class Manager(manager.Manager):
keystone.exception.UserNotFound
"""
- self.driver.add_role_to_user_and_project(user_id,
- tenant_id,
- config.CONF.member_role_id)
+ try:
+ self.driver.add_role_to_user_and_project(
+ user_id,
+ tenant_id,
+ config.CONF.member_role_id)
+ except exception.RoleNotFound:
+ LOG.info(_("Creating the default role %s "
+ "because it does not exist.") %
+ config.CONF.member_role_id)
+ role = {'id': CONF.member_role_id,
+ 'name': CONF.member_role_name}
+ self.driver.create_role(config.CONF.member_role_id, role)
+ #now that default role exists, the add should succeed
+ self.driver.add_role_to_user_and_project(
+ user_id,
+ tenant_id,
+ config.CONF.member_role_id)
def remove_user_from_project(self, tenant_id, user_id):
"""Remove user from a tenant
diff --git a/keystone/auth/controllers.py b/keystone/auth/controllers.py
index d1bd764f..e6557be5 100644
--- a/keystone/auth/controllers.py
+++ b/keystone/auth/controllers.py
@@ -17,12 +17,12 @@
from keystone.common import controller
from keystone.common import dependency
-from keystone.common import logging
from keystone.common import wsgi
from keystone import config
from keystone import exception
from keystone import identity
from keystone.openstack.common import importutils
+from keystone.openstack.common import log as logging
from keystone import token
from keystone import trust
@@ -328,7 +328,7 @@ class Auth(controller.V3Controller):
def authenticate(self, context, auth_info, auth_context):
"""Authenticate user."""
- # user have been authenticated externally
+ # user has been authenticated externally
if 'REMOTE_USER' in context:
external = get_auth_method('external')
external.authenticate(context, auth_info, auth_context)
diff --git a/keystone/auth/core.py b/keystone/auth/core.py
index b7bdb7c6..26e7a470 100644
--- a/keystone/auth/core.py
+++ b/keystone/auth/core.py
@@ -35,46 +35,52 @@ class AuthMethodHandler(object):
by default. "method_names" is a list and "extras" is
a dictionary.
- If successful, plugin must set "user_id" in "auth_context".
- "method_name" is used to convey any additional authentication methods
- in case authentication is for re-scoping. For example,
- if the authentication is for re-scoping, plugin must append the
- previous method names into "method_names". Also, plugin may add
- any additional information into "extras". Anything in "extras"
- will be conveyed in the token's "extras" field. Here's an example of
- "auth_context" on successful authentication.
+ If successful, plugin must set ``user_id`` in ``auth_context``.
+ ``method_name`` is used to convey any additional authentication methods
+ in case authentication is for re-scoping. For example, if the
+ authentication is for re-scoping, plugin must append the previous
+ method names into ``method_names``. Also, plugin may add any additional
+ information into ``extras``. Anything in ``extras`` will be conveyed in
+ the token's ``extras`` attribute. Here's an example of ``auth_context``
+ on successful authentication::
- {"user_id": "abc123",
- "methods": ["password", "token"],
- "extras": {}}
+ {
+ "extras": {},
+ "methods": [
+ "password",
+ "token"
+ ],
+ "user_id": "abc123"
+ }
Plugins are invoked in the order in which they are specified in the
- "methods" attribute of the "identity" object.
- For example, with the following authentication request,
+ ``methods`` attribute of the ``identity`` object. For example,
+ ``custom-plugin`` is invoked before ``password``, which is invoked
+ before ``token`` in the following authentication request::
- {"auth": {
- "identity": {
- "methods": ["custom-plugin", "password", "token"],
- "token": {
- "id": "sdfafasdfsfasfasdfds"
- },
- "custom-plugin": {
- "custom-data": "sdfdfsfsfsdfsf"
- },
- "password": {
- "user": {
- "id": "s23sfad1",
- "password": "secrete"
+ {
+ "auth": {
+ "identity": {
+ "custom-plugin": {
+ "custom-data": "sdfdfsfsfsdfsf"
+ },
+ "methods": [
+ "custom-plugin",
+ "password",
+ "token"
+ ],
+ "password": {
+ "user": {
+ "id": "s23sfad1",
+ "password": "secrete"
+ }
+ },
+ "token": {
+ "id": "sdfafasdfsfasfasdfds"
+ }
}
}
}
- }}
-
- plugins will be invoked in this order:
-
- 1. custom-plugin
- 2. password
- 3. token
:returns: None if authentication is successful.
Authentication payload in the form of a dictionary for the
diff --git a/keystone/auth/plugins/password.py b/keystone/auth/plugins/password.py
index f3cfeba8..66c6d05b 100644
--- a/keystone/auth/plugins/password.py
+++ b/keystone/auth/plugins/password.py
@@ -15,9 +15,9 @@
# under the License.
from keystone import auth
-from keystone.common import logging
from keystone import exception
from keystone import identity
+from keystone.openstack.common import log as logging
METHOD_NAME = 'password'
diff --git a/keystone/auth/plugins/token.py b/keystone/auth/plugins/token.py
index 720eccac..b82c0311 100644
--- a/keystone/auth/plugins/token.py
+++ b/keystone/auth/plugins/token.py
@@ -15,9 +15,9 @@
# under the License.
from keystone import auth
-from keystone.common import logging
from keystone.common import wsgi
from keystone import exception
+from keystone.openstack.common import log as logging
from keystone import token
diff --git a/keystone/catalog/backends/sql.py b/keystone/catalog/backends/sql.py
index 1dad5a80..d7b2123a 100644
--- a/keystone/catalog/backends/sql.py
+++ b/keystone/catalog/backends/sql.py
@@ -32,6 +32,7 @@ class Service(sql.ModelBase, sql.DictBase):
id = sql.Column(sql.String(64), primary_key=True)
type = sql.Column(sql.String(255))
extra = sql.Column(sql.JsonBlob())
+ endpoints = sql.relationship("Endpoint", backref="service")
class Endpoint(sql.ModelBase, sql.DictBase):
@@ -40,12 +41,12 @@ class Endpoint(sql.ModelBase, sql.DictBase):
'legacy_endpoint_id']
id = sql.Column(sql.String(64), primary_key=True)
legacy_endpoint_id = sql.Column(sql.String(64))
- interface = sql.Column(sql.String(8), primary_key=True)
- region = sql.Column('region', sql.String(255))
+ interface = sql.Column(sql.String(8), nullable=False)
+ region = sql.Column(sql.String(255))
service_id = sql.Column(sql.String(64),
sql.ForeignKey('service.id'),
nullable=False)
- url = sql.Column(sql.Text())
+ url = sql.Column(sql.Text(), nullable=False)
extra = sql.Column(sql.JsonBlob())
@@ -150,28 +151,26 @@ class Catalog(sql.Base, catalog.Driver):
d.update({'tenant_id': tenant_id,
'user_id': user_id})
+ session = self.get_session()
+ endpoints = (session.query(Endpoint).
+ options(sql.joinedload(Endpoint.service)).
+ all())
+
catalog = {}
- services = {}
- for endpoint in self.list_endpoints():
- # look up the service
- services.setdefault(
- endpoint['service_id'],
- self.get_service(endpoint['service_id']))
- service = services[endpoint['service_id']]
-
- # add the endpoint to the catalog if it's not already there
- catalog.setdefault(endpoint['region'], {})
- catalog[endpoint['region']].setdefault(
- service['type'], {
- 'id': endpoint['id'],
- 'name': service['name'],
- 'publicURL': '', # this may be overridden, but must exist
- })
-
- # add the interface's url
- url = core.format_url(endpoint.get('url'), d)
+
+ for endpoint in endpoints:
+ region = endpoint['region']
+ service_type = endpoint.service['type']
+ default_service = {
+ 'id': endpoint['id'],
+ 'name': endpoint.service['name'],
+ 'publicURL': ''
+ }
+ catalog.setdefault(region, {})
+ catalog[region].setdefault(service_type, default_service)
+ url = core.format_url(endpoint['url'], d)
interface_url = '%sURL' % endpoint['interface']
- catalog[endpoint['region']][service['type']][interface_url] = url
+ catalog[region][service_type][interface_url] = url
return catalog
@@ -180,27 +179,19 @@ class Catalog(sql.Base, catalog.Driver):
d.update({'tenant_id': tenant_id,
'user_id': user_id})
- services = {}
- for endpoint in self.list_endpoints():
- # look up the service
- service_id = endpoint['service_id']
- services.setdefault(
- service_id,
- self.get_service(service_id))
- service = services[service_id]
+ session = self.get_session()
+ services = (session.query(Service).
+ options(sql.joinedload(Service.endpoints)).
+ all())
+
+ def make_v3_endpoint(endpoint):
del endpoint['service_id']
endpoint['url'] = core.format_url(endpoint['url'], d)
- if 'endpoints' in services[service_id]:
- services[service_id]['endpoints'].append(endpoint)
- else:
- services[service_id]['endpoints'] = [endpoint]
-
- catalog = []
- for service_id, service in services.iteritems():
- formatted_service = {}
- formatted_service['id'] = service['id']
- formatted_service['type'] = service['type']
- formatted_service['endpoints'] = service['endpoints']
- catalog.append(formatted_service)
+ return endpoint
+
+ catalog = [{'endpoints': [make_v3_endpoint(ep.to_dict())
+ for ep in svc.endpoints],
+ 'id': svc.id,
+ 'type': svc.type} for svc in services]
return catalog
diff --git a/keystone/catalog/backends/templated.py b/keystone/catalog/backends/templated.py
index a96902d3..7fe73e91 100644
--- a/keystone/catalog/backends/templated.py
+++ b/keystone/catalog/backends/templated.py
@@ -18,8 +18,8 @@ import os.path
from keystone.catalog.backends import kvs
from keystone.catalog import core
-from keystone.common import logging
from keystone import config
+from keystone.openstack.common import log as logging
LOG = logging.getLogger(__name__)
diff --git a/keystone/catalog/core.py b/keystone/catalog/core.py
index b8a081ac..61b7e8ac 100644
--- a/keystone/catalog/core.py
+++ b/keystone/catalog/core.py
@@ -18,10 +18,10 @@
"""Main entry point into the Catalog service."""
from keystone.common import dependency
-from keystone.common import logging
from keystone.common import manager
from keystone import config
from keystone import exception
+from keystone.openstack.common import log as logging
CONF = config.CONF
diff --git a/keystone/clean.py b/keystone/clean.py
index c1d01ec8..7684210a 100644
--- a/keystone/clean.py
+++ b/keystone/clean.py
@@ -44,10 +44,11 @@ def check_enabled(property_name, enabled):
return bool(enabled)
-def check_name(property_name, name):
+def check_name(property_name, name, min_length=1, max_length=64):
check_type('%s name' % property_name, name, basestring, 'str or unicode')
name = name.strip()
- check_length('%s name' % property_name, name)
+ check_length('%s name' % property_name, name,
+ min_length=min_length, max_length=max_length)
return name
@@ -64,7 +65,7 @@ def project_enabled(enabled):
def user_name(name):
- return check_name('User', name)
+ return check_name('User', name, max_length=255)
def user_enabled(enabled):
diff --git a/keystone/cli.py b/keystone/cli.py
index 21d2ad40..18c095ce 100644
--- a/keystone/cli.py
+++ b/keystone/cli.py
@@ -20,12 +20,15 @@ import grp
import os
import pwd
+from migrate import exceptions
+
from oslo.config import cfg
import pbr.version
from keystone.common import openssl
from keystone.common.sql import migration
from keystone import config
+from keystone import contrib
from keystone.openstack.common import importutils
from keystone.openstack.common import jsonutils
from keystone import token
@@ -57,14 +60,35 @@ class DbSync(BaseApp):
'version. If not provided, db_sync will '
'migrate the database to the latest known '
'version.'))
+ parser.add_argument('--extension', default=None,
+ help=('Migrate the database for the specified '
+ 'extension. If not provided, db_sync will '
+ 'migrate the common repository.'))
+
return parser
@staticmethod
def main():
- for k in ['identity', 'catalog', 'policy', 'token', 'credential']:
- driver = importutils.import_object(getattr(CONF, k).driver)
- if hasattr(driver, 'db_sync'):
- driver.db_sync(CONF.command.version)
+ version = CONF.command.version
+ extension = CONF.command.extension
+ if not extension:
+ migration.db_sync(version=version)
+ else:
+ package_name = "%s.%s.migrate_repo" % (contrib.__name__, extension)
+ try:
+ package = importutils.import_module(package_name)
+ repo_path = os.path.abspath(os.path.dirname(package.__file__))
+ except ImportError:
+ print _("This extension does not provide migrations.")
+ exit(0)
+ try:
+ # Register the repo with the version control API
+ # If it already knows about the repo, it will throw
+ # an exception that we can safely ignore
+ migration.db_version_control(version=None, repo_path=repo_path)
+ except exceptions.DatabaseAlreadyControlledError:
+ pass
+ migration.db_sync(version=None, repo_path=repo_path)
class DbVersion(BaseApp):
@@ -72,9 +96,29 @@ class DbVersion(BaseApp):
name = 'db_version'
+ @classmethod
+ def add_argument_parser(cls, subparsers):
+ parser = super(DbVersion, cls).add_argument_parser(subparsers)
+ parser.add_argument('--extension', default=None,
+ help=('Migrate the database for the specified '
+ 'extension. If not provided, db_sync will '
+ 'migrate the common repository.'))
+
@staticmethod
def main():
- print(migration.db_version())
+ extension = CONF.command.extension
+ if extension:
+ try:
+ package_name = ("%s.%s.migrate_repo" %
+ (contrib.__name__, extension))
+ package = importutils.import_module(package_name)
+ repo_path = os.path.abspath(os.path.dirname(package.__file__))
+ print(migration.db_version(repo_path))
+ except ImportError:
+ print _("This extension does not provide migrations.")
+ exit(1)
+ else:
+ print(migration.db_version())
class BaseCertificateSetup(BaseApp):
diff --git a/keystone/common/cms.py b/keystone/common/cms.py
index 6ec740f8..09a98cdc 100644
--- a/keystone/common/cms.py
+++ b/keystone/common/cms.py
@@ -1,7 +1,7 @@
import hashlib
from keystone.common import environment
-from keystone.common import logging
+from keystone.openstack.common import log as logging
LOG = logging.getLogger(__name__)
diff --git a/keystone/common/config.py b/keystone/common/config.py
index 10c47a35..5a961d4a 100644
--- a/keystone/common/config.py
+++ b/keystone/common/config.py
@@ -14,110 +14,30 @@
# License for the specific language governing permissions and limitations
# under the License.
-import os
-import sys
-
from oslo.config import cfg
-from keystone.common import logging
+from keystone.openstack.common import log as logging
_DEFAULT_LOG_FORMAT = "%(asctime)s %(levelname)8s [%(name)s] %(message)s"
_DEFAULT_LOG_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
_DEFAULT_AUTH_METHODS = ['external', 'password', 'token']
-COMMON_CLI_OPTS = [
- cfg.BoolOpt('debug',
- short='d',
- default=False,
- help='Print debugging output (set logging level to '
- 'DEBUG instead of default WARNING level).'),
- cfg.BoolOpt('verbose',
- short='v',
- default=False,
- help='Print more verbose output (set logging level to '
- 'INFO instead of default WARNING level).'),
-]
-
-LOGGING_CLI_OPTS = [
- cfg.StrOpt('log-config',
- metavar='PATH',
- help='If this option is specified, the logging configuration '
- 'file specified is used and overrides any other logging '
- 'options specified. Please see the Python logging module '
- 'documentation for details on logging configuration '
- 'files.'),
- cfg.StrOpt('log-format',
- default=_DEFAULT_LOG_FORMAT,
- metavar='FORMAT',
- help='A logging.Formatter log message format string which may '
- 'use any of the available logging.LogRecord attributes.'),
- cfg.StrOpt('log-date-format',
- default=_DEFAULT_LOG_DATE_FORMAT,
- metavar='DATE_FORMAT',
- help='Format string for %%(asctime)s in log records.'),
- cfg.StrOpt('log-file',
- metavar='PATH',
- help='Name of log file to output. '
- 'If not set, logging will go to stdout.'),
- cfg.StrOpt('log-dir',
- help='The directory in which to store log files. '
- '(will be prepended to --log-file)'),
- cfg.BoolOpt('use-syslog',
- default=False,
- help='Use syslog for logging.'),
- cfg.StrOpt('syslog-log-facility',
- default='LOG_USER',
- help='syslog facility to receive log lines.')
-]
CONF = cfg.CONF
-def setup_logging(conf):
+def setup_logging(conf, product_name='keystone'):
"""Sets up the logging options for a log with supplied name
:param conf: a cfg.ConfOpts object
"""
-
- if conf.log_config:
- # Use a logging configuration file for all settings...
- if os.path.exists(conf.log_config):
- logging.config.fileConfig(conf.log_config)
- return
- else:
- raise RuntimeError(_('Unable to locate specified logging '
- 'config file: %s') % conf.log_config)
-
- root_logger = logging.root
- if conf.debug:
- root_logger.setLevel(logging.DEBUG)
- elif conf.verbose:
- root_logger.setLevel(logging.INFO)
- else:
- root_logger.setLevel(logging.WARNING)
-
- formatter = logging.Formatter(conf.log_format, conf.log_date_format)
-
- if conf.use_syslog:
- try:
- facility = getattr(logging.SysLogHandler,
- conf.syslog_log_facility)
- except AttributeError:
- raise ValueError(_('Invalid syslog facility'))
-
- handler = logging.SysLogHandler(address='/dev/log',
- facility=facility)
- elif conf.log_file:
- logfile = conf.log_file
- if conf.log_dir:
- logfile = os.path.join(conf.log_dir, logfile)
- handler = logging.WatchedFileHandler(logfile)
- else:
- handler = logging.StreamHandler(sys.stdout)
-
- handler.setFormatter(formatter)
- root_logger.addHandler(handler)
+ # NOTE(ldbragst): This method will be removed along with other
+ # refactoring in favor of using the
+ # keystone/openstack/common/log.py implementation. This just ensures
+ # that in the time between introduction and refactoring, we still have
+ # a working logging implementation.
+ logging.setup(product_name)
def setup_authentication():
@@ -176,9 +96,6 @@ def register_cli_int(*args, **kw):
def configure():
- CONF.register_cli_opts(COMMON_CLI_OPTS)
- CONF.register_cli_opts(LOGGING_CLI_OPTS)
-
register_cli_bool('standard-threads', default=False,
help='Do not monkey-patch threading system modules.')
@@ -210,6 +127,7 @@ def configure():
# identity
register_str('default_domain_id', group='identity', default='default')
+ register_int('max_password_length', group='identity', default=4096)
# trust
register_bool('enabled', group='trust', default=True)
diff --git a/keystone/common/controller.py b/keystone/common/controller.py
index affc34de..1bf65cda 100644
--- a/keystone/common/controller.py
+++ b/keystone/common/controller.py
@@ -3,11 +3,10 @@ import functools
import uuid
from keystone.common import dependency
-from keystone.common import logging
from keystone.common import wsgi
from keystone import config
from keystone import exception
-
+from keystone.openstack.common import log as logging
LOG = logging.getLogger(__name__)
CONF = config.CONF
@@ -169,6 +168,10 @@ class V2Controller(wsgi.Application):
self._delete_tokens_for_trust(trust['trustee_user_id'],
trust['id'])
+ def _delete_tokens_for_project(self, project_id):
+ for user_ref in self.identity_api.get_project_users(project_id):
+ self._delete_tokens_for_user(user_ref['id'], project_id=project_id)
+
def _require_attribute(self, ref, attr):
"""Ensures the reference contains the specified attribute."""
if ref.get(attr) is None or ref.get(attr) == '':
diff --git a/keystone/common/environment/__init__.py b/keystone/common/environment/__init__.py
index 2993536a..7ec82002 100644
--- a/keystone/common/environment/__init__.py
+++ b/keystone/common/environment/__init__.py
@@ -2,7 +2,7 @@ import functools
import os
from keystone.common import config
-from keystone.common import logging
+from keystone.openstack.common import log as logging
CONF = config.CONF
LOG = logging.getLogger(__name__)
diff --git a/keystone/common/environment/eventlet_server.py b/keystone/common/environment/eventlet_server.py
index fae0884e..874c4831 100644
--- a/keystone/common/environment/eventlet_server.py
+++ b/keystone/common/environment/eventlet_server.py
@@ -26,8 +26,7 @@ import eventlet
import eventlet.wsgi
import greenlet
-from keystone.common import logging
-from keystone.common import wsgi
+from keystone.openstack.common import log as logging
LOG = logging.getLogger(__name__)
@@ -108,7 +107,7 @@ class Server(object):
log = logging.getLogger('eventlet.wsgi.server')
try:
eventlet.wsgi.server(socket, application, custom_pool=self.pool,
- log=wsgi.WritableLogger(log))
+ log=logging.WritableLogger(log))
except Exception:
LOG.exception(_('Server error'))
raise
diff --git a/keystone/common/ldap/core.py b/keystone/common/ldap/core.py
index 7a2dfee7..48e4121f 100644
--- a/keystone/common/ldap/core.py
+++ b/keystone/common/ldap/core.py
@@ -20,9 +20,8 @@ import ldap
from ldap import filter as ldap_filter
from keystone.common.ldap import fakeldap
-from keystone.common import logging
from keystone import exception
-
+from keystone.openstack.common import log as logging
LOG = logging.getLogger(__name__)
@@ -114,7 +113,7 @@ class BaseLdap(object):
notfound_arg = None
options_name = None
model = None
- attribute_mapping = {}
+ attribute_options_names = {}
attribute_ignore = []
tree_dn = None
@@ -129,6 +128,7 @@ class BaseLdap(object):
self.tls_cacertfile = conf.ldap.tls_cacertfile
self.tls_cacertdir = conf.ldap.tls_cacertdir
self.tls_req_cert = parse_tls_cert(conf.ldap.tls_req_cert)
+ self.attribute_mapping = {}
if self.options_name is not None:
self.suffix = conf.ldap.suffix
@@ -145,6 +145,10 @@ class BaseLdap(object):
self.object_class = (getattr(conf.ldap, objclass)
or self.DEFAULT_OBJECTCLASS)
+ for k, v in self.attribute_options_names.iteritems():
+ v = '%s_%s_attribute' % (self.options_name, v)
+ self.attribute_mapping[k] = getattr(conf.ldap, v)
+
attr_mapping_opt = ('%s_additional_attribute_mapping' %
self.options_name)
attr_mapping = (getattr(conf.ldap, attr_mapping_opt)
@@ -167,6 +171,10 @@ class BaseLdap(object):
if self.notfound_arg is None:
self.notfound_arg = self.options_name + '_id'
+
+ attribute_ignore = '%s_attribute_ignore' % self.options_name
+ self.attribute_ignore = getattr(conf.ldap, attribute_ignore)
+
self.use_dumb_member = getattr(conf.ldap, 'use_dumb_member')
self.dumb_member = (getattr(conf.ldap, 'dumb_member') or
self.DUMB_MEMBER_DN)
@@ -500,7 +508,7 @@ class LdapWrapper(object):
def add_s(self, dn, attrs):
ldap_attrs = [(kind, [py2ldap(x) for x in safe_iter(values)])
for kind, values in attrs]
- if LOG.isEnabledFor(logging.DEBUG):
+ if LOG.isEnabledFor(LOG.debug):
sane_attrs = [(kind, values
if kind != 'userPassword'
else ['****'])
@@ -510,7 +518,7 @@ class LdapWrapper(object):
return self.conn.add_s(dn, ldap_attrs)
def search_s(self, dn, scope, query, attrlist=None):
- if LOG.isEnabledFor(logging.DEBUG):
+ if LOG.isEnabledFor(LOG.debug):
LOG.debug(_(
'LDAP search: dn=%(dn)s, scope=%(scope)s, query=%(query)s, '
'attrs=%(attrlist)s') % {
@@ -577,7 +585,7 @@ class LdapWrapper(object):
else [py2ldap(x) for x in safe_iter(values)]))
for op, kind, values in modlist]
- if LOG.isEnabledFor(logging.DEBUG):
+ if LOG.isEnabledFor(LOG.debug):
sane_modlist = [(op, kind, (values if kind != 'userPassword'
else ['****']))
for op, kind, values in ldap_modlist]
diff --git a/keystone/common/ldap/fakeldap.py b/keystone/common/ldap/fakeldap.py
index f6c95895..c19e1355 100644
--- a/keystone/common/ldap/fakeldap.py
+++ b/keystone/common/ldap/fakeldap.py
@@ -29,8 +29,8 @@ import shelve
import ldap
-from keystone.common import logging
from keystone.common import utils
+from keystone.openstack.common import log as logging
SCOPE_NAMES = {
@@ -41,8 +41,6 @@ SCOPE_NAMES = {
LOG = logging.getLogger(__name__)
-#Only enable a lower level than WARN if you are actively debugging
-LOG.level = logging.WARN
def _match_query(query, attrs):
diff --git a/keystone/common/openssl.py b/keystone/common/openssl.py
index fa09e37c..280815ae 100644
--- a/keystone/common/openssl.py
+++ b/keystone/common/openssl.py
@@ -19,9 +19,8 @@ import os
import stat
from keystone.common import environment
-from keystone.common import logging
from keystone import config
-
+from keystone.openstack.common import log as logging
LOG = logging.getLogger(__name__)
CONF = config.CONF
@@ -51,6 +50,7 @@ class BaseCertificateConfigure(object):
self.request_file_name = os.path.join(self.conf_dir, "req.pem")
self.ssl_dictionary = {'conf_dir': self.conf_dir,
'ca_cert': conf_obj.ca_certs,
+ 'default_md': 'default',
'ssl_config': self.ssl_config_file_name,
'ca_private_key': conf_obj.ca_key,
'request_file': self.request_file_name,
@@ -60,6 +60,17 @@ class BaseCertificateConfigure(object):
'valid_days': int(conf_obj.valid_days),
'cert_subject': conf_obj.cert_subject,
'ca_password': conf_obj.ca_password}
+
+ try:
+ # OpenSSL 1.0 and newer support default_md = default, olders do not
+ openssl_ver = environment.subprocess.Popen(
+ ['openssl', 'version'],
+ stdout=environment.subprocess.PIPE).stdout.read()
+ if "OpenSSL 0." in openssl_ver:
+ self.ssl_dictionary['default_md'] = 'sha1'
+ except OSError:
+ LOG.warn('Failed to invoke ``openssl version``, '
+ 'assuming is v1.0 or newer')
self.ssl_dictionary.update(kwargs)
def _make_dirs(self, file_name):
@@ -198,7 +209,7 @@ new_certs_dir = $dir
serial = $dir/serial
database = $dir/index.txt
default_days = 365
-default_md = default # use public key default MD
+default_md = %(default_md)s
preserve = no
email_in_dn = no
nameopt = default_ca
@@ -218,35 +229,35 @@ emailAddress = optional
[ req ]
default_bits = 2048 # Size of keys
default_keyfile = key.pem # name of generated keys
-default_md = default # message digest algorithm
-string_mask = nombstr # permitted characters
+string_mask = utf8only # permitted characters
distinguished_name = req_distinguished_name
req_extensions = v3_req
+x509_extensions = v3_ca
[ req_distinguished_name ]
-0.organizationName = Organization Name (company)
-organizationalUnitName = Organizational Unit Name (department, division)
-emailAddress = Email Address
-emailAddress_max = 40
-localityName = Locality Name (city, district)
-stateOrProvinceName = State or Province Name (full name)
countryName = Country Name (2 letter code)
countryName_min = 2
countryName_max = 2
+stateOrProvinceName = State or Province Name (full name)
+localityName = Locality Name (city, district)
+0.organizationName = Organization Name (company)
+organizationalUnitName = Organizational Unit Name (department, division)
commonName = Common Name (hostname, IP, or your name)
commonName_max = 64
+emailAddress = Email Address
+emailAddress_max = 64
[ v3_ca ]
basicConstraints = CA:TRUE
subjectKeyIdentifier = hash
-authorityKeyIdentifier = keyid:always,issuer:always
+authorityKeyIdentifier = keyid:always,issuer
[ v3_req ]
basicConstraints = CA:FALSE
-subjectKeyIdentifier = hash
+keyUsage = nonRepudiation, digitalSignature, keyEncipherment
[ usr_cert ]
basicConstraints = CA:FALSE
subjectKeyIdentifier = hash
-authorityKeyIdentifier = keyid:always,issuer:always
+authorityKeyIdentifier = keyid:always
"""
diff --git a/keystone/common/sql/core.py b/keystone/common/sql/core.py
index 2d3114f2..fdb45c74 100644
--- a/keystone/common/sql/core.py
+++ b/keystone/common/sql/core.py
@@ -26,10 +26,10 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute
import sqlalchemy.pool
from sqlalchemy import types as sql_types
-from keystone.common import logging
from keystone import config
from keystone import exception
from keystone.openstack.common import jsonutils
+from keystone.openstack.common import log as logging
LOG = logging.getLogger(__name__)
@@ -45,6 +45,7 @@ ModelBase = declarative.declarative_base()
# For exporting to other modules
Column = sql.Column
+Index = sql.Index
String = sql.String
ForeignKey = sql.ForeignKey
DateTime = sql.DateTime
@@ -54,6 +55,8 @@ NotFound = sql.orm.exc.NoResultFound
Boolean = sql.Boolean
Text = sql.Text
UniqueConstraint = sql.UniqueConstraint
+relationship = sql.orm.relationship
+joinedload = sql.orm.joinedload
def initialize_decorator(init):
@@ -179,6 +182,8 @@ class DictBase(object):
setattr(self, key, value)
def __getitem__(self, key):
+ if key in self.extra:
+ return self.extra[key]
return getattr(self, key)
def get(self, key, default=None):
diff --git a/keystone/common/sql/legacy.py b/keystone/common/sql/legacy.py
index c8adc900..d88e5a46 100644
--- a/keystone/common/sql/legacy.py
+++ b/keystone/common/sql/legacy.py
@@ -21,10 +21,10 @@ from sqlalchemy import exc
from keystone.assignment.backends import sql as assignment_sql
-from keystone.common import logging
from keystone import config
from keystone.contrib.ec2.backends import sql as ec2_sql
from keystone.identity.backends import sql as identity_sql
+from keystone.openstack.common import log as logging
LOG = logging.getLogger(__name__)
diff --git a/keystone/common/sql/migrate_repo/versions/031_drop_credential_indexes.py b/keystone/common/sql/migrate_repo/versions/031_drop_credential_indexes.py
new file mode 100644
index 00000000..89ca04f0
--- /dev/null
+++ b/keystone/common/sql/migrate_repo/versions/031_drop_credential_indexes.py
@@ -0,0 +1,40 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2013 OpenStack Foundation
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import sqlalchemy
+
+
+def upgrade(migrate_engine):
+ #This migration is relevant only for mysql because for all other
+ #migrate engines these indexes were successfully dropped.
+ if migrate_engine.name != 'mysql':
+ return
+ meta = sqlalchemy.MetaData(bind=migrate_engine)
+ table = sqlalchemy.Table('credential', meta, autoload=True)
+ for index in table.indexes:
+ index.drop()
+
+
+def downgrade(migrate_engine):
+ if migrate_engine.name != 'mysql':
+ return
+ meta = sqlalchemy.MetaData(bind=migrate_engine)
+ table = sqlalchemy.Table('credential', meta, autoload=True)
+ index = sqlalchemy.Index('user_id', table.c['user_id'])
+ index.create()
+ index = sqlalchemy.Index('credential_project_id_fkey',
+ table.c['project_id'])
+ index.create()
diff --git a/keystone/common/sql/migrate_repo/versions/032_username_length.py b/keystone/common/sql/migrate_repo/versions/032_username_length.py
new file mode 100644
index 00000000..636ebd75
--- /dev/null
+++ b/keystone/common/sql/migrate_repo/versions/032_username_length.py
@@ -0,0 +1,31 @@
+import sqlalchemy as sql
+from sqlalchemy.orm import sessionmaker
+
+
+def upgrade(migrate_engine):
+ meta = sql.MetaData()
+ meta.bind = migrate_engine
+ user_table = sql.Table('user', meta, autoload=True)
+ user_table.c.name.alter(type=sql.String(255))
+
+
+def downgrade(migrate_engine):
+ meta = sql.MetaData()
+ meta.bind = migrate_engine
+ user_table = sql.Table('user', meta, autoload=True)
+ if migrate_engine.name != 'mysql':
+ # NOTE(aloga): sqlite does not enforce length on the
+ # VARCHAR types: http://www.sqlite.org/faq.html#q9
+ # postgresql and DB2 do not truncate.
+ maker = sessionmaker(bind=migrate_engine)
+ session = maker()
+ for user in session.query(user_table).all():
+ values = {'name': user.name[:64]}
+ update = (user_table.update().
+ where(user_table.c.id == user.id).
+ values(values))
+ migrate_engine.execute(update)
+
+ session.commit()
+ session.close()
+ user_table.c.name.alter(type=sql.String(64))
diff --git a/keystone/common/sql/migration.py b/keystone/common/sql/migration.py
index 86e0254c..3cb9cd63 100644
--- a/keystone/common/sql/migration.py
+++ b/keystone/common/sql/migration.py
@@ -39,39 +39,51 @@ except ImportError:
sys.exit('python-migrate is not installed. Exiting.')
-def db_sync(version=None):
+def migrate_repository(version, current_version, repo_path):
+ if version is None or version > current_version:
+ result = versioning_api.upgrade(CONF.sql.connection,
+ repo_path, version)
+ else:
+ result = versioning_api.downgrade(
+ CONF.sql.connection, repo_path, version)
+ return result
+
+
+def db_sync(version=None, repo_path=None):
if version is not None:
try:
version = int(version)
except ValueError:
raise Exception(_('version should be an integer'))
+ if repo_path is None:
+ repo_path = find_migrate_repo()
+ current_version = db_version(repo_path=repo_path)
+ return migrate_repository(version, current_version, repo_path)
- current_version = db_version()
- repo_path = _find_migrate_repo()
- if version is None or version > current_version:
- return versioning_api.upgrade(CONF.sql.connection, repo_path, version)
- else:
- return versioning_api.downgrade(
- CONF.sql.connection, repo_path, version)
-
-def db_version():
- repo_path = _find_migrate_repo()
+def db_version(repo_path=None):
+ if repo_path is None:
+ repo_path = find_migrate_repo()
try:
return versioning_api.db_version(CONF.sql.connection, repo_path)
except versioning_exceptions.DatabaseNotControlledError:
return db_version_control(0)
-def db_version_control(version=None):
- repo_path = _find_migrate_repo()
+def db_version_control(version=None, repo_path=None):
+ if repo_path is None:
+ repo_path = find_migrate_repo()
versioning_api.version_control(CONF.sql.connection, repo_path, version)
return version
-def _find_migrate_repo():
+def find_migrate_repo(package=None):
"""Get the path for the migrate repository."""
- path = os.path.join(os.path.abspath(os.path.dirname(__file__)),
+ if package is None:
+ file = __file__
+ else:
+ file = package.__file__
+ path = os.path.join(os.path.abspath(os.path.dirname(file)),
'migrate_repo')
assert os.path.exists(path)
return path
diff --git a/keystone/common/sql/nova.py b/keystone/common/sql/nova.py
index fd8d2481..c7abfb81 100644
--- a/keystone/common/sql/nova.py
+++ b/keystone/common/sql/nova.py
@@ -19,10 +19,10 @@
import uuid
from keystone import assignment
-from keystone.common import logging
from keystone import config
from keystone.contrib.ec2.backends import sql as ec2_sql
from keystone import identity
+from keystone.openstack.common import log as logging
LOG = logging.getLogger(__name__)
diff --git a/keystone/common/utils.py b/keystone/common/utils.py
index fd2d7567..4abad57a 100644
--- a/keystone/common/utils.py
+++ b/keystone/common/utils.py
@@ -27,8 +27,8 @@ import passlib.hash
from keystone.common import config
from keystone.common import environment
-from keystone.common import logging
from keystone import exception
+from keystone.openstack.common import log as logging
CONF = config.CONF
@@ -36,8 +36,6 @@ config.register_int('crypt_strength', default=40000)
LOG = logging.getLogger(__name__)
-MAX_PASSWORD_LENGTH = 4096
-
def read_cached_file(filename, cache_info, reload_func=None):
"""Read from a file if it has been modified.
@@ -68,12 +66,13 @@ class SmarterEncoder(json.JSONEncoder):
def trunc_password(password):
- """Truncate passwords to the MAX_PASSWORD_LENGTH."""
+ """Truncate passwords to the max_length."""
+ max_length = CONF.identity.max_password_length
try:
- if len(password) > MAX_PASSWORD_LENGTH:
- return password[:MAX_PASSWORD_LENGTH]
- else:
- return password
+ if len(password) > max_length:
+ LOG.warning(
+ _('Truncating user password to %s characters.') % max_length)
+ return password[:max_length]
except TypeError:
raise exception.ValidationError(attribute='string', target='password')
diff --git a/keystone/common/wsgi.py b/keystone/common/wsgi.py
index f47cde13..d515fde6 100644
--- a/keystone/common/wsgi.py
+++ b/keystone/common/wsgi.py
@@ -27,11 +27,11 @@ import webob.dec
import webob.exc
from keystone.common import config
-from keystone.common import logging
from keystone.common import utils
from keystone import exception
from keystone.openstack.common import importutils
from keystone.openstack.common import jsonutils
+from keystone.openstack.common import log as logging
CONF = config.CONF
@@ -122,17 +122,6 @@ def validate_token_bind(context, token_ref):
raise exception.Unauthorized()
-class WritableLogger(object):
- """A thin wrapper that responds to `write` and logs."""
-
- def __init__(self, logger, level=logging.DEBUG):
- self.logger = logger
- self.level = level
-
- def write(self, msg):
- self.logger.log(self.level, msg)
-
-
class Request(webob.Request):
pass
@@ -394,7 +383,7 @@ class Debug(Middleware):
@webob.dec.wsgify(RequestClass=Request)
def __call__(self, req):
- if LOG.isEnabledFor(logging.DEBUG):
+ if LOG.isEnabledFor(LOG.debug):
LOG.debug('%s %s %s', ('*' * 20), 'REQUEST ENVIRON', ('*' * 20))
for key, value in req.environ.items():
LOG.debug('%s = %s', key, mask_password(value,
@@ -406,7 +395,7 @@ class Debug(Middleware):
LOG.debug('')
resp = req.get_response(self.application)
- if LOG.isEnabledFor(logging.DEBUG):
+ if LOG.isEnabledFor(LOG.debug):
LOG.debug('%s %s %s', ('*' * 20), 'RESPONSE HEADERS', ('*' * 20))
for (key, value) in resp.headers.iteritems():
LOG.debug('%s = %s', key, value)
@@ -455,7 +444,7 @@ class Router(object):
# if we're only running in debug, bump routes' internal logging up a
# notch, as it's very spammy
if CONF.debug:
- logging.getLogger('routes.middleware').setLevel(logging.INFO)
+ logging.getLogger('routes.middleware')
self.map = mapper
self._router = routes.middleware.RoutesMiddleware(self._dispatch,
diff --git a/keystone/contrib/access/core.py b/keystone/contrib/access/core.py
index f0221200..fbe09a24 100644
--- a/keystone/contrib/access/core.py
+++ b/keystone/contrib/access/core.py
@@ -14,12 +14,11 @@
# License for the specific language governing permissions and limitations
# under the License.
-import webob
import webob.dec
-from keystone.common import logging
from keystone.common import wsgi
from keystone import config
+from keystone.openstack.common import log as logging
from keystone.openstack.common import timeutils
diff --git a/tests/tmp/.gitkeep b/keystone/contrib/example/__init__.py
index e69de29b..e69de29b 100644
--- a/tests/tmp/.gitkeep
+++ b/keystone/contrib/example/__init__.py
diff --git a/keystone/contrib/example/migrate_repo/__init__.py b/keystone/contrib/example/migrate_repo/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/keystone/contrib/example/migrate_repo/__init__.py
diff --git a/keystone/contrib/example/migrate_repo/migrate.cfg b/keystone/contrib/example/migrate_repo/migrate.cfg
new file mode 100644
index 00000000..5b1b1c0a
--- /dev/null
+++ b/keystone/contrib/example/migrate_repo/migrate.cfg
@@ -0,0 +1,25 @@
+[db_settings]
+# Used to identify which repository this database is versioned under.
+# You can use the name of your project.
+repository_id=example
+
+# The name of the database table used to track the schema version.
+# This name shouldn't already be used by your project.
+# If this is changed once a database is under version control, you'll need to
+# change the table name in each database too.
+version_table=migrate_version
+
+# When committing a change script, Migrate will attempt to generate the
+# sql for all supported databases; normally, if one of them fails - probably
+# because you don't have that database installed - it is ignored and the
+# commit continues, perhaps ending successfully.
+# Databases in this list MUST compile successfully during a commit, or the
+# entire commit will fail. List the databases your application will actually
+# be using to ensure your updates to that database work properly.
+# This must be a list; example: ['postgres','sqlite']
+required_dbs=[]
+
+# When creating new change scripts, Migrate will stamp the new script with
+# a version number. By default this is latest_version + 1. You can set this
+# to 'true' to tell Migrate to use the UTC timestamp instead.
+use_timestamp_numbering=False
diff --git a/keystone/contrib/example/migrate_repo/versions/001_example_table.py b/keystone/contrib/example/migrate_repo/versions/001_example_table.py
new file mode 100644
index 00000000..bb2203d3
--- /dev/null
+++ b/keystone/contrib/example/migrate_repo/versions/001_example_table.py
@@ -0,0 +1,45 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012 OpenStack LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import sqlalchemy as sql
+
+
+def upgrade(migrate_engine):
+ # Upgrade operations go here. Don't create your own engine; bind
+ # migrate_engine to your metadata
+ meta = sql.MetaData()
+ meta.bind = migrate_engine
+
+ # catalog
+
+ service_table = sql.Table(
+ 'example',
+ meta,
+ sql.Column('id', sql.String(64), primary_key=True),
+ sql.Column('type', sql.String(255)),
+ sql.Column('extra', sql.Text()))
+ service_table.create(migrate_engine, checkfirst=True)
+
+
+def downgrade(migrate_engine):
+ # Operations to reverse the above upgrade go here.
+ meta = sql.MetaData()
+ meta.bind = migrate_engine
+
+ tables = ['example']
+ for t in tables:
+ table = sql.Table(t, meta, autoload=True)
+ table.drop(migrate_engine, checkfirst=True)
diff --git a/keystone/contrib/example/migrate_repo/versions/__init__.py b/keystone/contrib/example/migrate_repo/versions/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/keystone/contrib/example/migrate_repo/versions/__init__.py
diff --git a/keystone/contrib/stats/core.py b/keystone/contrib/stats/core.py
index 1d7b2cdf..9e6538db 100644
--- a/keystone/contrib/stats/core.py
+++ b/keystone/contrib/stats/core.py
@@ -15,12 +15,12 @@
# under the License.
from keystone.common import extension
-from keystone.common import logging
from keystone.common import manager
from keystone.common import wsgi
from keystone import config
from keystone import exception
from keystone import identity
+from keystone.openstack.common import log as logging
from keystone import policy
from keystone import token
diff --git a/keystone/contrib/user_crud/core.py b/keystone/contrib/user_crud/core.py
index f9f09b89..2129af40 100644
--- a/keystone/contrib/user_crud/core.py
+++ b/keystone/contrib/user_crud/core.py
@@ -18,10 +18,10 @@ import copy
import uuid
from keystone.common import extension
-from keystone.common import logging
from keystone.common import wsgi
from keystone import exception
from keystone import identity
+from keystone.openstack.common import log as logging
LOG = logging.getLogger(__name__)
@@ -82,7 +82,7 @@ class UserController(identity.controllers.User):
new_token_ref = copy.copy(token_ref)
new_token_ref['id'] = token_id
self.token_api.create_token(token_id, new_token_ref)
- logging.debug('TOKEN_REF %s', new_token_ref)
+ LOG.debug('TOKEN_REF %s', new_token_ref)
return {'access': {'token': new_token_ref}}
diff --git a/keystone/controllers.py b/keystone/controllers.py
index 8ffa073a..be3c57fa 100644
--- a/keystone/controllers.py
+++ b/keystone/controllers.py
@@ -15,10 +15,10 @@
# under the License.
from keystone.common import extension
-from keystone.common import logging
from keystone.common import wsgi
from keystone import config
from keystone import exception
+from keystone.openstack.common import log as logging
LOG = logging.getLogger(__name__)
diff --git a/keystone/credential/core.py b/keystone/credential/core.py
index a8921ba0..97cfc1c1 100644
--- a/keystone/credential/core.py
+++ b/keystone/credential/core.py
@@ -17,10 +17,10 @@
"""Main entry point into the Credentials service."""
from keystone.common import dependency
-from keystone.common import logging
from keystone.common import manager
from keystone import config
from keystone import exception
+from keystone.openstack.common import log as logging
CONF = config.CONF
diff --git a/keystone/exception.py b/keystone/exception.py
index 5e1defba..c0edc263 100644
--- a/keystone/exception.py
+++ b/keystone/exception.py
@@ -15,8 +15,8 @@
# under the License.
from keystone.common import config
-from keystone.common import logging
from keystone.openstack.common.gettextutils import _ # noqa
+from keystone.openstack.common import log as logging
CONF = config.CONF
diff --git a/keystone/identity/backends/ldap.py b/keystone/identity/backends/ldap.py
index 91ea1e41..ef3b5d61 100644
--- a/keystone/identity/backends/ldap.py
+++ b/keystone/identity/backends/ldap.py
@@ -21,12 +21,12 @@ import ldap
from keystone import clean
from keystone.common import dependency
from keystone.common import ldap as common_ldap
-from keystone.common import logging
from keystone.common import models
from keystone.common import utils
from keystone import config
from keystone import exception
from keystone import identity
+from keystone.openstack.common import log as logging
CONF = config.CONF
@@ -210,29 +210,20 @@ class UserApi(common_ldap.EnabledEmuMixIn, common_ldap.BaseLdap):
DEFAULT_STRUCTURAL_CLASSES = ['person']
DEFAULT_ID_ATTR = 'cn'
DEFAULT_OBJECTCLASS = 'inetOrgPerson'
- DEFAULT_ATTRIBUTE_IGNORE = ['tenant_id', 'tenants']
NotFound = exception.UserNotFound
options_name = 'user'
- attribute_mapping = {'password': 'userPassword',
- 'email': 'mail',
- 'name': 'sn',
- 'enabled': 'enabled',
- 'domain_id': 'domain_id'}
+ attribute_options_names = {'password': 'pass',
+ 'email': 'mail',
+ 'name': 'name',
+ 'enabled': 'enabled',
+ 'domain_id': 'domain_id'}
model = models.User
def __init__(self, conf):
super(UserApi, self).__init__(conf)
- self.attribute_mapping['name'] = conf.ldap.user_name_attribute
- self.attribute_mapping['email'] = conf.ldap.user_mail_attribute
- self.attribute_mapping['password'] = conf.ldap.user_pass_attribute
- self.attribute_mapping['enabled'] = conf.ldap.user_enabled_attribute
- self.attribute_mapping['domain_id'] = (
- conf.ldap.user_domain_id_attribute)
self.enabled_mask = conf.ldap.user_enabled_mask
self.enabled_default = conf.ldap.user_enabled_default
- self.attribute_ignore = (getattr(conf.ldap, 'user_attribute_ignore')
- or self.DEFAULT_ATTRIBUTE_IGNORE)
def _ldap_res_to_model(self, res):
obj = super(UserApi, self)._ldap_res_to_model(res)
@@ -277,25 +268,17 @@ class GroupApi(common_ldap.BaseLdap):
DEFAULT_OBJECTCLASS = 'groupOfNames'
DEFAULT_ID_ATTR = 'cn'
DEFAULT_MEMBER_ATTRIBUTE = 'member'
- DEFAULT_ATTRIBUTE_IGNORE = []
NotFound = exception.GroupNotFound
options_name = 'group'
- attribute_mapping = {'name': 'ou',
- 'description': 'description',
- 'groupId': 'cn',
- 'domain_id': 'domain_id'}
+ attribute_options_names = {'description': 'desc',
+ 'name': 'name',
+ 'domain_id': 'domain_id'}
model = models.Group
def __init__(self, conf):
super(GroupApi, self).__init__(conf)
- self.attribute_mapping['name'] = conf.ldap.group_name_attribute
- self.attribute_mapping['description'] = conf.ldap.group_desc_attribute
- self.attribute_mapping['domain_id'] = (
- conf.ldap.group_domain_id_attribute)
self.member_attribute = (getattr(conf.ldap, 'group_member_attribute')
or self.DEFAULT_MEMBER_ATTRIBUTE)
- self.attribute_ignore = (getattr(conf.ldap, 'group_attribute_ignore')
- or self.DEFAULT_ATTRIBUTE_IGNORE)
def create(self, values):
self.affirm_unique(values)
diff --git a/keystone/identity/backends/sql.py b/keystone/identity/backends/sql.py
index bff41106..65a34a8a 100644
--- a/keystone/identity/backends/sql.py
+++ b/keystone/identity/backends/sql.py
@@ -26,7 +26,7 @@ class User(sql.ModelBase, sql.DictBase):
__tablename__ = 'user'
attributes = ['id', 'name', 'domain_id', 'password', 'enabled']
id = sql.Column(sql.String(64), primary_key=True)
- name = sql.Column(sql.String(64), nullable=False)
+ name = sql.Column(sql.String(255), nullable=False)
domain_id = sql.Column(sql.String(64), sql.ForeignKey('domain.id'),
nullable=False)
password = sql.Column(sql.String(128))
diff --git a/keystone/identity/controllers.py b/keystone/identity/controllers.py
index 7ca1f8bf..67f3beac 100644
--- a/keystone/identity/controllers.py
+++ b/keystone/identity/controllers.py
@@ -22,10 +22,9 @@ import urlparse
import uuid
from keystone.common import controller
-from keystone.common import logging
from keystone import config
from keystone import exception
-
+from keystone.openstack.common import log as logging
CONF = config.CONF
DEFAULT_DOMAIN_ID = CONF.identity.default_domain_id
@@ -109,12 +108,20 @@ class Tenant(controller.V2Controller):
# be specifying that
clean_tenant = tenant.copy()
clean_tenant.pop('domain_id', None)
+
+ # If the project has been disabled (or enabled=False) we are
+ # deleting the tokens for that project.
+ if not tenant.get('enabled', True):
+ self._delete_tokens_for_project(tenant_id)
+
tenant_ref = self.identity_api.update_project(
tenant_id, clean_tenant)
return {'tenant': tenant_ref}
def delete_project(self, context, tenant_id):
self.assert_admin(context)
+ # Delete all tokens belonging to the users for that project
+ self._delete_tokens_for_project(tenant_id)
self.identity_api.delete_project(tenant_id)
def get_project_users(self, context, tenant_id, **kw):
@@ -572,6 +579,10 @@ class ProjectV3(controller.V3Controller):
def update_project(self, context, project_id, project):
self._require_matching_id(project_id, project)
+ # The project was disabled so we delete the tokens
+ if not project.get('enabled', True):
+ self._delete_tokens_for_project(project_id)
+
ref = self.identity_api.update_project(project_id, project)
return ProjectV3.wrap_member(context, ref)
@@ -580,6 +591,10 @@ class ProjectV3(controller.V3Controller):
for cred in self.credential_api.list_credentials():
if cred['project_id'] == project_id:
self.credential_api.delete_credential(cred['id'])
+
+ # Delete all tokens belonging to the users for that project
+ self._delete_tokens_for_project(project_id)
+
# Finally delete the project itself - the backend is
# responsible for deleting any role assignments related
# to this project
diff --git a/keystone/identity/core.py b/keystone/identity/core.py
index b2b3eaf0..7fb630e2 100644
--- a/keystone/identity/core.py
+++ b/keystone/identity/core.py
@@ -18,10 +18,10 @@
from keystone import clean
from keystone.common import dependency
-from keystone.common import logging
from keystone.common import manager
from keystone import config
from keystone import exception
+from keystone.openstack.common import log as logging
CONF = config.CONF
diff --git a/keystone/middleware/core.py b/keystone/middleware/core.py
index 863ef948..92b179c3 100644
--- a/keystone/middleware/core.py
+++ b/keystone/middleware/core.py
@@ -17,13 +17,12 @@
import webob.dec
from keystone.common import config
-from keystone.common import logging
from keystone.common import serializer
from keystone.common import utils
from keystone.common import wsgi
from keystone import exception
from keystone.openstack.common import jsonutils
-
+from keystone.openstack.common import log as logging
CONF = config.CONF
LOG = logging.getLogger(__name__)
diff --git a/keystone/middleware/s3_token.py b/keystone/middleware/s3_token.py
index b346893b..39678591 100644
--- a/keystone/middleware/s3_token.py
+++ b/keystone/middleware/s3_token.py
@@ -37,8 +37,8 @@ import httplib
import urllib
import webob
-from keystone.common import logging
from keystone.openstack.common import jsonutils
+from keystone.openstack.common import log as logging
PROTOCOL_NAME = 'S3 Token Authentication'
diff --git a/keystone/openstack/common/context.py b/keystone/openstack/common/context.py
new file mode 100644
index 00000000..643e62b4
--- /dev/null
+++ b/keystone/openstack/common/context.py
@@ -0,0 +1,83 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 OpenStack Foundation.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Simple class that stores security context information in the web request.
+
+Projects should subclass this class if they wish to enhance the request
+context or provide additional information in their specific WSGI pipeline.
+"""
+
+import itertools
+
+from keystone.openstack.common import uuidutils
+
+
+def generate_request_id():
+ return 'req-%s' % uuidutils.generate_uuid()
+
+
+class RequestContext(object):
+
+ """Helper class to represent useful information about a request context.
+
+ Stores information about the security context under which the user
+ accesses the system, as well as additional request information.
+ """
+
+ def __init__(self, auth_token=None, user=None, tenant=None, is_admin=False,
+ read_only=False, show_deleted=False, request_id=None):
+ self.auth_token = auth_token
+ self.user = user
+ self.tenant = tenant
+ self.is_admin = is_admin
+ self.read_only = read_only
+ self.show_deleted = show_deleted
+ if not request_id:
+ request_id = generate_request_id()
+ self.request_id = request_id
+
+ def to_dict(self):
+ return {'user': self.user,
+ 'tenant': self.tenant,
+ 'is_admin': self.is_admin,
+ 'read_only': self.read_only,
+ 'show_deleted': self.show_deleted,
+ 'auth_token': self.auth_token,
+ 'request_id': self.request_id}
+
+
+def get_admin_context(show_deleted=False):
+ context = RequestContext(None,
+ tenant=None,
+ is_admin=True,
+ show_deleted=show_deleted)
+ return context
+
+
+def get_context_from_function_and_args(function, args, kwargs):
+ """Find an arg of type RequestContext and return it.
+
+ This is useful in a couple of decorators where we don't
+ know much about the function we're wrapping.
+ """
+
+ for arg in itertools.chain(kwargs.values(), args):
+ if isinstance(arg, RequestContext):
+ return arg
+
+ return None
diff --git a/keystone/openstack/common/crypto/utils.py b/keystone/openstack/common/crypto/utils.py
index ef178cab..717989d4 100644
--- a/keystone/openstack/common/crypto/utils.py
+++ b/keystone/openstack/common/crypto/utils.py
@@ -19,8 +19,8 @@ import base64
from Crypto.Hash import HMAC
from Crypto import Random
-from keystone.openstack.common.gettextutils import _
-from keystone.openstack.common.importutils import import_module
+from keystone.openstack.common.gettextutils import _ # noqa
+from keystone.openstack.common import importutils
class CryptoutilsException(Exception):
@@ -54,7 +54,7 @@ class HKDF(object):
"""
def __init__(self, hashtype='SHA256'):
- self.hashfn = import_module('Crypto.Hash.' + hashtype)
+ self.hashfn = importutils.import_module('Crypto.Hash.' + hashtype)
self.max_okm_length = 255 * self.hashfn.digest_size
def extract(self, ikm, salt=None):
@@ -107,8 +107,8 @@ class SymmetricCrypto(object):
"""
def __init__(self, enctype='AES', hashtype='SHA256'):
- self.cipher = import_module('Crypto.Cipher.' + enctype)
- self.hashfn = import_module('Crypto.Hash.' + hashtype)
+ self.cipher = importutils.import_module('Crypto.Cipher.' + enctype)
+ self.hashfn = importutils.import_module('Crypto.Hash.' + hashtype)
def new_key(self, size):
return Random.new().read(size)
diff --git a/keystone/openstack/common/eventlet_backdoor.py b/keystone/openstack/common/eventlet_backdoor.py
new file mode 100644
index 00000000..c4d18ddb
--- /dev/null
+++ b/keystone/openstack/common/eventlet_backdoor.py
@@ -0,0 +1,146 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright (c) 2012 OpenStack Foundation.
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from __future__ import print_function
+
+import errno
+import gc
+import os
+import pprint
+import socket
+import sys
+import traceback
+
+import eventlet
+import eventlet.backdoor
+import greenlet
+from oslo.config import cfg
+
+from keystone.openstack.common.gettextutils import _ # noqa
+from keystone.openstack.common import log as logging
+
+help_for_backdoor_port = (
+ "Acceptable values are 0, <port>, and <start>:<end>, where 0 results "
+ "in listening on a random tcp port number; <port> results in listening "
+ "on the specified port number (and not enabling backdoor if that port "
+ "is in use); and <start>:<end> results in listening on the smallest "
+ "unused port number within the specified range of port numbers. The "
+ "chosen port is displayed in the service's log file.")
+eventlet_backdoor_opts = [
+ cfg.StrOpt('backdoor_port',
+ default=None,
+ help="Enable eventlet backdoor. %s" % help_for_backdoor_port)
+]
+
+CONF = cfg.CONF
+CONF.register_opts(eventlet_backdoor_opts)
+LOG = logging.getLogger(__name__)
+
+
+class EventletBackdoorConfigValueError(Exception):
+ def __init__(self, port_range, help_msg, ex):
+ msg = ('Invalid backdoor_port configuration %(range)s: %(ex)s. '
+ '%(help)s' %
+ {'range': port_range, 'ex': ex, 'help': help_msg})
+ super(EventletBackdoorConfigValueError, self).__init__(msg)
+ self.port_range = port_range
+
+
+def _dont_use_this():
+ print("Don't use this, just disconnect instead")
+
+
+def _find_objects(t):
+ return filter(lambda o: isinstance(o, t), gc.get_objects())
+
+
+def _print_greenthreads():
+ for i, gt in enumerate(_find_objects(greenlet.greenlet)):
+ print(i, gt)
+ traceback.print_stack(gt.gr_frame)
+ print()
+
+
+def _print_nativethreads():
+ for threadId, stack in sys._current_frames().items():
+ print(threadId)
+ traceback.print_stack(stack)
+ print()
+
+
+def _parse_port_range(port_range):
+ if ':' not in port_range:
+ start, end = port_range, port_range
+ else:
+ start, end = port_range.split(':', 1)
+ try:
+ start, end = int(start), int(end)
+ if end < start:
+ raise ValueError
+ return start, end
+ except ValueError as ex:
+ raise EventletBackdoorConfigValueError(port_range, ex,
+ help_for_backdoor_port)
+
+
+def _listen(host, start_port, end_port, listen_func):
+ try_port = start_port
+ while True:
+ try:
+ return listen_func((host, try_port))
+ except socket.error as exc:
+ if (exc.errno != errno.EADDRINUSE or
+ try_port >= end_port):
+ raise
+ try_port += 1
+
+
+def initialize_if_enabled():
+ backdoor_locals = {
+ 'exit': _dont_use_this, # So we don't exit the entire process
+ 'quit': _dont_use_this, # So we don't exit the entire process
+ 'fo': _find_objects,
+ 'pgt': _print_greenthreads,
+ 'pnt': _print_nativethreads,
+ }
+
+ if CONF.backdoor_port is None:
+ return None
+
+ start_port, end_port = _parse_port_range(str(CONF.backdoor_port))
+
+ # NOTE(johannes): The standard sys.displayhook will print the value of
+ # the last expression and set it to __builtin__._, which overwrites
+ # the __builtin__._ that gettext sets. Let's switch to using pprint
+ # since it won't interact poorly with gettext, and it's easier to
+ # read the output too.
+ def displayhook(val):
+ if val is not None:
+ pprint.pprint(val)
+ sys.displayhook = displayhook
+
+ sock = _listen('localhost', start_port, end_port, eventlet.listen)
+
+ # In the case of backdoor port being zero, a port number is assigned by
+ # listen(). In any case, pull the port number out here.
+ port = sock.getsockname()[1]
+ LOG.info(_('Eventlet backdoor listening on %(port)s for process %(pid)d') %
+ {'port': port, 'pid': os.getpid()})
+ eventlet.spawn_n(eventlet.backdoor.backdoor_server, sock,
+ locals=backdoor_locals)
+ return port
diff --git a/keystone/openstack/common/excutils.py b/keystone/openstack/common/excutils.py
new file mode 100644
index 00000000..28d59f90
--- /dev/null
+++ b/keystone/openstack/common/excutils.py
@@ -0,0 +1,99 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 OpenStack Foundation.
+# Copyright 2012, Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Exception related utilities.
+"""
+
+import logging
+import sys
+import time
+import traceback
+
+from keystone.openstack.common.gettextutils import _ # noqa
+
+
+class save_and_reraise_exception(object):
+ """Save current exception, run some code and then re-raise.
+
+ In some cases the exception context can be cleared, resulting in None
+ being attempted to be re-raised after an exception handler is run. This
+ can happen when eventlet switches greenthreads or when running an
+ exception handler, code raises and catches an exception. In both
+ cases the exception context will be cleared.
+
+ To work around this, we save the exception state, run handler code, and
+ then re-raise the original exception. If another exception occurs, the
+ saved exception is logged and the new exception is re-raised.
+
+ In some cases the caller may not want to re-raise the exception, and
+ for those circumstances this context provides a reraise flag that
+ can be used to suppress the exception. For example:
+
+ except Exception:
+ with save_and_reraise_exception() as ctxt:
+ decide_if_need_reraise()
+ if not should_be_reraised:
+ ctxt.reraise = False
+ """
+ def __init__(self):
+ self.reraise = True
+
+ def __enter__(self):
+ self.type_, self.value, self.tb, = sys.exc_info()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if exc_type is not None:
+ logging.error(_('Original exception being dropped: %s'),
+ traceback.format_exception(self.type_,
+ self.value,
+ self.tb))
+ return False
+ if self.reraise:
+ raise self.type_, self.value, self.tb
+
+
+def forever_retry_uncaught_exceptions(infunc):
+ def inner_func(*args, **kwargs):
+ last_log_time = 0
+ last_exc_message = None
+ exc_count = 0
+ while True:
+ try:
+ return infunc(*args, **kwargs)
+ except Exception as exc:
+ this_exc_message = unicode(exc)
+ if this_exc_message == last_exc_message:
+ exc_count += 1
+ else:
+ exc_count = 1
+ # Do not log any more frequently than once a minute unless
+ # the exception message changes
+ cur_time = int(time.time())
+ if (cur_time - last_log_time > 60 or
+ this_exc_message != last_exc_message):
+ logging.exception(
+ _('Unexpected exception occurred %d time(s)... '
+ 'retrying.') % exc_count)
+ last_log_time = cur_time
+ last_exc_message = this_exc_message
+ exc_count = 0
+ # This should be a very rare event. In case it isn't, do
+ # a sleep.
+ time.sleep(1)
+ return inner_func
diff --git a/keystone/openstack/common/gettextutils.py b/keystone/openstack/common/gettextutils.py
index 55ba3387..1a24c24c 100644
--- a/keystone/openstack/common/gettextutils.py
+++ b/keystone/openstack/common/gettextutils.py
@@ -1,6 +1,7 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2012 Red Hat, Inc.
+# Copyright 2013 IBM Corp.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
@@ -23,18 +24,27 @@ Usual usage in an openstack.common module:
from keystone.openstack.common.gettextutils import _
"""
+import copy
import gettext
+import logging.handlers
import os
+import re
+import UserString
+
+from babel import localedata
+import six
_localedir = os.environ.get('keystone'.upper() + '_LOCALEDIR')
_t = gettext.translation('keystone', localedir=_localedir, fallback=True)
+_AVAILABLE_LANGUAGES = []
+
def _(msg):
return _t.ugettext(msg)
-def install(domain):
+def install(domain, lazy=False):
"""Install a _() function using the given translation domain.
Given a translation domain, install a _() function using gettext's
@@ -44,7 +54,252 @@ def install(domain):
overriding the default localedir (e.g. /usr/share/locale) using
a translation-domain-specific environment variable (e.g.
NOVA_LOCALEDIR).
+
+ :param domain: the translation domain
+ :param lazy: indicates whether or not to install the lazy _() function.
+ The lazy _() introduces a way to do deferred translation
+ of messages by installing a _ that builds Message objects,
+ instead of strings, which can then be lazily translated into
+ any available locale.
+ """
+ if lazy:
+ # NOTE(mrodden): Lazy gettext functionality.
+ #
+ # The following introduces a deferred way to do translations on
+ # messages in OpenStack. We override the standard _() function
+ # and % (format string) operation to build Message objects that can
+ # later be translated when we have more information.
+ #
+ # Also included below is an example LocaleHandler that translates
+ # Messages to an associated locale, effectively allowing many logs,
+ # each with their own locale.
+
+ def _lazy_gettext(msg):
+ """Create and return a Message object.
+
+ Lazy gettext function for a given domain, it is a factory method
+ for a project/module to get a lazy gettext function for its own
+ translation domain (i.e. nova, glance, cinder, etc.)
+
+ Message encapsulates a string so that we can translate
+ it later when needed.
+ """
+ return Message(msg, domain)
+
+ import __builtin__
+ __builtin__.__dict__['_'] = _lazy_gettext
+ else:
+ localedir = '%s_LOCALEDIR' % domain.upper()
+ gettext.install(domain,
+ localedir=os.environ.get(localedir),
+ unicode=True)
+
+
+class Message(UserString.UserString, object):
+ """Class used to encapsulate translatable messages."""
+ def __init__(self, msg, domain):
+ # _msg is the gettext msgid and should never change
+ self._msg = msg
+ self._left_extra_msg = ''
+ self._right_extra_msg = ''
+ self.params = None
+ self.locale = None
+ self.domain = domain
+
+ @property
+ def data(self):
+ # NOTE(mrodden): this should always resolve to a unicode string
+ # that best represents the state of the message currently
+
+ localedir = os.environ.get(self.domain.upper() + '_LOCALEDIR')
+ if self.locale:
+ lang = gettext.translation(self.domain,
+ localedir=localedir,
+ languages=[self.locale],
+ fallback=True)
+ else:
+ # use system locale for translations
+ lang = gettext.translation(self.domain,
+ localedir=localedir,
+ fallback=True)
+
+ full_msg = (self._left_extra_msg +
+ lang.ugettext(self._msg) +
+ self._right_extra_msg)
+
+ if self.params is not None:
+ full_msg = full_msg % self.params
+
+ return six.text_type(full_msg)
+
+ def _save_dictionary_parameter(self, dict_param):
+ full_msg = self.data
+ # look for %(blah) fields in string;
+ # ignore %% and deal with the
+ # case where % is first character on the line
+ keys = re.findall('(?:[^%]|^)?%\((\w*)\)[a-z]', full_msg)
+
+ # if we don't find any %(blah) blocks but have a %s
+ if not keys and re.findall('(?:[^%]|^)%[a-z]', full_msg):
+ # apparently the full dictionary is the parameter
+ params = copy.deepcopy(dict_param)
+ else:
+ params = {}
+ for key in keys:
+ try:
+ params[key] = copy.deepcopy(dict_param[key])
+ except TypeError:
+ # cast uncopyable thing to unicode string
+ params[key] = unicode(dict_param[key])
+
+ return params
+
+ def _save_parameters(self, other):
+ # we check for None later to see if
+ # we actually have parameters to inject,
+ # so encapsulate if our parameter is actually None
+ if other is None:
+ self.params = (other, )
+ elif isinstance(other, dict):
+ self.params = self._save_dictionary_parameter(other)
+ else:
+ # fallback to casting to unicode,
+ # this will handle the problematic python code-like
+ # objects that cannot be deep-copied
+ try:
+ self.params = copy.deepcopy(other)
+ except TypeError:
+ self.params = unicode(other)
+
+ return self
+
+ # overrides to be more string-like
+ def __unicode__(self):
+ return self.data
+
+ def __str__(self):
+ return self.data.encode('utf-8')
+
+ def __getstate__(self):
+ to_copy = ['_msg', '_right_extra_msg', '_left_extra_msg',
+ 'domain', 'params', 'locale']
+ new_dict = self.__dict__.fromkeys(to_copy)
+ for attr in to_copy:
+ new_dict[attr] = copy.deepcopy(self.__dict__[attr])
+
+ return new_dict
+
+ def __setstate__(self, state):
+ for (k, v) in state.items():
+ setattr(self, k, v)
+
+ # operator overloads
+ def __add__(self, other):
+ copied = copy.deepcopy(self)
+ copied._right_extra_msg += other.__str__()
+ return copied
+
+ def __radd__(self, other):
+ copied = copy.deepcopy(self)
+ copied._left_extra_msg += other.__str__()
+ return copied
+
+ def __mod__(self, other):
+ # do a format string to catch and raise
+ # any possible KeyErrors from missing parameters
+ self.data % other
+ copied = copy.deepcopy(self)
+ return copied._save_parameters(other)
+
+ def __mul__(self, other):
+ return self.data * other
+
+ def __rmul__(self, other):
+ return other * self.data
+
+ def __getitem__(self, key):
+ return self.data[key]
+
+ def __getslice__(self, start, end):
+ return self.data.__getslice__(start, end)
+
+ def __getattribute__(self, name):
+ # NOTE(mrodden): handle lossy operations that we can't deal with yet
+ # These override the UserString implementation, since UserString
+ # uses our __class__ attribute to try and build a new message
+ # after running the inner data string through the operation.
+ # At that point, we have lost the gettext message id and can just
+ # safely resolve to a string instead.
+ ops = ['capitalize', 'center', 'decode', 'encode',
+ 'expandtabs', 'ljust', 'lstrip', 'replace', 'rjust', 'rstrip',
+ 'strip', 'swapcase', 'title', 'translate', 'upper', 'zfill']
+ if name in ops:
+ return getattr(self.data, name)
+ else:
+ return UserString.UserString.__getattribute__(self, name)
+
+
+def get_available_languages(domain):
+ """Lists the available languages for the given translation domain.
+
+ :param domain: the domain to get languages for
"""
- gettext.install(domain,
- localedir=os.environ.get(domain.upper() + '_LOCALEDIR'),
- unicode=True)
+ if _AVAILABLE_LANGUAGES:
+ return _AVAILABLE_LANGUAGES
+
+ localedir = '%s_LOCALEDIR' % domain.upper()
+ find = lambda x: gettext.find(domain,
+ localedir=os.environ.get(localedir),
+ languages=[x])
+
+ # NOTE(mrodden): en_US should always be available (and first in case
+ # order matters) since our in-line message strings are en_US
+ _AVAILABLE_LANGUAGES.append('en_US')
+ # NOTE(luisg): Babel <1.0 used a function called list(), which was
+ # renamed to locale_identifiers() in >=1.0, the requirements master list
+ # requires >=0.9.6, uncapped, so defensively work with both. We can remove
+ # this check when the master list updates to >=1.0, and all projects udpate
+ list_identifiers = (getattr(localedata, 'list', None) or
+ getattr(localedata, 'locale_identifiers'))
+ locale_identifiers = list_identifiers()
+ for i in locale_identifiers:
+ if find(i) is not None:
+ _AVAILABLE_LANGUAGES.append(i)
+ return _AVAILABLE_LANGUAGES
+
+
+def get_localized_message(message, user_locale):
+ """Gets a localized version of the given message in the given locale."""
+ if (isinstance(message, Message)):
+ if user_locale:
+ message.locale = user_locale
+ return unicode(message)
+ else:
+ return message
+
+
+class LocaleHandler(logging.Handler):
+ """Handler that can have a locale associated to translate Messages.
+
+ A quick example of how to utilize the Message class above.
+ LocaleHandler takes a locale and a target logging.Handler object
+ to forward LogRecord objects to after translating the internal Message.
+ """
+
+ def __init__(self, locale, target):
+ """Initialize a LocaleHandler
+
+ :param locale: locale to use for translating messages
+ :param target: logging.Handler object to forward
+ LogRecord objects to after translation
+ """
+ logging.Handler.__init__(self)
+ self.locale = locale
+ self.target = target
+
+ def emit(self, record):
+ if isinstance(record.msg, Message):
+ # set the locale and resolve to a string
+ record.msg.locale = self.locale
+
+ self.target.emit(record)
diff --git a/keystone/openstack/common/importutils.py b/keystone/openstack/common/importutils.py
index 3bd277f4..7a303f93 100644
--- a/keystone/openstack/common/importutils.py
+++ b/keystone/openstack/common/importutils.py
@@ -24,7 +24,7 @@ import traceback
def import_class(import_str):
- """Returns a class from a string including module and class"""
+ """Returns a class from a string including module and class."""
mod_str, _sep, class_str = import_str.rpartition('.')
try:
__import__(mod_str)
@@ -41,8 +41,9 @@ def import_object(import_str, *args, **kwargs):
def import_object_ns(name_space, import_str, *args, **kwargs):
- """
- Import a class and return an instance of it, first by trying
+ """Tries to import object from default namespace.
+
+ Imports a class and return an instance of it, first by trying
to find the class in a default namespace, then failing back to
a full path if not found in the default namespace.
"""
diff --git a/keystone/openstack/common/jsonutils.py b/keystone/openstack/common/jsonutils.py
index d73e4c26..ecea09bb 100644
--- a/keystone/openstack/common/jsonutils.py
+++ b/keystone/openstack/common/jsonutils.py
@@ -38,11 +38,24 @@ import functools
import inspect
import itertools
import json
+import types
import xmlrpclib
+import netaddr
+import six
+
from keystone.openstack.common import timeutils
+_nasty_type_tests = [inspect.ismodule, inspect.isclass, inspect.ismethod,
+ inspect.isfunction, inspect.isgeneratorfunction,
+ inspect.isgenerator, inspect.istraceback, inspect.isframe,
+ inspect.iscode, inspect.isbuiltin, inspect.isroutine,
+ inspect.isabstract]
+
+_simple_types = (types.NoneType, int, basestring, bool, float, long)
+
+
def to_primitive(value, convert_instances=False, convert_datetime=True,
level=0, max_depth=3):
"""Convert a complex object into primitives.
@@ -58,19 +71,32 @@ def to_primitive(value, convert_instances=False, convert_datetime=True,
Therefore, convert_instances=True is lossy ... be aware.
"""
- nasty = [inspect.ismodule, inspect.isclass, inspect.ismethod,
- inspect.isfunction, inspect.isgeneratorfunction,
- inspect.isgenerator, inspect.istraceback, inspect.isframe,
- inspect.iscode, inspect.isbuiltin, inspect.isroutine,
- inspect.isabstract]
- for test in nasty:
- if test(value):
- return unicode(value)
-
- # value of itertools.count doesn't get caught by inspects
- # above and results in infinite loop when list(value) is called.
+ # handle obvious types first - order of basic types determined by running
+ # full tests on nova project, resulting in the following counts:
+ # 572754 <type 'NoneType'>
+ # 460353 <type 'int'>
+ # 379632 <type 'unicode'>
+ # 274610 <type 'str'>
+ # 199918 <type 'dict'>
+ # 114200 <type 'datetime.datetime'>
+ # 51817 <type 'bool'>
+ # 26164 <type 'list'>
+ # 6491 <type 'float'>
+ # 283 <type 'tuple'>
+ # 19 <type 'long'>
+ if isinstance(value, _simple_types):
+ return value
+
+ if isinstance(value, datetime.datetime):
+ if convert_datetime:
+ return timeutils.strtime(value)
+ else:
+ return value
+
+ # value of itertools.count doesn't get caught by nasty_type_tests
+ # and results in infinite loop when list(value) is called.
if type(value) == itertools.count:
- return unicode(value)
+ return six.text_type(value)
# FIXME(vish): Workaround for LP bug 852095. Without this workaround,
# tests that raise an exception in a mocked method that
@@ -91,17 +117,18 @@ def to_primitive(value, convert_instances=False, convert_datetime=True,
convert_datetime=convert_datetime,
level=level,
max_depth=max_depth)
+ if isinstance(value, dict):
+ return dict((k, recursive(v)) for k, v in value.iteritems())
+ elif isinstance(value, (list, tuple)):
+ return [recursive(lv) for lv in value]
+
# It's not clear why xmlrpclib created their own DateTime type, but
# for our purposes, make it a datetime type which is explicitly
# handled
if isinstance(value, xmlrpclib.DateTime):
value = datetime.datetime(*tuple(value.timetuple())[:6])
- if isinstance(value, (list, tuple)):
- return [recursive(v) for v in value]
- elif isinstance(value, dict):
- return dict((k, recursive(v)) for k, v in value.iteritems())
- elif convert_datetime and isinstance(value, datetime.datetime):
+ if convert_datetime and isinstance(value, datetime.datetime):
return timeutils.strtime(value)
elif hasattr(value, 'iteritems'):
return recursive(dict(value.iteritems()), level=level + 1)
@@ -111,12 +138,16 @@ def to_primitive(value, convert_instances=False, convert_datetime=True,
# Likely an instance of something. Watch for cycles.
# Ignore class member vars.
return recursive(value.__dict__, level=level + 1)
+ elif isinstance(value, netaddr.IPAddress):
+ return six.text_type(value)
else:
+ if any(test(value) for test in _nasty_type_tests):
+ return six.text_type(value)
return value
except TypeError:
# Class objects are tricky since they may define something like
# __iter__ defined but it isn't callable as list().
- return unicode(value)
+ return six.text_type(value)
def dumps(value, default=to_primitive, **kwargs):
diff --git a/keystone/openstack/common/local.py b/keystone/openstack/common/local.py
new file mode 100644
index 00000000..e82f17d0
--- /dev/null
+++ b/keystone/openstack/common/local.py
@@ -0,0 +1,47 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 OpenStack Foundation.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Local storage of variables using weak references"""
+
+import threading
+import weakref
+
+
+class WeakLocal(threading.local):
+ def __getattribute__(self, attr):
+ rval = super(WeakLocal, self).__getattribute__(attr)
+ if rval:
+ # NOTE(mikal): this bit is confusing. What is stored is a weak
+ # reference, not the value itself. We therefore need to lookup
+ # the weak reference and return the inner value here.
+ rval = rval()
+ return rval
+
+ def __setattr__(self, attr, value):
+ value = weakref.ref(value)
+ return super(WeakLocal, self).__setattr__(attr, value)
+
+
+# NOTE(mikal): the name "store" should be deprecated in the future
+store = WeakLocal()
+
+# A "weak" store uses weak references and allows an object to fall out of scope
+# when it falls out of scope in the code that uses the thread local storage. A
+# "strong" store will hold a reference to the object so that it never falls out
+# of scope.
+weak_store = WeakLocal()
+strong_store = threading.local()
diff --git a/keystone/openstack/common/log.py b/keystone/openstack/common/log.py
new file mode 100644
index 00000000..5a43c326
--- /dev/null
+++ b/keystone/openstack/common/log.py
@@ -0,0 +1,559 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 OpenStack Foundation.
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Openstack logging handler.
+
+This module adds to logging functionality by adding the option to specify
+a context object when calling the various log methods. If the context object
+is not specified, default formatting is used. Additionally, an instance uuid
+may be passed as part of the log message, which is intended to make it easier
+for admins to find messages related to a specific instance.
+
+It also allows setting of formatting information through conf.
+
+"""
+
+import inspect
+import itertools
+import logging
+import logging.config
+import logging.handlers
+import os
+import sys
+import traceback
+
+from oslo.config import cfg
+from six import moves
+
+from keystone.openstack.common.gettextutils import _ # noqa
+from keystone.openstack.common import importutils
+from keystone.openstack.common import jsonutils
+from keystone.openstack.common import local
+
+
+_DEFAULT_LOG_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
+
+common_cli_opts = [
+ cfg.BoolOpt('debug',
+ short='d',
+ default=False,
+ help='Print debugging output (set logging level to '
+ 'DEBUG instead of default WARNING level).'),
+ cfg.BoolOpt('verbose',
+ short='v',
+ default=False,
+ help='Print more verbose output (set logging level to '
+ 'INFO instead of default WARNING level).'),
+]
+
+logging_cli_opts = [
+ cfg.StrOpt('log-config',
+ metavar='PATH',
+ help='If this option is specified, the logging configuration '
+ 'file specified is used and overrides any other logging '
+ 'options specified. Please see the Python logging module '
+ 'documentation for details on logging configuration '
+ 'files.'),
+ cfg.StrOpt('log-format',
+ default=None,
+ metavar='FORMAT',
+ help='DEPRECATED. '
+ 'A logging.Formatter log message format string which may '
+ 'use any of the available logging.LogRecord attributes. '
+ 'This option is deprecated. Please use '
+ 'logging_context_format_string and '
+ 'logging_default_format_string instead.'),
+ cfg.StrOpt('log-date-format',
+ default=_DEFAULT_LOG_DATE_FORMAT,
+ metavar='DATE_FORMAT',
+ help='Format string for %%(asctime)s in log records. '
+ 'Default: %(default)s'),
+ cfg.StrOpt('log-file',
+ metavar='PATH',
+ deprecated_name='logfile',
+ help='(Optional) Name of log file to output to. '
+ 'If no default is set, logging will go to stdout.'),
+ cfg.StrOpt('log-dir',
+ deprecated_name='logdir',
+ help='(Optional) The base directory used for relative '
+ '--log-file paths'),
+ cfg.BoolOpt('use-syslog',
+ default=False,
+ help='Use syslog for logging.'),
+ cfg.StrOpt('syslog-log-facility',
+ default='LOG_USER',
+ help='syslog facility to receive log lines')
+]
+
+generic_log_opts = [
+ cfg.BoolOpt('use_stderr',
+ default=True,
+ help='Log output to standard error')
+]
+
+log_opts = [
+ cfg.StrOpt('logging_context_format_string',
+ default='%(asctime)s.%(msecs)03d %(process)d %(levelname)s '
+ '%(name)s [%(request_id)s %(user)s %(tenant)s] '
+ '%(instance)s%(message)s',
+ help='format string to use for log messages with context'),
+ cfg.StrOpt('logging_default_format_string',
+ default='%(asctime)s.%(msecs)03d %(process)d %(levelname)s '
+ '%(name)s [-] %(instance)s%(message)s',
+ help='format string to use for log messages without context'),
+ cfg.StrOpt('logging_debug_format_suffix',
+ default='%(funcName)s %(pathname)s:%(lineno)d',
+ help='data to append to log format when level is DEBUG'),
+ cfg.StrOpt('logging_exception_prefix',
+ default='%(asctime)s.%(msecs)03d %(process)d TRACE %(name)s '
+ '%(instance)s',
+ help='prefix each line of exception output with this format'),
+ cfg.ListOpt('default_log_levels',
+ default=[
+ 'amqplib=WARN',
+ 'sqlalchemy=WARN',
+ 'boto=WARN',
+ 'suds=INFO',
+ 'keystone=INFO',
+ 'eventlet.wsgi.server=WARN'
+ ],
+ help='list of logger=LEVEL pairs'),
+ cfg.BoolOpt('publish_errors',
+ default=False,
+ help='publish error events'),
+ cfg.BoolOpt('fatal_deprecations',
+ default=False,
+ help='make deprecations fatal'),
+
+ # NOTE(mikal): there are two options here because sometimes we are handed
+ # a full instance (and could include more information), and other times we
+ # are just handed a UUID for the instance.
+ cfg.StrOpt('instance_format',
+ default='[instance: %(uuid)s] ',
+ help='If an instance is passed with the log message, format '
+ 'it like this'),
+ cfg.StrOpt('instance_uuid_format',
+ default='[instance: %(uuid)s] ',
+ help='If an instance UUID is passed with the log message, '
+ 'format it like this'),
+]
+
+CONF = cfg.CONF
+CONF.register_cli_opts(common_cli_opts)
+CONF.register_cli_opts(logging_cli_opts)
+CONF.register_opts(generic_log_opts)
+CONF.register_opts(log_opts)
+
+# our new audit level
+# NOTE(jkoelker) Since we synthesized an audit level, make the logging
+# module aware of it so it acts like other levels.
+logging.AUDIT = logging.INFO + 1
+logging.addLevelName(logging.AUDIT, 'AUDIT')
+
+
+try:
+ NullHandler = logging.NullHandler
+except AttributeError: # NOTE(jkoelker) NullHandler added in Python 2.7
+ class NullHandler(logging.Handler):
+ def handle(self, record):
+ pass
+
+ def emit(self, record):
+ pass
+
+ def createLock(self):
+ self.lock = None
+
+
+def _dictify_context(context):
+ if context is None:
+ return None
+ if not isinstance(context, dict) and getattr(context, 'to_dict', None):
+ context = context.to_dict()
+ return context
+
+
+def _get_binary_name():
+ return os.path.basename(inspect.stack()[-1][1])
+
+
+def _get_log_file_path(binary=None):
+ logfile = CONF.log_file
+ logdir = CONF.log_dir
+
+ if logfile and not logdir:
+ return logfile
+
+ if logfile and logdir:
+ return os.path.join(logdir, logfile)
+
+ if logdir:
+ binary = binary or _get_binary_name()
+ return '%s.log' % (os.path.join(logdir, binary),)
+
+
+class BaseLoggerAdapter(logging.LoggerAdapter):
+
+ def audit(self, msg, *args, **kwargs):
+ self.log(logging.AUDIT, msg, *args, **kwargs)
+
+
+class LazyAdapter(BaseLoggerAdapter):
+ def __init__(self, name='unknown', version='unknown'):
+ self._logger = None
+ self.extra = {}
+ self.name = name
+ self.version = version
+
+ @property
+ def logger(self):
+ if not self._logger:
+ self._logger = getLogger(self.name, self.version)
+ return self._logger
+
+
+class ContextAdapter(BaseLoggerAdapter):
+ warn = logging.LoggerAdapter.warning
+
+ def __init__(self, logger, project_name, version_string):
+ self.logger = logger
+ self.project = project_name
+ self.version = version_string
+
+ @property
+ def handlers(self):
+ return self.logger.handlers
+
+ def deprecated(self, msg, *args, **kwargs):
+ stdmsg = _("Deprecated: %s") % msg
+ if CONF.fatal_deprecations:
+ self.critical(stdmsg, *args, **kwargs)
+ raise DeprecatedConfig(msg=stdmsg)
+ else:
+ self.warn(stdmsg, *args, **kwargs)
+
+ def process(self, msg, kwargs):
+ if 'extra' not in kwargs:
+ kwargs['extra'] = {}
+ extra = kwargs['extra']
+
+ context = kwargs.pop('context', None)
+ if not context:
+ context = getattr(local.store, 'context', None)
+ if context:
+ extra.update(_dictify_context(context))
+
+ instance = kwargs.pop('instance', None)
+ instance_extra = ''
+ if instance:
+ instance_extra = CONF.instance_format % instance
+ else:
+ instance_uuid = kwargs.pop('instance_uuid', None)
+ if instance_uuid:
+ instance_extra = (CONF.instance_uuid_format
+ % {'uuid': instance_uuid})
+ extra.update({'instance': instance_extra})
+
+ extra.update({"project": self.project})
+ extra.update({"version": self.version})
+ extra['extra'] = extra.copy()
+ return msg, kwargs
+
+
+class JSONFormatter(logging.Formatter):
+ def __init__(self, fmt=None, datefmt=None):
+ # NOTE(jkoelker) we ignore the fmt argument, but its still there
+ # since logging.config.fileConfig passes it.
+ self.datefmt = datefmt
+
+ def formatException(self, ei, strip_newlines=True):
+ lines = traceback.format_exception(*ei)
+ if strip_newlines:
+ lines = [itertools.ifilter(
+ lambda x: x,
+ line.rstrip().splitlines()) for line in lines]
+ lines = list(itertools.chain(*lines))
+ return lines
+
+ def format(self, record):
+ message = {'message': record.getMessage(),
+ 'asctime': self.formatTime(record, self.datefmt),
+ 'name': record.name,
+ 'msg': record.msg,
+ 'args': record.args,
+ 'levelname': record.levelname,
+ 'levelno': record.levelno,
+ 'pathname': record.pathname,
+ 'filename': record.filename,
+ 'module': record.module,
+ 'lineno': record.lineno,
+ 'funcname': record.funcName,
+ 'created': record.created,
+ 'msecs': record.msecs,
+ 'relative_created': record.relativeCreated,
+ 'thread': record.thread,
+ 'thread_name': record.threadName,
+ 'process_name': record.processName,
+ 'process': record.process,
+ 'traceback': None}
+
+ if hasattr(record, 'extra'):
+ message['extra'] = record.extra
+
+ if record.exc_info:
+ message['traceback'] = self.formatException(record.exc_info)
+
+ return jsonutils.dumps(message)
+
+
+def _create_logging_excepthook(product_name):
+ def logging_excepthook(type, value, tb):
+ extra = {}
+ if CONF.verbose:
+ extra['exc_info'] = (type, value, tb)
+ getLogger(product_name).critical(str(value), **extra)
+ return logging_excepthook
+
+
+class LogConfigError(Exception):
+
+ message = _('Error loading logging config %(log_config)s: %(err_msg)s')
+
+ def __init__(self, log_config, err_msg):
+ self.log_config = log_config
+ self.err_msg = err_msg
+
+ def __str__(self):
+ return self.message % dict(log_config=self.log_config,
+ err_msg=self.err_msg)
+
+
+def _load_log_config(log_config):
+ try:
+ logging.config.fileConfig(log_config)
+ except moves.configparser.Error as exc:
+ raise LogConfigError(log_config, str(exc))
+
+
+def setup(product_name):
+ """Setup logging."""
+ if CONF.log_config:
+ _load_log_config(CONF.log_config)
+ else:
+ _setup_logging_from_conf()
+ sys.excepthook = _create_logging_excepthook(product_name)
+
+
+def set_defaults(logging_context_format_string):
+ cfg.set_defaults(log_opts,
+ logging_context_format_string=
+ logging_context_format_string)
+
+
+def _find_facility_from_conf():
+ facility_names = logging.handlers.SysLogHandler.facility_names
+ facility = getattr(logging.handlers.SysLogHandler,
+ CONF.syslog_log_facility,
+ None)
+
+ if facility is None and CONF.syslog_log_facility in facility_names:
+ facility = facility_names.get(CONF.syslog_log_facility)
+
+ if facility is None:
+ valid_facilities = facility_names.keys()
+ consts = ['LOG_AUTH', 'LOG_AUTHPRIV', 'LOG_CRON', 'LOG_DAEMON',
+ 'LOG_FTP', 'LOG_KERN', 'LOG_LPR', 'LOG_MAIL', 'LOG_NEWS',
+ 'LOG_AUTH', 'LOG_SYSLOG', 'LOG_USER', 'LOG_UUCP',
+ 'LOG_LOCAL0', 'LOG_LOCAL1', 'LOG_LOCAL2', 'LOG_LOCAL3',
+ 'LOG_LOCAL4', 'LOG_LOCAL5', 'LOG_LOCAL6', 'LOG_LOCAL7']
+ valid_facilities.extend(consts)
+ raise TypeError(_('syslog facility must be one of: %s') %
+ ', '.join("'%s'" % fac
+ for fac in valid_facilities))
+
+ return facility
+
+
+def _setup_logging_from_conf():
+ log_root = getLogger(None).logger
+ for handler in log_root.handlers:
+ log_root.removeHandler(handler)
+
+ if CONF.use_syslog:
+ facility = _find_facility_from_conf()
+ syslog = logging.handlers.SysLogHandler(address='/dev/log',
+ facility=facility)
+ log_root.addHandler(syslog)
+
+ logpath = _get_log_file_path()
+ if logpath:
+ filelog = logging.handlers.WatchedFileHandler(logpath)
+ log_root.addHandler(filelog)
+
+ if CONF.use_stderr:
+ streamlog = ColorHandler()
+ log_root.addHandler(streamlog)
+
+ elif not CONF.log_file:
+ # pass sys.stdout as a positional argument
+ # python2.6 calls the argument strm, in 2.7 it's stream
+ streamlog = logging.StreamHandler(sys.stdout)
+ log_root.addHandler(streamlog)
+
+ if CONF.publish_errors:
+ handler = importutils.import_object(
+ "keystone.openstack.common.log_handler.PublishErrorsHandler",
+ logging.ERROR)
+ log_root.addHandler(handler)
+
+ datefmt = CONF.log_date_format
+ for handler in log_root.handlers:
+ # NOTE(alaski): CONF.log_format overrides everything currently. This
+ # should be deprecated in favor of context aware formatting.
+ if CONF.log_format:
+ handler.setFormatter(logging.Formatter(fmt=CONF.log_format,
+ datefmt=datefmt))
+ log_root.info('Deprecated: log_format is now deprecated and will '
+ 'be removed in the next release')
+ else:
+ handler.setFormatter(ContextFormatter(datefmt=datefmt))
+
+ if CONF.debug:
+ log_root.setLevel(logging.DEBUG)
+ elif CONF.verbose:
+ log_root.setLevel(logging.INFO)
+ else:
+ log_root.setLevel(logging.WARNING)
+
+ for pair in CONF.default_log_levels:
+ mod, _sep, level_name = pair.partition('=')
+ level = logging.getLevelName(level_name)
+ logger = logging.getLogger(mod)
+ logger.setLevel(level)
+
+_loggers = {}
+
+
+def getLogger(name='unknown', version='unknown'):
+ if name not in _loggers:
+ _loggers[name] = ContextAdapter(logging.getLogger(name),
+ name,
+ version)
+ return _loggers[name]
+
+
+def getLazyLogger(name='unknown', version='unknown'):
+ """Returns lazy logger.
+
+ Creates a pass-through logger that does not create the real logger
+ until it is really needed and delegates all calls to the real logger
+ once it is created.
+ """
+ return LazyAdapter(name, version)
+
+
+class WritableLogger(object):
+ """A thin wrapper that responds to `write` and logs."""
+
+ def __init__(self, logger, level=logging.INFO):
+ self.logger = logger
+ self.level = level
+
+ def write(self, msg):
+ self.logger.log(self.level, msg)
+
+
+class ContextFormatter(logging.Formatter):
+ """A context.RequestContext aware formatter configured through flags.
+
+ The flags used to set format strings are: logging_context_format_string
+ and logging_default_format_string. You can also specify
+ logging_debug_format_suffix to append extra formatting if the log level is
+ debug.
+
+ For information about what variables are available for the formatter see:
+ http://docs.python.org/library/logging.html#formatter
+
+ """
+
+ def format(self, record):
+ """Uses contextstring if request_id is set, otherwise default."""
+ # NOTE(sdague): default the fancier formating params
+ # to an empty string so we don't throw an exception if
+ # they get used
+ for key in ('instance', 'color'):
+ if key not in record.__dict__:
+ record.__dict__[key] = ''
+
+ if record.__dict__.get('request_id', None):
+ self._fmt = CONF.logging_context_format_string
+ else:
+ self._fmt = CONF.logging_default_format_string
+
+ if (record.levelno == logging.DEBUG and
+ CONF.logging_debug_format_suffix):
+ self._fmt += " " + CONF.logging_debug_format_suffix
+
+ # Cache this on the record, Logger will respect our formated copy
+ if record.exc_info:
+ record.exc_text = self.formatException(record.exc_info, record)
+ return logging.Formatter.format(self, record)
+
+ def formatException(self, exc_info, record=None):
+ """Format exception output with CONF.logging_exception_prefix."""
+ if not record:
+ return logging.Formatter.formatException(self, exc_info)
+
+ stringbuffer = moves.StringIO()
+ traceback.print_exception(exc_info[0], exc_info[1], exc_info[2],
+ None, stringbuffer)
+ lines = stringbuffer.getvalue().split('\n')
+ stringbuffer.close()
+
+ if CONF.logging_exception_prefix.find('%(asctime)') != -1:
+ record.asctime = self.formatTime(record, self.datefmt)
+
+ formatted_lines = []
+ for line in lines:
+ pl = CONF.logging_exception_prefix % record.__dict__
+ fl = '%s%s' % (pl, line)
+ formatted_lines.append(fl)
+ return '\n'.join(formatted_lines)
+
+
+class ColorHandler(logging.StreamHandler):
+ LEVEL_COLORS = {
+ logging.DEBUG: '\033[00;32m', # GREEN
+ logging.INFO: '\033[00;36m', # CYAN
+ logging.AUDIT: '\033[01;36m', # BOLD CYAN
+ logging.WARN: '\033[01;33m', # BOLD YELLOW
+ logging.ERROR: '\033[01;31m', # BOLD RED
+ logging.CRITICAL: '\033[01;31m', # BOLD RED
+ }
+
+ def format(self, record):
+ record.color = self.LEVEL_COLORS[record.levelno]
+ return logging.StreamHandler.format(self, record)
+
+
+class DeprecatedConfig(Exception):
+ message = _("Fatal call to deprecated config: %(msg)s")
+
+ def __init__(self, msg):
+ super(Exception, self).__init__(self.message % dict(msg=msg))
diff --git a/keystone/openstack/common/loopingcall.py b/keystone/openstack/common/loopingcall.py
new file mode 100644
index 00000000..0801db09
--- /dev/null
+++ b/keystone/openstack/common/loopingcall.py
@@ -0,0 +1,147 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# Copyright 2011 Justin Santa Barbara
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import sys
+
+from eventlet import event
+from eventlet import greenthread
+
+from keystone.openstack.common.gettextutils import _ # noqa
+from keystone.openstack.common import log as logging
+from keystone.openstack.common import timeutils
+
+LOG = logging.getLogger(__name__)
+
+
+class LoopingCallDone(Exception):
+ """Exception to break out and stop a LoopingCall.
+
+ The poll-function passed to LoopingCall can raise this exception to
+ break out of the loop normally. This is somewhat analogous to
+ StopIteration.
+
+ An optional return-value can be included as the argument to the exception;
+ this return-value will be returned by LoopingCall.wait()
+
+ """
+
+ def __init__(self, retvalue=True):
+ """:param retvalue: Value that LoopingCall.wait() should return."""
+ self.retvalue = retvalue
+
+
+class LoopingCallBase(object):
+ def __init__(self, f=None, *args, **kw):
+ self.args = args
+ self.kw = kw
+ self.f = f
+ self._running = False
+ self.done = None
+
+ def stop(self):
+ self._running = False
+
+ def wait(self):
+ return self.done.wait()
+
+
+class FixedIntervalLoopingCall(LoopingCallBase):
+ """A fixed interval looping call."""
+
+ def start(self, interval, initial_delay=None):
+ self._running = True
+ done = event.Event()
+
+ def _inner():
+ if initial_delay:
+ greenthread.sleep(initial_delay)
+
+ try:
+ while self._running:
+ start = timeutils.utcnow()
+ self.f(*self.args, **self.kw)
+ end = timeutils.utcnow()
+ if not self._running:
+ break
+ delay = interval - timeutils.delta_seconds(start, end)
+ if delay <= 0:
+ LOG.warn(_('task run outlasted interval by %s sec') %
+ -delay)
+ greenthread.sleep(delay if delay > 0 else 0)
+ except LoopingCallDone as e:
+ self.stop()
+ done.send(e.retvalue)
+ except Exception:
+ LOG.exception(_('in fixed duration looping call'))
+ done.send_exception(*sys.exc_info())
+ return
+ else:
+ done.send(True)
+
+ self.done = done
+
+ greenthread.spawn_n(_inner)
+ return self.done
+
+
+# TODO(mikal): this class name is deprecated in Havana and should be removed
+# in the I release
+LoopingCall = FixedIntervalLoopingCall
+
+
+class DynamicLoopingCall(LoopingCallBase):
+ """A looping call which sleeps until the next known event.
+
+ The function called should return how long to sleep for before being
+ called again.
+ """
+
+ def start(self, initial_delay=None, periodic_interval_max=None):
+ self._running = True
+ done = event.Event()
+
+ def _inner():
+ if initial_delay:
+ greenthread.sleep(initial_delay)
+
+ try:
+ while self._running:
+ idle = self.f(*self.args, **self.kw)
+ if not self._running:
+ break
+
+ if periodic_interval_max is not None:
+ idle = min(idle, periodic_interval_max)
+ LOG.debug(_('Dynamic looping call sleeping for %.02f '
+ 'seconds'), idle)
+ greenthread.sleep(idle)
+ except LoopingCallDone as e:
+ self.stop()
+ done.send(e.retvalue)
+ except Exception:
+ LOG.exception(_('in dynamic looping call'))
+ done.send_exception(*sys.exc_info())
+ return
+ else:
+ done.send(True)
+
+ self.done = done
+
+ greenthread.spawn(_inner)
+ return self.done
diff --git a/keystone/openstack/common/network_utils.py b/keystone/openstack/common/network_utils.py
new file mode 100644
index 00000000..dbed1ceb
--- /dev/null
+++ b/keystone/openstack/common/network_utils.py
@@ -0,0 +1,81 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012 OpenStack Foundation.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Network-related utilities and helper functions.
+"""
+
+import urlparse
+
+
+def parse_host_port(address, default_port=None):
+ """Interpret a string as a host:port pair.
+
+ An IPv6 address MUST be escaped if accompanied by a port,
+ because otherwise ambiguity ensues: 2001:db8:85a3::8a2e:370:7334
+ means both [2001:db8:85a3::8a2e:370:7334] and
+ [2001:db8:85a3::8a2e:370]:7334.
+
+ >>> parse_host_port('server01:80')
+ ('server01', 80)
+ >>> parse_host_port('server01')
+ ('server01', None)
+ >>> parse_host_port('server01', default_port=1234)
+ ('server01', 1234)
+ >>> parse_host_port('[::1]:80')
+ ('::1', 80)
+ >>> parse_host_port('[::1]')
+ ('::1', None)
+ >>> parse_host_port('[::1]', default_port=1234)
+ ('::1', 1234)
+ >>> parse_host_port('2001:db8:85a3::8a2e:370:7334', default_port=1234)
+ ('2001:db8:85a3::8a2e:370:7334', 1234)
+
+ """
+ if address[0] == '[':
+ # Escaped ipv6
+ _host, _port = address[1:].split(']')
+ host = _host
+ if ':' in _port:
+ port = _port.split(':')[1]
+ else:
+ port = default_port
+ else:
+ if address.count(':') == 1:
+ host, port = address.split(':')
+ else:
+ # 0 means ipv4, >1 means ipv6.
+ # We prohibit unescaped ipv6 addresses with port.
+ host = address
+ port = default_port
+
+ return (host, None if port is None else int(port))
+
+
+def urlsplit(url, scheme='', allow_fragments=True):
+ """Parse a URL using urlparse.urlsplit(), splitting query and fragments.
+ This function papers over Python issue9374 when needed.
+
+ The parameters are the same as urlparse.urlsplit.
+ """
+ scheme, netloc, path, query, fragment = urlparse.urlsplit(
+ url, scheme, allow_fragments)
+ if allow_fragments and '#' in path:
+ path, fragment = path.split('#', 1)
+ if '?' in path:
+ path, query = path.split('?', 1)
+ return urlparse.SplitResult(scheme, netloc, path, query, fragment)
diff --git a/keystone/openstack/common/notifier/__init__.py b/keystone/openstack/common/notifier/__init__.py
new file mode 100644
index 00000000..45c3b46a
--- /dev/null
+++ b/keystone/openstack/common/notifier/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2011 OpenStack Foundation.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
diff --git a/keystone/openstack/common/notifier/api.py b/keystone/openstack/common/notifier/api.py
new file mode 100644
index 00000000..51eb7eae
--- /dev/null
+++ b/keystone/openstack/common/notifier/api.py
@@ -0,0 +1,173 @@
+# Copyright 2011 OpenStack Foundation.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import socket
+import uuid
+
+from oslo.config import cfg
+
+from keystone.openstack.common import context
+from keystone.openstack.common.gettextutils import _ # noqa
+from keystone.openstack.common import importutils
+from keystone.openstack.common import jsonutils
+from keystone.openstack.common import log as logging
+from keystone.openstack.common import timeutils
+
+
+LOG = logging.getLogger(__name__)
+
+notifier_opts = [
+ cfg.MultiStrOpt('notification_driver',
+ default=[],
+ help='Driver or drivers to handle sending notifications'),
+ cfg.StrOpt('default_notification_level',
+ default='INFO',
+ help='Default notification level for outgoing notifications'),
+ cfg.StrOpt('default_publisher_id',
+ default=None,
+ help='Default publisher_id for outgoing notifications'),
+]
+
+CONF = cfg.CONF
+CONF.register_opts(notifier_opts)
+
+WARN = 'WARN'
+INFO = 'INFO'
+ERROR = 'ERROR'
+CRITICAL = 'CRITICAL'
+DEBUG = 'DEBUG'
+
+log_levels = (DEBUG, WARN, INFO, ERROR, CRITICAL)
+
+
+class BadPriorityException(Exception):
+ pass
+
+
+def notify_decorator(name, fn):
+ """Decorator for notify which is used from utils.monkey_patch().
+
+ :param name: name of the function
+ :param function: - object of the function
+ :returns: function -- decorated function
+
+ """
+ def wrapped_func(*args, **kwarg):
+ body = {}
+ body['args'] = []
+ body['kwarg'] = {}
+ for arg in args:
+ body['args'].append(arg)
+ for key in kwarg:
+ body['kwarg'][key] = kwarg[key]
+
+ ctxt = context.get_context_from_function_and_args(fn, args, kwarg)
+ notify(ctxt,
+ CONF.default_publisher_id or socket.gethostname(),
+ name,
+ CONF.default_notification_level,
+ body)
+ return fn(*args, **kwarg)
+ return wrapped_func
+
+
+def publisher_id(service, host=None):
+ if not host:
+ try:
+ host = CONF.host
+ except AttributeError:
+ host = CONF.default_publisher_id or socket.gethostname()
+ return "%s.%s" % (service, host)
+
+
+def notify(context, publisher_id, event_type, priority, payload):
+ """Sends a notification using the specified driver
+
+ :param publisher_id: the source worker_type.host of the message
+ :param event_type: the literal type of event (ex. Instance Creation)
+ :param priority: patterned after the enumeration of Python logging
+ levels in the set (DEBUG, WARN, INFO, ERROR, CRITICAL)
+ :param payload: A python dictionary of attributes
+
+ Outgoing message format includes the above parameters, and appends the
+ following:
+
+ message_id
+ a UUID representing the id for this notification
+
+ timestamp
+ the GMT timestamp the notification was sent at
+
+ The composite message will be constructed as a dictionary of the above
+ attributes, which will then be sent via the transport mechanism defined
+ by the driver.
+
+ Message example::
+
+ {'message_id': str(uuid.uuid4()),
+ 'publisher_id': 'compute.host1',
+ 'timestamp': timeutils.utcnow(),
+ 'priority': 'WARN',
+ 'event_type': 'compute.create_instance',
+ 'payload': {'instance_id': 12, ... }}
+
+ """
+ if priority not in log_levels:
+ raise BadPriorityException(
+ _('%s not in valid priorities') % priority)
+
+ # Ensure everything is JSON serializable.
+ payload = jsonutils.to_primitive(payload, convert_instances=True)
+
+ msg = dict(message_id=str(uuid.uuid4()),
+ publisher_id=publisher_id,
+ event_type=event_type,
+ priority=priority,
+ payload=payload,
+ timestamp=str(timeutils.utcnow()))
+
+ for driver in _get_drivers():
+ try:
+ driver.notify(context, msg)
+ except Exception as e:
+ LOG.exception(_("Problem '%(e)s' attempting to "
+ "send to notification system. "
+ "Payload=%(payload)s")
+ % dict(e=e, payload=payload))
+
+
+_drivers = None
+
+
+def _get_drivers():
+ """Instantiate, cache, and return drivers based on the CONF."""
+ global _drivers
+ if _drivers is None:
+ _drivers = {}
+ for notification_driver in CONF.notification_driver:
+ try:
+ driver = importutils.import_module(notification_driver)
+ _drivers[notification_driver] = driver
+ except ImportError:
+ LOG.exception(_("Failed to load notifier %s. "
+ "These notifications will not be sent.") %
+ notification_driver)
+ return _drivers.values()
+
+
+def _reset_drivers():
+ """Used by unit tests to reset the drivers."""
+ global _drivers
+ _drivers = None
diff --git a/keystone/openstack/common/notifier/log_notifier.py b/keystone/openstack/common/notifier/log_notifier.py
new file mode 100644
index 00000000..fcf1f98e
--- /dev/null
+++ b/keystone/openstack/common/notifier/log_notifier.py
@@ -0,0 +1,37 @@
+# Copyright 2011 OpenStack Foundation.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from oslo.config import cfg
+
+from keystone.openstack.common import jsonutils
+from keystone.openstack.common import log as logging
+
+
+CONF = cfg.CONF
+
+
+def notify(_context, message):
+ """Notifies the recipient of the desired event given the model.
+
+ Log notifications using OpenStack's default logging system.
+ """
+
+ priority = message.get('priority',
+ CONF.default_notification_level)
+ priority = priority.lower()
+ logger = logging.getLogger(
+ 'keystone.openstack.common.notification.%s' %
+ message['event_type'])
+ getattr(logger, priority)(jsonutils.dumps(message))
diff --git a/keystone/openstack/common/notifier/no_op_notifier.py b/keystone/openstack/common/notifier/no_op_notifier.py
new file mode 100644
index 00000000..13d946e3
--- /dev/null
+++ b/keystone/openstack/common/notifier/no_op_notifier.py
@@ -0,0 +1,19 @@
+# Copyright 2011 OpenStack Foundation.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+
+def notify(_context, message):
+ """Notifies the recipient of the desired event given the model."""
+ pass
diff --git a/keystone/openstack/common/notifier/rpc_notifier.py b/keystone/openstack/common/notifier/rpc_notifier.py
new file mode 100644
index 00000000..dad3bef5
--- /dev/null
+++ b/keystone/openstack/common/notifier/rpc_notifier.py
@@ -0,0 +1,46 @@
+# Copyright 2011 OpenStack Foundation.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from oslo.config import cfg
+
+from keystone.openstack.common import context as req_context
+from keystone.openstack.common.gettextutils import _ # noqa
+from keystone.openstack.common import log as logging
+from keystone.openstack.common import rpc
+
+LOG = logging.getLogger(__name__)
+
+notification_topic_opt = cfg.ListOpt(
+ 'notification_topics', default=['notifications', ],
+ help='AMQP topic used for OpenStack notifications')
+
+CONF = cfg.CONF
+CONF.register_opt(notification_topic_opt)
+
+
+def notify(context, message):
+ """Sends a notification via RPC."""
+ if not context:
+ context = req_context.get_admin_context()
+ priority = message.get('priority',
+ CONF.default_notification_level)
+ priority = priority.lower()
+ for topic in CONF.notification_topics:
+ topic = '%s.%s' % (topic, priority)
+ try:
+ rpc.notify(context, topic, message)
+ except Exception:
+ LOG.exception(_("Could not send notification to %(topic)s. "
+ "Payload=%(message)s"), locals())
diff --git a/keystone/openstack/common/notifier/rpc_notifier2.py b/keystone/openstack/common/notifier/rpc_notifier2.py
new file mode 100644
index 00000000..7b77bf12
--- /dev/null
+++ b/keystone/openstack/common/notifier/rpc_notifier2.py
@@ -0,0 +1,52 @@
+# Copyright 2011 OpenStack Foundation.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+'''messaging based notification driver, with message envelopes'''
+
+from oslo.config import cfg
+
+from keystone.openstack.common import context as req_context
+from keystone.openstack.common.gettextutils import _ # noqa
+from keystone.openstack.common import log as logging
+from keystone.openstack.common import rpc
+
+LOG = logging.getLogger(__name__)
+
+notification_topic_opt = cfg.ListOpt(
+ 'topics', default=['notifications', ],
+ help='AMQP topic(s) used for OpenStack notifications')
+
+opt_group = cfg.OptGroup(name='rpc_notifier2',
+ title='Options for rpc_notifier2')
+
+CONF = cfg.CONF
+CONF.register_group(opt_group)
+CONF.register_opt(notification_topic_opt, opt_group)
+
+
+def notify(context, message):
+ """Sends a notification via RPC."""
+ if not context:
+ context = req_context.get_admin_context()
+ priority = message.get('priority',
+ CONF.default_notification_level)
+ priority = priority.lower()
+ for topic in CONF.rpc_notifier2.topics:
+ topic = '%s.%s' % (topic, priority)
+ try:
+ rpc.notify(context, topic, message, envelope=True)
+ except Exception:
+ LOG.exception(_("Could not send notification to %(topic)s. "
+ "Payload=%(message)s"), locals())
diff --git a/keystone/openstack/common/notifier/test_notifier.py b/keystone/openstack/common/notifier/test_notifier.py
new file mode 100644
index 00000000..96c1746b
--- /dev/null
+++ b/keystone/openstack/common/notifier/test_notifier.py
@@ -0,0 +1,22 @@
+# Copyright 2011 OpenStack Foundation.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+
+NOTIFICATIONS = []
+
+
+def notify(_context, message):
+ """Test notifier, stores notifications in memory for unittests."""
+ NOTIFICATIONS.append(message)
diff --git a/keystone/openstack/common/rpc/__init__.py b/keystone/openstack/common/rpc/__init__.py
new file mode 100644
index 00000000..248a7458
--- /dev/null
+++ b/keystone/openstack/common/rpc/__init__.py
@@ -0,0 +1,307 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+# Copyright 2011 Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+A remote procedure call (rpc) abstraction.
+
+For some wrappers that add message versioning to rpc, see:
+ rpc.dispatcher
+ rpc.proxy
+"""
+
+import inspect
+
+from oslo.config import cfg
+
+from keystone.openstack.common.gettextutils import _ # noqa
+from keystone.openstack.common import importutils
+from keystone.openstack.common import local
+from keystone.openstack.common import log as logging
+
+
+LOG = logging.getLogger(__name__)
+
+
+rpc_opts = [
+ cfg.StrOpt('rpc_backend',
+ default='%s.impl_kombu' % __package__,
+ help="The messaging module to use, defaults to kombu."),
+ cfg.IntOpt('rpc_thread_pool_size',
+ default=64,
+ help='Size of RPC thread pool'),
+ cfg.IntOpt('rpc_conn_pool_size',
+ default=30,
+ help='Size of RPC connection pool'),
+ cfg.IntOpt('rpc_response_timeout',
+ default=60,
+ help='Seconds to wait for a response from call or multicall'),
+ cfg.IntOpt('rpc_cast_timeout',
+ default=30,
+ help='Seconds to wait before a cast expires (TTL). '
+ 'Only supported by impl_zmq.'),
+ cfg.ListOpt('allowed_rpc_exception_modules',
+ default=['keystone.openstack.common.exception',
+ 'nova.exception',
+ 'cinder.exception',
+ 'exceptions',
+ ],
+ help='Modules of exceptions that are permitted to be recreated'
+ 'upon receiving exception data from an rpc call.'),
+ cfg.BoolOpt('fake_rabbit',
+ default=False,
+ help='If passed, use a fake RabbitMQ provider'),
+ cfg.StrOpt('control_exchange',
+ default='openstack',
+ help='AMQP exchange to connect to if using RabbitMQ or Qpid'),
+]
+
+CONF = cfg.CONF
+CONF.register_opts(rpc_opts)
+
+
+def set_defaults(control_exchange):
+ cfg.set_defaults(rpc_opts,
+ control_exchange=control_exchange)
+
+
+def create_connection(new=True):
+ """Create a connection to the message bus used for rpc.
+
+ For some example usage of creating a connection and some consumers on that
+ connection, see nova.service.
+
+ :param new: Whether or not to create a new connection. A new connection
+ will be created by default. If new is False, the
+ implementation is free to return an existing connection from a
+ pool.
+
+ :returns: An instance of openstack.common.rpc.common.Connection
+ """
+ return _get_impl().create_connection(CONF, new=new)
+
+
+def _check_for_lock():
+ if not CONF.debug:
+ return None
+
+ if ((hasattr(local.strong_store, 'locks_held')
+ and local.strong_store.locks_held)):
+ stack = ' :: '.join([frame[3] for frame in inspect.stack()])
+ LOG.warn(_('A RPC is being made while holding a lock. The locks '
+ 'currently held are %(locks)s. This is probably a bug. '
+ 'Please report it. Include the following: [%(stack)s].'),
+ {'locks': local.strong_store.locks_held,
+ 'stack': stack})
+ return True
+
+ return False
+
+
+def call(context, topic, msg, timeout=None, check_for_lock=False):
+ """Invoke a remote method that returns something.
+
+ :param context: Information that identifies the user that has made this
+ request.
+ :param topic: The topic to send the rpc message to. This correlates to the
+ topic argument of
+ openstack.common.rpc.common.Connection.create_consumer()
+ and only applies when the consumer was created with
+ fanout=False.
+ :param msg: This is a dict in the form { "method" : "method_to_invoke",
+ "args" : dict_of_kwargs }
+ :param timeout: int, number of seconds to use for a response timeout.
+ If set, this overrides the rpc_response_timeout option.
+ :param check_for_lock: if True, a warning is emitted if a RPC call is made
+ with a lock held.
+
+ :returns: A dict from the remote method.
+
+ :raises: openstack.common.rpc.common.Timeout if a complete response
+ is not received before the timeout is reached.
+ """
+ if check_for_lock:
+ _check_for_lock()
+ return _get_impl().call(CONF, context, topic, msg, timeout)
+
+
+def cast(context, topic, msg):
+ """Invoke a remote method that does not return anything.
+
+ :param context: Information that identifies the user that has made this
+ request.
+ :param topic: The topic to send the rpc message to. This correlates to the
+ topic argument of
+ openstack.common.rpc.common.Connection.create_consumer()
+ and only applies when the consumer was created with
+ fanout=False.
+ :param msg: This is a dict in the form { "method" : "method_to_invoke",
+ "args" : dict_of_kwargs }
+
+ :returns: None
+ """
+ return _get_impl().cast(CONF, context, topic, msg)
+
+
+def fanout_cast(context, topic, msg):
+ """Broadcast a remote method invocation with no return.
+
+ This method will get invoked on all consumers that were set up with this
+ topic name and fanout=True.
+
+ :param context: Information that identifies the user that has made this
+ request.
+ :param topic: The topic to send the rpc message to. This correlates to the
+ topic argument of
+ openstack.common.rpc.common.Connection.create_consumer()
+ and only applies when the consumer was created with
+ fanout=True.
+ :param msg: This is a dict in the form { "method" : "method_to_invoke",
+ "args" : dict_of_kwargs }
+
+ :returns: None
+ """
+ return _get_impl().fanout_cast(CONF, context, topic, msg)
+
+
+def multicall(context, topic, msg, timeout=None, check_for_lock=False):
+ """Invoke a remote method and get back an iterator.
+
+ In this case, the remote method will be returning multiple values in
+ separate messages, so the return values can be processed as the come in via
+ an iterator.
+
+ :param context: Information that identifies the user that has made this
+ request.
+ :param topic: The topic to send the rpc message to. This correlates to the
+ topic argument of
+ openstack.common.rpc.common.Connection.create_consumer()
+ and only applies when the consumer was created with
+ fanout=False.
+ :param msg: This is a dict in the form { "method" : "method_to_invoke",
+ "args" : dict_of_kwargs }
+ :param timeout: int, number of seconds to use for a response timeout.
+ If set, this overrides the rpc_response_timeout option.
+ :param check_for_lock: if True, a warning is emitted if a RPC call is made
+ with a lock held.
+
+ :returns: An iterator. The iterator will yield a tuple (N, X) where N is
+ an index that starts at 0 and increases by one for each value
+ returned and X is the Nth value that was returned by the remote
+ method.
+
+ :raises: openstack.common.rpc.common.Timeout if a complete response
+ is not received before the timeout is reached.
+ """
+ if check_for_lock:
+ _check_for_lock()
+ return _get_impl().multicall(CONF, context, topic, msg, timeout)
+
+
+def notify(context, topic, msg, envelope=False):
+ """Send notification event.
+
+ :param context: Information that identifies the user that has made this
+ request.
+ :param topic: The topic to send the notification to.
+ :param msg: This is a dict of content of event.
+ :param envelope: Set to True to enable message envelope for notifications.
+
+ :returns: None
+ """
+ return _get_impl().notify(cfg.CONF, context, topic, msg, envelope)
+
+
+def cleanup():
+ """Clean up resoruces in use by implementation.
+
+ Clean up any resources that have been allocated by the RPC implementation.
+ This is typically open connections to a messaging service. This function
+ would get called before an application using this API exits to allow
+ connections to get torn down cleanly.
+
+ :returns: None
+ """
+ return _get_impl().cleanup()
+
+
+def cast_to_server(context, server_params, topic, msg):
+ """Invoke a remote method that does not return anything.
+
+ :param context: Information that identifies the user that has made this
+ request.
+ :param server_params: Connection information
+ :param topic: The topic to send the notification to.
+ :param msg: This is a dict in the form { "method" : "method_to_invoke",
+ "args" : dict_of_kwargs }
+
+ :returns: None
+ """
+ return _get_impl().cast_to_server(CONF, context, server_params, topic,
+ msg)
+
+
+def fanout_cast_to_server(context, server_params, topic, msg):
+ """Broadcast to a remote method invocation with no return.
+
+ :param context: Information that identifies the user that has made this
+ request.
+ :param server_params: Connection information
+ :param topic: The topic to send the notification to.
+ :param msg: This is a dict in the form { "method" : "method_to_invoke",
+ "args" : dict_of_kwargs }
+
+ :returns: None
+ """
+ return _get_impl().fanout_cast_to_server(CONF, context, server_params,
+ topic, msg)
+
+
+def queue_get_for(context, topic, host):
+ """Get a queue name for a given topic + host.
+
+ This function only works if this naming convention is followed on the
+ consumer side, as well. For example, in nova, every instance of the
+ nova-foo service calls create_consumer() for two topics:
+
+ foo
+ foo.<host>
+
+ Messages sent to the 'foo' topic are distributed to exactly one instance of
+ the nova-foo service. The services are chosen in a round-robin fashion.
+ Messages sent to the 'foo.<host>' topic are sent to the nova-foo service on
+ <host>.
+ """
+ return '%s.%s' % (topic, host) if host else topic
+
+
+_RPCIMPL = None
+
+
+def _get_impl():
+ """Delay import of rpc_backend until configuration is loaded."""
+ global _RPCIMPL
+ if _RPCIMPL is None:
+ try:
+ _RPCIMPL = importutils.import_module(CONF.rpc_backend)
+ except ImportError:
+ # For backwards compatibility with older nova config.
+ impl = CONF.rpc_backend.replace('nova.rpc',
+ 'nova.openstack.common.rpc')
+ _RPCIMPL = importutils.import_module(impl)
+ return _RPCIMPL
diff --git a/keystone/openstack/common/rpc/amqp.py b/keystone/openstack/common/rpc/amqp.py
new file mode 100644
index 00000000..3bcedbdb
--- /dev/null
+++ b/keystone/openstack/common/rpc/amqp.py
@@ -0,0 +1,615 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+# Copyright 2011 - 2012, Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Shared code between AMQP based openstack.common.rpc implementations.
+
+The code in this module is shared between the rpc implemenations based on AMQP.
+Specifically, this includes impl_kombu and impl_qpid. impl_carrot also uses
+AMQP, but is deprecated and predates this code.
+"""
+
+import collections
+import inspect
+import sys
+import uuid
+
+from eventlet import greenpool
+from eventlet import pools
+from eventlet import queue
+from eventlet import semaphore
+from oslo.config import cfg
+
+from keystone.openstack.common import excutils
+from keystone.openstack.common.gettextutils import _ # noqa
+from keystone.openstack.common import local
+from keystone.openstack.common import log as logging
+from keystone.openstack.common.rpc import common as rpc_common
+
+
+amqp_opts = [
+ cfg.BoolOpt('amqp_durable_queues',
+ default=False,
+ deprecated_name='rabbit_durable_queues',
+ deprecated_group='DEFAULT',
+ help='Use durable queues in amqp.'),
+ cfg.BoolOpt('amqp_auto_delete',
+ default=False,
+ help='Auto-delete queues in amqp.'),
+]
+
+cfg.CONF.register_opts(amqp_opts)
+
+UNIQUE_ID = '_unique_id'
+LOG = logging.getLogger(__name__)
+
+
+class Pool(pools.Pool):
+ """Class that implements a Pool of Connections."""
+ def __init__(self, conf, connection_cls, *args, **kwargs):
+ self.connection_cls = connection_cls
+ self.conf = conf
+ kwargs.setdefault("max_size", self.conf.rpc_conn_pool_size)
+ kwargs.setdefault("order_as_stack", True)
+ super(Pool, self).__init__(*args, **kwargs)
+ self.reply_proxy = None
+
+ # TODO(comstud): Timeout connections not used in a while
+ def create(self):
+ LOG.debug(_('Pool creating new connection'))
+ return self.connection_cls(self.conf)
+
+ def empty(self):
+ while self.free_items:
+ self.get().close()
+ # Force a new connection pool to be created.
+ # Note that this was added due to failing unit test cases. The issue
+ # is the above "while loop" gets all the cached connections from the
+ # pool and closes them, but never returns them to the pool, a pool
+ # leak. The unit tests hang waiting for an item to be returned to the
+ # pool. The unit tests get here via the tearDown() method. In the run
+ # time code, it gets here via cleanup() and only appears in service.py
+ # just before doing a sys.exit(), so cleanup() only happens once and
+ # the leakage is not a problem.
+ self.connection_cls.pool = None
+
+
+_pool_create_sem = semaphore.Semaphore()
+
+
+def get_connection_pool(conf, connection_cls):
+ with _pool_create_sem:
+ # Make sure only one thread tries to create the connection pool.
+ if not connection_cls.pool:
+ connection_cls.pool = Pool(conf, connection_cls)
+ return connection_cls.pool
+
+
+class ConnectionContext(rpc_common.Connection):
+ """The class that is actually returned to the create_connection() caller.
+
+ This is essentially a wrapper around Connection that supports 'with'.
+ It can also return a new Connection, or one from a pool.
+
+ The function will also catch when an instance of this class is to be
+ deleted. With that we can return Connections to the pool on exceptions
+ and so forth without making the caller be responsible for catching them.
+ If possible the function makes sure to return a connection to the pool.
+ """
+
+ def __init__(self, conf, connection_pool, pooled=True, server_params=None):
+ """Create a new connection, or get one from the pool."""
+ self.connection = None
+ self.conf = conf
+ self.connection_pool = connection_pool
+ if pooled:
+ self.connection = connection_pool.get()
+ else:
+ self.connection = connection_pool.connection_cls(
+ conf,
+ server_params=server_params)
+ self.pooled = pooled
+
+ def __enter__(self):
+ """When with ConnectionContext() is used, return self."""
+ return self
+
+ def _done(self):
+ """If the connection came from a pool, clean it up and put it back.
+ If it did not come from a pool, close it.
+ """
+ if self.connection:
+ if self.pooled:
+ # Reset the connection so it's ready for the next caller
+ # to grab from the pool
+ self.connection.reset()
+ self.connection_pool.put(self.connection)
+ else:
+ try:
+ self.connection.close()
+ except Exception:
+ pass
+ self.connection = None
+
+ def __exit__(self, exc_type, exc_value, tb):
+ """End of 'with' statement. We're done here."""
+ self._done()
+
+ def __del__(self):
+ """Caller is done with this connection. Make sure we cleaned up."""
+ self._done()
+
+ def close(self):
+ """Caller is done with this connection."""
+ self._done()
+
+ def create_consumer(self, topic, proxy, fanout=False):
+ self.connection.create_consumer(topic, proxy, fanout)
+
+ def create_worker(self, topic, proxy, pool_name):
+ self.connection.create_worker(topic, proxy, pool_name)
+
+ def join_consumer_pool(self, callback, pool_name, topic, exchange_name,
+ ack_on_error=True):
+ self.connection.join_consumer_pool(callback,
+ pool_name,
+ topic,
+ exchange_name,
+ ack_on_error)
+
+ def consume_in_thread(self):
+ self.connection.consume_in_thread()
+
+ def __getattr__(self, key):
+ """Proxy all other calls to the Connection instance."""
+ if self.connection:
+ return getattr(self.connection, key)
+ else:
+ raise rpc_common.InvalidRPCConnectionReuse()
+
+
+class ReplyProxy(ConnectionContext):
+ """Connection class for RPC replies / callbacks."""
+ def __init__(self, conf, connection_pool):
+ self._call_waiters = {}
+ self._num_call_waiters = 0
+ self._num_call_waiters_wrn_threshhold = 10
+ self._reply_q = 'reply_' + uuid.uuid4().hex
+ super(ReplyProxy, self).__init__(conf, connection_pool, pooled=False)
+ self.declare_direct_consumer(self._reply_q, self._process_data)
+ self.consume_in_thread()
+
+ def _process_data(self, message_data):
+ msg_id = message_data.pop('_msg_id', None)
+ waiter = self._call_waiters.get(msg_id)
+ if not waiter:
+ LOG.warn(_('No calling threads waiting for msg_id : %(msg_id)s'
+ ', message : %(data)s'), {'msg_id': msg_id,
+ 'data': message_data})
+ LOG.warn(_('_call_waiters: %s') % str(self._call_waiters))
+ else:
+ waiter.put(message_data)
+
+ def add_call_waiter(self, waiter, msg_id):
+ self._num_call_waiters += 1
+ if self._num_call_waiters > self._num_call_waiters_wrn_threshhold:
+ LOG.warn(_('Number of call waiters is greater than warning '
+ 'threshhold: %d. There could be a MulticallProxyWaiter '
+ 'leak.') % self._num_call_waiters_wrn_threshhold)
+ self._num_call_waiters_wrn_threshhold *= 2
+ self._call_waiters[msg_id] = waiter
+
+ def del_call_waiter(self, msg_id):
+ self._num_call_waiters -= 1
+ del self._call_waiters[msg_id]
+
+ def get_reply_q(self):
+ return self._reply_q
+
+
+def msg_reply(conf, msg_id, reply_q, connection_pool, reply=None,
+ failure=None, ending=False, log_failure=True):
+ """Sends a reply or an error on the channel signified by msg_id.
+
+ Failure should be a sys.exc_info() tuple.
+
+ """
+ with ConnectionContext(conf, connection_pool) as conn:
+ if failure:
+ failure = rpc_common.serialize_remote_exception(failure,
+ log_failure)
+
+ msg = {'result': reply, 'failure': failure}
+ if ending:
+ msg['ending'] = True
+ _add_unique_id(msg)
+ # If a reply_q exists, add the msg_id to the reply and pass the
+ # reply_q to direct_send() to use it as the response queue.
+ # Otherwise use the msg_id for backward compatibilty.
+ if reply_q:
+ msg['_msg_id'] = msg_id
+ conn.direct_send(reply_q, rpc_common.serialize_msg(msg))
+ else:
+ conn.direct_send(msg_id, rpc_common.serialize_msg(msg))
+
+
+class RpcContext(rpc_common.CommonRpcContext):
+ """Context that supports replying to a rpc.call."""
+ def __init__(self, **kwargs):
+ self.msg_id = kwargs.pop('msg_id', None)
+ self.reply_q = kwargs.pop('reply_q', None)
+ self.conf = kwargs.pop('conf')
+ super(RpcContext, self).__init__(**kwargs)
+
+ def deepcopy(self):
+ values = self.to_dict()
+ values['conf'] = self.conf
+ values['msg_id'] = self.msg_id
+ values['reply_q'] = self.reply_q
+ return self.__class__(**values)
+
+ def reply(self, reply=None, failure=None, ending=False,
+ connection_pool=None, log_failure=True):
+ if self.msg_id:
+ msg_reply(self.conf, self.msg_id, self.reply_q, connection_pool,
+ reply, failure, ending, log_failure)
+ if ending:
+ self.msg_id = None
+
+
+def unpack_context(conf, msg):
+ """Unpack context from msg."""
+ context_dict = {}
+ for key in list(msg.keys()):
+ # NOTE(vish): Some versions of python don't like unicode keys
+ # in kwargs.
+ key = str(key)
+ if key.startswith('_context_'):
+ value = msg.pop(key)
+ context_dict[key[9:]] = value
+ context_dict['msg_id'] = msg.pop('_msg_id', None)
+ context_dict['reply_q'] = msg.pop('_reply_q', None)
+ context_dict['conf'] = conf
+ ctx = RpcContext.from_dict(context_dict)
+ rpc_common._safe_log(LOG.debug, _('unpacked context: %s'), ctx.to_dict())
+ return ctx
+
+
+def pack_context(msg, context):
+ """Pack context into msg.
+
+ Values for message keys need to be less than 255 chars, so we pull
+ context out into a bunch of separate keys. If we want to support
+ more arguments in rabbit messages, we may want to do the same
+ for args at some point.
+
+ """
+ if isinstance(context, dict):
+ context_d = dict([('_context_%s' % key, value)
+ for (key, value) in context.iteritems()])
+ else:
+ context_d = dict([('_context_%s' % key, value)
+ for (key, value) in context.to_dict().iteritems()])
+
+ msg.update(context_d)
+
+
+class _MsgIdCache(object):
+ """This class checks any duplicate messages."""
+
+ # NOTE: This value is considered can be a configuration item, but
+ # it is not necessary to change its value in most cases,
+ # so let this value as static for now.
+ DUP_MSG_CHECK_SIZE = 16
+
+ def __init__(self, **kwargs):
+ self.prev_msgids = collections.deque([],
+ maxlen=self.DUP_MSG_CHECK_SIZE)
+
+ def check_duplicate_message(self, message_data):
+ """AMQP consumers may read same message twice when exceptions occur
+ before ack is returned. This method prevents doing it.
+ """
+ if UNIQUE_ID in message_data:
+ msg_id = message_data[UNIQUE_ID]
+ if msg_id not in self.prev_msgids:
+ self.prev_msgids.append(msg_id)
+ else:
+ raise rpc_common.DuplicateMessageError(msg_id=msg_id)
+
+
+def _add_unique_id(msg):
+ """Add unique_id for checking duplicate messages."""
+ unique_id = uuid.uuid4().hex
+ msg.update({UNIQUE_ID: unique_id})
+ LOG.debug(_('UNIQUE_ID is %s.') % (unique_id))
+
+
+class _ThreadPoolWithWait(object):
+ """Base class for a delayed invocation manager.
+
+ Used by the Connection class to start up green threads
+ to handle incoming messages.
+ """
+
+ def __init__(self, conf, connection_pool):
+ self.pool = greenpool.GreenPool(conf.rpc_thread_pool_size)
+ self.connection_pool = connection_pool
+ self.conf = conf
+
+ def wait(self):
+ """Wait for all callback threads to exit."""
+ self.pool.waitall()
+
+
+class CallbackWrapper(_ThreadPoolWithWait):
+ """Wraps a straight callback.
+
+ Allows it to be invoked in a green thread.
+ """
+
+ def __init__(self, conf, callback, connection_pool):
+ """Initiates CallbackWrapper object.
+
+ :param conf: cfg.CONF instance
+ :param callback: a callable (probably a function)
+ :param connection_pool: connection pool as returned by
+ get_connection_pool()
+ """
+ super(CallbackWrapper, self).__init__(
+ conf=conf,
+ connection_pool=connection_pool,
+ )
+ self.callback = callback
+
+ def __call__(self, message_data):
+ self.pool.spawn_n(self.callback, message_data)
+
+
+class ProxyCallback(_ThreadPoolWithWait):
+ """Calls methods on a proxy object based on method and args."""
+
+ def __init__(self, conf, proxy, connection_pool):
+ super(ProxyCallback, self).__init__(
+ conf=conf,
+ connection_pool=connection_pool,
+ )
+ self.proxy = proxy
+ self.msg_id_cache = _MsgIdCache()
+
+ def __call__(self, message_data):
+ """Consumer callback to call a method on a proxy object.
+
+ Parses the message for validity and fires off a thread to call the
+ proxy object method.
+
+ Message data should be a dictionary with two keys:
+ method: string representing the method to call
+ args: dictionary of arg: value
+
+ Example: {'method': 'echo', 'args': {'value': 42}}
+
+ """
+ # It is important to clear the context here, because at this point
+ # the previous context is stored in local.store.context
+ if hasattr(local.store, 'context'):
+ del local.store.context
+ rpc_common._safe_log(LOG.debug, _('received %s'), message_data)
+ self.msg_id_cache.check_duplicate_message(message_data)
+ ctxt = unpack_context(self.conf, message_data)
+ method = message_data.get('method')
+ args = message_data.get('args', {})
+ version = message_data.get('version')
+ namespace = message_data.get('namespace')
+ if not method:
+ LOG.warn(_('no method for message: %s') % message_data)
+ ctxt.reply(_('No method for message: %s') % message_data,
+ connection_pool=self.connection_pool)
+ return
+ self.pool.spawn_n(self._process_data, ctxt, version, method,
+ namespace, args)
+
+ def _process_data(self, ctxt, version, method, namespace, args):
+ """Process a message in a new thread.
+
+ If the proxy object we have has a dispatch method
+ (see rpc.dispatcher.RpcDispatcher), pass it the version,
+ method, and args and let it dispatch as appropriate. If not, use
+ the old behavior of magically calling the specified method on the
+ proxy we have here.
+ """
+ ctxt.update_store()
+ try:
+ rval = self.proxy.dispatch(ctxt, version, method, namespace,
+ **args)
+ # Check if the result was a generator
+ if inspect.isgenerator(rval):
+ for x in rval:
+ ctxt.reply(x, None, connection_pool=self.connection_pool)
+ else:
+ ctxt.reply(rval, None, connection_pool=self.connection_pool)
+ # This final None tells multicall that it is done.
+ ctxt.reply(ending=True, connection_pool=self.connection_pool)
+ except rpc_common.ClientException as e:
+ LOG.debug(_('Expected exception during message handling (%s)') %
+ e._exc_info[1])
+ ctxt.reply(None, e._exc_info,
+ connection_pool=self.connection_pool,
+ log_failure=False)
+ except Exception:
+ # sys.exc_info() is deleted by LOG.exception().
+ exc_info = sys.exc_info()
+ LOG.error(_('Exception during message handling'),
+ exc_info=exc_info)
+ ctxt.reply(None, exc_info, connection_pool=self.connection_pool)
+
+
+class MulticallProxyWaiter(object):
+ def __init__(self, conf, msg_id, timeout, connection_pool):
+ self._msg_id = msg_id
+ self._timeout = timeout or conf.rpc_response_timeout
+ self._reply_proxy = connection_pool.reply_proxy
+ self._done = False
+ self._got_ending = False
+ self._conf = conf
+ self._dataqueue = queue.LightQueue()
+ # Add this caller to the reply proxy's call_waiters
+ self._reply_proxy.add_call_waiter(self, self._msg_id)
+ self.msg_id_cache = _MsgIdCache()
+
+ def put(self, data):
+ self._dataqueue.put(data)
+
+ def done(self):
+ if self._done:
+ return
+ self._done = True
+ # Remove this caller from reply proxy's call_waiters
+ self._reply_proxy.del_call_waiter(self._msg_id)
+
+ def _process_data(self, data):
+ result = None
+ self.msg_id_cache.check_duplicate_message(data)
+ if data['failure']:
+ failure = data['failure']
+ result = rpc_common.deserialize_remote_exception(self._conf,
+ failure)
+ elif data.get('ending', False):
+ self._got_ending = True
+ else:
+ result = data['result']
+ return result
+
+ def __iter__(self):
+ """Return a result until we get a reply with an 'ending' flag."""
+ if self._done:
+ raise StopIteration
+ while True:
+ try:
+ data = self._dataqueue.get(timeout=self._timeout)
+ result = self._process_data(data)
+ except queue.Empty:
+ self.done()
+ raise rpc_common.Timeout()
+ except Exception:
+ with excutils.save_and_reraise_exception():
+ self.done()
+ if self._got_ending:
+ self.done()
+ raise StopIteration
+ if isinstance(result, Exception):
+ self.done()
+ raise result
+ yield result
+
+
+def create_connection(conf, new, connection_pool):
+ """Create a connection."""
+ return ConnectionContext(conf, connection_pool, pooled=not new)
+
+
+_reply_proxy_create_sem = semaphore.Semaphore()
+
+
+def multicall(conf, context, topic, msg, timeout, connection_pool):
+ """Make a call that returns multiple times."""
+ LOG.debug(_('Making synchronous call on %s ...'), topic)
+ msg_id = uuid.uuid4().hex
+ msg.update({'_msg_id': msg_id})
+ LOG.debug(_('MSG_ID is %s') % (msg_id))
+ _add_unique_id(msg)
+ pack_context(msg, context)
+
+ with _reply_proxy_create_sem:
+ if not connection_pool.reply_proxy:
+ connection_pool.reply_proxy = ReplyProxy(conf, connection_pool)
+ msg.update({'_reply_q': connection_pool.reply_proxy.get_reply_q()})
+ wait_msg = MulticallProxyWaiter(conf, msg_id, timeout, connection_pool)
+ with ConnectionContext(conf, connection_pool) as conn:
+ conn.topic_send(topic, rpc_common.serialize_msg(msg), timeout)
+ return wait_msg
+
+
+def call(conf, context, topic, msg, timeout, connection_pool):
+ """Sends a message on a topic and wait for a response."""
+ rv = multicall(conf, context, topic, msg, timeout, connection_pool)
+ # NOTE(vish): return the last result from the multicall
+ rv = list(rv)
+ if not rv:
+ return
+ return rv[-1]
+
+
+def cast(conf, context, topic, msg, connection_pool):
+ """Sends a message on a topic without waiting for a response."""
+ LOG.debug(_('Making asynchronous cast on %s...'), topic)
+ _add_unique_id(msg)
+ pack_context(msg, context)
+ with ConnectionContext(conf, connection_pool) as conn:
+ conn.topic_send(topic, rpc_common.serialize_msg(msg))
+
+
+def fanout_cast(conf, context, topic, msg, connection_pool):
+ """Sends a message on a fanout exchange without waiting for a response."""
+ LOG.debug(_('Making asynchronous fanout cast...'))
+ _add_unique_id(msg)
+ pack_context(msg, context)
+ with ConnectionContext(conf, connection_pool) as conn:
+ conn.fanout_send(topic, rpc_common.serialize_msg(msg))
+
+
+def cast_to_server(conf, context, server_params, topic, msg, connection_pool):
+ """Sends a message on a topic to a specific server."""
+ _add_unique_id(msg)
+ pack_context(msg, context)
+ with ConnectionContext(conf, connection_pool, pooled=False,
+ server_params=server_params) as conn:
+ conn.topic_send(topic, rpc_common.serialize_msg(msg))
+
+
+def fanout_cast_to_server(conf, context, server_params, topic, msg,
+ connection_pool):
+ """Sends a message on a fanout exchange to a specific server."""
+ _add_unique_id(msg)
+ pack_context(msg, context)
+ with ConnectionContext(conf, connection_pool, pooled=False,
+ server_params=server_params) as conn:
+ conn.fanout_send(topic, rpc_common.serialize_msg(msg))
+
+
+def notify(conf, context, topic, msg, connection_pool, envelope):
+ """Sends a notification event on a topic."""
+ LOG.debug(_('Sending %(event_type)s on %(topic)s'),
+ dict(event_type=msg.get('event_type'),
+ topic=topic))
+ _add_unique_id(msg)
+ pack_context(msg, context)
+ with ConnectionContext(conf, connection_pool) as conn:
+ if envelope:
+ msg = rpc_common.serialize_msg(msg)
+ conn.notify_send(topic, msg)
+
+
+def cleanup(connection_pool):
+ if connection_pool:
+ connection_pool.empty()
+
+
+def get_control_exchange(conf):
+ return conf.control_exchange
diff --git a/keystone/openstack/common/rpc/common.py b/keystone/openstack/common/rpc/common.py
new file mode 100644
index 00000000..3696f0fb
--- /dev/null
+++ b/keystone/openstack/common/rpc/common.py
@@ -0,0 +1,509 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+# Copyright 2011 Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import copy
+import sys
+import traceback
+
+from oslo.config import cfg
+import six
+
+from keystone.openstack.common.gettextutils import _ # noqa
+from keystone.openstack.common import importutils
+from keystone.openstack.common import jsonutils
+from keystone.openstack.common import local
+from keystone.openstack.common import log as logging
+
+
+CONF = cfg.CONF
+LOG = logging.getLogger(__name__)
+
+
+'''RPC Envelope Version.
+
+This version number applies to the top level structure of messages sent out.
+It does *not* apply to the message payload, which must be versioned
+independently. For example, when using rpc APIs, a version number is applied
+for changes to the API being exposed over rpc. This version number is handled
+in the rpc proxy and dispatcher modules.
+
+This version number applies to the message envelope that is used in the
+serialization done inside the rpc layer. See serialize_msg() and
+deserialize_msg().
+
+The current message format (version 2.0) is very simple. It is:
+
+ {
+ 'oslo.version': <RPC Envelope Version as a String>,
+ 'oslo.message': <Application Message Payload, JSON encoded>
+ }
+
+Message format version '1.0' is just considered to be the messages we sent
+without a message envelope.
+
+So, the current message envelope just includes the envelope version. It may
+eventually contain additional information, such as a signature for the message
+payload.
+
+We will JSON encode the application message payload. The message envelope,
+which includes the JSON encoded application message body, will be passed down
+to the messaging libraries as a dict.
+'''
+_RPC_ENVELOPE_VERSION = '2.0'
+
+_VERSION_KEY = 'oslo.version'
+_MESSAGE_KEY = 'oslo.message'
+
+_REMOTE_POSTFIX = '_Remote'
+
+
+class RPCException(Exception):
+ msg_fmt = _("An unknown RPC related exception occurred.")
+
+ def __init__(self, message=None, **kwargs):
+ self.kwargs = kwargs
+
+ if not message:
+ try:
+ message = self.msg_fmt % kwargs
+
+ except Exception:
+ # kwargs doesn't match a variable in the message
+ # log the issue and the kwargs
+ LOG.exception(_('Exception in string format operation'))
+ for name, value in kwargs.iteritems():
+ LOG.error("%s: %s" % (name, value))
+ # at least get the core message out if something happened
+ message = self.msg_fmt
+
+ super(RPCException, self).__init__(message)
+
+
+class RemoteError(RPCException):
+ """Signifies that a remote class has raised an exception.
+
+ Contains a string representation of the type of the original exception,
+ the value of the original exception, and the traceback. These are
+ sent to the parent as a joined string so printing the exception
+ contains all of the relevant info.
+
+ """
+ msg_fmt = _("Remote error: %(exc_type)s %(value)s\n%(traceback)s.")
+
+ def __init__(self, exc_type=None, value=None, traceback=None):
+ self.exc_type = exc_type
+ self.value = value
+ self.traceback = traceback
+ super(RemoteError, self).__init__(exc_type=exc_type,
+ value=value,
+ traceback=traceback)
+
+
+class Timeout(RPCException):
+ """Signifies that a timeout has occurred.
+
+ This exception is raised if the rpc_response_timeout is reached while
+ waiting for a response from the remote side.
+ """
+ msg_fmt = _('Timeout while waiting on RPC response - '
+ 'topic: "%(topic)s", RPC method: "%(method)s" '
+ 'info: "%(info)s"')
+
+ def __init__(self, info=None, topic=None, method=None):
+ """Initiates Timeout object.
+
+ :param info: Extra info to convey to the user
+ :param topic: The topic that the rpc call was sent to
+ :param rpc_method_name: The name of the rpc method being
+ called
+ """
+ self.info = info
+ self.topic = topic
+ self.method = method
+ super(Timeout, self).__init__(
+ None,
+ info=info or _('<unknown>'),
+ topic=topic or _('<unknown>'),
+ method=method or _('<unknown>'))
+
+
+class DuplicateMessageError(RPCException):
+ msg_fmt = _("Found duplicate message(%(msg_id)s). Skipping it.")
+
+
+class InvalidRPCConnectionReuse(RPCException):
+ msg_fmt = _("Invalid reuse of an RPC connection.")
+
+
+class UnsupportedRpcVersion(RPCException):
+ msg_fmt = _("Specified RPC version, %(version)s, not supported by "
+ "this endpoint.")
+
+
+class UnsupportedRpcEnvelopeVersion(RPCException):
+ msg_fmt = _("Specified RPC envelope version, %(version)s, "
+ "not supported by this endpoint.")
+
+
+class RpcVersionCapError(RPCException):
+ msg_fmt = _("Specified RPC version cap, %(version_cap)s, is too low")
+
+
+class Connection(object):
+ """A connection, returned by rpc.create_connection().
+
+ This class represents a connection to the message bus used for rpc.
+ An instance of this class should never be created by users of the rpc API.
+ Use rpc.create_connection() instead.
+ """
+ def close(self):
+ """Close the connection.
+
+ This method must be called when the connection will no longer be used.
+ It will ensure that any resources associated with the connection, such
+ as a network connection, and cleaned up.
+ """
+ raise NotImplementedError()
+
+ def create_consumer(self, topic, proxy, fanout=False):
+ """Create a consumer on this connection.
+
+ A consumer is associated with a message queue on the backend message
+ bus. The consumer will read messages from the queue, unpack them, and
+ dispatch them to the proxy object. The contents of the message pulled
+ off of the queue will determine which method gets called on the proxy
+ object.
+
+ :param topic: This is a name associated with what to consume from.
+ Multiple instances of a service may consume from the same
+ topic. For example, all instances of nova-compute consume
+ from a queue called "compute". In that case, the
+ messages will get distributed amongst the consumers in a
+ round-robin fashion if fanout=False. If fanout=True,
+ every consumer associated with this topic will get a
+ copy of every message.
+ :param proxy: The object that will handle all incoming messages.
+ :param fanout: Whether or not this is a fanout topic. See the
+ documentation for the topic parameter for some
+ additional comments on this.
+ """
+ raise NotImplementedError()
+
+ def create_worker(self, topic, proxy, pool_name):
+ """Create a worker on this connection.
+
+ A worker is like a regular consumer of messages directed to a
+ topic, except that it is part of a set of such consumers (the
+ "pool") which may run in parallel. Every pool of workers will
+ receive a given message, but only one worker in the pool will
+ be asked to process it. Load is distributed across the members
+ of the pool in round-robin fashion.
+
+ :param topic: This is a name associated with what to consume from.
+ Multiple instances of a service may consume from the same
+ topic.
+ :param proxy: The object that will handle all incoming messages.
+ :param pool_name: String containing the name of the pool of workers
+ """
+ raise NotImplementedError()
+
+ def join_consumer_pool(self, callback, pool_name, topic, exchange_name):
+ """Register as a member of a group of consumers.
+
+ Uses given topic from the specified exchange.
+ Exactly one member of a given pool will receive each message.
+
+ A message will be delivered to multiple pools, if more than
+ one is created.
+
+ :param callback: Callable to be invoked for each message.
+ :type callback: callable accepting one argument
+ :param pool_name: The name of the consumer pool.
+ :type pool_name: str
+ :param topic: The routing topic for desired messages.
+ :type topic: str
+ :param exchange_name: The name of the message exchange where
+ the client should attach. Defaults to
+ the configured exchange.
+ :type exchange_name: str
+ """
+ raise NotImplementedError()
+
+ def consume_in_thread(self):
+ """Spawn a thread to handle incoming messages.
+
+ Spawn a thread that will be responsible for handling all incoming
+ messages for consumers that were set up on this connection.
+
+ Message dispatching inside of this is expected to be implemented in a
+ non-blocking manner. An example implementation would be having this
+ thread pull messages in for all of the consumers, but utilize a thread
+ pool for dispatching the messages to the proxy objects.
+ """
+ raise NotImplementedError()
+
+
+def _safe_log(log_func, msg, msg_data):
+ """Sanitizes the msg_data field before logging."""
+ SANITIZE = ['_context_auth_token', 'auth_token', 'new_pass']
+
+ def _fix_passwords(d):
+ """Sanitizes the password fields in the dictionary."""
+ for k in d.iterkeys():
+ if k.lower().find('password') != -1:
+ d[k] = '<SANITIZED>'
+ elif k.lower() in SANITIZE:
+ d[k] = '<SANITIZED>'
+ elif isinstance(d[k], dict):
+ _fix_passwords(d[k])
+ return d
+
+ return log_func(msg, _fix_passwords(copy.deepcopy(msg_data)))
+
+
+def serialize_remote_exception(failure_info, log_failure=True):
+ """Prepares exception data to be sent over rpc.
+
+ Failure_info should be a sys.exc_info() tuple.
+
+ """
+ tb = traceback.format_exception(*failure_info)
+ failure = failure_info[1]
+ if log_failure:
+ LOG.error(_("Returning exception %s to caller"),
+ six.text_type(failure))
+ LOG.error(tb)
+
+ kwargs = {}
+ if hasattr(failure, 'kwargs'):
+ kwargs = failure.kwargs
+
+ # NOTE(matiu): With cells, it's possible to re-raise remote, remote
+ # exceptions. Lets turn it back into the original exception type.
+ cls_name = str(failure.__class__.__name__)
+ mod_name = str(failure.__class__.__module__)
+ if (cls_name.endswith(_REMOTE_POSTFIX) and
+ mod_name.endswith(_REMOTE_POSTFIX)):
+ cls_name = cls_name[:-len(_REMOTE_POSTFIX)]
+ mod_name = mod_name[:-len(_REMOTE_POSTFIX)]
+
+ data = {
+ 'class': cls_name,
+ 'module': mod_name,
+ 'message': six.text_type(failure),
+ 'tb': tb,
+ 'args': failure.args,
+ 'kwargs': kwargs
+ }
+
+ json_data = jsonutils.dumps(data)
+
+ return json_data
+
+
+def deserialize_remote_exception(conf, data):
+ failure = jsonutils.loads(str(data))
+
+ trace = failure.get('tb', [])
+ message = failure.get('message', "") + "\n" + "\n".join(trace)
+ name = failure.get('class')
+ module = failure.get('module')
+
+ # NOTE(ameade): We DO NOT want to allow just any module to be imported, in
+ # order to prevent arbitrary code execution.
+ if module not in conf.allowed_rpc_exception_modules:
+ return RemoteError(name, failure.get('message'), trace)
+
+ try:
+ mod = importutils.import_module(module)
+ klass = getattr(mod, name)
+ if not issubclass(klass, Exception):
+ raise TypeError("Can only deserialize Exceptions")
+
+ failure = klass(*failure.get('args', []), **failure.get('kwargs', {}))
+ except (AttributeError, TypeError, ImportError):
+ return RemoteError(name, failure.get('message'), trace)
+
+ ex_type = type(failure)
+ str_override = lambda self: message
+ new_ex_type = type(ex_type.__name__ + _REMOTE_POSTFIX, (ex_type,),
+ {'__str__': str_override, '__unicode__': str_override})
+ new_ex_type.__module__ = '%s%s' % (module, _REMOTE_POSTFIX)
+ try:
+ # NOTE(ameade): Dynamically create a new exception type and swap it in
+ # as the new type for the exception. This only works on user defined
+ # Exceptions and not core python exceptions. This is important because
+ # we cannot necessarily change an exception message so we must override
+ # the __str__ method.
+ failure.__class__ = new_ex_type
+ except TypeError:
+ # NOTE(ameade): If a core exception then just add the traceback to the
+ # first exception argument.
+ failure.args = (message,) + failure.args[1:]
+ return failure
+
+
+class CommonRpcContext(object):
+ def __init__(self, **kwargs):
+ self.values = kwargs
+
+ def __getattr__(self, key):
+ try:
+ return self.values[key]
+ except KeyError:
+ raise AttributeError(key)
+
+ def to_dict(self):
+ return copy.deepcopy(self.values)
+
+ @classmethod
+ def from_dict(cls, values):
+ return cls(**values)
+
+ def deepcopy(self):
+ return self.from_dict(self.to_dict())
+
+ def update_store(self):
+ local.store.context = self
+
+ def elevated(self, read_deleted=None, overwrite=False):
+ """Return a version of this context with admin flag set."""
+ # TODO(russellb) This method is a bit of a nova-ism. It makes
+ # some assumptions about the data in the request context sent
+ # across rpc, while the rest of this class does not. We could get
+ # rid of this if we changed the nova code that uses this to
+ # convert the RpcContext back to its native RequestContext doing
+ # something like nova.context.RequestContext.from_dict(ctxt.to_dict())
+
+ context = self.deepcopy()
+ context.values['is_admin'] = True
+
+ context.values.setdefault('roles', [])
+
+ if 'admin' not in context.values['roles']:
+ context.values['roles'].append('admin')
+
+ if read_deleted is not None:
+ context.values['read_deleted'] = read_deleted
+
+ return context
+
+
+class ClientException(Exception):
+ """Encapsulates actual exception expected to be hit by a RPC proxy object.
+
+ Merely instantiating it records the current exception information, which
+ will be passed back to the RPC client without exceptional logging.
+ """
+ def __init__(self):
+ self._exc_info = sys.exc_info()
+
+
+def catch_client_exception(exceptions, func, *args, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except Exception as e:
+ if type(e) in exceptions:
+ raise ClientException()
+ else:
+ raise
+
+
+def client_exceptions(*exceptions):
+ """Decorator for manager methods that raise expected exceptions.
+
+ Marking a Manager method with this decorator allows the declaration
+ of expected exceptions that the RPC layer should not consider fatal,
+ and not log as if they were generated in a real error scenario. Note
+ that this will cause listed exceptions to be wrapped in a
+ ClientException, which is used internally by the RPC layer.
+ """
+ def outer(func):
+ def inner(*args, **kwargs):
+ return catch_client_exception(exceptions, func, *args, **kwargs)
+ return inner
+ return outer
+
+
+def version_is_compatible(imp_version, version):
+ """Determine whether versions are compatible.
+
+ :param imp_version: The version implemented
+ :param version: The version requested by an incoming message.
+ """
+ version_parts = version.split('.')
+ imp_version_parts = imp_version.split('.')
+ if int(version_parts[0]) != int(imp_version_parts[0]): # Major
+ return False
+ if int(version_parts[1]) > int(imp_version_parts[1]): # Minor
+ return False
+ return True
+
+
+def serialize_msg(raw_msg):
+ # NOTE(russellb) See the docstring for _RPC_ENVELOPE_VERSION for more
+ # information about this format.
+ msg = {_VERSION_KEY: _RPC_ENVELOPE_VERSION,
+ _MESSAGE_KEY: jsonutils.dumps(raw_msg)}
+
+ return msg
+
+
+def deserialize_msg(msg):
+ # NOTE(russellb): Hang on to your hats, this road is about to
+ # get a little bumpy.
+ #
+ # Robustness Principle:
+ # "Be strict in what you send, liberal in what you accept."
+ #
+ # At this point we have to do a bit of guessing about what it
+ # is we just received. Here is the set of possibilities:
+ #
+ # 1) We received a dict. This could be 2 things:
+ #
+ # a) Inspect it to see if it looks like a standard message envelope.
+ # If so, great!
+ #
+ # b) If it doesn't look like a standard message envelope, it could either
+ # be a notification, or a message from before we added a message
+ # envelope (referred to as version 1.0).
+ # Just return the message as-is.
+ #
+ # 2) It's any other non-dict type. Just return it and hope for the best.
+ # This case covers return values from rpc.call() from before message
+ # envelopes were used. (messages to call a method were always a dict)
+
+ if not isinstance(msg, dict):
+ # See #2 above.
+ return msg
+
+ base_envelope_keys = (_VERSION_KEY, _MESSAGE_KEY)
+ if not all(map(lambda key: key in msg, base_envelope_keys)):
+ # See #1.b above.
+ return msg
+
+ # At this point we think we have the message envelope
+ # format we were expecting. (#1.a above)
+
+ if not version_is_compatible(_RPC_ENVELOPE_VERSION, msg[_VERSION_KEY]):
+ raise UnsupportedRpcEnvelopeVersion(version=msg[_VERSION_KEY])
+
+ raw_msg = jsonutils.loads(msg[_MESSAGE_KEY])
+
+ return raw_msg
diff --git a/keystone/openstack/common/rpc/dispatcher.py b/keystone/openstack/common/rpc/dispatcher.py
new file mode 100644
index 00000000..d2fd5dc5
--- /dev/null
+++ b/keystone/openstack/common/rpc/dispatcher.py
@@ -0,0 +1,178 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012 Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Code for rpc message dispatching.
+
+Messages that come in have a version number associated with them. RPC API
+version numbers are in the form:
+
+ Major.Minor
+
+For a given message with version X.Y, the receiver must be marked as able to
+handle messages of version A.B, where:
+
+ A = X
+
+ B >= Y
+
+The Major version number would be incremented for an almost completely new API.
+The Minor version number would be incremented for backwards compatible changes
+to an existing API. A backwards compatible change could be something like
+adding a new method, adding an argument to an existing method (but not
+requiring it), or changing the type for an existing argument (but still
+handling the old type as well).
+
+The conversion over to a versioned API must be done on both the client side and
+server side of the API at the same time. However, as the code stands today,
+there can be both versioned and unversioned APIs implemented in the same code
+base.
+
+EXAMPLES
+========
+
+Nova was the first project to use versioned rpc APIs. Consider the compute rpc
+API as an example. The client side is in nova/compute/rpcapi.py and the server
+side is in nova/compute/manager.py.
+
+
+Example 1) Adding a new method.
+-------------------------------
+
+Adding a new method is a backwards compatible change. It should be added to
+nova/compute/manager.py, and RPC_API_VERSION should be bumped from X.Y to
+X.Y+1. On the client side, the new method in nova/compute/rpcapi.py should
+have a specific version specified to indicate the minimum API version that must
+be implemented for the method to be supported. For example::
+
+ def get_host_uptime(self, ctxt, host):
+ topic = _compute_topic(self.topic, ctxt, host, None)
+ return self.call(ctxt, self.make_msg('get_host_uptime'), topic,
+ version='1.1')
+
+In this case, version '1.1' is the first version that supported the
+get_host_uptime() method.
+
+
+Example 2) Adding a new parameter.
+----------------------------------
+
+Adding a new parameter to an rpc method can be made backwards compatible. The
+RPC_API_VERSION on the server side (nova/compute/manager.py) should be bumped.
+The implementation of the method must not expect the parameter to be present.::
+
+ def some_remote_method(self, arg1, arg2, newarg=None):
+ # The code needs to deal with newarg=None for cases
+ # where an older client sends a message without it.
+ pass
+
+On the client side, the same changes should be made as in example 1. The
+minimum version that supports the new parameter should be specified.
+"""
+
+from keystone.openstack.common.rpc import common as rpc_common
+from keystone.openstack.common.rpc import serializer as rpc_serializer
+
+
+class RpcDispatcher(object):
+ """Dispatch rpc messages according to the requested API version.
+
+ This class can be used as the top level 'manager' for a service. It
+ contains a list of underlying managers that have an API_VERSION attribute.
+ """
+
+ def __init__(self, callbacks, serializer=None):
+ """Initialize the rpc dispatcher.
+
+ :param callbacks: List of proxy objects that are an instance
+ of a class with rpc methods exposed. Each proxy
+ object should have an RPC_API_VERSION attribute.
+ :param serializer: The Serializer object that will be used to
+ deserialize arguments before the method call and
+ to serialize the result after it returns.
+ """
+ self.callbacks = callbacks
+ if serializer is None:
+ serializer = rpc_serializer.NoOpSerializer()
+ self.serializer = serializer
+ super(RpcDispatcher, self).__init__()
+
+ def _deserialize_args(self, context, kwargs):
+ """Helper method called to deserialize args before dispatch.
+
+ This calls our serializer on each argument, returning a new set of
+ args that have been deserialized.
+
+ :param context: The request context
+ :param kwargs: The arguments to be deserialized
+ :returns: A new set of deserialized args
+ """
+ new_kwargs = dict()
+ for argname, arg in kwargs.iteritems():
+ new_kwargs[argname] = self.serializer.deserialize_entity(context,
+ arg)
+ return new_kwargs
+
+ def dispatch(self, ctxt, version, method, namespace, **kwargs):
+ """Dispatch a message based on a requested version.
+
+ :param ctxt: The request context
+ :param version: The requested API version from the incoming message
+ :param method: The method requested to be called by the incoming
+ message.
+ :param namespace: The namespace for the requested method. If None,
+ the dispatcher will look for a method on a callback
+ object with no namespace set.
+ :param kwargs: A dict of keyword arguments to be passed to the method.
+
+ :returns: Whatever is returned by the underlying method that gets
+ called.
+ """
+ if not version:
+ version = '1.0'
+
+ had_compatible = False
+ for proxyobj in self.callbacks:
+ # Check for namespace compatibility
+ try:
+ cb_namespace = proxyobj.RPC_API_NAMESPACE
+ except AttributeError:
+ cb_namespace = None
+
+ if namespace != cb_namespace:
+ continue
+
+ # Check for version compatibility
+ try:
+ rpc_api_version = proxyobj.RPC_API_VERSION
+ except AttributeError:
+ rpc_api_version = '1.0'
+
+ is_compatible = rpc_common.version_is_compatible(rpc_api_version,
+ version)
+ had_compatible = had_compatible or is_compatible
+
+ if not hasattr(proxyobj, method):
+ continue
+ if is_compatible:
+ kwargs = self._deserialize_args(ctxt, kwargs)
+ result = getattr(proxyobj, method)(ctxt, **kwargs)
+ return self.serializer.serialize_entity(ctxt, result)
+
+ if had_compatible:
+ raise AttributeError("No such RPC function '%s'" % method)
+ else:
+ raise rpc_common.UnsupportedRpcVersion(version=version)
diff --git a/keystone/openstack/common/rpc/impl_fake.py b/keystone/openstack/common/rpc/impl_fake.py
new file mode 100644
index 00000000..9479ab4d
--- /dev/null
+++ b/keystone/openstack/common/rpc/impl_fake.py
@@ -0,0 +1,195 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 OpenStack Foundation
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""Fake RPC implementation which calls proxy methods directly with no
+queues. Casts will block, but this is very useful for tests.
+"""
+
+import inspect
+# NOTE(russellb): We specifically want to use json, not our own jsonutils.
+# jsonutils has some extra logic to automatically convert objects to primitive
+# types so that they can be serialized. We want to catch all cases where
+# non-primitive types make it into this code and treat it as an error.
+import json
+import time
+
+import eventlet
+
+from keystone.openstack.common.rpc import common as rpc_common
+
+CONSUMERS = {}
+
+
+class RpcContext(rpc_common.CommonRpcContext):
+ def __init__(self, **kwargs):
+ super(RpcContext, self).__init__(**kwargs)
+ self._response = []
+ self._done = False
+
+ def deepcopy(self):
+ values = self.to_dict()
+ new_inst = self.__class__(**values)
+ new_inst._response = self._response
+ new_inst._done = self._done
+ return new_inst
+
+ def reply(self, reply=None, failure=None, ending=False):
+ if ending:
+ self._done = True
+ if not self._done:
+ self._response.append((reply, failure))
+
+
+class Consumer(object):
+ def __init__(self, topic, proxy):
+ self.topic = topic
+ self.proxy = proxy
+
+ def call(self, context, version, method, namespace, args, timeout):
+ done = eventlet.event.Event()
+
+ def _inner():
+ ctxt = RpcContext.from_dict(context.to_dict())
+ try:
+ rval = self.proxy.dispatch(context, version, method,
+ namespace, **args)
+ res = []
+ # Caller might have called ctxt.reply() manually
+ for (reply, failure) in ctxt._response:
+ if failure:
+ raise failure[0], failure[1], failure[2]
+ res.append(reply)
+ # if ending not 'sent'...we might have more data to
+ # return from the function itself
+ if not ctxt._done:
+ if inspect.isgenerator(rval):
+ for val in rval:
+ res.append(val)
+ else:
+ res.append(rval)
+ done.send(res)
+ except rpc_common.ClientException as e:
+ done.send_exception(e._exc_info[1])
+ except Exception as e:
+ done.send_exception(e)
+
+ thread = eventlet.greenthread.spawn(_inner)
+
+ if timeout:
+ start_time = time.time()
+ while not done.ready():
+ eventlet.greenthread.sleep(1)
+ cur_time = time.time()
+ if (cur_time - start_time) > timeout:
+ thread.kill()
+ raise rpc_common.Timeout()
+
+ return done.wait()
+
+
+class Connection(object):
+ """Connection object."""
+
+ def __init__(self):
+ self.consumers = []
+
+ def create_consumer(self, topic, proxy, fanout=False):
+ consumer = Consumer(topic, proxy)
+ self.consumers.append(consumer)
+ if topic not in CONSUMERS:
+ CONSUMERS[topic] = []
+ CONSUMERS[topic].append(consumer)
+
+ def close(self):
+ for consumer in self.consumers:
+ CONSUMERS[consumer.topic].remove(consumer)
+ self.consumers = []
+
+ def consume_in_thread(self):
+ pass
+
+
+def create_connection(conf, new=True):
+ """Create a connection."""
+ return Connection()
+
+
+def check_serialize(msg):
+ """Make sure a message intended for rpc can be serialized."""
+ json.dumps(msg)
+
+
+def multicall(conf, context, topic, msg, timeout=None):
+ """Make a call that returns multiple times."""
+
+ check_serialize(msg)
+
+ method = msg.get('method')
+ if not method:
+ return
+ args = msg.get('args', {})
+ version = msg.get('version', None)
+ namespace = msg.get('namespace', None)
+
+ try:
+ consumer = CONSUMERS[topic][0]
+ except (KeyError, IndexError):
+ return iter([None])
+ else:
+ return consumer.call(context, version, method, namespace, args,
+ timeout)
+
+
+def call(conf, context, topic, msg, timeout=None):
+ """Sends a message on a topic and wait for a response."""
+ rv = multicall(conf, context, topic, msg, timeout)
+ # NOTE(vish): return the last result from the multicall
+ rv = list(rv)
+ if not rv:
+ return
+ return rv[-1]
+
+
+def cast(conf, context, topic, msg):
+ check_serialize(msg)
+ try:
+ call(conf, context, topic, msg)
+ except Exception:
+ pass
+
+
+def notify(conf, context, topic, msg, envelope):
+ check_serialize(msg)
+
+
+def cleanup():
+ pass
+
+
+def fanout_cast(conf, context, topic, msg):
+ """Cast to all consumers of a topic."""
+ check_serialize(msg)
+ method = msg.get('method')
+ if not method:
+ return
+ args = msg.get('args', {})
+ version = msg.get('version', None)
+ namespace = msg.get('namespace', None)
+
+ for consumer in CONSUMERS.get(topic, []):
+ try:
+ consumer.call(context, version, method, namespace, args, None)
+ except Exception:
+ pass
diff --git a/keystone/openstack/common/rpc/impl_kombu.py b/keystone/openstack/common/rpc/impl_kombu.py
new file mode 100644
index 00000000..0e641db0
--- /dev/null
+++ b/keystone/openstack/common/rpc/impl_kombu.py
@@ -0,0 +1,861 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 OpenStack Foundation
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import functools
+import itertools
+import socket
+import ssl
+import time
+import uuid
+
+import eventlet
+import greenlet
+import kombu
+import kombu.connection
+import kombu.entity
+import kombu.messaging
+from oslo.config import cfg
+
+from keystone.openstack.common import excutils
+from keystone.openstack.common.gettextutils import _ # noqa
+from keystone.openstack.common import network_utils
+from keystone.openstack.common.rpc import amqp as rpc_amqp
+from keystone.openstack.common.rpc import common as rpc_common
+from keystone.openstack.common import sslutils
+
+kombu_opts = [
+ cfg.StrOpt('kombu_ssl_version',
+ default='',
+ help='SSL version to use (valid only if SSL enabled). '
+ 'valid values are TLSv1, SSLv23 and SSLv3. SSLv2 may '
+ 'be available on some distributions'
+ ),
+ cfg.StrOpt('kombu_ssl_keyfile',
+ default='',
+ help='SSL key file (valid only if SSL enabled)'),
+ cfg.StrOpt('kombu_ssl_certfile',
+ default='',
+ help='SSL cert file (valid only if SSL enabled)'),
+ cfg.StrOpt('kombu_ssl_ca_certs',
+ default='',
+ help=('SSL certification authority file '
+ '(valid only if SSL enabled)')),
+ cfg.StrOpt('rabbit_host',
+ default='localhost',
+ help='The RabbitMQ broker address where a single node is used'),
+ cfg.IntOpt('rabbit_port',
+ default=5672,
+ help='The RabbitMQ broker port where a single node is used'),
+ cfg.ListOpt('rabbit_hosts',
+ default=['$rabbit_host:$rabbit_port'],
+ help='RabbitMQ HA cluster host:port pairs'),
+ cfg.BoolOpt('rabbit_use_ssl',
+ default=False,
+ help='connect over SSL for RabbitMQ'),
+ cfg.StrOpt('rabbit_userid',
+ default='guest',
+ help='the RabbitMQ userid'),
+ cfg.StrOpt('rabbit_password',
+ default='guest',
+ help='the RabbitMQ password',
+ secret=True),
+ cfg.StrOpt('rabbit_virtual_host',
+ default='/',
+ help='the RabbitMQ virtual host'),
+ cfg.IntOpt('rabbit_retry_interval',
+ default=1,
+ help='how frequently to retry connecting with RabbitMQ'),
+ cfg.IntOpt('rabbit_retry_backoff',
+ default=2,
+ help='how long to backoff for between retries when connecting '
+ 'to RabbitMQ'),
+ cfg.IntOpt('rabbit_max_retries',
+ default=0,
+ help='maximum retries with trying to connect to RabbitMQ '
+ '(the default of 0 implies an infinite retry count)'),
+ cfg.BoolOpt('rabbit_ha_queues',
+ default=False,
+ help='use H/A queues in RabbitMQ (x-ha-policy: all).'
+ 'You need to wipe RabbitMQ database when '
+ 'changing this option.'),
+
+]
+
+cfg.CONF.register_opts(kombu_opts)
+
+LOG = rpc_common.LOG
+
+
+def _get_queue_arguments(conf):
+ """Construct the arguments for declaring a queue.
+
+ If the rabbit_ha_queues option is set, we declare a mirrored queue
+ as described here:
+
+ http://www.rabbitmq.com/ha.html
+
+ Setting x-ha-policy to all means that the queue will be mirrored
+ to all nodes in the cluster.
+ """
+ return {'x-ha-policy': 'all'} if conf.rabbit_ha_queues else {}
+
+
+class ConsumerBase(object):
+ """Consumer base class."""
+
+ def __init__(self, channel, callback, tag, **kwargs):
+ """Declare a queue on an amqp channel.
+
+ 'channel' is the amqp channel to use
+ 'callback' is the callback to call when messages are received
+ 'tag' is a unique ID for the consumer on the channel
+
+ queue name, exchange name, and other kombu options are
+ passed in here as a dictionary.
+ """
+ self.callback = callback
+ self.tag = str(tag)
+ self.kwargs = kwargs
+ self.queue = None
+ self.ack_on_error = kwargs.get('ack_on_error', True)
+ self.reconnect(channel)
+
+ def reconnect(self, channel):
+ """Re-declare the queue after a rabbit reconnect."""
+ self.channel = channel
+ self.kwargs['channel'] = channel
+ self.queue = kombu.entity.Queue(**self.kwargs)
+ self.queue.declare()
+
+ def _callback_handler(self, message, callback):
+ """Call callback with deserialized message.
+
+ Messages that are processed without exception are ack'ed.
+
+ If the message processing generates an exception, it will be
+ ack'ed if ack_on_error=True. Otherwise it will be .reject()'ed.
+ Rejection is better than waiting for the message to timeout.
+ Rejected messages are immediately requeued.
+ """
+
+ ack_msg = False
+ try:
+ msg = rpc_common.deserialize_msg(message.payload)
+ callback(msg)
+ ack_msg = True
+ except Exception:
+ if self.ack_on_error:
+ ack_msg = True
+ LOG.exception(_("Failed to process message"
+ " ... skipping it."))
+ else:
+ LOG.exception(_("Failed to process message"
+ " ... will requeue."))
+ finally:
+ if ack_msg:
+ message.ack()
+ else:
+ message.reject()
+
+ def consume(self, *args, **kwargs):
+ """Actually declare the consumer on the amqp channel. This will
+ start the flow of messages from the queue. Using the
+ Connection.iterconsume() iterator will process the messages,
+ calling the appropriate callback.
+
+ If a callback is specified in kwargs, use that. Otherwise,
+ use the callback passed during __init__()
+
+ If kwargs['nowait'] is True, then this call will block until
+ a message is read.
+
+ """
+
+ options = {'consumer_tag': self.tag}
+ options['nowait'] = kwargs.get('nowait', False)
+ callback = kwargs.get('callback', self.callback)
+ if not callback:
+ raise ValueError("No callback defined")
+
+ def _callback(raw_message):
+ message = self.channel.message_to_python(raw_message)
+ self._callback_handler(message, callback)
+
+ self.queue.consume(*args, callback=_callback, **options)
+
+ def cancel(self):
+ """Cancel the consuming from the queue, if it has started."""
+ try:
+ self.queue.cancel(self.tag)
+ except KeyError as e:
+ # NOTE(comstud): Kludge to get around a amqplib bug
+ if str(e) != "u'%s'" % self.tag:
+ raise
+ self.queue = None
+
+
+class DirectConsumer(ConsumerBase):
+ """Queue/consumer class for 'direct'."""
+
+ def __init__(self, conf, channel, msg_id, callback, tag, **kwargs):
+ """Init a 'direct' queue.
+
+ 'channel' is the amqp channel to use
+ 'msg_id' is the msg_id to listen on
+ 'callback' is the callback to call when messages are received
+ 'tag' is a unique ID for the consumer on the channel
+
+ Other kombu options may be passed
+ """
+ # Default options
+ options = {'durable': False,
+ 'queue_arguments': _get_queue_arguments(conf),
+ 'auto_delete': True,
+ 'exclusive': False}
+ options.update(kwargs)
+ exchange = kombu.entity.Exchange(name=msg_id,
+ type='direct',
+ durable=options['durable'],
+ auto_delete=options['auto_delete'])
+ super(DirectConsumer, self).__init__(channel,
+ callback,
+ tag,
+ name=msg_id,
+ exchange=exchange,
+ routing_key=msg_id,
+ **options)
+
+
+class TopicConsumer(ConsumerBase):
+ """Consumer class for 'topic'."""
+
+ def __init__(self, conf, channel, topic, callback, tag, name=None,
+ exchange_name=None, **kwargs):
+ """Init a 'topic' queue.
+
+ :param channel: the amqp channel to use
+ :param topic: the topic to listen on
+ :paramtype topic: str
+ :param callback: the callback to call when messages are received
+ :param tag: a unique ID for the consumer on the channel
+ :param name: optional queue name, defaults to topic
+ :paramtype name: str
+
+ Other kombu options may be passed as keyword arguments
+ """
+ # Default options
+ options = {'durable': conf.amqp_durable_queues,
+ 'queue_arguments': _get_queue_arguments(conf),
+ 'auto_delete': conf.amqp_auto_delete,
+ 'exclusive': False}
+ options.update(kwargs)
+ exchange_name = exchange_name or rpc_amqp.get_control_exchange(conf)
+ exchange = kombu.entity.Exchange(name=exchange_name,
+ type='topic',
+ durable=options['durable'],
+ auto_delete=options['auto_delete'])
+ super(TopicConsumer, self).__init__(channel,
+ callback,
+ tag,
+ name=name or topic,
+ exchange=exchange,
+ routing_key=topic,
+ **options)
+
+
+class FanoutConsumer(ConsumerBase):
+ """Consumer class for 'fanout'."""
+
+ def __init__(self, conf, channel, topic, callback, tag, **kwargs):
+ """Init a 'fanout' queue.
+
+ 'channel' is the amqp channel to use
+ 'topic' is the topic to listen on
+ 'callback' is the callback to call when messages are received
+ 'tag' is a unique ID for the consumer on the channel
+
+ Other kombu options may be passed
+ """
+ unique = uuid.uuid4().hex
+ exchange_name = '%s_fanout' % topic
+ queue_name = '%s_fanout_%s' % (topic, unique)
+
+ # Default options
+ options = {'durable': False,
+ 'queue_arguments': _get_queue_arguments(conf),
+ 'auto_delete': True,
+ 'exclusive': False}
+ options.update(kwargs)
+ exchange = kombu.entity.Exchange(name=exchange_name, type='fanout',
+ durable=options['durable'],
+ auto_delete=options['auto_delete'])
+ super(FanoutConsumer, self).__init__(channel, callback, tag,
+ name=queue_name,
+ exchange=exchange,
+ routing_key=topic,
+ **options)
+
+
+class Publisher(object):
+ """Base Publisher class."""
+
+ def __init__(self, channel, exchange_name, routing_key, **kwargs):
+ """Init the Publisher class with the exchange_name, routing_key,
+ and other options
+ """
+ self.exchange_name = exchange_name
+ self.routing_key = routing_key
+ self.kwargs = kwargs
+ self.reconnect(channel)
+
+ def reconnect(self, channel):
+ """Re-establish the Producer after a rabbit reconnection."""
+ self.exchange = kombu.entity.Exchange(name=self.exchange_name,
+ **self.kwargs)
+ self.producer = kombu.messaging.Producer(exchange=self.exchange,
+ channel=channel,
+ routing_key=self.routing_key)
+
+ def send(self, msg, timeout=None):
+ """Send a message."""
+ if timeout:
+ #
+ # AMQP TTL is in milliseconds when set in the header.
+ #
+ self.producer.publish(msg, headers={'ttl': (timeout * 1000)})
+ else:
+ self.producer.publish(msg)
+
+
+class DirectPublisher(Publisher):
+ """Publisher class for 'direct'."""
+ def __init__(self, conf, channel, msg_id, **kwargs):
+ """init a 'direct' publisher.
+
+ Kombu options may be passed as keyword args to override defaults
+ """
+
+ options = {'durable': False,
+ 'auto_delete': True,
+ 'exclusive': False}
+ options.update(kwargs)
+ super(DirectPublisher, self).__init__(channel, msg_id, msg_id,
+ type='direct', **options)
+
+
+class TopicPublisher(Publisher):
+ """Publisher class for 'topic'."""
+ def __init__(self, conf, channel, topic, **kwargs):
+ """init a 'topic' publisher.
+
+ Kombu options may be passed as keyword args to override defaults
+ """
+ options = {'durable': conf.amqp_durable_queues,
+ 'auto_delete': conf.amqp_auto_delete,
+ 'exclusive': False}
+ options.update(kwargs)
+ exchange_name = rpc_amqp.get_control_exchange(conf)
+ super(TopicPublisher, self).__init__(channel,
+ exchange_name,
+ topic,
+ type='topic',
+ **options)
+
+
+class FanoutPublisher(Publisher):
+ """Publisher class for 'fanout'."""
+ def __init__(self, conf, channel, topic, **kwargs):
+ """init a 'fanout' publisher.
+
+ Kombu options may be passed as keyword args to override defaults
+ """
+ options = {'durable': False,
+ 'auto_delete': True,
+ 'exclusive': False}
+ options.update(kwargs)
+ super(FanoutPublisher, self).__init__(channel, '%s_fanout' % topic,
+ None, type='fanout', **options)
+
+
+class NotifyPublisher(TopicPublisher):
+ """Publisher class for 'notify'."""
+
+ def __init__(self, conf, channel, topic, **kwargs):
+ self.durable = kwargs.pop('durable', conf.amqp_durable_queues)
+ self.queue_arguments = _get_queue_arguments(conf)
+ super(NotifyPublisher, self).__init__(conf, channel, topic, **kwargs)
+
+ def reconnect(self, channel):
+ super(NotifyPublisher, self).reconnect(channel)
+
+ # NOTE(jerdfelt): Normally the consumer would create the queue, but
+ # we do this to ensure that messages don't get dropped if the
+ # consumer is started after we do
+ queue = kombu.entity.Queue(channel=channel,
+ exchange=self.exchange,
+ durable=self.durable,
+ name=self.routing_key,
+ routing_key=self.routing_key,
+ queue_arguments=self.queue_arguments)
+ queue.declare()
+
+
+class Connection(object):
+ """Connection object."""
+
+ pool = None
+
+ def __init__(self, conf, server_params=None):
+ self.consumers = []
+ self.consumer_thread = None
+ self.proxy_callbacks = []
+ self.conf = conf
+ self.max_retries = self.conf.rabbit_max_retries
+ # Try forever?
+ if self.max_retries <= 0:
+ self.max_retries = None
+ self.interval_start = self.conf.rabbit_retry_interval
+ self.interval_stepping = self.conf.rabbit_retry_backoff
+ # max retry-interval = 30 seconds
+ self.interval_max = 30
+ self.memory_transport = False
+
+ if server_params is None:
+ server_params = {}
+ # Keys to translate from server_params to kombu params
+ server_params_to_kombu_params = {'username': 'userid'}
+
+ ssl_params = self._fetch_ssl_params()
+ params_list = []
+ for adr in self.conf.rabbit_hosts:
+ hostname, port = network_utils.parse_host_port(
+ adr, default_port=self.conf.rabbit_port)
+
+ params = {
+ 'hostname': hostname,
+ 'port': port,
+ 'userid': self.conf.rabbit_userid,
+ 'password': self.conf.rabbit_password,
+ 'virtual_host': self.conf.rabbit_virtual_host,
+ }
+
+ for sp_key, value in server_params.iteritems():
+ p_key = server_params_to_kombu_params.get(sp_key, sp_key)
+ params[p_key] = value
+
+ if self.conf.fake_rabbit:
+ params['transport'] = 'memory'
+ if self.conf.rabbit_use_ssl:
+ params['ssl'] = ssl_params
+
+ params_list.append(params)
+
+ self.params_list = params_list
+
+ self.memory_transport = self.conf.fake_rabbit
+
+ self.connection = None
+ self.reconnect()
+
+ def _fetch_ssl_params(self):
+ """Handles fetching what ssl params should be used for the connection
+ (if any).
+ """
+ ssl_params = dict()
+
+ # http://docs.python.org/library/ssl.html - ssl.wrap_socket
+ if self.conf.kombu_ssl_version:
+ ssl_params['ssl_version'] = sslutils.validate_ssl_version(
+ self.conf.kombu_ssl_version)
+ if self.conf.kombu_ssl_keyfile:
+ ssl_params['keyfile'] = self.conf.kombu_ssl_keyfile
+ if self.conf.kombu_ssl_certfile:
+ ssl_params['certfile'] = self.conf.kombu_ssl_certfile
+ if self.conf.kombu_ssl_ca_certs:
+ ssl_params['ca_certs'] = self.conf.kombu_ssl_ca_certs
+ # We might want to allow variations in the
+ # future with this?
+ ssl_params['cert_reqs'] = ssl.CERT_REQUIRED
+
+ # Return the extended behavior or just have the default behavior
+ return ssl_params or True
+
+ def _connect(self, params):
+ """Connect to rabbit. Re-establish any queues that may have
+ been declared before if we are reconnecting. Exceptions should
+ be handled by the caller.
+ """
+ if self.connection:
+ LOG.info(_("Reconnecting to AMQP server on "
+ "%(hostname)s:%(port)d") % params)
+ try:
+ self.connection.release()
+ except self.connection_errors:
+ pass
+ # Setting this in case the next statement fails, though
+ # it shouldn't be doing any network operations, yet.
+ self.connection = None
+ self.connection = kombu.connection.BrokerConnection(**params)
+ self.connection_errors = self.connection.connection_errors
+ if self.memory_transport:
+ # Kludge to speed up tests.
+ self.connection.transport.polling_interval = 0.0
+ self.consumer_num = itertools.count(1)
+ self.connection.connect()
+ self.channel = self.connection.channel()
+ # work around 'memory' transport bug in 1.1.3
+ if self.memory_transport:
+ self.channel._new_queue('ae.undeliver')
+ for consumer in self.consumers:
+ consumer.reconnect(self.channel)
+ LOG.info(_('Connected to AMQP server on %(hostname)s:%(port)d') %
+ params)
+
+ def reconnect(self):
+ """Handles reconnecting and re-establishing queues.
+ Will retry up to self.max_retries number of times.
+ self.max_retries = 0 means to retry forever.
+ Sleep between tries, starting at self.interval_start
+ seconds, backing off self.interval_stepping number of seconds
+ each attempt.
+ """
+
+ attempt = 0
+ while True:
+ params = self.params_list[attempt % len(self.params_list)]
+ attempt += 1
+ try:
+ self._connect(params)
+ return
+ except (IOError, self.connection_errors) as e:
+ pass
+ except Exception as e:
+ # NOTE(comstud): Unfortunately it's possible for amqplib
+ # to return an error not covered by its transport
+ # connection_errors in the case of a timeout waiting for
+ # a protocol response. (See paste link in LP888621)
+ # So, we check all exceptions for 'timeout' in them
+ # and try to reconnect in this case.
+ if 'timeout' not in str(e):
+ raise
+
+ log_info = {}
+ log_info['err_str'] = str(e)
+ log_info['max_retries'] = self.max_retries
+ log_info.update(params)
+
+ if self.max_retries and attempt == self.max_retries:
+ msg = _('Unable to connect to AMQP server on '
+ '%(hostname)s:%(port)d after %(max_retries)d '
+ 'tries: %(err_str)s') % log_info
+ LOG.error(msg)
+ raise rpc_common.RPCException(msg)
+
+ if attempt == 1:
+ sleep_time = self.interval_start or 1
+ elif attempt > 1:
+ sleep_time += self.interval_stepping
+ if self.interval_max:
+ sleep_time = min(sleep_time, self.interval_max)
+
+ log_info['sleep_time'] = sleep_time
+ LOG.error(_('AMQP server on %(hostname)s:%(port)d is '
+ 'unreachable: %(err_str)s. Trying again in '
+ '%(sleep_time)d seconds.') % log_info)
+ time.sleep(sleep_time)
+
+ def ensure(self, error_callback, method, *args, **kwargs):
+ while True:
+ try:
+ return method(*args, **kwargs)
+ except (self.connection_errors, socket.timeout, IOError) as e:
+ if error_callback:
+ error_callback(e)
+ except Exception as e:
+ # NOTE(comstud): Unfortunately it's possible for amqplib
+ # to return an error not covered by its transport
+ # connection_errors in the case of a timeout waiting for
+ # a protocol response. (See paste link in LP888621)
+ # So, we check all exceptions for 'timeout' in them
+ # and try to reconnect in this case.
+ if 'timeout' not in str(e):
+ raise
+ if error_callback:
+ error_callback(e)
+ self.reconnect()
+
+ def get_channel(self):
+ """Convenience call for bin/clear_rabbit_queues."""
+ return self.channel
+
+ def close(self):
+ """Close/release this connection."""
+ self.cancel_consumer_thread()
+ self.wait_on_proxy_callbacks()
+ self.connection.release()
+ self.connection = None
+
+ def reset(self):
+ """Reset a connection so it can be used again."""
+ self.cancel_consumer_thread()
+ self.wait_on_proxy_callbacks()
+ self.channel.close()
+ self.channel = self.connection.channel()
+ # work around 'memory' transport bug in 1.1.3
+ if self.memory_transport:
+ self.channel._new_queue('ae.undeliver')
+ self.consumers = []
+
+ def declare_consumer(self, consumer_cls, topic, callback):
+ """Create a Consumer using the class that was passed in and
+ add it to our list of consumers
+ """
+
+ def _connect_error(exc):
+ log_info = {'topic': topic, 'err_str': str(exc)}
+ LOG.error(_("Failed to declare consumer for topic '%(topic)s': "
+ "%(err_str)s") % log_info)
+
+ def _declare_consumer():
+ consumer = consumer_cls(self.conf, self.channel, topic, callback,
+ self.consumer_num.next())
+ self.consumers.append(consumer)
+ return consumer
+
+ return self.ensure(_connect_error, _declare_consumer)
+
+ def iterconsume(self, limit=None, timeout=None):
+ """Return an iterator that will consume from all queues/consumers."""
+
+ info = {'do_consume': True}
+
+ def _error_callback(exc):
+ if isinstance(exc, socket.timeout):
+ LOG.debug(_('Timed out waiting for RPC response: %s') %
+ str(exc))
+ raise rpc_common.Timeout()
+ else:
+ LOG.exception(_('Failed to consume message from queue: %s') %
+ str(exc))
+ info['do_consume'] = True
+
+ def _consume():
+ if info['do_consume']:
+ queues_head = self.consumers[:-1] # not fanout.
+ queues_tail = self.consumers[-1] # fanout
+ for queue in queues_head:
+ queue.consume(nowait=True)
+ queues_tail.consume(nowait=False)
+ info['do_consume'] = False
+ return self.connection.drain_events(timeout=timeout)
+
+ for iteration in itertools.count(0):
+ if limit and iteration >= limit:
+ raise StopIteration
+ yield self.ensure(_error_callback, _consume)
+
+ def cancel_consumer_thread(self):
+ """Cancel a consumer thread."""
+ if self.consumer_thread is not None:
+ self.consumer_thread.kill()
+ try:
+ self.consumer_thread.wait()
+ except greenlet.GreenletExit:
+ pass
+ self.consumer_thread = None
+
+ def wait_on_proxy_callbacks(self):
+ """Wait for all proxy callback threads to exit."""
+ for proxy_cb in self.proxy_callbacks:
+ proxy_cb.wait()
+
+ def publisher_send(self, cls, topic, msg, timeout=None, **kwargs):
+ """Send to a publisher based on the publisher class."""
+
+ def _error_callback(exc):
+ log_info = {'topic': topic, 'err_str': str(exc)}
+ LOG.exception(_("Failed to publish message to topic "
+ "'%(topic)s': %(err_str)s") % log_info)
+
+ def _publish():
+ publisher = cls(self.conf, self.channel, topic, **kwargs)
+ publisher.send(msg, timeout)
+
+ self.ensure(_error_callback, _publish)
+
+ def declare_direct_consumer(self, topic, callback):
+ """Create a 'direct' queue.
+ In nova's use, this is generally a msg_id queue used for
+ responses for call/multicall
+ """
+ self.declare_consumer(DirectConsumer, topic, callback)
+
+ def declare_topic_consumer(self, topic, callback=None, queue_name=None,
+ exchange_name=None, ack_on_error=True):
+ """Create a 'topic' consumer."""
+ self.declare_consumer(functools.partial(TopicConsumer,
+ name=queue_name,
+ exchange_name=exchange_name,
+ ack_on_error=ack_on_error,
+ ),
+ topic, callback)
+
+ def declare_fanout_consumer(self, topic, callback):
+ """Create a 'fanout' consumer."""
+ self.declare_consumer(FanoutConsumer, topic, callback)
+
+ def direct_send(self, msg_id, msg):
+ """Send a 'direct' message."""
+ self.publisher_send(DirectPublisher, msg_id, msg)
+
+ def topic_send(self, topic, msg, timeout=None):
+ """Send a 'topic' message."""
+ self.publisher_send(TopicPublisher, topic, msg, timeout)
+
+ def fanout_send(self, topic, msg):
+ """Send a 'fanout' message."""
+ self.publisher_send(FanoutPublisher, topic, msg)
+
+ def notify_send(self, topic, msg, **kwargs):
+ """Send a notify message on a topic."""
+ self.publisher_send(NotifyPublisher, topic, msg, None, **kwargs)
+
+ def consume(self, limit=None):
+ """Consume from all queues/consumers."""
+ it = self.iterconsume(limit=limit)
+ while True:
+ try:
+ it.next()
+ except StopIteration:
+ return
+
+ def consume_in_thread(self):
+ """Consumer from all queues/consumers in a greenthread."""
+ @excutils.forever_retry_uncaught_exceptions
+ def _consumer_thread():
+ try:
+ self.consume()
+ except greenlet.GreenletExit:
+ return
+ if self.consumer_thread is None:
+ self.consumer_thread = eventlet.spawn(_consumer_thread)
+ return self.consumer_thread
+
+ def create_consumer(self, topic, proxy, fanout=False):
+ """Create a consumer that calls a method in a proxy object."""
+ proxy_cb = rpc_amqp.ProxyCallback(
+ self.conf, proxy,
+ rpc_amqp.get_connection_pool(self.conf, Connection))
+ self.proxy_callbacks.append(proxy_cb)
+
+ if fanout:
+ self.declare_fanout_consumer(topic, proxy_cb)
+ else:
+ self.declare_topic_consumer(topic, proxy_cb)
+
+ def create_worker(self, topic, proxy, pool_name):
+ """Create a worker that calls a method in a proxy object."""
+ proxy_cb = rpc_amqp.ProxyCallback(
+ self.conf, proxy,
+ rpc_amqp.get_connection_pool(self.conf, Connection))
+ self.proxy_callbacks.append(proxy_cb)
+ self.declare_topic_consumer(topic, proxy_cb, pool_name)
+
+ def join_consumer_pool(self, callback, pool_name, topic,
+ exchange_name=None, ack_on_error=True):
+ """Register as a member of a group of consumers for a given topic from
+ the specified exchange.
+
+ Exactly one member of a given pool will receive each message.
+
+ A message will be delivered to multiple pools, if more than
+ one is created.
+ """
+ callback_wrapper = rpc_amqp.CallbackWrapper(
+ conf=self.conf,
+ callback=callback,
+ connection_pool=rpc_amqp.get_connection_pool(self.conf,
+ Connection),
+ )
+ self.proxy_callbacks.append(callback_wrapper)
+ self.declare_topic_consumer(
+ queue_name=pool_name,
+ topic=topic,
+ exchange_name=exchange_name,
+ callback=callback_wrapper,
+ ack_on_error=ack_on_error,
+ )
+
+
+def create_connection(conf, new=True):
+ """Create a connection."""
+ return rpc_amqp.create_connection(
+ conf, new,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def multicall(conf, context, topic, msg, timeout=None):
+ """Make a call that returns multiple times."""
+ return rpc_amqp.multicall(
+ conf, context, topic, msg, timeout,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def call(conf, context, topic, msg, timeout=None):
+ """Sends a message on a topic and wait for a response."""
+ return rpc_amqp.call(
+ conf, context, topic, msg, timeout,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def cast(conf, context, topic, msg):
+ """Sends a message on a topic without waiting for a response."""
+ return rpc_amqp.cast(
+ conf, context, topic, msg,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def fanout_cast(conf, context, topic, msg):
+ """Sends a message on a fanout exchange without waiting for a response."""
+ return rpc_amqp.fanout_cast(
+ conf, context, topic, msg,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def cast_to_server(conf, context, server_params, topic, msg):
+ """Sends a message on a topic to a specific server."""
+ return rpc_amqp.cast_to_server(
+ conf, context, server_params, topic, msg,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def fanout_cast_to_server(conf, context, server_params, topic, msg):
+ """Sends a message on a fanout exchange to a specific server."""
+ return rpc_amqp.fanout_cast_to_server(
+ conf, context, server_params, topic, msg,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def notify(conf, context, topic, msg, envelope):
+ """Sends a notification event on a topic."""
+ return rpc_amqp.notify(
+ conf, context, topic, msg,
+ rpc_amqp.get_connection_pool(conf, Connection),
+ envelope)
+
+
+def cleanup():
+ return rpc_amqp.cleanup(Connection.pool)
diff --git a/keystone/openstack/common/rpc/impl_qpid.py b/keystone/openstack/common/rpc/impl_qpid.py
new file mode 100644
index 00000000..7e67c81d
--- /dev/null
+++ b/keystone/openstack/common/rpc/impl_qpid.py
@@ -0,0 +1,739 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 OpenStack Foundation
+# Copyright 2011 - 2012, Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import functools
+import itertools
+import time
+import uuid
+
+import eventlet
+import greenlet
+from oslo.config import cfg
+
+from keystone.openstack.common import excutils
+from keystone.openstack.common.gettextutils import _ # noqa
+from keystone.openstack.common import importutils
+from keystone.openstack.common import jsonutils
+from keystone.openstack.common import log as logging
+from keystone.openstack.common.rpc import amqp as rpc_amqp
+from keystone.openstack.common.rpc import common as rpc_common
+
+qpid_codec = importutils.try_import("qpid.codec010")
+qpid_messaging = importutils.try_import("qpid.messaging")
+qpid_exceptions = importutils.try_import("qpid.messaging.exceptions")
+
+LOG = logging.getLogger(__name__)
+
+qpid_opts = [
+ cfg.StrOpt('qpid_hostname',
+ default='localhost',
+ help='Qpid broker hostname'),
+ cfg.IntOpt('qpid_port',
+ default=5672,
+ help='Qpid broker port'),
+ cfg.ListOpt('qpid_hosts',
+ default=['$qpid_hostname:$qpid_port'],
+ help='Qpid HA cluster host:port pairs'),
+ cfg.StrOpt('qpid_username',
+ default='',
+ help='Username for qpid connection'),
+ cfg.StrOpt('qpid_password',
+ default='',
+ help='Password for qpid connection',
+ secret=True),
+ cfg.StrOpt('qpid_sasl_mechanisms',
+ default='',
+ help='Space separated list of SASL mechanisms to use for auth'),
+ cfg.IntOpt('qpid_heartbeat',
+ default=60,
+ help='Seconds between connection keepalive heartbeats'),
+ cfg.StrOpt('qpid_protocol',
+ default='tcp',
+ help="Transport to use, either 'tcp' or 'ssl'"),
+ cfg.BoolOpt('qpid_tcp_nodelay',
+ default=True,
+ help='Disable Nagle algorithm'),
+]
+
+cfg.CONF.register_opts(qpid_opts)
+
+JSON_CONTENT_TYPE = 'application/json; charset=utf8'
+
+
+class ConsumerBase(object):
+ """Consumer base class."""
+
+ def __init__(self, session, callback, node_name, node_opts,
+ link_name, link_opts):
+ """Declare a queue on an amqp session.
+
+ 'session' is the amqp session to use
+ 'callback' is the callback to call when messages are received
+ 'node_name' is the first part of the Qpid address string, before ';'
+ 'node_opts' will be applied to the "x-declare" section of "node"
+ in the address string.
+ 'link_name' goes into the "name" field of the "link" in the address
+ string
+ 'link_opts' will be applied to the "x-declare" section of "link"
+ in the address string.
+ """
+ self.callback = callback
+ self.receiver = None
+ self.session = None
+
+ addr_opts = {
+ "create": "always",
+ "node": {
+ "type": "topic",
+ "x-declare": {
+ "durable": True,
+ "auto-delete": True,
+ },
+ },
+ "link": {
+ "name": link_name,
+ "durable": True,
+ "x-declare": {
+ "durable": False,
+ "auto-delete": True,
+ "exclusive": False,
+ },
+ },
+ }
+ addr_opts["node"]["x-declare"].update(node_opts)
+ addr_opts["link"]["x-declare"].update(link_opts)
+
+ self.address = "%s ; %s" % (node_name, jsonutils.dumps(addr_opts))
+
+ self.connect(session)
+
+ def connect(self, session):
+ """Declare the reciever on connect."""
+ self._declare_receiver(session)
+
+ def reconnect(self, session):
+ """Re-declare the receiver after a qpid reconnect."""
+ self._declare_receiver(session)
+
+ def _declare_receiver(self, session):
+ self.session = session
+ self.receiver = session.receiver(self.address)
+ self.receiver.capacity = 1
+
+ def _unpack_json_msg(self, msg):
+ """Load the JSON data in msg if msg.content_type indicates that it
+ is necessary. Put the loaded data back into msg.content and
+ update msg.content_type appropriately.
+
+ A Qpid Message containing a dict will have a content_type of
+ 'amqp/map', whereas one containing a string that needs to be converted
+ back from JSON will have a content_type of JSON_CONTENT_TYPE.
+
+ :param msg: a Qpid Message object
+ :returns: None
+ """
+ if msg.content_type == JSON_CONTENT_TYPE:
+ msg.content = jsonutils.loads(msg.content)
+ msg.content_type = 'amqp/map'
+
+ def consume(self):
+ """Fetch the message and pass it to the callback object."""
+ message = self.receiver.fetch()
+ try:
+ self._unpack_json_msg(message)
+ msg = rpc_common.deserialize_msg(message.content)
+ self.callback(msg)
+ except Exception:
+ LOG.exception(_("Failed to process message... skipping it."))
+ finally:
+ # TODO(sandy): Need support for optional ack_on_error.
+ self.session.acknowledge(message)
+
+ def get_receiver(self):
+ return self.receiver
+
+ def get_node_name(self):
+ return self.address.split(';')[0]
+
+
+class DirectConsumer(ConsumerBase):
+ """Queue/consumer class for 'direct'."""
+
+ def __init__(self, conf, session, msg_id, callback):
+ """Init a 'direct' queue.
+
+ 'session' is the amqp session to use
+ 'msg_id' is the msg_id to listen on
+ 'callback' is the callback to call when messages are received
+ """
+
+ super(DirectConsumer, self).__init__(
+ session, callback,
+ "%s/%s" % (msg_id, msg_id),
+ {"type": "direct"},
+ msg_id,
+ {
+ "auto-delete": conf.amqp_auto_delete,
+ "exclusive": True,
+ "durable": conf.amqp_durable_queues,
+ })
+
+
+class TopicConsumer(ConsumerBase):
+ """Consumer class for 'topic'."""
+
+ def __init__(self, conf, session, topic, callback, name=None,
+ exchange_name=None):
+ """Init a 'topic' queue.
+
+ :param session: the amqp session to use
+ :param topic: is the topic to listen on
+ :paramtype topic: str
+ :param callback: the callback to call when messages are received
+ :param name: optional queue name, defaults to topic
+ """
+
+ exchange_name = exchange_name or rpc_amqp.get_control_exchange(conf)
+ super(TopicConsumer, self).__init__(
+ session, callback,
+ "%s/%s" % (exchange_name, topic),
+ {}, name or topic,
+ {
+ "auto-delete": conf.amqp_auto_delete,
+ "durable": conf.amqp_durable_queues,
+ })
+
+
+class FanoutConsumer(ConsumerBase):
+ """Consumer class for 'fanout'."""
+
+ def __init__(self, conf, session, topic, callback):
+ """Init a 'fanout' queue.
+
+ 'session' is the amqp session to use
+ 'topic' is the topic to listen on
+ 'callback' is the callback to call when messages are received
+ """
+ self.conf = conf
+
+ super(FanoutConsumer, self).__init__(
+ session, callback,
+ "%s_fanout" % topic,
+ {"durable": False, "type": "fanout"},
+ "%s_fanout_%s" % (topic, uuid.uuid4().hex),
+ {"exclusive": True})
+
+ def reconnect(self, session):
+ topic = self.get_node_name().rpartition('_fanout')[0]
+ params = {
+ 'session': session,
+ 'topic': topic,
+ 'callback': self.callback,
+ }
+
+ self.__init__(conf=self.conf, **params)
+
+ super(FanoutConsumer, self).reconnect(session)
+
+
+class Publisher(object):
+ """Base Publisher class."""
+
+ def __init__(self, session, node_name, node_opts=None):
+ """Init the Publisher class with the exchange_name, routing_key,
+ and other options
+ """
+ self.sender = None
+ self.session = session
+
+ addr_opts = {
+ "create": "always",
+ "node": {
+ "type": "topic",
+ "x-declare": {
+ "durable": False,
+ # auto-delete isn't implemented for exchanges in qpid,
+ # but put in here anyway
+ "auto-delete": True,
+ },
+ },
+ }
+ if node_opts:
+ addr_opts["node"]["x-declare"].update(node_opts)
+
+ self.address = "%s ; %s" % (node_name, jsonutils.dumps(addr_opts))
+
+ self.reconnect(session)
+
+ def reconnect(self, session):
+ """Re-establish the Sender after a reconnection."""
+ self.sender = session.sender(self.address)
+
+ def _pack_json_msg(self, msg):
+ """Qpid cannot serialize dicts containing strings longer than 65535
+ characters. This function dumps the message content to a JSON
+ string, which Qpid is able to handle.
+
+ :param msg: May be either a Qpid Message object or a bare dict.
+ :returns: A Qpid Message with its content field JSON encoded.
+ """
+ try:
+ msg.content = jsonutils.dumps(msg.content)
+ except AttributeError:
+ # Need to have a Qpid message so we can set the content_type.
+ msg = qpid_messaging.Message(jsonutils.dumps(msg))
+ msg.content_type = JSON_CONTENT_TYPE
+ return msg
+
+ def send(self, msg):
+ """Send a message."""
+ try:
+ # Check if Qpid can encode the message
+ check_msg = msg
+ if not hasattr(check_msg, 'content_type'):
+ check_msg = qpid_messaging.Message(msg)
+ content_type = check_msg.content_type
+ enc, dec = qpid_messaging.message.get_codec(content_type)
+ enc(check_msg.content)
+ except qpid_codec.CodecException:
+ # This means the message couldn't be serialized as a dict.
+ msg = self._pack_json_msg(msg)
+ self.sender.send(msg)
+
+
+class DirectPublisher(Publisher):
+ """Publisher class for 'direct'."""
+ def __init__(self, conf, session, msg_id):
+ """Init a 'direct' publisher."""
+ super(DirectPublisher, self).__init__(session, msg_id,
+ {"type": "direct"})
+
+
+class TopicPublisher(Publisher):
+ """Publisher class for 'topic'."""
+ def __init__(self, conf, session, topic):
+ """init a 'topic' publisher.
+ """
+ exchange_name = rpc_amqp.get_control_exchange(conf)
+ super(TopicPublisher, self).__init__(session,
+ "%s/%s" % (exchange_name, topic))
+
+
+class FanoutPublisher(Publisher):
+ """Publisher class for 'fanout'."""
+ def __init__(self, conf, session, topic):
+ """init a 'fanout' publisher.
+ """
+ super(FanoutPublisher, self).__init__(
+ session,
+ "%s_fanout" % topic, {"type": "fanout"})
+
+
+class NotifyPublisher(Publisher):
+ """Publisher class for notifications."""
+ def __init__(self, conf, session, topic):
+ """init a 'topic' publisher.
+ """
+ exchange_name = rpc_amqp.get_control_exchange(conf)
+ super(NotifyPublisher, self).__init__(session,
+ "%s/%s" % (exchange_name, topic),
+ {"durable": True})
+
+
+class Connection(object):
+ """Connection object."""
+
+ pool = None
+
+ def __init__(self, conf, server_params=None):
+ if not qpid_messaging:
+ raise ImportError("Failed to import qpid.messaging")
+
+ self.session = None
+ self.consumers = {}
+ self.consumer_thread = None
+ self.proxy_callbacks = []
+ self.conf = conf
+
+ if server_params and 'hostname' in server_params:
+ # NOTE(russellb) This enables support for cast_to_server.
+ server_params['qpid_hosts'] = [
+ '%s:%d' % (server_params['hostname'],
+ server_params.get('port', 5672))
+ ]
+
+ params = {
+ 'qpid_hosts': self.conf.qpid_hosts,
+ 'username': self.conf.qpid_username,
+ 'password': self.conf.qpid_password,
+ }
+ params.update(server_params or {})
+
+ self.brokers = params['qpid_hosts']
+ self.username = params['username']
+ self.password = params['password']
+ self.connection_create(self.brokers[0])
+ self.reconnect()
+
+ def connection_create(self, broker):
+ # Create the connection - this does not open the connection
+ self.connection = qpid_messaging.Connection(broker)
+
+ # Check if flags are set and if so set them for the connection
+ # before we call open
+ self.connection.username = self.username
+ self.connection.password = self.password
+
+ self.connection.sasl_mechanisms = self.conf.qpid_sasl_mechanisms
+ # Reconnection is done by self.reconnect()
+ self.connection.reconnect = False
+ self.connection.heartbeat = self.conf.qpid_heartbeat
+ self.connection.transport = self.conf.qpid_protocol
+ self.connection.tcp_nodelay = self.conf.qpid_tcp_nodelay
+
+ def _register_consumer(self, consumer):
+ self.consumers[str(consumer.get_receiver())] = consumer
+
+ def _lookup_consumer(self, receiver):
+ return self.consumers[str(receiver)]
+
+ def reconnect(self):
+ """Handles reconnecting and re-establishing sessions and queues."""
+ attempt = 0
+ delay = 1
+ while True:
+ # Close the session if necessary
+ if self.connection.opened():
+ try:
+ self.connection.close()
+ except qpid_exceptions.ConnectionError:
+ pass
+
+ broker = self.brokers[attempt % len(self.brokers)]
+ attempt += 1
+
+ try:
+ self.connection_create(broker)
+ self.connection.open()
+ except qpid_exceptions.ConnectionError as e:
+ msg_dict = dict(e=e, delay=delay)
+ msg = _("Unable to connect to AMQP server: %(e)s. "
+ "Sleeping %(delay)s seconds") % msg_dict
+ LOG.error(msg)
+ time.sleep(delay)
+ delay = min(2 * delay, 60)
+ else:
+ LOG.info(_('Connected to AMQP server on %s'), broker)
+ break
+
+ self.session = self.connection.session()
+
+ if self.consumers:
+ consumers = self.consumers
+ self.consumers = {}
+
+ for consumer in consumers.itervalues():
+ consumer.reconnect(self.session)
+ self._register_consumer(consumer)
+
+ LOG.debug(_("Re-established AMQP queues"))
+
+ def ensure(self, error_callback, method, *args, **kwargs):
+ while True:
+ try:
+ return method(*args, **kwargs)
+ except (qpid_exceptions.Empty,
+ qpid_exceptions.ConnectionError) as e:
+ if error_callback:
+ error_callback(e)
+ self.reconnect()
+
+ def close(self):
+ """Close/release this connection."""
+ self.cancel_consumer_thread()
+ self.wait_on_proxy_callbacks()
+ try:
+ self.connection.close()
+ except Exception:
+ # NOTE(dripton) Logging exceptions that happen during cleanup just
+ # causes confusion; there's really nothing useful we can do with
+ # them.
+ pass
+ self.connection = None
+
+ def reset(self):
+ """Reset a connection so it can be used again."""
+ self.cancel_consumer_thread()
+ self.wait_on_proxy_callbacks()
+ self.session.close()
+ self.session = self.connection.session()
+ self.consumers = {}
+
+ def declare_consumer(self, consumer_cls, topic, callback):
+ """Create a Consumer using the class that was passed in and
+ add it to our list of consumers
+ """
+ def _connect_error(exc):
+ log_info = {'topic': topic, 'err_str': str(exc)}
+ LOG.error(_("Failed to declare consumer for topic '%(topic)s': "
+ "%(err_str)s") % log_info)
+
+ def _declare_consumer():
+ consumer = consumer_cls(self.conf, self.session, topic, callback)
+ self._register_consumer(consumer)
+ return consumer
+
+ return self.ensure(_connect_error, _declare_consumer)
+
+ def iterconsume(self, limit=None, timeout=None):
+ """Return an iterator that will consume from all queues/consumers."""
+
+ def _error_callback(exc):
+ if isinstance(exc, qpid_exceptions.Empty):
+ LOG.debug(_('Timed out waiting for RPC response: %s') %
+ str(exc))
+ raise rpc_common.Timeout()
+ else:
+ LOG.exception(_('Failed to consume message from queue: %s') %
+ str(exc))
+
+ def _consume():
+ nxt_receiver = self.session.next_receiver(timeout=timeout)
+ try:
+ self._lookup_consumer(nxt_receiver).consume()
+ except Exception:
+ LOG.exception(_("Error processing message. Skipping it."))
+
+ for iteration in itertools.count(0):
+ if limit and iteration >= limit:
+ raise StopIteration
+ yield self.ensure(_error_callback, _consume)
+
+ def cancel_consumer_thread(self):
+ """Cancel a consumer thread."""
+ if self.consumer_thread is not None:
+ self.consumer_thread.kill()
+ try:
+ self.consumer_thread.wait()
+ except greenlet.GreenletExit:
+ pass
+ self.consumer_thread = None
+
+ def wait_on_proxy_callbacks(self):
+ """Wait for all proxy callback threads to exit."""
+ for proxy_cb in self.proxy_callbacks:
+ proxy_cb.wait()
+
+ def publisher_send(self, cls, topic, msg):
+ """Send to a publisher based on the publisher class."""
+
+ def _connect_error(exc):
+ log_info = {'topic': topic, 'err_str': str(exc)}
+ LOG.exception(_("Failed to publish message to topic "
+ "'%(topic)s': %(err_str)s") % log_info)
+
+ def _publisher_send():
+ publisher = cls(self.conf, self.session, topic)
+ publisher.send(msg)
+
+ return self.ensure(_connect_error, _publisher_send)
+
+ def declare_direct_consumer(self, topic, callback):
+ """Create a 'direct' queue.
+ In nova's use, this is generally a msg_id queue used for
+ responses for call/multicall
+ """
+ self.declare_consumer(DirectConsumer, topic, callback)
+
+ def declare_topic_consumer(self, topic, callback=None, queue_name=None,
+ exchange_name=None):
+ """Create a 'topic' consumer."""
+ self.declare_consumer(functools.partial(TopicConsumer,
+ name=queue_name,
+ exchange_name=exchange_name,
+ ),
+ topic, callback)
+
+ def declare_fanout_consumer(self, topic, callback):
+ """Create a 'fanout' consumer."""
+ self.declare_consumer(FanoutConsumer, topic, callback)
+
+ def direct_send(self, msg_id, msg):
+ """Send a 'direct' message."""
+ self.publisher_send(DirectPublisher, msg_id, msg)
+
+ def topic_send(self, topic, msg, timeout=None):
+ """Send a 'topic' message."""
+ #
+ # We want to create a message with attributes, e.g. a TTL. We
+ # don't really need to keep 'msg' in its JSON format any longer
+ # so let's create an actual qpid message here and get some
+ # value-add on the go.
+ #
+ # WARNING: Request timeout happens to be in the same units as
+ # qpid's TTL (seconds). If this changes in the future, then this
+ # will need to be altered accordingly.
+ #
+ qpid_message = qpid_messaging.Message(content=msg, ttl=timeout)
+ self.publisher_send(TopicPublisher, topic, qpid_message)
+
+ def fanout_send(self, topic, msg):
+ """Send a 'fanout' message."""
+ self.publisher_send(FanoutPublisher, topic, msg)
+
+ def notify_send(self, topic, msg, **kwargs):
+ """Send a notify message on a topic."""
+ self.publisher_send(NotifyPublisher, topic, msg)
+
+ def consume(self, limit=None):
+ """Consume from all queues/consumers."""
+ it = self.iterconsume(limit=limit)
+ while True:
+ try:
+ it.next()
+ except StopIteration:
+ return
+
+ def consume_in_thread(self):
+ """Consumer from all queues/consumers in a greenthread."""
+ @excutils.forever_retry_uncaught_exceptions
+ def _consumer_thread():
+ try:
+ self.consume()
+ except greenlet.GreenletExit:
+ return
+ if self.consumer_thread is None:
+ self.consumer_thread = eventlet.spawn(_consumer_thread)
+ return self.consumer_thread
+
+ def create_consumer(self, topic, proxy, fanout=False):
+ """Create a consumer that calls a method in a proxy object."""
+ proxy_cb = rpc_amqp.ProxyCallback(
+ self.conf, proxy,
+ rpc_amqp.get_connection_pool(self.conf, Connection))
+ self.proxy_callbacks.append(proxy_cb)
+
+ if fanout:
+ consumer = FanoutConsumer(self.conf, self.session, topic, proxy_cb)
+ else:
+ consumer = TopicConsumer(self.conf, self.session, topic, proxy_cb)
+
+ self._register_consumer(consumer)
+
+ return consumer
+
+ def create_worker(self, topic, proxy, pool_name):
+ """Create a worker that calls a method in a proxy object."""
+ proxy_cb = rpc_amqp.ProxyCallback(
+ self.conf, proxy,
+ rpc_amqp.get_connection_pool(self.conf, Connection))
+ self.proxy_callbacks.append(proxy_cb)
+
+ consumer = TopicConsumer(self.conf, self.session, topic, proxy_cb,
+ name=pool_name)
+
+ self._register_consumer(consumer)
+
+ return consumer
+
+ def join_consumer_pool(self, callback, pool_name, topic,
+ exchange_name=None, ack_on_error=True):
+ """Register as a member of a group of consumers for a given topic from
+ the specified exchange.
+
+ Exactly one member of a given pool will receive each message.
+
+ A message will be delivered to multiple pools, if more than
+ one is created.
+ """
+ callback_wrapper = rpc_amqp.CallbackWrapper(
+ conf=self.conf,
+ callback=callback,
+ connection_pool=rpc_amqp.get_connection_pool(self.conf,
+ Connection),
+ )
+ self.proxy_callbacks.append(callback_wrapper)
+
+ consumer = TopicConsumer(conf=self.conf,
+ session=self.session,
+ topic=topic,
+ callback=callback_wrapper,
+ name=pool_name,
+ exchange_name=exchange_name)
+
+ self._register_consumer(consumer)
+ return consumer
+
+
+def create_connection(conf, new=True):
+ """Create a connection."""
+ return rpc_amqp.create_connection(
+ conf, new,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def multicall(conf, context, topic, msg, timeout=None):
+ """Make a call that returns multiple times."""
+ return rpc_amqp.multicall(
+ conf, context, topic, msg, timeout,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def call(conf, context, topic, msg, timeout=None):
+ """Sends a message on a topic and wait for a response."""
+ return rpc_amqp.call(
+ conf, context, topic, msg, timeout,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def cast(conf, context, topic, msg):
+ """Sends a message on a topic without waiting for a response."""
+ return rpc_amqp.cast(
+ conf, context, topic, msg,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def fanout_cast(conf, context, topic, msg):
+ """Sends a message on a fanout exchange without waiting for a response."""
+ return rpc_amqp.fanout_cast(
+ conf, context, topic, msg,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def cast_to_server(conf, context, server_params, topic, msg):
+ """Sends a message on a topic to a specific server."""
+ return rpc_amqp.cast_to_server(
+ conf, context, server_params, topic, msg,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def fanout_cast_to_server(conf, context, server_params, topic, msg):
+ """Sends a message on a fanout exchange to a specific server."""
+ return rpc_amqp.fanout_cast_to_server(
+ conf, context, server_params, topic, msg,
+ rpc_amqp.get_connection_pool(conf, Connection))
+
+
+def notify(conf, context, topic, msg, envelope):
+ """Sends a notification event on a topic."""
+ return rpc_amqp.notify(conf, context, topic, msg,
+ rpc_amqp.get_connection_pool(conf, Connection),
+ envelope)
+
+
+def cleanup():
+ return rpc_amqp.cleanup(Connection.pool)
diff --git a/keystone/openstack/common/rpc/impl_zmq.py b/keystone/openstack/common/rpc/impl_zmq.py
new file mode 100644
index 00000000..1aaf8575
--- /dev/null
+++ b/keystone/openstack/common/rpc/impl_zmq.py
@@ -0,0 +1,817 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 Cloudscaling Group, Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import os
+import pprint
+import re
+import socket
+import sys
+import types
+import uuid
+
+import eventlet
+import greenlet
+from oslo.config import cfg
+
+from keystone.openstack.common import excutils
+from keystone.openstack.common.gettextutils import _ # noqa
+from keystone.openstack.common import importutils
+from keystone.openstack.common import jsonutils
+from keystone.openstack.common.rpc import common as rpc_common
+
+zmq = importutils.try_import('eventlet.green.zmq')
+
+# for convenience, are not modified.
+pformat = pprint.pformat
+Timeout = eventlet.timeout.Timeout
+LOG = rpc_common.LOG
+RemoteError = rpc_common.RemoteError
+RPCException = rpc_common.RPCException
+
+zmq_opts = [
+ cfg.StrOpt('rpc_zmq_bind_address', default='*',
+ help='ZeroMQ bind address. Should be a wildcard (*), '
+ 'an ethernet interface, or IP. '
+ 'The "host" option should point or resolve to this '
+ 'address.'),
+
+ # The module.Class to use for matchmaking.
+ cfg.StrOpt(
+ 'rpc_zmq_matchmaker',
+ default=('keystone.openstack.common.rpc.'
+ 'matchmaker.MatchMakerLocalhost'),
+ help='MatchMaker driver',
+ ),
+
+ # The following port is unassigned by IANA as of 2012-05-21
+ cfg.IntOpt('rpc_zmq_port', default=9501,
+ help='ZeroMQ receiver listening port'),
+
+ cfg.IntOpt('rpc_zmq_contexts', default=1,
+ help='Number of ZeroMQ contexts, defaults to 1'),
+
+ cfg.IntOpt('rpc_zmq_topic_backlog', default=None,
+ help='Maximum number of ingress messages to locally buffer '
+ 'per topic. Default is unlimited.'),
+
+ cfg.StrOpt('rpc_zmq_ipc_dir', default='/var/run/openstack',
+ help='Directory for holding IPC sockets'),
+
+ cfg.StrOpt('rpc_zmq_host', default=socket.gethostname(),
+ help='Name of this node. Must be a valid hostname, FQDN, or '
+ 'IP address. Must match "host" option, if running Nova.')
+]
+
+
+CONF = cfg.CONF
+CONF.register_opts(zmq_opts)
+
+ZMQ_CTX = None # ZeroMQ Context, must be global.
+matchmaker = None # memoized matchmaker object
+
+
+def _serialize(data):
+ """Serialization wrapper.
+
+ We prefer using JSON, but it cannot encode all types.
+ Error if a developer passes us bad data.
+ """
+ try:
+ return jsonutils.dumps(data, ensure_ascii=True)
+ except TypeError:
+ with excutils.save_and_reraise_exception():
+ LOG.error(_("JSON serialization failed."))
+
+
+def _deserialize(data):
+ """Deserialization wrapper."""
+ LOG.debug(_("Deserializing: %s"), data)
+ return jsonutils.loads(data)
+
+
+class ZmqSocket(object):
+ """A tiny wrapper around ZeroMQ.
+
+ Simplifies the send/recv protocol and connection management.
+ Can be used as a Context (supports the 'with' statement).
+ """
+
+ def __init__(self, addr, zmq_type, bind=True, subscribe=None):
+ self.sock = _get_ctxt().socket(zmq_type)
+ self.addr = addr
+ self.type = zmq_type
+ self.subscriptions = []
+
+ # Support failures on sending/receiving on wrong socket type.
+ self.can_recv = zmq_type in (zmq.PULL, zmq.SUB)
+ self.can_send = zmq_type in (zmq.PUSH, zmq.PUB)
+ self.can_sub = zmq_type in (zmq.SUB, )
+
+ # Support list, str, & None for subscribe arg (cast to list)
+ do_sub = {
+ list: subscribe,
+ str: [subscribe],
+ type(None): []
+ }[type(subscribe)]
+
+ for f in do_sub:
+ self.subscribe(f)
+
+ str_data = {'addr': addr, 'type': self.socket_s(),
+ 'subscribe': subscribe, 'bind': bind}
+
+ LOG.debug(_("Connecting to %(addr)s with %(type)s"), str_data)
+ LOG.debug(_("-> Subscribed to %(subscribe)s"), str_data)
+ LOG.debug(_("-> bind: %(bind)s"), str_data)
+
+ try:
+ if bind:
+ self.sock.bind(addr)
+ else:
+ self.sock.connect(addr)
+ except Exception:
+ raise RPCException(_("Could not open socket."))
+
+ def socket_s(self):
+ """Get socket type as string."""
+ t_enum = ('PUSH', 'PULL', 'PUB', 'SUB', 'REP', 'REQ', 'ROUTER',
+ 'DEALER')
+ return dict(map(lambda t: (getattr(zmq, t), t), t_enum))[self.type]
+
+ def subscribe(self, msg_filter):
+ """Subscribe."""
+ if not self.can_sub:
+ raise RPCException("Cannot subscribe on this socket.")
+ LOG.debug(_("Subscribing to %s"), msg_filter)
+
+ try:
+ self.sock.setsockopt(zmq.SUBSCRIBE, msg_filter)
+ except Exception:
+ return
+
+ self.subscriptions.append(msg_filter)
+
+ def unsubscribe(self, msg_filter):
+ """Unsubscribe."""
+ if msg_filter not in self.subscriptions:
+ return
+ self.sock.setsockopt(zmq.UNSUBSCRIBE, msg_filter)
+ self.subscriptions.remove(msg_filter)
+
+ def close(self):
+ if self.sock is None or self.sock.closed:
+ return
+
+ # We must unsubscribe, or we'll leak descriptors.
+ if self.subscriptions:
+ for f in self.subscriptions:
+ try:
+ self.sock.setsockopt(zmq.UNSUBSCRIBE, f)
+ except Exception:
+ pass
+ self.subscriptions = []
+
+ try:
+ # Default is to linger
+ self.sock.close()
+ except Exception:
+ # While this is a bad thing to happen,
+ # it would be much worse if some of the code calling this
+ # were to fail. For now, lets log, and later evaluate
+ # if we can safely raise here.
+ LOG.error("ZeroMQ socket could not be closed.")
+ self.sock = None
+
+ def recv(self, **kwargs):
+ if not self.can_recv:
+ raise RPCException(_("You cannot recv on this socket."))
+ return self.sock.recv_multipart(**kwargs)
+
+ def send(self, data, **kwargs):
+ if not self.can_send:
+ raise RPCException(_("You cannot send on this socket."))
+ self.sock.send_multipart(data, **kwargs)
+
+
+class ZmqClient(object):
+ """Client for ZMQ sockets."""
+
+ def __init__(self, addr):
+ self.outq = ZmqSocket(addr, zmq.PUSH, bind=False)
+
+ def cast(self, msg_id, topic, data, envelope):
+ msg_id = msg_id or 0
+
+ if not envelope:
+ self.outq.send(map(bytes,
+ (msg_id, topic, 'cast', _serialize(data))))
+ return
+
+ rpc_envelope = rpc_common.serialize_msg(data[1], envelope)
+ zmq_msg = reduce(lambda x, y: x + y, rpc_envelope.items())
+ self.outq.send(map(bytes,
+ (msg_id, topic, 'impl_zmq_v2', data[0]) + zmq_msg))
+
+ def close(self):
+ self.outq.close()
+
+
+class RpcContext(rpc_common.CommonRpcContext):
+ """Context that supports replying to a rpc.call."""
+ def __init__(self, **kwargs):
+ self.replies = []
+ super(RpcContext, self).__init__(**kwargs)
+
+ def deepcopy(self):
+ values = self.to_dict()
+ values['replies'] = self.replies
+ return self.__class__(**values)
+
+ def reply(self, reply=None, failure=None, ending=False):
+ if ending:
+ return
+ self.replies.append(reply)
+
+ @classmethod
+ def marshal(self, ctx):
+ ctx_data = ctx.to_dict()
+ return _serialize(ctx_data)
+
+ @classmethod
+ def unmarshal(self, data):
+ return RpcContext.from_dict(_deserialize(data))
+
+
+class InternalContext(object):
+ """Used by ConsumerBase as a private context for - methods."""
+
+ def __init__(self, proxy):
+ self.proxy = proxy
+ self.msg_waiter = None
+
+ def _get_response(self, ctx, proxy, topic, data):
+ """Process a curried message and cast the result to topic."""
+ LOG.debug(_("Running func with context: %s"), ctx.to_dict())
+ data.setdefault('version', None)
+ data.setdefault('args', {})
+
+ try:
+ result = proxy.dispatch(
+ ctx, data['version'], data['method'],
+ data.get('namespace'), **data['args'])
+ return ConsumerBase.normalize_reply(result, ctx.replies)
+ except greenlet.GreenletExit:
+ # ignore these since they are just from shutdowns
+ pass
+ except rpc_common.ClientException as e:
+ LOG.debug(_("Expected exception during message handling (%s)") %
+ e._exc_info[1])
+ return {'exc':
+ rpc_common.serialize_remote_exception(e._exc_info,
+ log_failure=False)}
+ except Exception:
+ LOG.error(_("Exception during message handling"))
+ return {'exc':
+ rpc_common.serialize_remote_exception(sys.exc_info())}
+
+ def reply(self, ctx, proxy,
+ msg_id=None, context=None, topic=None, msg=None):
+ """Reply to a casted call."""
+ # NOTE(ewindisch): context kwarg exists for Grizzly compat.
+ # this may be able to be removed earlier than
+ # 'I' if ConsumerBase.process were refactored.
+ if type(msg) is list:
+ payload = msg[-1]
+ else:
+ payload = msg
+
+ response = ConsumerBase.normalize_reply(
+ self._get_response(ctx, proxy, topic, payload),
+ ctx.replies)
+
+ LOG.debug(_("Sending reply"))
+ _multi_send(_cast, ctx, topic, {
+ 'method': '-process_reply',
+ 'args': {
+ 'msg_id': msg_id, # Include for Folsom compat.
+ 'response': response
+ }
+ }, _msg_id=msg_id)
+
+
+class ConsumerBase(object):
+ """Base Consumer."""
+
+ def __init__(self):
+ self.private_ctx = InternalContext(None)
+
+ @classmethod
+ def normalize_reply(self, result, replies):
+ #TODO(ewindisch): re-evaluate and document this method.
+ if isinstance(result, types.GeneratorType):
+ return list(result)
+ elif replies:
+ return replies
+ else:
+ return [result]
+
+ def process(self, proxy, ctx, data):
+ data.setdefault('version', None)
+ data.setdefault('args', {})
+
+ # Method starting with - are
+ # processed internally. (non-valid method name)
+ method = data.get('method')
+ if not method:
+ LOG.error(_("RPC message did not include method."))
+ return
+
+ # Internal method
+ # uses internal context for safety.
+ if method == '-reply':
+ self.private_ctx.reply(ctx, proxy, **data['args'])
+ return
+
+ proxy.dispatch(ctx, data['version'],
+ data['method'], data.get('namespace'), **data['args'])
+
+
+class ZmqBaseReactor(ConsumerBase):
+ """A consumer class implementing a centralized casting broker (PULL-PUSH).
+
+ Used for RoundRobin requests.
+ """
+
+ def __init__(self, conf):
+ super(ZmqBaseReactor, self).__init__()
+
+ self.proxies = {}
+ self.threads = []
+ self.sockets = []
+ self.subscribe = {}
+
+ self.pool = eventlet.greenpool.GreenPool(conf.rpc_thread_pool_size)
+
+ def register(self, proxy, in_addr, zmq_type_in,
+ in_bind=True, subscribe=None):
+
+ LOG.info(_("Registering reactor"))
+
+ if zmq_type_in not in (zmq.PULL, zmq.SUB):
+ raise RPCException("Bad input socktype")
+
+ # Items push in.
+ inq = ZmqSocket(in_addr, zmq_type_in, bind=in_bind,
+ subscribe=subscribe)
+
+ self.proxies[inq] = proxy
+ self.sockets.append(inq)
+
+ LOG.info(_("In reactor registered"))
+
+ def consume_in_thread(self):
+ def _consume(sock):
+ LOG.info(_("Consuming socket"))
+ while True:
+ self.consume(sock)
+
+ for k in self.proxies.keys():
+ self.threads.append(
+ self.pool.spawn(_consume, k)
+ )
+
+ def wait(self):
+ for t in self.threads:
+ t.wait()
+
+ def close(self):
+ for s in self.sockets:
+ s.close()
+
+ for t in self.threads:
+ t.kill()
+
+
+class ZmqProxy(ZmqBaseReactor):
+ """A consumer class implementing a topic-based proxy.
+
+ Forwards to IPC sockets.
+ """
+
+ def __init__(self, conf):
+ super(ZmqProxy, self).__init__(conf)
+ pathsep = set((os.path.sep or '', os.path.altsep or '', '/', '\\'))
+ self.badchars = re.compile(r'[%s]' % re.escape(''.join(pathsep)))
+
+ self.topic_proxy = {}
+
+ def consume(self, sock):
+ ipc_dir = CONF.rpc_zmq_ipc_dir
+
+ data = sock.recv(copy=False)
+ topic = data[1].bytes
+
+ if topic.startswith('fanout~'):
+ sock_type = zmq.PUB
+ topic = topic.split('.', 1)[0]
+ elif topic.startswith('zmq_replies'):
+ sock_type = zmq.PUB
+ else:
+ sock_type = zmq.PUSH
+
+ if topic not in self.topic_proxy:
+ def publisher(waiter):
+ LOG.info(_("Creating proxy for topic: %s"), topic)
+
+ try:
+ # The topic is received over the network,
+ # don't trust this input.
+ if self.badchars.search(topic) is not None:
+ emsg = _("Topic contained dangerous characters.")
+ LOG.warn(emsg)
+ raise RPCException(emsg)
+
+ out_sock = ZmqSocket("ipc://%s/zmq_topic_%s" %
+ (ipc_dir, topic),
+ sock_type, bind=True)
+ except RPCException:
+ waiter.send_exception(*sys.exc_info())
+ return
+
+ self.topic_proxy[topic] = eventlet.queue.LightQueue(
+ CONF.rpc_zmq_topic_backlog)
+ self.sockets.append(out_sock)
+
+ # It takes some time for a pub socket to open,
+ # before we can have any faith in doing a send() to it.
+ if sock_type == zmq.PUB:
+ eventlet.sleep(.5)
+
+ waiter.send(True)
+
+ while(True):
+ data = self.topic_proxy[topic].get()
+ out_sock.send(data, copy=False)
+
+ wait_sock_creation = eventlet.event.Event()
+ eventlet.spawn(publisher, wait_sock_creation)
+
+ try:
+ wait_sock_creation.wait()
+ except RPCException:
+ LOG.error(_("Topic socket file creation failed."))
+ return
+
+ try:
+ self.topic_proxy[topic].put_nowait(data)
+ except eventlet.queue.Full:
+ LOG.error(_("Local per-topic backlog buffer full for topic "
+ "%(topic)s. Dropping message.") % {'topic': topic})
+
+ def consume_in_thread(self):
+ """Runs the ZmqProxy service."""
+ ipc_dir = CONF.rpc_zmq_ipc_dir
+ consume_in = "tcp://%s:%s" % \
+ (CONF.rpc_zmq_bind_address,
+ CONF.rpc_zmq_port)
+ consumption_proxy = InternalContext(None)
+
+ try:
+ os.makedirs(ipc_dir)
+ except os.error:
+ if not os.path.isdir(ipc_dir):
+ with excutils.save_and_reraise_exception():
+ LOG.error(_("Required IPC directory does not exist at"
+ " %s") % (ipc_dir, ))
+ try:
+ self.register(consumption_proxy,
+ consume_in,
+ zmq.PULL)
+ except zmq.ZMQError:
+ if os.access(ipc_dir, os.X_OK):
+ with excutils.save_and_reraise_exception():
+ LOG.error(_("Permission denied to IPC directory at"
+ " %s") % (ipc_dir, ))
+ with excutils.save_and_reraise_exception():
+ LOG.error(_("Could not create ZeroMQ receiver daemon. "
+ "Socket may already be in use."))
+
+ super(ZmqProxy, self).consume_in_thread()
+
+
+def unflatten_envelope(packenv):
+ """Unflattens the RPC envelope.
+
+ Takes a list and returns a dictionary.
+ i.e. [1,2,3,4] => {1: 2, 3: 4}
+ """
+ i = iter(packenv)
+ h = {}
+ try:
+ while True:
+ k = i.next()
+ h[k] = i.next()
+ except StopIteration:
+ return h
+
+
+class ZmqReactor(ZmqBaseReactor):
+ """A consumer class implementing a consumer for messages.
+
+ Can also be used as a 1:1 proxy
+ """
+
+ def __init__(self, conf):
+ super(ZmqReactor, self).__init__(conf)
+
+ def consume(self, sock):
+ #TODO(ewindisch): use zero-copy (i.e. references, not copying)
+ data = sock.recv()
+ LOG.debug(_("CONSUMER RECEIVED DATA: %s"), data)
+
+ proxy = self.proxies[sock]
+
+ if data[2] == 'cast': # Legacy protocol
+ packenv = data[3]
+
+ ctx, msg = _deserialize(packenv)
+ request = rpc_common.deserialize_msg(msg)
+ ctx = RpcContext.unmarshal(ctx)
+ elif data[2] == 'impl_zmq_v2':
+ packenv = data[4:]
+
+ msg = unflatten_envelope(packenv)
+ request = rpc_common.deserialize_msg(msg)
+
+ # Unmarshal only after verifying the message.
+ ctx = RpcContext.unmarshal(data[3])
+ else:
+ LOG.error(_("ZMQ Envelope version unsupported or unknown."))
+ return
+
+ self.pool.spawn_n(self.process, proxy, ctx, request)
+
+
+class Connection(rpc_common.Connection):
+ """Manages connections and threads."""
+
+ def __init__(self, conf):
+ self.topics = []
+ self.reactor = ZmqReactor(conf)
+
+ def create_consumer(self, topic, proxy, fanout=False):
+ # Register with matchmaker.
+ _get_matchmaker().register(topic, CONF.rpc_zmq_host)
+
+ # Subscription scenarios
+ if fanout:
+ sock_type = zmq.SUB
+ subscribe = ('', fanout)[type(fanout) == str]
+ topic = 'fanout~' + topic.split('.', 1)[0]
+ else:
+ sock_type = zmq.PULL
+ subscribe = None
+ topic = '.'.join((topic.split('.', 1)[0], CONF.rpc_zmq_host))
+
+ if topic in self.topics:
+ LOG.info(_("Skipping topic registration. Already registered."))
+ return
+
+ # Receive messages from (local) proxy
+ inaddr = "ipc://%s/zmq_topic_%s" % \
+ (CONF.rpc_zmq_ipc_dir, topic)
+
+ LOG.debug(_("Consumer is a zmq.%s"),
+ ['PULL', 'SUB'][sock_type == zmq.SUB])
+
+ self.reactor.register(proxy, inaddr, sock_type,
+ subscribe=subscribe, in_bind=False)
+ self.topics.append(topic)
+
+ def close(self):
+ _get_matchmaker().stop_heartbeat()
+ for topic in self.topics:
+ _get_matchmaker().unregister(topic, CONF.rpc_zmq_host)
+
+ self.reactor.close()
+ self.topics = []
+
+ def wait(self):
+ self.reactor.wait()
+
+ def consume_in_thread(self):
+ _get_matchmaker().start_heartbeat()
+ self.reactor.consume_in_thread()
+
+
+def _cast(addr, context, topic, msg, timeout=None, envelope=False,
+ _msg_id=None):
+ timeout_cast = timeout or CONF.rpc_cast_timeout
+ payload = [RpcContext.marshal(context), msg]
+
+ with Timeout(timeout_cast, exception=rpc_common.Timeout):
+ try:
+ conn = ZmqClient(addr)
+
+ # assumes cast can't return an exception
+ conn.cast(_msg_id, topic, payload, envelope)
+ except zmq.ZMQError:
+ raise RPCException("Cast failed. ZMQ Socket Exception")
+ finally:
+ if 'conn' in vars():
+ conn.close()
+
+
+def _call(addr, context, topic, msg, timeout=None,
+ envelope=False):
+ # timeout_response is how long we wait for a response
+ timeout = timeout or CONF.rpc_response_timeout
+
+ # The msg_id is used to track replies.
+ msg_id = uuid.uuid4().hex
+
+ # Replies always come into the reply service.
+ reply_topic = "zmq_replies.%s" % CONF.rpc_zmq_host
+
+ LOG.debug(_("Creating payload"))
+ # Curry the original request into a reply method.
+ mcontext = RpcContext.marshal(context)
+ payload = {
+ 'method': '-reply',
+ 'args': {
+ 'msg_id': msg_id,
+ 'topic': reply_topic,
+ # TODO(ewindisch): safe to remove mcontext in I.
+ 'msg': [mcontext, msg]
+ }
+ }
+
+ LOG.debug(_("Creating queue socket for reply waiter"))
+
+ # Messages arriving async.
+ # TODO(ewindisch): have reply consumer with dynamic subscription mgmt
+ with Timeout(timeout, exception=rpc_common.Timeout):
+ try:
+ msg_waiter = ZmqSocket(
+ "ipc://%s/zmq_topic_zmq_replies.%s" %
+ (CONF.rpc_zmq_ipc_dir,
+ CONF.rpc_zmq_host),
+ zmq.SUB, subscribe=msg_id, bind=False
+ )
+
+ LOG.debug(_("Sending cast"))
+ _cast(addr, context, topic, payload, envelope)
+
+ LOG.debug(_("Cast sent; Waiting reply"))
+ # Blocks until receives reply
+ msg = msg_waiter.recv()
+ LOG.debug(_("Received message: %s"), msg)
+ LOG.debug(_("Unpacking response"))
+
+ if msg[2] == 'cast': # Legacy version
+ raw_msg = _deserialize(msg[-1])[-1]
+ elif msg[2] == 'impl_zmq_v2':
+ rpc_envelope = unflatten_envelope(msg[4:])
+ raw_msg = rpc_common.deserialize_msg(rpc_envelope)
+ else:
+ raise rpc_common.UnsupportedRpcEnvelopeVersion(
+ _("Unsupported or unknown ZMQ envelope returned."))
+
+ responses = raw_msg['args']['response']
+ # ZMQError trumps the Timeout error.
+ except zmq.ZMQError:
+ raise RPCException("ZMQ Socket Error")
+ except (IndexError, KeyError):
+ raise RPCException(_("RPC Message Invalid."))
+ finally:
+ if 'msg_waiter' in vars():
+ msg_waiter.close()
+
+ # It seems we don't need to do all of the following,
+ # but perhaps it would be useful for multicall?
+ # One effect of this is that we're checking all
+ # responses for Exceptions.
+ for resp in responses:
+ if isinstance(resp, types.DictType) and 'exc' in resp:
+ raise rpc_common.deserialize_remote_exception(CONF, resp['exc'])
+
+ return responses[-1]
+
+
+def _multi_send(method, context, topic, msg, timeout=None,
+ envelope=False, _msg_id=None):
+ """Wraps the sending of messages.
+
+ Dispatches to the matchmaker and sends message to all relevant hosts.
+ """
+ conf = CONF
+ LOG.debug(_("%(msg)s") % {'msg': ' '.join(map(pformat, (topic, msg)))})
+
+ queues = _get_matchmaker().queues(topic)
+ LOG.debug(_("Sending message(s) to: %s"), queues)
+
+ # Don't stack if we have no matchmaker results
+ if not queues:
+ LOG.warn(_("No matchmaker results. Not casting."))
+ # While not strictly a timeout, callers know how to handle
+ # this exception and a timeout isn't too big a lie.
+ raise rpc_common.Timeout(_("No match from matchmaker."))
+
+ # This supports brokerless fanout (addresses > 1)
+ for queue in queues:
+ (_topic, ip_addr) = queue
+ _addr = "tcp://%s:%s" % (ip_addr, conf.rpc_zmq_port)
+
+ if method.__name__ == '_cast':
+ eventlet.spawn_n(method, _addr, context,
+ _topic, msg, timeout, envelope,
+ _msg_id)
+ return
+ return method(_addr, context, _topic, msg, timeout,
+ envelope)
+
+
+def create_connection(conf, new=True):
+ return Connection(conf)
+
+
+def multicall(conf, *args, **kwargs):
+ """Multiple calls."""
+ return _multi_send(_call, *args, **kwargs)
+
+
+def call(conf, *args, **kwargs):
+ """Send a message, expect a response."""
+ data = _multi_send(_call, *args, **kwargs)
+ return data[-1]
+
+
+def cast(conf, *args, **kwargs):
+ """Send a message expecting no reply."""
+ _multi_send(_cast, *args, **kwargs)
+
+
+def fanout_cast(conf, context, topic, msg, **kwargs):
+ """Send a message to all listening and expect no reply."""
+ # NOTE(ewindisch): fanout~ is used because it avoid splitting on .
+ # and acts as a non-subtle hint to the matchmaker and ZmqProxy.
+ _multi_send(_cast, context, 'fanout~' + str(topic), msg, **kwargs)
+
+
+def notify(conf, context, topic, msg, envelope):
+ """Send notification event.
+
+ Notifications are sent to topic-priority.
+ This differs from the AMQP drivers which send to topic.priority.
+ """
+ # NOTE(ewindisch): dot-priority in rpc notifier does not
+ # work with our assumptions.
+ topic = topic.replace('.', '-')
+ cast(conf, context, topic, msg, envelope=envelope)
+
+
+def cleanup():
+ """Clean up resources in use by implementation."""
+ global ZMQ_CTX
+ if ZMQ_CTX:
+ ZMQ_CTX.term()
+ ZMQ_CTX = None
+
+ global matchmaker
+ matchmaker = None
+
+
+def _get_ctxt():
+ if not zmq:
+ raise ImportError("Failed to import eventlet.green.zmq")
+
+ global ZMQ_CTX
+ if not ZMQ_CTX:
+ ZMQ_CTX = zmq.Context(CONF.rpc_zmq_contexts)
+ return ZMQ_CTX
+
+
+def _get_matchmaker(*args, **kwargs):
+ global matchmaker
+ if not matchmaker:
+ mm = CONF.rpc_zmq_matchmaker
+ if mm.endswith('matchmaker.MatchMakerRing'):
+ mm.replace('matchmaker', 'matchmaker_ring')
+ LOG.warn(_('rpc_zmq_matchmaker = %(orig)s is deprecated; use'
+ ' %(new)s instead') % dict(
+ orig=CONF.rpc_zmq_matchmaker, new=mm))
+ matchmaker = importutils.import_object(mm, *args, **kwargs)
+ return matchmaker
diff --git a/keystone/openstack/common/rpc/matchmaker.py b/keystone/openstack/common/rpc/matchmaker.py
new file mode 100644
index 00000000..ff3fcbc7
--- /dev/null
+++ b/keystone/openstack/common/rpc/matchmaker.py
@@ -0,0 +1,324 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 Cloudscaling Group, Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+The MatchMaker classes should except a Topic or Fanout exchange key and
+return keys for direct exchanges, per (approximate) AMQP parlance.
+"""
+
+import contextlib
+
+import eventlet
+from oslo.config import cfg
+
+from keystone.openstack.common.gettextutils import _ # noqa
+from keystone.openstack.common import log as logging
+
+
+matchmaker_opts = [
+ cfg.IntOpt('matchmaker_heartbeat_freq',
+ default=300,
+ help='Heartbeat frequency'),
+ cfg.IntOpt('matchmaker_heartbeat_ttl',
+ default=600,
+ help='Heartbeat time-to-live.'),
+]
+
+CONF = cfg.CONF
+CONF.register_opts(matchmaker_opts)
+LOG = logging.getLogger(__name__)
+contextmanager = contextlib.contextmanager
+
+
+class MatchMakerException(Exception):
+ """Signified a match could not be found."""
+ message = _("Match not found by MatchMaker.")
+
+
+class Exchange(object):
+ """Implements lookups.
+
+ Subclass this to support hashtables, dns, etc.
+ """
+ def __init__(self):
+ pass
+
+ def run(self, key):
+ raise NotImplementedError()
+
+
+class Binding(object):
+ """A binding on which to perform a lookup."""
+ def __init__(self):
+ pass
+
+ def test(self, key):
+ raise NotImplementedError()
+
+
+class MatchMakerBase(object):
+ """Match Maker Base Class.
+
+ Build off HeartbeatMatchMakerBase if building a heartbeat-capable
+ MatchMaker.
+ """
+ def __init__(self):
+ # Array of tuples. Index [2] toggles negation, [3] is last-if-true
+ self.bindings = []
+
+ self.no_heartbeat_msg = _('Matchmaker does not implement '
+ 'registration or heartbeat.')
+
+ def register(self, key, host):
+ """Register a host on a backend.
+
+ Heartbeats, if applicable, may keepalive registration.
+ """
+ pass
+
+ def ack_alive(self, key, host):
+ """Acknowledge that a key.host is alive.
+
+ Used internally for updating heartbeats, but may also be used
+ publically to acknowledge a system is alive (i.e. rpc message
+ successfully sent to host)
+ """
+ pass
+
+ def is_alive(self, topic, host):
+ """Checks if a host is alive."""
+ pass
+
+ def expire(self, topic, host):
+ """Explicitly expire a host's registration."""
+ pass
+
+ def send_heartbeats(self):
+ """Send all heartbeats.
+
+ Use start_heartbeat to spawn a heartbeat greenthread,
+ which loops this method.
+ """
+ pass
+
+ def unregister(self, key, host):
+ """Unregister a topic."""
+ pass
+
+ def start_heartbeat(self):
+ """Spawn heartbeat greenthread."""
+ pass
+
+ def stop_heartbeat(self):
+ """Destroys the heartbeat greenthread."""
+ pass
+
+ def add_binding(self, binding, rule, last=True):
+ self.bindings.append((binding, rule, False, last))
+
+ #NOTE(ewindisch): kept the following method in case we implement the
+ # underlying support.
+ #def add_negate_binding(self, binding, rule, last=True):
+ # self.bindings.append((binding, rule, True, last))
+
+ def queues(self, key):
+ workers = []
+
+ # bit is for negate bindings - if we choose to implement it.
+ # last stops processing rules if this matches.
+ for (binding, exchange, bit, last) in self.bindings:
+ if binding.test(key):
+ workers.extend(exchange.run(key))
+
+ # Support last.
+ if last:
+ return workers
+ return workers
+
+
+class HeartbeatMatchMakerBase(MatchMakerBase):
+ """Base for a heart-beat capable MatchMaker.
+
+ Provides common methods for registering, unregistering, and maintaining
+ heartbeats.
+ """
+ def __init__(self):
+ self.hosts = set()
+ self._heart = None
+ self.host_topic = {}
+
+ super(HeartbeatMatchMakerBase, self).__init__()
+
+ def send_heartbeats(self):
+ """Send all heartbeats.
+
+ Use start_heartbeat to spawn a heartbeat greenthread,
+ which loops this method.
+ """
+ for key, host in self.host_topic:
+ self.ack_alive(key, host)
+
+ def ack_alive(self, key, host):
+ """Acknowledge that a host.topic is alive.
+
+ Used internally for updating heartbeats, but may also be used
+ publically to acknowledge a system is alive (i.e. rpc message
+ successfully sent to host)
+ """
+ raise NotImplementedError("Must implement ack_alive")
+
+ def backend_register(self, key, host):
+ """Implements registration logic.
+
+ Called by register(self,key,host)
+ """
+ raise NotImplementedError("Must implement backend_register")
+
+ def backend_unregister(self, key, key_host):
+ """Implements de-registration logic.
+
+ Called by unregister(self,key,host)
+ """
+ raise NotImplementedError("Must implement backend_unregister")
+
+ def register(self, key, host):
+ """Register a host on a backend.
+
+ Heartbeats, if applicable, may keepalive registration.
+ """
+ self.hosts.add(host)
+ self.host_topic[(key, host)] = host
+ key_host = '.'.join((key, host))
+
+ self.backend_register(key, key_host)
+
+ self.ack_alive(key, host)
+
+ def unregister(self, key, host):
+ """Unregister a topic."""
+ if (key, host) in self.host_topic:
+ del self.host_topic[(key, host)]
+
+ self.hosts.discard(host)
+ self.backend_unregister(key, '.'.join((key, host)))
+
+ LOG.info(_("Matchmaker unregistered: %(key)s, %(host)s"),
+ {'key': key, 'host': host})
+
+ def start_heartbeat(self):
+ """Implementation of MatchMakerBase.start_heartbeat.
+
+ Launches greenthread looping send_heartbeats(),
+ yielding for CONF.matchmaker_heartbeat_freq seconds
+ between iterations.
+ """
+ if not self.hosts:
+ raise MatchMakerException(
+ _("Register before starting heartbeat."))
+
+ def do_heartbeat():
+ while True:
+ self.send_heartbeats()
+ eventlet.sleep(CONF.matchmaker_heartbeat_freq)
+
+ self._heart = eventlet.spawn(do_heartbeat)
+
+ def stop_heartbeat(self):
+ """Destroys the heartbeat greenthread."""
+ if self._heart:
+ self._heart.kill()
+
+
+class DirectBinding(Binding):
+ """Specifies a host in the key via a '.' character.
+
+ Although dots are used in the key, the behavior here is
+ that it maps directly to a host, thus direct.
+ """
+ def test(self, key):
+ return '.' in key
+
+
+class TopicBinding(Binding):
+ """Where a 'bare' key without dots.
+
+ AMQP generally considers topic exchanges to be those *with* dots,
+ but we deviate here in terminology as the behavior here matches
+ that of a topic exchange (whereas where there are dots, behavior
+ matches that of a direct exchange.
+ """
+ def test(self, key):
+ return '.' not in key
+
+
+class FanoutBinding(Binding):
+ """Match on fanout keys, where key starts with 'fanout.' string."""
+ def test(self, key):
+ return key.startswith('fanout~')
+
+
+class StubExchange(Exchange):
+ """Exchange that does nothing."""
+ def run(self, key):
+ return [(key, None)]
+
+
+class LocalhostExchange(Exchange):
+ """Exchange where all direct topics are local."""
+ def __init__(self, host='localhost'):
+ self.host = host
+ super(Exchange, self).__init__()
+
+ def run(self, key):
+ return [('.'.join((key.split('.')[0], self.host)), self.host)]
+
+
+class DirectExchange(Exchange):
+ """Exchange where all topic keys are split, sending to second half.
+
+ i.e. "compute.host" sends a message to "compute.host" running on "host"
+ """
+ def __init__(self):
+ super(Exchange, self).__init__()
+
+ def run(self, key):
+ e = key.split('.', 1)[1]
+ return [(key, e)]
+
+
+class MatchMakerLocalhost(MatchMakerBase):
+ """Match Maker where all bare topics resolve to localhost.
+
+ Useful for testing.
+ """
+ def __init__(self, host='localhost'):
+ super(MatchMakerLocalhost, self).__init__()
+ self.add_binding(FanoutBinding(), LocalhostExchange(host))
+ self.add_binding(DirectBinding(), DirectExchange())
+ self.add_binding(TopicBinding(), LocalhostExchange(host))
+
+
+class MatchMakerStub(MatchMakerBase):
+ """Match Maker where topics are untouched.
+
+ Useful for testing, or for AMQP/brokered queues.
+ Will not work where knowledge of hosts is known (i.e. zeromq)
+ """
+ def __init__(self):
+ super(MatchMakerStub, self).__init__()
+
+ self.add_binding(FanoutBinding(), StubExchange())
+ self.add_binding(DirectBinding(), StubExchange())
+ self.add_binding(TopicBinding(), StubExchange())
diff --git a/keystone/openstack/common/rpc/matchmaker_redis.py b/keystone/openstack/common/rpc/matchmaker_redis.py
new file mode 100644
index 00000000..20006f68
--- /dev/null
+++ b/keystone/openstack/common/rpc/matchmaker_redis.py
@@ -0,0 +1,145 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2013 Cloudscaling Group, Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+The MatchMaker classes should accept a Topic or Fanout exchange key and
+return keys for direct exchanges, per (approximate) AMQP parlance.
+"""
+
+from oslo.config import cfg
+
+from keystone.openstack.common import importutils
+from keystone.openstack.common import log as logging
+from keystone.openstack.common.rpc import matchmaker as mm_common
+
+redis = importutils.try_import('redis')
+
+
+matchmaker_redis_opts = [
+ cfg.StrOpt('host',
+ default='127.0.0.1',
+ help='Host to locate redis'),
+ cfg.IntOpt('port',
+ default=6379,
+ help='Use this port to connect to redis host.'),
+ cfg.StrOpt('password',
+ default=None,
+ help='Password for Redis server. (optional)'),
+]
+
+CONF = cfg.CONF
+opt_group = cfg.OptGroup(name='matchmaker_redis',
+ title='Options for Redis-based MatchMaker')
+CONF.register_group(opt_group)
+CONF.register_opts(matchmaker_redis_opts, opt_group)
+LOG = logging.getLogger(__name__)
+
+
+class RedisExchange(mm_common.Exchange):
+ def __init__(self, matchmaker):
+ self.matchmaker = matchmaker
+ self.redis = matchmaker.redis
+ super(RedisExchange, self).__init__()
+
+
+class RedisTopicExchange(RedisExchange):
+ """Exchange where all topic keys are split, sending to second half.
+
+ i.e. "compute.host" sends a message to "compute" running on "host"
+ """
+ def run(self, topic):
+ while True:
+ member_name = self.redis.srandmember(topic)
+
+ if not member_name:
+ # If this happens, there are no
+ # longer any members.
+ break
+
+ if not self.matchmaker.is_alive(topic, member_name):
+ continue
+
+ host = member_name.split('.', 1)[1]
+ return [(member_name, host)]
+ return []
+
+
+class RedisFanoutExchange(RedisExchange):
+ """Return a list of all hosts."""
+ def run(self, topic):
+ topic = topic.split('~', 1)[1]
+ hosts = self.redis.smembers(topic)
+ good_hosts = filter(
+ lambda host: self.matchmaker.is_alive(topic, host), hosts)
+
+ return [(x, x.split('.', 1)[1]) for x in good_hosts]
+
+
+class MatchMakerRedis(mm_common.HeartbeatMatchMakerBase):
+ """MatchMaker registering and looking-up hosts with a Redis server."""
+ def __init__(self):
+ super(MatchMakerRedis, self).__init__()
+
+ if not redis:
+ raise ImportError("Failed to import module redis.")
+
+ self.redis = redis.StrictRedis(
+ host=CONF.matchmaker_redis.host,
+ port=CONF.matchmaker_redis.port,
+ password=CONF.matchmaker_redis.password)
+
+ self.add_binding(mm_common.FanoutBinding(), RedisFanoutExchange(self))
+ self.add_binding(mm_common.DirectBinding(), mm_common.DirectExchange())
+ self.add_binding(mm_common.TopicBinding(), RedisTopicExchange(self))
+
+ def ack_alive(self, key, host):
+ topic = "%s.%s" % (key, host)
+ if not self.redis.expire(topic, CONF.matchmaker_heartbeat_ttl):
+ # If we could not update the expiration, the key
+ # might have been pruned. Re-register, creating a new
+ # key in Redis.
+ self.register(self.topic_host[host], host)
+
+ def is_alive(self, topic, host):
+ if self.redis.ttl(host) == -1:
+ self.expire(topic, host)
+ return False
+ return True
+
+ def expire(self, topic, host):
+ with self.redis.pipeline() as pipe:
+ pipe.multi()
+ pipe.delete(host)
+ pipe.srem(topic, host)
+ pipe.execute()
+
+ def backend_register(self, key, key_host):
+ with self.redis.pipeline() as pipe:
+ pipe.multi()
+ pipe.sadd(key, key_host)
+
+ # No value is needed, we just
+ # care if it exists. Sets aren't viable
+ # because only keys can expire.
+ pipe.set(key_host, '')
+
+ pipe.execute()
+
+ def backend_unregister(self, key, key_host):
+ with self.redis.pipeline() as pipe:
+ pipe.multi()
+ pipe.srem(key, key_host)
+ pipe.delete(key_host)
+ pipe.execute()
diff --git a/keystone/openstack/common/rpc/matchmaker_ring.py b/keystone/openstack/common/rpc/matchmaker_ring.py
new file mode 100644
index 00000000..91417f0a
--- /dev/null
+++ b/keystone/openstack/common/rpc/matchmaker_ring.py
@@ -0,0 +1,108 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011-2013 Cloudscaling Group, Inc
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+The MatchMaker classes should except a Topic or Fanout exchange key and
+return keys for direct exchanges, per (approximate) AMQP parlance.
+"""
+
+import itertools
+import json
+
+from oslo.config import cfg
+
+from keystone.openstack.common.gettextutils import _ # noqa
+from keystone.openstack.common import log as logging
+from keystone.openstack.common.rpc import matchmaker as mm
+
+
+matchmaker_opts = [
+ # Matchmaker ring file
+ cfg.StrOpt('ringfile',
+ deprecated_name='matchmaker_ringfile',
+ deprecated_group='DEFAULT',
+ default='/etc/oslo/matchmaker_ring.json',
+ help='Matchmaker ring file (JSON)'),
+]
+
+CONF = cfg.CONF
+CONF.register_opts(matchmaker_opts, 'matchmaker_ring')
+LOG = logging.getLogger(__name__)
+
+
+class RingExchange(mm.Exchange):
+ """Match Maker where hosts are loaded from a static JSON formatted file.
+
+ __init__ takes optional ring dictionary argument, otherwise
+ loads the ringfile from CONF.mathcmaker_ringfile.
+ """
+ def __init__(self, ring=None):
+ super(RingExchange, self).__init__()
+
+ if ring:
+ self.ring = ring
+ else:
+ fh = open(CONF.matchmaker_ring.ringfile, 'r')
+ self.ring = json.load(fh)
+ fh.close()
+
+ self.ring0 = {}
+ for k in self.ring.keys():
+ self.ring0[k] = itertools.cycle(self.ring[k])
+
+ def _ring_has(self, key):
+ return key in self.ring0
+
+
+class RoundRobinRingExchange(RingExchange):
+ """A Topic Exchange based on a hashmap."""
+ def __init__(self, ring=None):
+ super(RoundRobinRingExchange, self).__init__(ring)
+
+ def run(self, key):
+ if not self._ring_has(key):
+ LOG.warn(
+ _("No key defining hosts for topic '%s', "
+ "see ringfile") % (key, )
+ )
+ return []
+ host = next(self.ring0[key])
+ return [(key + '.' + host, host)]
+
+
+class FanoutRingExchange(RingExchange):
+ """Fanout Exchange based on a hashmap."""
+ def __init__(self, ring=None):
+ super(FanoutRingExchange, self).__init__(ring)
+
+ def run(self, key):
+ # Assume starts with "fanout~", strip it for lookup.
+ nkey = key.split('fanout~')[1:][0]
+ if not self._ring_has(nkey):
+ LOG.warn(
+ _("No key defining hosts for topic '%s', "
+ "see ringfile") % (nkey, )
+ )
+ return []
+ return map(lambda x: (key + '.' + x, x), self.ring[nkey])
+
+
+class MatchMakerRing(mm.MatchMakerBase):
+ """Match Maker where hosts are loaded from a static hashmap."""
+ def __init__(self, ring=None):
+ super(MatchMakerRing, self).__init__()
+ self.add_binding(mm.FanoutBinding(), FanoutRingExchange(ring))
+ self.add_binding(mm.DirectBinding(), mm.DirectExchange())
+ self.add_binding(mm.TopicBinding(), RoundRobinRingExchange(ring))
diff --git a/keystone/openstack/common/rpc/proxy.py b/keystone/openstack/common/rpc/proxy.py
new file mode 100644
index 00000000..9cc61920
--- /dev/null
+++ b/keystone/openstack/common/rpc/proxy.py
@@ -0,0 +1,226 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012-2013 Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+A helper class for proxy objects to remote APIs.
+
+For more information about rpc API version numbers, see:
+ rpc/dispatcher.py
+"""
+
+
+from keystone.openstack.common import rpc
+from keystone.openstack.common.rpc import common as rpc_common
+from keystone.openstack.common.rpc import serializer as rpc_serializer
+
+
+class RpcProxy(object):
+ """A helper class for rpc clients.
+
+ This class is a wrapper around the RPC client API. It allows you to
+ specify the topic and API version in a single place. This is intended to
+ be used as a base class for a class that implements the client side of an
+ rpc API.
+ """
+
+ # The default namespace, which can be overriden in a subclass.
+ RPC_API_NAMESPACE = None
+
+ def __init__(self, topic, default_version, version_cap=None,
+ serializer=None):
+ """Initialize an RpcProxy.
+
+ :param topic: The topic to use for all messages.
+ :param default_version: The default API version to request in all
+ outgoing messages. This can be overridden on a per-message
+ basis.
+ :param version_cap: Optionally cap the maximum version used for sent
+ messages.
+ :param serializer: Optionaly (de-)serialize entities with a
+ provided helper.
+ """
+ self.topic = topic
+ self.default_version = default_version
+ self.version_cap = version_cap
+ if serializer is None:
+ serializer = rpc_serializer.NoOpSerializer()
+ self.serializer = serializer
+ super(RpcProxy, self).__init__()
+
+ def _set_version(self, msg, vers):
+ """Helper method to set the version in a message.
+
+ :param msg: The message having a version added to it.
+ :param vers: The version number to add to the message.
+ """
+ v = vers if vers else self.default_version
+ if (self.version_cap and not
+ rpc_common.version_is_compatible(self.version_cap, v)):
+ raise rpc_common.RpcVersionCapError(version_cap=self.version_cap)
+ msg['version'] = v
+
+ def _get_topic(self, topic):
+ """Return the topic to use for a message."""
+ return topic if topic else self.topic
+
+ def can_send_version(self, version):
+ """Check to see if a version is compatible with the version cap."""
+ return (not self.version_cap or
+ rpc_common.version_is_compatible(self.version_cap, version))
+
+ @staticmethod
+ def make_namespaced_msg(method, namespace, **kwargs):
+ return {'method': method, 'namespace': namespace, 'args': kwargs}
+
+ def make_msg(self, method, **kwargs):
+ return self.make_namespaced_msg(method, self.RPC_API_NAMESPACE,
+ **kwargs)
+
+ def _serialize_msg_args(self, context, kwargs):
+ """Helper method called to serialize message arguments.
+
+ This calls our serializer on each argument, returning a new
+ set of args that have been serialized.
+
+ :param context: The request context
+ :param kwargs: The arguments to serialize
+ :returns: A new set of serialized arguments
+ """
+ new_kwargs = dict()
+ for argname, arg in kwargs.iteritems():
+ new_kwargs[argname] = self.serializer.serialize_entity(context,
+ arg)
+ return new_kwargs
+
+ def call(self, context, msg, topic=None, version=None, timeout=None):
+ """rpc.call() a remote method.
+
+ :param context: The request context
+ :param msg: The message to send, including the method and args.
+ :param topic: Override the topic for this message.
+ :param version: (Optional) Override the requested API version in this
+ message.
+ :param timeout: (Optional) A timeout to use when waiting for the
+ response. If no timeout is specified, a default timeout will be
+ used that is usually sufficient.
+
+ :returns: The return value from the remote method.
+ """
+ self._set_version(msg, version)
+ msg['args'] = self._serialize_msg_args(context, msg['args'])
+ real_topic = self._get_topic(topic)
+ try:
+ result = rpc.call(context, real_topic, msg, timeout)
+ return self.serializer.deserialize_entity(context, result)
+ except rpc.common.Timeout as exc:
+ raise rpc.common.Timeout(
+ exc.info, real_topic, msg.get('method'))
+
+ def multicall(self, context, msg, topic=None, version=None, timeout=None):
+ """rpc.multicall() a remote method.
+
+ :param context: The request context
+ :param msg: The message to send, including the method and args.
+ :param topic: Override the topic for this message.
+ :param version: (Optional) Override the requested API version in this
+ message.
+ :param timeout: (Optional) A timeout to use when waiting for the
+ response. If no timeout is specified, a default timeout will be
+ used that is usually sufficient.
+
+ :returns: An iterator that lets you process each of the returned values
+ from the remote method as they arrive.
+ """
+ self._set_version(msg, version)
+ msg['args'] = self._serialize_msg_args(context, msg['args'])
+ real_topic = self._get_topic(topic)
+ try:
+ result = rpc.multicall(context, real_topic, msg, timeout)
+ return self.serializer.deserialize_entity(context, result)
+ except rpc.common.Timeout as exc:
+ raise rpc.common.Timeout(
+ exc.info, real_topic, msg.get('method'))
+
+ def cast(self, context, msg, topic=None, version=None):
+ """rpc.cast() a remote method.
+
+ :param context: The request context
+ :param msg: The message to send, including the method and args.
+ :param topic: Override the topic for this message.
+ :param version: (Optional) Override the requested API version in this
+ message.
+
+ :returns: None. rpc.cast() does not wait on any return value from the
+ remote method.
+ """
+ self._set_version(msg, version)
+ msg['args'] = self._serialize_msg_args(context, msg['args'])
+ rpc.cast(context, self._get_topic(topic), msg)
+
+ def fanout_cast(self, context, msg, topic=None, version=None):
+ """rpc.fanout_cast() a remote method.
+
+ :param context: The request context
+ :param msg: The message to send, including the method and args.
+ :param topic: Override the topic for this message.
+ :param version: (Optional) Override the requested API version in this
+ message.
+
+ :returns: None. rpc.fanout_cast() does not wait on any return value
+ from the remote method.
+ """
+ self._set_version(msg, version)
+ msg['args'] = self._serialize_msg_args(context, msg['args'])
+ rpc.fanout_cast(context, self._get_topic(topic), msg)
+
+ def cast_to_server(self, context, server_params, msg, topic=None,
+ version=None):
+ """rpc.cast_to_server() a remote method.
+
+ :param context: The request context
+ :param server_params: Server parameters. See rpc.cast_to_server() for
+ details.
+ :param msg: The message to send, including the method and args.
+ :param topic: Override the topic for this message.
+ :param version: (Optional) Override the requested API version in this
+ message.
+
+ :returns: None. rpc.cast_to_server() does not wait on any
+ return values.
+ """
+ self._set_version(msg, version)
+ msg['args'] = self._serialize_msg_args(context, msg['args'])
+ rpc.cast_to_server(context, server_params, self._get_topic(topic), msg)
+
+ def fanout_cast_to_server(self, context, server_params, msg, topic=None,
+ version=None):
+ """rpc.fanout_cast_to_server() a remote method.
+
+ :param context: The request context
+ :param server_params: Server parameters. See rpc.cast_to_server() for
+ details.
+ :param msg: The message to send, including the method and args.
+ :param topic: Override the topic for this message.
+ :param version: (Optional) Override the requested API version in this
+ message.
+
+ :returns: None. rpc.fanout_cast_to_server() does not wait on any
+ return values.
+ """
+ self._set_version(msg, version)
+ msg['args'] = self._serialize_msg_args(context, msg['args'])
+ rpc.fanout_cast_to_server(context, server_params,
+ self._get_topic(topic), msg)
diff --git a/keystone/openstack/common/rpc/securemessage.py b/keystone/openstack/common/rpc/securemessage.py
new file mode 100644
index 00000000..ef5e191c
--- /dev/null
+++ b/keystone/openstack/common/rpc/securemessage.py
@@ -0,0 +1,521 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2013 Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import base64
+import collections
+import os
+import struct
+import time
+
+import requests
+
+from oslo.config import cfg
+
+from keystone.openstack.common.crypto import utils as cryptoutils
+from keystone.openstack.common import jsonutils
+from keystone.openstack.common import log as logging
+
+secure_message_opts = [
+ cfg.BoolOpt('enabled', default=True,
+ help='Whether Secure Messaging (Signing) is enabled,'
+ ' defaults to enabled'),
+ cfg.BoolOpt('enforced', default=False,
+ help='Whether Secure Messaging (Signing) is enforced,'
+ ' defaults to not enforced'),
+ cfg.BoolOpt('encrypt', default=False,
+ help='Whether Secure Messaging (Encryption) is enabled,'
+ ' defaults to not enabled'),
+ cfg.StrOpt('secret_keys_file',
+ help='Path to the file containing the keys, takes precedence'
+ ' over secret_key'),
+ cfg.MultiStrOpt('secret_key',
+ help='A list of keys: (ex: name:<base64 encoded key>),'
+ ' ignored if secret_keys_file is set'),
+ cfg.StrOpt('kds_endpoint',
+ help='KDS endpoint (ex: http://kds.example.com:35357/v3)'),
+]
+secure_message_group = cfg.OptGroup('secure_messages',
+ title='Secure Messaging options')
+
+LOG = logging.getLogger(__name__)
+
+
+class SecureMessageException(Exception):
+ """Generic Exception for Secure Messages."""
+
+ msg = "An unknown Secure Message related exception occurred."
+
+ def __init__(self, msg=None):
+ if msg is None:
+ msg = self.msg
+ super(SecureMessageException, self).__init__(msg)
+
+
+class SharedKeyNotFound(SecureMessageException):
+ """No shared key was found and no other external authentication mechanism
+ is available.
+ """
+
+ msg = "Shared Key for [%s] Not Found. (%s)"
+
+ def __init__(self, name, errmsg):
+ super(SharedKeyNotFound, self).__init__(self.msg % (name, errmsg))
+
+
+class InvalidMetadata(SecureMessageException):
+ """The metadata is invalid."""
+
+ msg = "Invalid metadata: %s"
+
+ def __init__(self, err):
+ super(InvalidMetadata, self).__init__(self.msg % err)
+
+
+class InvalidSignature(SecureMessageException):
+ """Signature validation failed."""
+
+ msg = "Failed to validate signature (source=%s, destination=%s)"
+
+ def __init__(self, src, dst):
+ super(InvalidSignature, self).__init__(self.msg % (src, dst))
+
+
+class UnknownDestinationName(SecureMessageException):
+ """The Destination name is unknown to us."""
+
+ msg = "Invalid destination name (%s)"
+
+ def __init__(self, name):
+ super(UnknownDestinationName, self).__init__(self.msg % name)
+
+
+class InvalidEncryptedTicket(SecureMessageException):
+ """The Encrypted Ticket could not be successfully handled."""
+
+ msg = "Invalid Ticket (source=%s, destination=%s)"
+
+ def __init__(self, src, dst):
+ super(InvalidEncryptedTicket, self).__init__(self.msg % (src, dst))
+
+
+class InvalidExpiredTicket(SecureMessageException):
+ """The ticket received is already expired."""
+
+ msg = "Expired ticket (source=%s, destination=%s)"
+
+ def __init__(self, src, dst):
+ super(InvalidExpiredTicket, self).__init__(self.msg % (src, dst))
+
+
+class CommunicationError(SecureMessageException):
+ """The Communication with the KDS failed."""
+
+ msg = "Communication Error (target=%s): %s"
+
+ def __init__(self, target, errmsg):
+ super(CommunicationError, self).__init__(self.msg % (target, errmsg))
+
+
+class InvalidArgument(SecureMessageException):
+ """Bad initialization argument."""
+
+ msg = "Invalid argument: %s"
+
+ def __init__(self, errmsg):
+ super(InvalidArgument, self).__init__(self.msg % errmsg)
+
+
+Ticket = collections.namedtuple('Ticket', ['skey', 'ekey', 'esek'])
+
+
+class KeyStore(object):
+ """A storage class for Signing and Encryption Keys.
+
+ This class creates an object that holds Generic Keys like Signing
+ Keys, Encryption Keys, Encrypted SEK Tickets ...
+ """
+
+ def __init__(self):
+ self._kvps = dict()
+
+ def _get_key_name(self, source, target, ktype):
+ return (source, target, ktype)
+
+ def _put(self, src, dst, ktype, expiration, data):
+ name = self._get_key_name(src, dst, ktype)
+ self._kvps[name] = (expiration, data)
+
+ def _get(self, src, dst, ktype):
+ name = self._get_key_name(src, dst, ktype)
+ if name in self._kvps:
+ expiration, data = self._kvps[name]
+ if expiration > time.time():
+ return data
+ else:
+ del self._kvps[name]
+
+ return None
+
+ def clear(self):
+ """Wipes the store clear of all data."""
+ self._kvps.clear()
+
+ def put_ticket(self, source, target, skey, ekey, esek, expiration):
+ """Puts a sek pair in the cache.
+
+ :param source: Client name
+ :param target: Target name
+ :param skey: The Signing Key
+ :param ekey: The Encription Key
+ :param esek: The token encrypted with the target key
+ :param expiration: Expiration time in seconds since Epoch
+ """
+ keys = Ticket(skey, ekey, esek)
+ self._put(source, target, 'ticket', expiration, keys)
+
+ def get_ticket(self, source, target):
+ """Returns a Ticket (skey, ekey, esek) namedtuple for the
+ source/target pair.
+ """
+ return self._get(source, target, 'ticket')
+
+
+_KEY_STORE = KeyStore()
+
+
+class _KDSClient(object):
+
+ USER_AGENT = 'oslo-incubator/rpc'
+
+ def __init__(self, endpoint=None, timeout=None):
+ """A KDS Client class."""
+
+ self._endpoint = endpoint
+ if timeout is not None:
+ self.timeout = float(timeout)
+ else:
+ self.timeout = None
+
+ def _do_get(self, url, request):
+ req_kwargs = dict()
+ req_kwargs['headers'] = dict()
+ req_kwargs['headers']['User-Agent'] = self.USER_AGENT
+ req_kwargs['headers']['Content-Type'] = 'application/json'
+ req_kwargs['data'] = jsonutils.dumps({'request': request})
+ if self.timeout is not None:
+ req_kwargs['timeout'] = self.timeout
+
+ try:
+ resp = requests.get(url, **req_kwargs)
+ except requests.ConnectionError as e:
+ err = "Unable to establish connection. %s" % e
+ raise CommunicationError(url, err)
+
+ return resp
+
+ def _get_reply(self, url, resp):
+ if resp.text:
+ try:
+ body = jsonutils.loads(resp.text)
+ reply = body['reply']
+ except (KeyError, TypeError, ValueError):
+ msg = "Failed to decode reply: %s" % resp.text
+ raise CommunicationError(url, msg)
+ else:
+ msg = "No reply data was returned."
+ raise CommunicationError(url, msg)
+
+ return reply
+
+ def _get_ticket(self, request, url=None, redirects=10):
+ """Send an HTTP request.
+
+ Wraps around 'requests' to handle redirects and common errors.
+ """
+ if url is None:
+ if not self._endpoint:
+ raise CommunicationError(url, 'Endpoint not configured')
+ url = self._endpoint + '/kds/ticket'
+
+ while redirects:
+ resp = self._do_get(url, request)
+ if resp.status_code in (301, 302, 305):
+ # Redirected. Reissue the request to the new location.
+ url = resp.headers['location']
+ redirects -= 1
+ continue
+ elif resp.status_code != 200:
+ msg = "Request returned failure status: %s (%s)"
+ err = msg % (resp.status_code, resp.text)
+ raise CommunicationError(url, err)
+
+ return self._get_reply(url, resp)
+
+ raise CommunicationError(url, "Too many redirections, giving up!")
+
+ def get_ticket(self, source, target, crypto, key):
+
+ # prepare metadata
+ md = {'requestor': source,
+ 'target': target,
+ 'timestamp': time.time(),
+ 'nonce': struct.unpack('Q', os.urandom(8))[0]}
+ metadata = base64.b64encode(jsonutils.dumps(md))
+
+ # sign metadata
+ signature = crypto.sign(key, metadata)
+
+ # HTTP request
+ reply = self._get_ticket({'metadata': metadata,
+ 'signature': signature})
+
+ # verify reply
+ signature = crypto.sign(key, (reply['metadata'] + reply['ticket']))
+ if signature != reply['signature']:
+ raise InvalidEncryptedTicket(md['source'], md['destination'])
+ md = jsonutils.loads(base64.b64decode(reply['metadata']))
+ if ((md['source'] != source or
+ md['destination'] != target or
+ md['expiration'] < time.time())):
+ raise InvalidEncryptedTicket(md['source'], md['destination'])
+
+ # return ticket data
+ tkt = jsonutils.loads(crypto.decrypt(key, reply['ticket']))
+
+ return tkt, md['expiration']
+
+
+# we need to keep a global nonce, as this value should never repeat non
+# matter how many SecureMessage objects we create
+_NONCE = None
+
+
+def _get_nonce():
+ """We keep a single counter per instance, as it is so huge we can't
+ possibly cycle through within 1/100 of a second anyway.
+ """
+
+ global _NONCE
+ # Lazy initialize, for now get a random value, multiply by 2^32 and
+ # use it as the nonce base. The counter itself will rotate after
+ # 2^32 increments.
+ if _NONCE is None:
+ _NONCE = [struct.unpack('I', os.urandom(4))[0], 0]
+
+ # Increment counter and wrap at 2^32
+ _NONCE[1] += 1
+ if _NONCE[1] > 0xffffffff:
+ _NONCE[1] = 0
+
+ # Return base + counter
+ return long((_NONCE[0] * 0xffffffff)) + _NONCE[1]
+
+
+class SecureMessage(object):
+ """A Secure Message object.
+
+ This class creates a signing/encryption facility for RPC messages.
+ It encapsulates all the necessary crypto primitives to insulate
+ regular code from the intricacies of message authentication, validation
+ and optionally encryption.
+
+ :param topic: The topic name of the queue
+ :param host: The server name, together with the topic it forms a unique
+ name that is used to source signing keys, and verify
+ incoming messages.
+ :param conf: a ConfigOpts object
+ :param key: (optional) explicitly pass in endpoint private key.
+ If not provided it will be sourced from the service config
+ :param key_store: (optional) Storage class for local caching
+ :param encrypt: (defaults to False) Whether to encrypt messages
+ :param enctype: (defaults to AES) Cipher to use
+ :param hashtype: (defaults to SHA256) Hash function to use for signatures
+ """
+
+ def __init__(self, topic, host, conf, key=None, key_store=None,
+ encrypt=None, enctype='AES', hashtype='SHA256'):
+
+ conf.register_group(secure_message_group)
+ conf.register_opts(secure_message_opts, group='secure_messages')
+
+ self._name = '%s.%s' % (topic, host)
+ self._key = key
+ self._conf = conf.secure_messages
+ self._encrypt = self._conf.encrypt if (encrypt is None) else encrypt
+ self._crypto = cryptoutils.SymmetricCrypto(enctype, hashtype)
+ self._hkdf = cryptoutils.HKDF(hashtype)
+ self._kds = _KDSClient(self._conf.kds_endpoint)
+
+ if self._key is None:
+ self._key = self._init_key(topic, self._name)
+ if self._key is None:
+ err = "Secret Key (or key file) is missing or malformed"
+ raise SharedKeyNotFound(self._name, err)
+
+ self._key_store = key_store or _KEY_STORE
+
+ def _init_key(self, topic, name):
+ keys = None
+ if self._conf.secret_keys_file:
+ with open(self._conf.secret_keys_file, 'r') as f:
+ keys = f.readlines()
+ elif self._conf.secret_key:
+ keys = self._conf.secret_key
+
+ if keys is None:
+ return None
+
+ for k in keys:
+ if k[0] == '#':
+ continue
+ if ':' not in k:
+ break
+ svc, key = k.split(':', 1)
+ if svc == topic or svc == name:
+ return base64.b64decode(key)
+
+ return None
+
+ def _split_key(self, key, size):
+ sig_key = key[:size]
+ enc_key = key[size:]
+ return sig_key, enc_key
+
+ def _decode_esek(self, key, source, target, timestamp, esek):
+ """This function decrypts the esek buffer passed in and returns a
+ KeyStore to be used to check and decrypt the received message.
+
+ :param key: The key to use to decrypt the ticket (esek)
+ :param source: The name of the source service
+ :param traget: The name of the target service
+ :param timestamp: The incoming message timestamp
+ :param esek: a base64 encoded encrypted block containing a JSON string
+ """
+ rkey = None
+
+ try:
+ s = self._crypto.decrypt(key, esek)
+ j = jsonutils.loads(s)
+
+ rkey = base64.b64decode(j['key'])
+ expiration = j['timestamp'] + j['ttl']
+ if j['timestamp'] > timestamp or timestamp > expiration:
+ raise InvalidExpiredTicket(source, target)
+
+ except Exception:
+ raise InvalidEncryptedTicket(source, target)
+
+ info = '%s,%s,%s' % (source, target, str(j['timestamp']))
+
+ sek = self._hkdf.expand(rkey, info, len(key) * 2)
+
+ return self._split_key(sek, len(key))
+
+ def _get_ticket(self, target):
+ """This function will check if we already have a SEK for the specified
+ target in the cache, or will go and try to fetch a new SEK from the key
+ server.
+
+ :param target: The name of the target service
+ """
+ ticket = self._key_store.get_ticket(self._name, target)
+
+ if ticket is not None:
+ return ticket
+
+ tkt, expiration = self._kds.get_ticket(self._name, target,
+ self._crypto, self._key)
+
+ self._key_store.put_ticket(self._name, target,
+ base64.b64decode(tkt['skey']),
+ base64.b64decode(tkt['ekey']),
+ tkt['esek'], expiration)
+ return self._key_store.get_ticket(self._name, target)
+
+ def encode(self, version, target, json_msg):
+ """This is the main encoding function.
+
+ It takes a target and a message and returns a tuple consisting of a
+ JSON serialized metadata object, a JSON serialized (and optionally
+ encrypted) message, and a signature.
+
+ :param version: the current envelope version
+ :param target: The name of the target service (usually with hostname)
+ :param json_msg: a serialized json message object
+ """
+ ticket = self._get_ticket(target)
+
+ metadata = jsonutils.dumps({'source': self._name,
+ 'destination': target,
+ 'timestamp': time.time(),
+ 'nonce': _get_nonce(),
+ 'esek': ticket.esek,
+ 'encryption': self._encrypt})
+
+ message = json_msg
+ if self._encrypt:
+ message = self._crypto.encrypt(ticket.ekey, message)
+
+ signature = self._crypto.sign(ticket.skey,
+ version + metadata + message)
+
+ return (metadata, message, signature)
+
+ def decode(self, version, metadata, message, signature):
+ """This is the main decoding function.
+
+ It takes a version, metadata, message and signature strings and
+ returns a tuple with a (decrypted) message and metadata or raises
+ an exception in case of error.
+
+ :param version: the current envelope version
+ :param metadata: a JSON serialized object with metadata for validation
+ :param message: a JSON serialized (base64 encoded encrypted) message
+ :param signature: a base64 encoded signature
+ """
+ md = jsonutils.loads(metadata)
+
+ check_args = ('source', 'destination', 'timestamp',
+ 'nonce', 'esek', 'encryption')
+ for arg in check_args:
+ if arg not in md:
+ raise InvalidMetadata('Missing metadata "%s"' % arg)
+
+ if md['destination'] != self._name:
+ # TODO(simo) handle group keys by checking target
+ raise UnknownDestinationName(md['destination'])
+
+ try:
+ skey, ekey = self._decode_esek(self._key,
+ md['source'], md['destination'],
+ md['timestamp'], md['esek'])
+ except InvalidExpiredTicket:
+ raise
+ except Exception:
+ raise InvalidMetadata('Failed to decode ESEK for %s/%s' % (
+ md['source'], md['destination']))
+
+ sig = self._crypto.sign(skey, version + metadata + message)
+
+ if sig != signature:
+ raise InvalidSignature(md['source'], md['destination'])
+
+ if md['encryption'] is True:
+ msg = self._crypto.decrypt(ekey, message)
+ else:
+ msg = message
+
+ return (md, msg)
diff --git a/keystone/openstack/common/rpc/serializer.py b/keystone/openstack/common/rpc/serializer.py
new file mode 100644
index 00000000..76c68310
--- /dev/null
+++ b/keystone/openstack/common/rpc/serializer.py
@@ -0,0 +1,52 @@
+# Copyright 2013 IBM Corp.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Provides the definition of an RPC serialization handler"""
+
+import abc
+
+
+class Serializer(object):
+ """Generic (de-)serialization definition base class."""
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractmethod
+ def serialize_entity(self, context, entity):
+ """Serialize something to primitive form.
+
+ :param context: Security context
+ :param entity: Entity to be serialized
+ :returns: Serialized form of entity
+ """
+ pass
+
+ @abc.abstractmethod
+ def deserialize_entity(self, context, entity):
+ """Deserialize something from primitive form.
+
+ :param context: Security context
+ :param entity: Primitive to be deserialized
+ :returns: Deserialized form of entity
+ """
+ pass
+
+
+class NoOpSerializer(Serializer):
+ """A serializer that does nothing."""
+
+ def serialize_entity(self, context, entity):
+ return entity
+
+ def deserialize_entity(self, context, entity):
+ return entity
diff --git a/keystone/openstack/common/rpc/service.py b/keystone/openstack/common/rpc/service.py
new file mode 100644
index 00000000..34eacb60
--- /dev/null
+++ b/keystone/openstack/common/rpc/service.py
@@ -0,0 +1,78 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+# Copyright 2011 Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from keystone.openstack.common.gettextutils import _ # noqa
+from keystone.openstack.common import log as logging
+from keystone.openstack.common import rpc
+from keystone.openstack.common.rpc import dispatcher as rpc_dispatcher
+from keystone.openstack.common import service
+
+
+LOG = logging.getLogger(__name__)
+
+
+class Service(service.Service):
+ """Service object for binaries running on hosts.
+
+ A service enables rpc by listening to queues based on topic and host.
+ """
+ def __init__(self, host, topic, manager=None, serializer=None):
+ super(Service, self).__init__()
+ self.host = host
+ self.topic = topic
+ self.serializer = serializer
+ if manager is None:
+ self.manager = self
+ else:
+ self.manager = manager
+
+ def start(self):
+ super(Service, self).start()
+
+ self.conn = rpc.create_connection(new=True)
+ LOG.debug(_("Creating Consumer connection for Service %s") %
+ self.topic)
+
+ dispatcher = rpc_dispatcher.RpcDispatcher([self.manager],
+ self.serializer)
+
+ # Share this same connection for these Consumers
+ self.conn.create_consumer(self.topic, dispatcher, fanout=False)
+
+ node_topic = '%s.%s' % (self.topic, self.host)
+ self.conn.create_consumer(node_topic, dispatcher, fanout=False)
+
+ self.conn.create_consumer(self.topic, dispatcher, fanout=True)
+
+ # Hook to allow the manager to do other initializations after
+ # the rpc connection is created.
+ if callable(getattr(self.manager, 'initialize_service_hook', None)):
+ self.manager.initialize_service_hook(self)
+
+ # Consume from all consumers in a thread
+ self.conn.consume_in_thread()
+
+ def stop(self):
+ # Try to shut the connection down, but if we get any sort of
+ # errors, go ahead and ignore them.. as we're shutting down anyway
+ try:
+ self.conn.close()
+ except Exception:
+ pass
+ super(Service, self).stop()
diff --git a/keystone/openstack/common/rpc/zmq_receiver.py b/keystone/openstack/common/rpc/zmq_receiver.py
new file mode 100755
index 00000000..2f095f10
--- /dev/null
+++ b/keystone/openstack/common/rpc/zmq_receiver.py
@@ -0,0 +1,41 @@
+#!/usr/bin/env python
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 OpenStack Foundation
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import eventlet
+eventlet.monkey_patch()
+
+import contextlib
+import sys
+
+from oslo.config import cfg
+
+from keystone.openstack.common import log as logging
+from keystone.openstack.common import rpc
+from keystone.openstack.common.rpc import impl_zmq
+
+CONF = cfg.CONF
+CONF.register_opts(rpc.rpc_opts)
+CONF.register_opts(impl_zmq.zmq_opts)
+
+
+def main():
+ CONF(sys.argv[1:], project='oslo')
+ logging.setup("oslo")
+
+ with contextlib.closing(impl_zmq.ZmqProxy(CONF)) as reactor:
+ reactor.consume_in_thread()
+ reactor.wait()
diff --git a/keystone/openstack/common/service.py b/keystone/openstack/common/service.py
new file mode 100644
index 00000000..8418e2a8
--- /dev/null
+++ b/keystone/openstack/common/service.py
@@ -0,0 +1,450 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# Copyright 2011 Justin Santa Barbara
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Generic Node base class for all workers that run on hosts."""
+
+import errno
+import os
+import random
+import signal
+import sys
+import time
+
+import eventlet
+from eventlet import event
+import logging as std_logging
+from oslo.config import cfg
+
+from keystone.openstack.common import eventlet_backdoor
+from keystone.openstack.common.gettextutils import _ # noqa
+from keystone.openstack.common import importutils
+from keystone.openstack.common import log as logging
+from keystone.openstack.common import threadgroup
+
+
+rpc = importutils.try_import('keystone.openstack.common.rpc')
+CONF = cfg.CONF
+LOG = logging.getLogger(__name__)
+
+
+class Launcher(object):
+ """Launch one or more services and wait for them to complete."""
+
+ def __init__(self):
+ """Initialize the service launcher.
+
+ :returns: None
+
+ """
+ self.services = Services()
+ self.backdoor_port = eventlet_backdoor.initialize_if_enabled()
+
+ def launch_service(self, service):
+ """Load and start the given service.
+
+ :param service: The service you would like to start.
+ :returns: None
+
+ """
+ service.backdoor_port = self.backdoor_port
+ self.services.add(service)
+
+ def stop(self):
+ """Stop all services which are currently running.
+
+ :returns: None
+
+ """
+ self.services.stop()
+
+ def wait(self):
+ """Waits until all services have been stopped, and then returns.
+
+ :returns: None
+
+ """
+ self.services.wait()
+
+ def restart(self):
+ """Reload config files and restart service.
+
+ :returns: None
+
+ """
+ cfg.CONF.reload_config_files()
+ self.services.restart()
+
+
+class SignalExit(SystemExit):
+ def __init__(self, signo, exccode=1):
+ super(SignalExit, self).__init__(exccode)
+ self.signo = signo
+
+
+class ServiceLauncher(Launcher):
+ def _handle_signal(self, signo, frame):
+ # Allow the process to be killed again and die from natural causes
+ signal.signal(signal.SIGTERM, signal.SIG_DFL)
+ signal.signal(signal.SIGINT, signal.SIG_DFL)
+ signal.signal(signal.SIGHUP, signal.SIG_DFL)
+
+ raise SignalExit(signo)
+
+ def handle_signal(self):
+ signal.signal(signal.SIGTERM, self._handle_signal)
+ signal.signal(signal.SIGINT, self._handle_signal)
+ signal.signal(signal.SIGHUP, self._handle_signal)
+
+ def _wait_for_exit_or_signal(self):
+ status = None
+ signo = 0
+
+ LOG.debug(_('Full set of CONF:'))
+ CONF.log_opt_values(LOG, std_logging.DEBUG)
+
+ try:
+ super(ServiceLauncher, self).wait()
+ except SignalExit as exc:
+ signame = {signal.SIGTERM: 'SIGTERM',
+ signal.SIGINT: 'SIGINT',
+ signal.SIGHUP: 'SIGHUP'}[exc.signo]
+ LOG.info(_('Caught %s, exiting'), signame)
+ status = exc.code
+ signo = exc.signo
+ except SystemExit as exc:
+ status = exc.code
+ finally:
+ self.stop()
+ if rpc:
+ try:
+ rpc.cleanup()
+ except Exception:
+ # We're shutting down, so it doesn't matter at this point.
+ LOG.exception(_('Exception during rpc cleanup.'))
+
+ return status, signo
+
+ def wait(self):
+ while True:
+ self.handle_signal()
+ status, signo = self._wait_for_exit_or_signal()
+ if signo != signal.SIGHUP:
+ return status
+ self.restart()
+
+
+class ServiceWrapper(object):
+ def __init__(self, service, workers):
+ self.service = service
+ self.workers = workers
+ self.children = set()
+ self.forktimes = []
+
+
+class ProcessLauncher(object):
+ def __init__(self):
+ self.children = {}
+ self.sigcaught = None
+ self.running = True
+ rfd, self.writepipe = os.pipe()
+ self.readpipe = eventlet.greenio.GreenPipe(rfd, 'r')
+ self.handle_signal()
+
+ def handle_signal(self):
+ signal.signal(signal.SIGTERM, self._handle_signal)
+ signal.signal(signal.SIGINT, self._handle_signal)
+ signal.signal(signal.SIGHUP, self._handle_signal)
+
+ def _handle_signal(self, signo, frame):
+ self.sigcaught = signo
+ self.running = False
+
+ # Allow the process to be killed again and die from natural causes
+ signal.signal(signal.SIGTERM, signal.SIG_DFL)
+ signal.signal(signal.SIGINT, signal.SIG_DFL)
+ signal.signal(signal.SIGHUP, signal.SIG_DFL)
+
+ def _pipe_watcher(self):
+ # This will block until the write end is closed when the parent
+ # dies unexpectedly
+ self.readpipe.read()
+
+ LOG.info(_('Parent process has died unexpectedly, exiting'))
+
+ sys.exit(1)
+
+ def _child_process_handle_signal(self):
+ # Setup child signal handlers differently
+ def _sigterm(*args):
+ signal.signal(signal.SIGTERM, signal.SIG_DFL)
+ raise SignalExit(signal.SIGTERM)
+
+ def _sighup(*args):
+ signal.signal(signal.SIGHUP, signal.SIG_DFL)
+ raise SignalExit(signal.SIGHUP)
+
+ signal.signal(signal.SIGTERM, _sigterm)
+ signal.signal(signal.SIGHUP, _sighup)
+ # Block SIGINT and let the parent send us a SIGTERM
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
+
+ def _child_wait_for_exit_or_signal(self, launcher):
+ status = None
+ signo = 0
+
+ try:
+ launcher.wait()
+ except SignalExit as exc:
+ signame = {signal.SIGTERM: 'SIGTERM',
+ signal.SIGINT: 'SIGINT',
+ signal.SIGHUP: 'SIGHUP'}[exc.signo]
+ LOG.info(_('Caught %s, exiting'), signame)
+ status = exc.code
+ signo = exc.signo
+ except SystemExit as exc:
+ status = exc.code
+ except BaseException:
+ LOG.exception(_('Unhandled exception'))
+ status = 2
+ finally:
+ launcher.stop()
+
+ return status, signo
+
+ def _child_process(self, service):
+ self._child_process_handle_signal()
+
+ # Reopen the eventlet hub to make sure we don't share an epoll
+ # fd with parent and/or siblings, which would be bad
+ eventlet.hubs.use_hub()
+
+ # Close write to ensure only parent has it open
+ os.close(self.writepipe)
+ # Create greenthread to watch for parent to close pipe
+ eventlet.spawn_n(self._pipe_watcher)
+
+ # Reseed random number generator
+ random.seed()
+
+ launcher = Launcher()
+ launcher.launch_service(service)
+ return launcher
+
+ def _start_child(self, wrap):
+ if len(wrap.forktimes) > wrap.workers:
+ # Limit ourselves to one process a second (over the period of
+ # number of workers * 1 second). This will allow workers to
+ # start up quickly but ensure we don't fork off children that
+ # die instantly too quickly.
+ if time.time() - wrap.forktimes[0] < wrap.workers:
+ LOG.info(_('Forking too fast, sleeping'))
+ time.sleep(1)
+
+ wrap.forktimes.pop(0)
+
+ wrap.forktimes.append(time.time())
+
+ pid = os.fork()
+ if pid == 0:
+ # NOTE(johannes): All exceptions are caught to ensure this
+ # doesn't fallback into the loop spawning children. It would
+ # be bad for a child to spawn more children.
+ launcher = self._child_process(wrap.service)
+ while True:
+ self._child_process_handle_signal()
+ status, signo = self._child_wait_for_exit_or_signal(launcher)
+ if signo != signal.SIGHUP:
+ break
+ launcher.restart()
+
+ os._exit(status)
+
+ LOG.info(_('Started child %d'), pid)
+
+ wrap.children.add(pid)
+ self.children[pid] = wrap
+
+ return pid
+
+ def launch_service(self, service, workers=1):
+ wrap = ServiceWrapper(service, workers)
+
+ LOG.info(_('Starting %d workers'), wrap.workers)
+ while self.running and len(wrap.children) < wrap.workers:
+ self._start_child(wrap)
+
+ def _wait_child(self):
+ try:
+ # Don't block if no child processes have exited
+ pid, status = os.waitpid(0, os.WNOHANG)
+ if not pid:
+ return None
+ except OSError as exc:
+ if exc.errno not in (errno.EINTR, errno.ECHILD):
+ raise
+ return None
+
+ if os.WIFSIGNALED(status):
+ sig = os.WTERMSIG(status)
+ LOG.info(_('Child %(pid)d killed by signal %(sig)d'),
+ dict(pid=pid, sig=sig))
+ else:
+ code = os.WEXITSTATUS(status)
+ LOG.info(_('Child %(pid)s exited with status %(code)d'),
+ dict(pid=pid, code=code))
+
+ if pid not in self.children:
+ LOG.warning(_('pid %d not in child list'), pid)
+ return None
+
+ wrap = self.children.pop(pid)
+ wrap.children.remove(pid)
+ return wrap
+
+ def _respawn_children(self):
+ while self.running:
+ wrap = self._wait_child()
+ if not wrap:
+ # Yield to other threads if no children have exited
+ # Sleep for a short time to avoid excessive CPU usage
+ # (see bug #1095346)
+ eventlet.greenthread.sleep(.01)
+ continue
+ while self.running and len(wrap.children) < wrap.workers:
+ self._start_child(wrap)
+
+ def wait(self):
+ """Loop waiting on children to die and respawning as necessary."""
+
+ LOG.debug(_('Full set of CONF:'))
+ CONF.log_opt_values(LOG, std_logging.DEBUG)
+
+ while True:
+ self.handle_signal()
+ self._respawn_children()
+ if self.sigcaught:
+ signame = {signal.SIGTERM: 'SIGTERM',
+ signal.SIGINT: 'SIGINT',
+ signal.SIGHUP: 'SIGHUP'}[self.sigcaught]
+ LOG.info(_('Caught %s, stopping children'), signame)
+ if self.sigcaught != signal.SIGHUP:
+ break
+
+ for pid in self.children:
+ os.kill(pid, signal.SIGHUP)
+ self.running = True
+ self.sigcaught = None
+
+ for pid in self.children:
+ try:
+ os.kill(pid, signal.SIGTERM)
+ except OSError as exc:
+ if exc.errno != errno.ESRCH:
+ raise
+
+ # Wait for children to die
+ if self.children:
+ LOG.info(_('Waiting on %d children to exit'), len(self.children))
+ while self.children:
+ self._wait_child()
+
+
+class Service(object):
+ """Service object for binaries running on hosts."""
+
+ def __init__(self, threads=1000):
+ self.tg = threadgroup.ThreadGroup(threads)
+
+ # signal that the service is done shutting itself down:
+ self._done = event.Event()
+
+ def reset(self):
+ # NOTE(Fengqian): docs for Event.reset() recommend against using it
+ self._done = event.Event()
+
+ def start(self):
+ pass
+
+ def stop(self):
+ self.tg.stop()
+ self.tg.wait()
+ # Signal that service cleanup is done:
+ if not self._done.ready():
+ self._done.send()
+
+ def wait(self):
+ self._done.wait()
+
+
+class Services(object):
+
+ def __init__(self):
+ self.services = []
+ self.tg = threadgroup.ThreadGroup()
+ self.done = event.Event()
+
+ def add(self, service):
+ self.services.append(service)
+ self.tg.add_thread(self.run_service, service, self.done)
+
+ def stop(self):
+ # wait for graceful shutdown of services:
+ for service in self.services:
+ service.stop()
+ service.wait()
+
+ # Each service has performed cleanup, now signal that the run_service
+ # wrapper threads can now die:
+ if not self.done.ready():
+ self.done.send()
+
+ # reap threads:
+ self.tg.stop()
+
+ def wait(self):
+ self.tg.wait()
+
+ def restart(self):
+ self.stop()
+ self.done = event.Event()
+ for restart_service in self.services:
+ restart_service.reset()
+ self.tg.add_thread(self.run_service, restart_service, self.done)
+
+ @staticmethod
+ def run_service(service, done):
+ """Service start wrapper.
+
+ :param service: service to run
+ :param done: event to wait on until a shutdown is triggered
+ :returns: None
+
+ """
+ service.start()
+ done.wait()
+
+
+def launch(service, workers=None):
+ if workers:
+ launcher = ProcessLauncher()
+ launcher.launch_service(service, workers=workers)
+ else:
+ launcher = ServiceLauncher()
+ launcher.launch_service(service)
+ return launcher
diff --git a/keystone/openstack/common/sslutils.py b/keystone/openstack/common/sslutils.py
new file mode 100644
index 00000000..3aa975b8
--- /dev/null
+++ b/keystone/openstack/common/sslutils.py
@@ -0,0 +1,100 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2013 IBM Corp.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import os
+import ssl
+
+from oslo.config import cfg
+
+from keystone.openstack.common.gettextutils import _ # noqa
+
+
+ssl_opts = [
+ cfg.StrOpt('ca_file',
+ default=None,
+ help="CA certificate file to use to verify "
+ "connecting clients"),
+ cfg.StrOpt('cert_file',
+ default=None,
+ help="Certificate file to use when starting "
+ "the server securely"),
+ cfg.StrOpt('key_file',
+ default=None,
+ help="Private key file to use when starting "
+ "the server securely"),
+]
+
+
+CONF = cfg.CONF
+CONF.register_opts(ssl_opts, "ssl")
+
+
+def is_enabled():
+ cert_file = CONF.ssl.cert_file
+ key_file = CONF.ssl.key_file
+ ca_file = CONF.ssl.ca_file
+ use_ssl = cert_file or key_file
+
+ if cert_file and not os.path.exists(cert_file):
+ raise RuntimeError(_("Unable to find cert_file : %s") % cert_file)
+
+ if ca_file and not os.path.exists(ca_file):
+ raise RuntimeError(_("Unable to find ca_file : %s") % ca_file)
+
+ if key_file and not os.path.exists(key_file):
+ raise RuntimeError(_("Unable to find key_file : %s") % key_file)
+
+ if use_ssl and (not cert_file or not key_file):
+ raise RuntimeError(_("When running server in SSL mode, you must "
+ "specify both a cert_file and key_file "
+ "option value in your configuration file"))
+
+ return use_ssl
+
+
+def wrap(sock):
+ ssl_kwargs = {
+ 'server_side': True,
+ 'certfile': CONF.ssl.cert_file,
+ 'keyfile': CONF.ssl.key_file,
+ 'cert_reqs': ssl.CERT_NONE,
+ }
+
+ if CONF.ssl.ca_file:
+ ssl_kwargs['ca_certs'] = CONF.ssl.ca_file
+ ssl_kwargs['cert_reqs'] = ssl.CERT_REQUIRED
+
+ return ssl.wrap_socket(sock, **ssl_kwargs)
+
+
+_SSL_PROTOCOLS = {
+ "tlsv1": ssl.PROTOCOL_TLSv1,
+ "sslv23": ssl.PROTOCOL_SSLv23,
+ "sslv3": ssl.PROTOCOL_SSLv3
+}
+
+try:
+ _SSL_PROTOCOLS["sslv2"] = ssl.PROTOCOL_SSLv2
+except AttributeError:
+ pass
+
+
+def validate_ssl_version(version):
+ key = version.lower()
+ try:
+ return _SSL_PROTOCOLS[key]
+ except KeyError:
+ raise RuntimeError(_("Invalid SSL version : %s") % version)
diff --git a/keystone/openstack/common/threadgroup.py b/keystone/openstack/common/threadgroup.py
new file mode 100644
index 00000000..cde9fc7f
--- /dev/null
+++ b/keystone/openstack/common/threadgroup.py
@@ -0,0 +1,121 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012 Red Hat, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import eventlet
+from eventlet import greenpool
+from eventlet import greenthread
+
+from keystone.openstack.common import log as logging
+from keystone.openstack.common import loopingcall
+
+
+LOG = logging.getLogger(__name__)
+
+
+def _thread_done(gt, *args, **kwargs):
+ """Callback function to be passed to GreenThread.link() when we spawn()
+ Calls the :class:`ThreadGroup` to notify if.
+
+ """
+ kwargs['group'].thread_done(kwargs['thread'])
+
+
+class Thread(object):
+ """Wrapper around a greenthread, that holds a reference to the
+ :class:`ThreadGroup`. The Thread will notify the :class:`ThreadGroup` when
+ it has done so it can be removed from the threads list.
+ """
+ def __init__(self, thread, group):
+ self.thread = thread
+ self.thread.link(_thread_done, group=group, thread=self)
+
+ def stop(self):
+ self.thread.kill()
+
+ def wait(self):
+ return self.thread.wait()
+
+
+class ThreadGroup(object):
+ """The point of the ThreadGroup classis to:
+
+ * keep track of timers and greenthreads (making it easier to stop them
+ when need be).
+ * provide an easy API to add timers.
+ """
+ def __init__(self, thread_pool_size=10):
+ self.pool = greenpool.GreenPool(thread_pool_size)
+ self.threads = []
+ self.timers = []
+
+ def add_dynamic_timer(self, callback, initial_delay=None,
+ periodic_interval_max=None, *args, **kwargs):
+ timer = loopingcall.DynamicLoopingCall(callback, *args, **kwargs)
+ timer.start(initial_delay=initial_delay,
+ periodic_interval_max=periodic_interval_max)
+ self.timers.append(timer)
+
+ def add_timer(self, interval, callback, initial_delay=None,
+ *args, **kwargs):
+ pulse = loopingcall.FixedIntervalLoopingCall(callback, *args, **kwargs)
+ pulse.start(interval=interval,
+ initial_delay=initial_delay)
+ self.timers.append(pulse)
+
+ def add_thread(self, callback, *args, **kwargs):
+ gt = self.pool.spawn(callback, *args, **kwargs)
+ th = Thread(gt, self)
+ self.threads.append(th)
+
+ def thread_done(self, thread):
+ self.threads.remove(thread)
+
+ def stop(self):
+ current = greenthread.getcurrent()
+ for x in self.threads:
+ if x is current:
+ # don't kill the current thread.
+ continue
+ try:
+ x.stop()
+ except Exception as ex:
+ LOG.exception(ex)
+
+ for x in self.timers:
+ try:
+ x.stop()
+ except Exception as ex:
+ LOG.exception(ex)
+ self.timers = []
+
+ def wait(self):
+ for x in self.timers:
+ try:
+ x.wait()
+ except eventlet.greenlet.GreenletExit:
+ pass
+ except Exception as ex:
+ LOG.exception(ex)
+ current = greenthread.getcurrent()
+ for x in self.threads:
+ if x is current:
+ continue
+ try:
+ x.wait()
+ except eventlet.greenlet.GreenletExit:
+ pass
+ except Exception as ex:
+ LOG.exception(ex)
diff --git a/keystone/openstack/common/timeutils.py b/keystone/openstack/common/timeutils.py
index 60943659..aa9f7080 100644
--- a/keystone/openstack/common/timeutils.py
+++ b/keystone/openstack/common/timeutils.py
@@ -23,6 +23,7 @@ import calendar
import datetime
import iso8601
+import six
# ISO 8601 extended time format with microseconds
@@ -32,7 +33,7 @@ PERFECT_TIME_FORMAT = _ISO8601_TIME_FORMAT_SUBSECOND
def isotime(at=None, subsecond=False):
- """Stringify time in ISO 8601 format"""
+ """Stringify time in ISO 8601 format."""
if not at:
at = utcnow()
st = at.strftime(_ISO8601_TIME_FORMAT
@@ -44,13 +45,13 @@ def isotime(at=None, subsecond=False):
def parse_isotime(timestr):
- """Parse time from ISO 8601 format"""
+ """Parse time from ISO 8601 format."""
try:
return iso8601.parse_date(timestr)
except iso8601.ParseError as e:
- raise ValueError(e.message)
+ raise ValueError(unicode(e))
except TypeError as e:
- raise ValueError(e.message)
+ raise ValueError(unicode(e))
def strtime(at=None, fmt=PERFECT_TIME_FORMAT):
@@ -66,7 +67,7 @@ def parse_strtime(timestr, fmt=PERFECT_TIME_FORMAT):
def normalize_time(timestamp):
- """Normalize time in arbitrary timezone to UTC naive object"""
+ """Normalize time in arbitrary timezone to UTC naive object."""
offset = timestamp.utcoffset()
if offset is None:
return timestamp
@@ -75,14 +76,14 @@ def normalize_time(timestamp):
def is_older_than(before, seconds):
"""Return True if before is older than seconds."""
- if isinstance(before, basestring):
+ if isinstance(before, six.string_types):
before = parse_strtime(before).replace(tzinfo=None)
return utcnow() - before > datetime.timedelta(seconds=seconds)
def is_newer_than(after, seconds):
"""Return True if after is newer than seconds."""
- if isinstance(after, basestring):
+ if isinstance(after, six.string_types):
after = parse_strtime(after).replace(tzinfo=None)
return after - utcnow() > datetime.timedelta(seconds=seconds)
@@ -103,7 +104,7 @@ def utcnow():
def iso8601_from_timestamp(timestamp):
- """Returns a iso8601 formated date from timestamp"""
+ """Returns a iso8601 formated date from timestamp."""
return isotime(datetime.datetime.utcfromtimestamp(timestamp))
@@ -111,9 +112,9 @@ utcnow.override_time = None
def set_time_override(override_time=datetime.datetime.utcnow()):
- """
- Override utils.utcnow to return a constant time or a list thereof,
- one at a time.
+ """Overrides utils.utcnow.
+
+ Make it return a constant time or a list thereof, one at a time.
"""
utcnow.override_time = override_time
@@ -141,7 +142,8 @@ def clear_time_override():
def marshall_now(now=None):
"""Make an rpc-safe datetime with microseconds.
- Note: tzinfo is stripped, but not required for relative times."""
+ Note: tzinfo is stripped, but not required for relative times.
+ """
if not now:
now = utcnow()
return dict(day=now.day, month=now.month, year=now.year, hour=now.hour,
@@ -161,7 +163,8 @@ def unmarshall_time(tyme):
def delta_seconds(before, after):
- """
+ """Return the difference between two timing objects.
+
Compute the difference in seconds between two date, time, or
datetime objects (as a float, to microsecond resolution).
"""
@@ -174,8 +177,7 @@ def delta_seconds(before, after):
def is_soon(dt, window):
- """
- Determines if time is going to happen in the next window seconds.
+ """Determines if time is going to happen in the next window seconds.
:params dt: the time
:params window: minimum seconds to remain to consider the time not soon
diff --git a/keystone/openstack/common/uuidutils.py b/keystone/openstack/common/uuidutils.py
new file mode 100644
index 00000000..7608acb9
--- /dev/null
+++ b/keystone/openstack/common/uuidutils.py
@@ -0,0 +1,39 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright (c) 2012 Intel Corporation.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+UUID related utilities and helper functions.
+"""
+
+import uuid
+
+
+def generate_uuid():
+ return str(uuid.uuid4())
+
+
+def is_uuid_like(val):
+ """Returns validation of a value as a UUID.
+
+ For our purposes, a UUID is a canonical form string:
+ aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa
+
+ """
+ try:
+ return str(uuid.UUID(val)) == val
+ except (TypeError, ValueError, AttributeError):
+ return False
diff --git a/keystone/policy/backends/rules.py b/keystone/policy/backends/rules.py
index 63110e69..31a26d88 100644
--- a/keystone/policy/backends/rules.py
+++ b/keystone/policy/backends/rules.py
@@ -19,10 +19,10 @@
import os.path
-from keystone.common import logging
from keystone.common import utils
from keystone import config
from keystone import exception
+from keystone.openstack.common import log as logging
from keystone.openstack.common import policy as common_policy
from keystone import policy
diff --git a/keystone/service.py b/keystone/service.py
index ce64aba8..f2c95f78 100644
--- a/keystone/service.py
+++ b/keystone/service.py
@@ -14,19 +14,20 @@
# License for the specific language governing permissions and limitations
# under the License.
+import functools
import routes
from keystone import assignment
from keystone import auth
from keystone import catalog
from keystone.common import dependency
-from keystone.common import logging
from keystone.common import wsgi
from keystone import config
from keystone.contrib import ec2
from keystone import controllers
from keystone import credential
from keystone import identity
+from keystone.openstack.common import log as logging
from keystone import policy
from keystone import routers
from keystone import token
@@ -56,7 +57,23 @@ DRIVERS = dict(
dependency.resolve_future_dependencies()
-@logging.fail_gracefully
+def fail_gracefully(f):
+ """Logs exceptions and aborts."""
+ @functools.wraps(f)
+ def wrapper(*args, **kw):
+ try:
+ return f(*args, **kw)
+ except Exception as e:
+ LOG.debug(e, exc_info=True)
+
+ # exception message is printed to all logs
+ LOG.critical(e)
+
+ exit(1)
+ return wrapper
+
+
+@fail_gracefully
def public_app_factory(global_conf, **local_conf):
controllers.register_version('v2.0')
conf = global_conf.copy()
@@ -68,7 +85,7 @@ def public_app_factory(global_conf, **local_conf):
routers.Extension(False)])
-@logging.fail_gracefully
+@fail_gracefully
def admin_app_factory(global_conf, **local_conf):
conf = global_conf.copy()
conf.update(local_conf)
@@ -79,7 +96,7 @@ def admin_app_factory(global_conf, **local_conf):
routers.Extension()])
-@logging.fail_gracefully
+@fail_gracefully
def public_version_app_factory(global_conf, **local_conf):
conf = global_conf.copy()
conf.update(local_conf)
@@ -87,7 +104,7 @@ def public_version_app_factory(global_conf, **local_conf):
[routers.Versions('public')])
-@logging.fail_gracefully
+@fail_gracefully
def admin_version_app_factory(global_conf, **local_conf):
conf = global_conf.copy()
conf.update(local_conf)
@@ -95,7 +112,7 @@ def admin_version_app_factory(global_conf, **local_conf):
[routers.Versions('admin')])
-@logging.fail_gracefully
+@fail_gracefully
def v3_app_factory(global_conf, **local_conf):
controllers.register_version('v3')
conf = global_conf.copy()
diff --git a/keystone/tests/__init__.py b/keystone/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/keystone/tests/__init__.py
diff --git a/tests/_ldap_livetest.py b/keystone/tests/_ldap_livetest.py
index ead54ea7..4562ccb6 100644
--- a/tests/_ldap_livetest.py
+++ b/keystone/tests/_ldap_livetest.py
@@ -22,7 +22,7 @@ from keystone.common import ldap as ldap_common
from keystone import config
from keystone import exception
from keystone.identity.backends import ldap as identity_ldap
-from keystone import test
+from keystone.tests import core as test
import test_backend_ldap
@@ -87,9 +87,6 @@ class LiveLDAPIdentity(test_backend_ldap.LDAPIdentity):
def tearDown(self):
test.TestCase.tearDown(self)
- def test_user_enable_attribute_mask(self):
- self.skipTest('Test is for Active Directory Only')
-
def test_ldap_dereferencing(self):
alt_users_ldif = {'objectclass': ['top', 'organizationalUnit'],
'ou': 'alt_users'}
@@ -158,3 +155,11 @@ class LiveLDAPIdentity(test_backend_ldap.LDAPIdentity):
alias_dereferencing=deref)
self.assertEqual(ldap.DEREF_SEARCHING,
ldap_wrapper.conn.get_option(ldap.OPT_DEREF))
+
+ def test_user_enable_attribute_mask(self):
+ CONF.ldap.user_enabled_emulation = False
+ CONF.ldap.user_enabled_attribute = 'employeeType'
+ super(LiveLDAPIdentity, self).test_user_enable_attribute_mask()
+
+ def test_create_unicode_user_name(self):
+ self.skipTest('Addressed by bug #1172106')
diff --git a/tests/_ldap_tls_livetest.py b/keystone/tests/_ldap_tls_livetest.py
index f52b6360..f1c43453 100644
--- a/tests/_ldap_tls_livetest.py
+++ b/keystone/tests/_ldap_tls_livetest.py
@@ -21,7 +21,7 @@ import ldap.modlist
from keystone import config
from keystone import exception
from keystone import identity
-from keystone import test
+from keystone.tests import core as test
import _ldap_livetest
diff --git a/tests/_sql_livetest.py b/keystone/tests/_sql_livetest.py
index a271ce7c..a271ce7c 100644
--- a/tests/_sql_livetest.py
+++ b/keystone/tests/_sql_livetest.py
diff --git a/tests/_test_import_auth_token.py b/keystone/tests/_test_import_auth_token.py
index 4e16f9a4..4e16f9a4 100644
--- a/tests/_test_import_auth_token.py
+++ b/keystone/tests/_test_import_auth_token.py
diff --git a/tests/auth_plugin_external_disabled.conf b/keystone/tests/auth_plugin_external_disabled.conf
index fed281d4..fed281d4 100644
--- a/tests/auth_plugin_external_disabled.conf
+++ b/keystone/tests/auth_plugin_external_disabled.conf
diff --git a/tests/auth_plugin_external_domain.conf b/keystone/tests/auth_plugin_external_domain.conf
index b7be122f..b7be122f 100644
--- a/tests/auth_plugin_external_domain.conf
+++ b/keystone/tests/auth_plugin_external_domain.conf
diff --git a/tests/backend_db2.conf b/keystone/tests/backend_db2.conf
index 44032255..44032255 100644
--- a/tests/backend_db2.conf
+++ b/keystone/tests/backend_db2.conf
diff --git a/tests/backend_ldap.conf b/keystone/tests/backend_ldap.conf
index 6b3f8a75..6b3f8a75 100644
--- a/tests/backend_ldap.conf
+++ b/keystone/tests/backend_ldap.conf
diff --git a/tests/backend_ldap_sql.conf b/keystone/tests/backend_ldap_sql.conf
index 8dcfa40d..5579e75d 100644
--- a/tests/backend_ldap_sql.conf
+++ b/keystone/tests/backend_ldap_sql.conf
@@ -34,3 +34,4 @@ driver = keystone.policy.backends.sql.Policy
[trust]
driver = keystone.trust.backends.sql.Trust
+
diff --git a/tests/backend_liveldap.conf b/keystone/tests/backend_liveldap.conf
index 297d96d6..297d96d6 100644
--- a/tests/backend_liveldap.conf
+++ b/keystone/tests/backend_liveldap.conf
diff --git a/tests/backend_mysql.conf b/keystone/tests/backend_mysql.conf
index ee3b276e..ee3b276e 100644
--- a/tests/backend_mysql.conf
+++ b/keystone/tests/backend_mysql.conf
diff --git a/tests/backend_pam.conf b/keystone/tests/backend_pam.conf
index 41f868c7..41f868c7 100644
--- a/tests/backend_pam.conf
+++ b/keystone/tests/backend_pam.conf
diff --git a/tests/backend_postgresql.conf b/keystone/tests/backend_postgresql.conf
index 8468ad33..8468ad33 100644
--- a/tests/backend_postgresql.conf
+++ b/keystone/tests/backend_postgresql.conf
diff --git a/tests/backend_sql.conf b/keystone/tests/backend_sql.conf
index 0baf610c..0baf610c 100644
--- a/tests/backend_sql.conf
+++ b/keystone/tests/backend_sql.conf
diff --git a/tests/backend_sql_disk.conf b/keystone/tests/backend_sql_disk.conf
index 0f8dfea7..0f8dfea7 100644
--- a/tests/backend_sql_disk.conf
+++ b/keystone/tests/backend_sql_disk.conf
diff --git a/tests/backend_tls_liveldap.conf b/keystone/tests/backend_tls_liveldap.conf
index 409af674..409af674 100644
--- a/tests/backend_tls_liveldap.conf
+++ b/keystone/tests/backend_tls_liveldap.conf
diff --git a/keystone/test.py b/keystone/tests/core.py
index 9118b2ea..8d075335 100644
--- a/keystone/test.py
+++ b/keystone/tests/core.py
@@ -40,7 +40,6 @@ from keystone import assignment
from keystone import catalog
from keystone.common import dependency
from keystone.common import kvs
-from keystone.common import logging
from keystone.common import sql
from keystone.common import utils
from keystone.common import wsgi
@@ -49,6 +48,7 @@ from keystone.contrib import ec2
from keystone import credential
from keystone import exception
from keystone import identity
+from keystone.openstack.common import log as logging
from keystone.openstack.common import timeutils
from keystone import policy
from keystone import token
@@ -57,9 +57,9 @@ from keystone import trust
LOG = logging.getLogger(__name__)
-ROOTDIR = os.path.dirname(os.path.abspath(os.curdir))
+ROOTDIR = os.path.dirname(os.path.abspath('..'))
VENDOR = os.path.join(ROOTDIR, 'vendor')
-TESTSDIR = os.path.join(ROOTDIR, 'tests')
+TESTSDIR = os.path.join(ROOTDIR, 'keystone', 'tests')
ETCDIR = os.path.join(ROOTDIR, 'etc')
TMPDIR = os.path.join(TESTSDIR, 'tmp')
@@ -68,9 +68,6 @@ CONF = config.CONF
cd = os.chdir
-logging.getLogger('routes.middleware').level = logging.WARN
-
-
def rootdir(*p):
return os.path.join(ROOTDIR, *p)
diff --git a/tests/default_catalog.templates b/keystone/tests/default_catalog.templates
index f26c949a..f26c949a 100644
--- a/tests/default_catalog.templates
+++ b/keystone/tests/default_catalog.templates
diff --git a/tests/default_fixtures.py b/keystone/tests/default_fixtures.py
index 2695da88..2695da88 100644
--- a/tests/default_fixtures.py
+++ b/keystone/tests/default_fixtures.py
diff --git a/tests/legacy_d5.mysql b/keystone/tests/legacy_d5.mysql
index 57b31feb..57b31feb 100644
--- a/tests/legacy_d5.mysql
+++ b/keystone/tests/legacy_d5.mysql
diff --git a/tests/legacy_d5.sqlite b/keystone/tests/legacy_d5.sqlite
index d96dbf40..d96dbf40 100644
--- a/tests/legacy_d5.sqlite
+++ b/keystone/tests/legacy_d5.sqlite
diff --git a/tests/legacy_diablo.mysql b/keystone/tests/legacy_diablo.mysql
index 543f439f..543f439f 100644
--- a/tests/legacy_diablo.mysql
+++ b/keystone/tests/legacy_diablo.mysql
diff --git a/tests/legacy_diablo.sqlite b/keystone/tests/legacy_diablo.sqlite
index edf15be4..edf15be4 100644
--- a/tests/legacy_diablo.sqlite
+++ b/keystone/tests/legacy_diablo.sqlite
diff --git a/tests/legacy_essex.mysql b/keystone/tests/legacy_essex.mysql
index eade2cbf..eade2cbf 100644
--- a/tests/legacy_essex.mysql
+++ b/keystone/tests/legacy_essex.mysql
diff --git a/tests/legacy_essex.sqlite b/keystone/tests/legacy_essex.sqlite
index 72326d76..72326d76 100644
--- a/tests/legacy_essex.sqlite
+++ b/keystone/tests/legacy_essex.sqlite
diff --git a/tests/test_auth.py b/keystone/tests/test_auth.py
index db5314be..598b11d3 100644
--- a/tests/test_auth.py
+++ b/keystone/tests/test_auth.py
@@ -16,7 +16,7 @@ import copy
import datetime
import uuid
-from keystone import test
+from keystone.tests import core as test
from keystone import auth
from keystone import config
@@ -179,7 +179,8 @@ class AuthBadRequests(AuthTest):
def test_authenticate_password_too_large(self):
"""Verify sending large 'password' raises the right exception."""
- body_dict = _build_user_auth(username='FOO', password='0' * 8193)
+ length = CONF.identity.max_password_length + 1
+ body_dict = _build_user_auth(username='FOO', password='0' * length)
self.assertRaises(exception.ValidationSizeError,
self.controller.authenticate,
{}, body_dict)
diff --git a/tests/test_auth_plugin.conf b/keystone/tests/test_auth_plugin.conf
index edec8f79..edec8f79 100644
--- a/tests/test_auth_plugin.conf
+++ b/keystone/tests/test_auth_plugin.conf
diff --git a/tests/test_auth_plugin.py b/keystone/tests/test_auth_plugin.py
index d158ec46..e3346cf1 100644
--- a/tests/test_auth_plugin.py
+++ b/keystone/tests/test_auth_plugin.py
@@ -16,7 +16,7 @@
import uuid
-from keystone import test
+from keystone.tests import core as test
from keystone import auth
from keystone import exception
diff --git a/tests/test_backend.py b/keystone/tests/test_backend.py
index 7e4d820e..52628985 100644
--- a/tests/test_backend.py
+++ b/keystone/tests/test_backend.py
@@ -17,7 +17,7 @@
import datetime
import uuid
-from keystone import test
+from keystone.tests import core as test
from keystone.catalog import core
from keystone import config
@@ -1453,6 +1453,18 @@ class IdentityTests(object):
tenants = self.identity_api.get_projects_for_user(self.user_foo['id'])
self.assertIn(self.tenant_baz['id'], tenants)
+ def test_add_user_to_project_missing_default_role(self):
+ self.assignment_api.delete_role(CONF.member_role_id)
+ self.assertRaises(exception.RoleNotFound,
+ self.assignment_api.get_role,
+ CONF.member_role_id)
+ self.identity_api.add_user_to_project(self.tenant_baz['id'],
+ self.user_foo['id'])
+ tenants = self.identity_api.get_projects_for_user(self.user_foo['id'])
+ self.assertIn(self.tenant_baz['id'], tenants)
+ default_role = self.assignment_api.get_role(CONF.member_role_id)
+ self.assertIsNotNone(default_role)
+
def test_add_user_to_project_404(self):
self.assertRaises(exception.ProjectNotFound,
self.identity_api.add_user_to_project,
@@ -1616,7 +1628,7 @@ class IdentityTests(object):
tenant)
def test_create_user_long_name_fails(self):
- user = {'id': 'fake1', 'name': 'a' * 65,
+ user = {'id': 'fake1', 'name': 'a' * 256,
'domain_id': DEFAULT_DOMAIN_ID}
self.assertRaises(exception.ValidationError,
self.identity_api.create_user,
@@ -1689,7 +1701,7 @@ class IdentityTests(object):
user = {'id': 'fake1', 'name': 'fake1',
'domain_id': DEFAULT_DOMAIN_ID}
self.identity_api.create_user('fake1', user)
- user['name'] = 'a' * 65
+ user['name'] = 'a' * 256
self.assertRaises(exception.ValidationError,
self.identity_api.update_user,
'fake1',
diff --git a/tests/test_backend_kvs.py b/keystone/tests/test_backend_kvs.py
index d92a7510..34b87c60 100644
--- a/tests/test_backend_kvs.py
+++ b/keystone/tests/test_backend_kvs.py
@@ -15,10 +15,9 @@
# under the License.
import uuid
-from keystone import test
-
from keystone import exception
from keystone import identity
+from keystone.tests import core as test
import default_fixtures
import test_backend
diff --git a/tests/test_backend_ldap.py b/keystone/tests/test_backend_ldap.py
index ec2b2737..6f9cfef9 100644
--- a/tests/test_backend_ldap.py
+++ b/keystone/tests/test_backend_ldap.py
@@ -23,7 +23,7 @@ from keystone.common import sql
from keystone import config
from keystone import exception
from keystone import identity
-from keystone import test
+from keystone.tests import core as test
import default_fixtures
import test_backend
@@ -454,10 +454,12 @@ class LDAPIdentity(test.TestCase, BaseLDAPIdentity):
self.assertNotIn('name', role_ref)
def test_user_enable_attribute_mask(self):
- CONF.ldap.user_enabled_attribute = 'enabled'
CONF.ldap.user_enabled_mask = 2
CONF.ldap.user_enabled_default = 512
self.clear_database()
+ self.load_backends()
+ self.load_fixtures(default_fixtures)
+
user = {'id': 'fake1', 'name': 'fake1', 'enabled': True}
self.identity_api.create_user('fake1', user)
user_ref = self.identity_api.get_user('fake1')
diff --git a/tests/test_backend_memcache.py b/keystone/tests/test_backend_memcache.py
index 7516e0dd..0377c0e6 100644
--- a/tests/test_backend_memcache.py
+++ b/keystone/tests/test_backend_memcache.py
@@ -20,7 +20,7 @@ import uuid
import memcache
-from keystone import test
+from keystone.tests import core as test
from keystone.common import utils
from keystone import exception
diff --git a/tests/test_backend_pam.py b/keystone/tests/test_backend_pam.py
index b66faa9c..65817837 100644
--- a/tests/test_backend_pam.py
+++ b/keystone/tests/test_backend_pam.py
@@ -16,7 +16,7 @@
import uuid
-from keystone import test
+from keystone.tests import core as test
from keystone import config
from keystone.identity.backends import pam as identity_pam
diff --git a/tests/test_backend_sql.py b/keystone/tests/test_backend_sql.py
index 89276e86..24159eb6 100644
--- a/tests/test_backend_sql.py
+++ b/keystone/tests/test_backend_sql.py
@@ -18,11 +18,10 @@ import uuid
import sqlalchemy
-from keystone import test
-
from keystone.common import sql
from keystone import config
from keystone import exception
+from keystone.tests import core as test
import default_fixtures
import test_backend
@@ -82,7 +81,7 @@ class SqlModels(SqlTests):
def test_user_model(self):
cols = (('id', sql.String, 64),
- ('name', sql.String, 64),
+ ('name', sql.String, 255),
('password', sql.String, 128),
('domain_id', sql.String, 64),
('enabled', sql.Boolean, None),
diff --git a/tests/test_backend_templated.py b/keystone/tests/test_backend_templated.py
index bfa19192..603ad82a 100644
--- a/tests/test_backend_templated.py
+++ b/keystone/tests/test_backend_templated.py
@@ -16,7 +16,7 @@
import os
-from keystone import test
+from keystone.tests import core as test
from keystone import exception
diff --git a/tests/test_catalog.py b/keystone/tests/test_catalog.py
index 3c00b1e8..3c00b1e8 100644
--- a/tests/test_catalog.py
+++ b/keystone/tests/test_catalog.py
diff --git a/tests/test_cert_setup.py b/keystone/tests/test_cert_setup.py
index e6c395e9..88fa6d75 100644
--- a/tests/test_cert_setup.py
+++ b/keystone/tests/test_cert_setup.py
@@ -18,7 +18,7 @@
import os
import shutil
-from keystone import test
+from keystone.tests import core as test
from keystone.common import openssl
from keystone import exception
diff --git a/tests/test_config.py b/keystone/tests/test_config.py
index 3165a4f4..28b372a6 100644
--- a/tests/test_config.py
+++ b/keystone/tests/test_config.py
@@ -1,4 +1,4 @@
-from keystone import test
+from keystone.tests import core as test
from keystone import config
from keystone import exception
diff --git a/tests/test_content_types.py b/keystone/tests/test_content_types.py
index ebb5dcef..7c874732 100644
--- a/tests/test_content_types.py
+++ b/keystone/tests/test_content_types.py
@@ -20,7 +20,7 @@ import uuid
from lxml import etree
import webtest
-from keystone import test
+from keystone.tests import core as test
from keystone.common import extension
from keystone.common import serializer
diff --git a/tests/test_contrib_s3_core.py b/keystone/tests/test_contrib_s3_core.py
index e2c328b5..3cf799bc 100644
--- a/tests/test_contrib_s3_core.py
+++ b/keystone/tests/test_contrib_s3_core.py
@@ -16,7 +16,7 @@
import uuid
-from keystone import test
+from keystone.tests import core as test
from keystone.contrib import ec2
from keystone.contrib import s3
diff --git a/tests/test_contrib_stats_core.py b/keystone/tests/test_contrib_stats_core.py
index 907c7d25..567c485e 100644
--- a/tests/test_contrib_stats_core.py
+++ b/keystone/tests/test_contrib_stats_core.py
@@ -17,7 +17,7 @@
from keystone.contrib import stats
from keystone import config
-from keystone import test
+from keystone.tests import core as test
CONF = config.CONF
diff --git a/tests/test_drivers.py b/keystone/tests/test_drivers.py
index c83c1a89..c83c1a89 100644
--- a/tests/test_drivers.py
+++ b/keystone/tests/test_drivers.py
diff --git a/tests/test_exception.py b/keystone/tests/test_exception.py
index d442d572..9658ed19 100644
--- a/tests/test_exception.py
+++ b/keystone/tests/test_exception.py
@@ -16,7 +16,7 @@
import uuid
-from keystone import test
+from keystone.tests import core as test
from keystone.common import wsgi
from keystone import config
diff --git a/tests/test_import_legacy.py b/keystone/tests/test_import_legacy.py
index 9e164099..b3b83c0f 100644
--- a/tests/test_import_legacy.py
+++ b/keystone/tests/test_import_legacy.py
@@ -21,7 +21,7 @@ try:
except ImportError:
from pysqlite2 import dbapi2 as dbapi
-from keystone import test
+from keystone.tests import core as test
from keystone.catalog.backends import templated as catalog_templated
from keystone.common.sql import legacy
diff --git a/tests/test_injection.py b/keystone/tests/test_injection.py
index 36cd0126..36cd0126 100644
--- a/tests/test_injection.py
+++ b/keystone/tests/test_injection.py
diff --git a/tests/test_ipv6.py b/keystone/tests/test_ipv6.py
index 9825a5fa..fa64bc43 100644
--- a/tests/test_ipv6.py
+++ b/keystone/tests/test_ipv6.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from keystone import test
+from keystone.tests import core as test
from keystone.common import environment
from keystone import config
diff --git a/tests/test_keystoneclient.py b/keystone/tests/test_keystoneclient.py
index 38062d4b..0c323ddd 100644
--- a/tests/test_keystoneclient.py
+++ b/keystone/tests/test_keystoneclient.py
@@ -17,11 +17,10 @@
import uuid
import webob
-from keystone import test
-
from keystone import config
from keystone.openstack.common import jsonutils
from keystone.openstack.common import timeutils
+from keystone.tests import core as test
import default_fixtures
@@ -379,6 +378,46 @@ class KeystoneClientTests(object):
client.tokens.authenticate,
token=token_id)
+ def test_disable_tenant_invalidates_token(self):
+ from keystoneclient import exceptions as client_exceptions
+
+ admin_client = self.get_client(admin=True)
+ foo_client = self.get_client(self.user_foo)
+ tenant_bar = admin_client.tenants.get(self.tenant_bar['id'])
+
+ # Disable the tenant.
+ tenant_bar.update(enabled=False)
+
+ # Test that the token has been removed.
+ self.assertRaises(client_exceptions.Unauthorized,
+ foo_client.tokens.authenticate,
+ token=foo_client.auth_token)
+
+ # Test that the user access has been disabled.
+ self.assertRaises(client_exceptions.Unauthorized,
+ self.get_client,
+ self.user_foo)
+
+ def test_delete_tenant_invalidates_token(self):
+ from keystoneclient import exceptions as client_exceptions
+
+ admin_client = self.get_client(admin=True)
+ foo_client = self.get_client(self.user_foo)
+ tenant_bar = admin_client.tenants.get(self.tenant_bar['id'])
+
+ # Delete the tenant.
+ tenant_bar.delete()
+
+ # Test that the token has been removed.
+ self.assertRaises(client_exceptions.Unauthorized,
+ foo_client.tokens.authenticate,
+ token=foo_client.auth_token)
+
+ # Test that the user access has been disabled.
+ self.assertRaises(client_exceptions.Unauthorized,
+ self.get_client,
+ self.user_foo)
+
def test_disable_user_invalidates_token(self):
from keystoneclient import exceptions as client_exceptions
@@ -1166,6 +1205,12 @@ class KcEssex3TestCase(CompatTestCase, KeystoneClientTests):
def test_policy_crud(self):
self.skipTest('N/A due to lack of endpoint CRUD')
+ def test_disable_tenant_invalidates_token(self):
+ self.skipTest('N/A')
+
+ def test_delete_tenant_invalidates_token(self):
+ self.skipTest('N/A')
+
class Kc11TestCase(CompatTestCase, KeystoneClientTests):
def get_checkout(self):
diff --git a/tests/test_keystoneclient_sql.py b/keystone/tests/test_keystoneclient_sql.py
index 166d808c..105d8353 100644
--- a/tests/test_keystoneclient_sql.py
+++ b/keystone/tests/test_keystoneclient_sql.py
@@ -16,10 +16,9 @@
import uuid
-from keystone import test
-
from keystone.common import sql
from keystone import config
+from keystone.tests import core as test
import test_keystoneclient
diff --git a/tests/test_middleware.py b/keystone/tests/test_middleware.py
index 9f9d3fd2..df33d172 100644
--- a/tests/test_middleware.py
+++ b/keystone/tests/test_middleware.py
@@ -16,7 +16,7 @@
import webob
-from keystone import test
+from keystone.tests import core as test
from keystone import config
from keystone import middleware
diff --git a/tests/test_no_admin_token_auth.py b/keystone/tests/test_no_admin_token_auth.py
index ffdaa7a8..3a7113d8 100644
--- a/tests/test_no_admin_token_auth.py
+++ b/keystone/tests/test_no_admin_token_auth.py
@@ -2,7 +2,7 @@
import os
import webtest
-from keystone import test
+from keystone.tests import core as test
def _generate_paste_config():
diff --git a/tests/test_overrides.conf b/keystone/tests/test_overrides.conf
index ef7524b7..aac29f26 100644
--- a/tests/test_overrides.conf
+++ b/keystone/tests/test_overrides.conf
@@ -15,6 +15,6 @@ driver = keystone.trust.backends.kvs.Trust
driver = keystone.token.backends.kvs.Token
[signing]
-certfile = ../examples/pki/certs/signing_cert.pem
-keyfile = ../examples/pki/private/signing_key.pem
-ca_certs = ../examples/pki/certs/cacert.pem
+certfile = ../../examples/pki/certs/signing_cert.pem
+keyfile = ../../examples/pki/private/signing_key.pem
+ca_certs = ../../examples/pki/certs/cacert.pem
diff --git a/tests/test_pki_token_provider.conf b/keystone/tests/test_pki_token_provider.conf
index 255972c3..255972c3 100644
--- a/tests/test_pki_token_provider.conf
+++ b/keystone/tests/test_pki_token_provider.conf
diff --git a/tests/test_policy.py b/keystone/tests/test_policy.py
index 010a5abf..bdf91c94 100644
--- a/tests/test_policy.py
+++ b/keystone/tests/test_policy.py
@@ -19,7 +19,7 @@ import StringIO
import tempfile
import urllib2
-from keystone import test
+from keystone.tests import core as test
from keystone import config
from keystone import exception
diff --git a/tests/test_s3_token_middleware.py b/keystone/tests/test_s3_token_middleware.py
index ec31f2ac..ec31f2ac 100644
--- a/tests/test_s3_token_middleware.py
+++ b/keystone/tests/test_s3_token_middleware.py
diff --git a/tests/test_serializer.py b/keystone/tests/test_serializer.py
index 2024949b..260a533c 100644
--- a/tests/test_serializer.py
+++ b/keystone/tests/test_serializer.py
@@ -17,7 +17,7 @@
import copy
from keystone.common import serializer
-from keystone import test
+from keystone.tests import core as test
class XmlSerializerTestCase(test.TestCase):
diff --git a/tests/test_singular_plural.py b/keystone/tests/test_singular_plural.py
index ea3ad27c..ea3ad27c 100644
--- a/tests/test_singular_plural.py
+++ b/keystone/tests/test_singular_plural.py
diff --git a/tests/test_sizelimit.py b/keystone/tests/test_sizelimit.py
index abd2b639..a37b0e31 100644
--- a/tests/test_sizelimit.py
+++ b/keystone/tests/test_sizelimit.py
@@ -14,7 +14,7 @@
import webob
-from keystone import test
+from keystone.tests import core as test
from keystone import config
from keystone import exception
diff --git a/tests/test_sql_core.py b/keystone/tests/test_sql_core.py
index e60005f5..e3379152 100644
--- a/tests/test_sql_core.py
+++ b/keystone/tests/test_sql_core.py
@@ -14,7 +14,7 @@
from keystone.common import sql
-from keystone import test
+from keystone.tests import core as test
class CallbackMonitor:
diff --git a/keystone/tests/test_sql_migrate_extensions.py b/keystone/tests/test_sql_migrate_extensions.py
new file mode 100644
index 00000000..4a529559
--- /dev/null
+++ b/keystone/tests/test_sql_migrate_extensions.py
@@ -0,0 +1,47 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012 OpenStack LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+To run these tests against a live database:
+1. Modify the file `tests/backend_sql.conf` to use the connection for your
+ live database
+2. Set up a blank, live database.
+3. run the tests using
+ ./run_tests.sh -N test_sql_upgrade
+ WARNING::
+ Your database will be wiped.
+ Do not do this against a Database with valuable data as
+ all data will be lost.
+"""
+
+from keystone.contrib import example
+
+import test_sql_upgrade
+
+
+class SqlUpgradeExampleExtension(test_sql_upgrade.SqlMigrateBase):
+ def repo_package(self):
+ return example
+
+ def test_upgrade(self):
+ self.assertTableDoesNotExist('example')
+ self.upgrade(1, repository=self.repo_path)
+ self.assertTableColumns('example', ['id', 'type', 'extra'])
+
+ def test_downgrade(self):
+ self.upgrade(1, repository=self.repo_path)
+ self.assertTableColumns('example', ['id', 'type', 'extra'])
+ self.downgrade(0, repository=self.repo_path)
+ self.assertTableDoesNotExist('example')
diff --git a/tests/test_sql_upgrade.py b/keystone/tests/test_sql_upgrade.py
index cf82b814..7d60ced4 100644
--- a/tests/test_sql_upgrade.py
+++ b/keystone/tests/test_sql_upgrade.py
@@ -15,8 +15,8 @@
# under the License.
"""
To run these tests against a live database:
-1. Modify the file `tests/backend_sql.conf` to use the connection for your
- live database
+1. Modify the file `keystone/tests/backend_sql.conf` to use the connection for
+ your live database
2. Set up a blank, live database.
3. run the tests using
./run_tests.sh -N test_sql_upgrade
@@ -32,7 +32,7 @@ import uuid
from migrate.versioning import api as versioning_api
import sqlalchemy
-from keystone import test
+from keystone.tests import core as test
from keystone.common import sql
from keystone.common.sql import migration
@@ -45,8 +45,7 @@ CONF = config.CONF
DEFAULT_DOMAIN_ID = CONF.identity.default_domain_id
-class SqlUpgradeTests(test.TestCase):
-
+class SqlMigrateBase(test.TestCase):
def initialize_sql(self):
self.metadata = sqlalchemy.MetaData()
self.metadata.bind = self.engine
@@ -55,12 +54,15 @@ class SqlUpgradeTests(test.TestCase):
test.testsdir('test_overrides.conf'),
test.testsdir('backend_sql.conf')]
- #override this to sepcify the complete list of configuration files
+ #override this to specify the complete list of configuration files
def config_files(self):
return self._config_file_list
+ def repo_package(self):
+ return None
+
def setUp(self):
- super(SqlUpgradeTests, self).setUp()
+ super(SqlMigrateBase, self).setUp()
self.config(self.config_files())
self.base = sql.Base()
@@ -71,7 +73,7 @@ class SqlUpgradeTests(test.TestCase):
autocommit=False)
self.initialize_sql()
- self.repo_path = migration._find_migrate_repo()
+ self.repo_path = migration.find_migrate_repo(self.repo_package())
self.schema = versioning_api.ControlledSchema.create(
self.engine,
self.repo_path, 0)
@@ -85,7 +87,64 @@ class SqlUpgradeTests(test.TestCase):
autoload=True)
self.downgrade(0)
table.drop(self.engine, checkfirst=True)
- super(SqlUpgradeTests, self).tearDown()
+ super(SqlMigrateBase, self).tearDown()
+
+ def select_table(self, name):
+ table = sqlalchemy.Table(name,
+ self.metadata,
+ autoload=True)
+ s = sqlalchemy.select([table])
+ return s
+
+ def assertTableExists(self, table_name):
+ try:
+ self.select_table(table_name)
+ except sqlalchemy.exc.NoSuchTableError:
+ raise AssertionError('Table "%s" does not exist' % table_name)
+
+ def assertTableDoesNotExist(self, table_name):
+ """Asserts that a given table exists cannot be selected by name."""
+ # Switch to a different metadata otherwise you might still
+ # detect renamed or dropped tables
+ try:
+ temp_metadata = sqlalchemy.MetaData()
+ temp_metadata.bind = self.engine
+ sqlalchemy.Table(table_name, temp_metadata, autoload=True)
+ except sqlalchemy.exc.NoSuchTableError:
+ pass
+ else:
+ raise AssertionError('Table "%s" already exists' % table_name)
+
+ def upgrade(self, *args, **kwargs):
+ self._migrate(*args, **kwargs)
+
+ def downgrade(self, *args, **kwargs):
+ self._migrate(*args, downgrade=True, **kwargs)
+
+ def _migrate(self, version, repository=None, downgrade=False,
+ current_schema=None):
+ repository = repository or self.repo_path
+ err = ''
+ version = versioning_api._migrate_version(self.schema,
+ version,
+ not downgrade,
+ err)
+ if not current_schema:
+ current_schema = self.schema
+ changeset = current_schema.changeset(version)
+ for ver, change in changeset:
+ self.schema.runchange(ver, change, changeset.step)
+ self.assertEqual(self.schema.version, version)
+
+ def assertTableColumns(self, table_name, expected_cols):
+ """Asserts that the table contains the expected set of columns."""
+ self.initialize_sql()
+ table = self.select_table(table_name)
+ actual_cols = [col.name for col in table.columns]
+ self.assertEqual(expected_cols, actual_cols, '%s table' % table_name)
+
+
+class SqlUpgradeTests(SqlMigrateBase):
def test_blank_db_to_start(self):
self.assertTableDoesNotExist('user')
@@ -108,13 +167,6 @@ class SqlUpgradeTests(test.TestCase):
self.downgrade(x - 1)
self.upgrade(x)
- def assertTableColumns(self, table_name, expected_cols):
- """Asserts that the table contains the expected set of columns."""
- self.initialize_sql()
- table = self.select_table(table_name)
- actual_cols = [col.name for col in table.columns]
- self.assertEqual(expected_cols, actual_cols, '%s table' % table_name)
-
def test_upgrade_add_initial_tables(self):
self.upgrade(1)
self.assertTableColumns("user", ["id", "name", "extra"])
@@ -504,6 +556,42 @@ class SqlUpgradeTests(test.TestCase):
insert.execute(d)
session.commit()
+ def test_upgrade_31_to_32(self):
+ self.upgrade(32)
+
+ user_table = self.select_table("user")
+ self.assertEquals(user_table.c.name.type.length, 255)
+
+ def test_downgrade_32_to_31(self):
+ self.upgrade(32)
+ session = self.Session()
+ # NOTE(aloga): we need a different metadata object
+ user_table = sqlalchemy.Table('user',
+ sqlalchemy.MetaData(),
+ autoload=True,
+ autoload_with=self.engine)
+ user_id = uuid.uuid4().hex
+ ins = user_table.insert().values(
+ {'id': user_id,
+ 'name': 'a' * 255,
+ 'password': uuid.uuid4().hex,
+ 'enabled': True,
+ 'domain_id': DEFAULT_DOMAIN_ID,
+ 'extra': '{}'})
+ session.execute(ins)
+ session.commit()
+
+ self.downgrade(31)
+ # Check that username has been truncated
+ q = session.query(user_table.c.name)
+ q = q.filter(user_table.c.id == user_id)
+ r = q.one()
+ user_name = r[0]
+ self.assertEquals(len(user_name), 64)
+
+ user_table = self.select_table("user")
+ self.assertEquals(user_table.c.name.type.length, 64)
+
def test_downgrade_to_0(self):
self.upgrade(self.max_version)
@@ -1186,6 +1274,24 @@ class SqlUpgradeTests(test.TestCase):
self.assertEqual(cred.user_id,
ec2_credential['user_id'])
+ def test_drop_credential_indexes(self):
+ self.upgrade(31)
+ table = sqlalchemy.Table('credential', self.metadata, autoload=True)
+ self.assertEqual(len(table.indexes), 0)
+
+ def test_downgrade_30(self):
+ self.upgrade(31)
+ self.downgrade(30)
+ table = sqlalchemy.Table('credential', self.metadata, autoload=True)
+ index_data = [(idx.name, idx.columns.keys())
+ for idx in table.indexes]
+ if self.engine.name == 'mysql':
+ self.assertIn(('user_id', ['user_id']), index_data)
+ self.assertIn(('credential_project_id_fkey', ['project_id']),
+ index_data)
+ else:
+ self.assertEqual(len(index_data), 0)
+
def populate_user_table(self, with_pass_enab=False,
with_pass_enab_domain=False):
# Populate the appropriate fields in the user
@@ -1284,50 +1390,6 @@ class SqlUpgradeTests(test.TestCase):
'extra': json.dumps(extra)})
self.engine.execute(ins)
- def select_table(self, name):
- table = sqlalchemy.Table(name,
- self.metadata,
- autoload=True)
- s = sqlalchemy.select([table])
- return s
-
- def assertTableExists(self, table_name):
- try:
- self.select_table(table_name)
- except sqlalchemy.exc.NoSuchTableError:
- raise AssertionError('Table "%s" does not exist' % table_name)
-
- def assertTableDoesNotExist(self, table_name):
- """Asserts that a given table exists cannot be selected by name."""
- # Switch to a different metadata otherwise you might still
- # detect renamed or dropped tables
- try:
- temp_metadata = sqlalchemy.MetaData()
- temp_metadata.bind = self.engine
- sqlalchemy.Table(table_name, temp_metadata, autoload=True)
- except sqlalchemy.exc.NoSuchTableError:
- pass
- else:
- raise AssertionError('Table "%s" already exists' % table_name)
-
- def upgrade(self, *args, **kwargs):
- self._migrate(*args, **kwargs)
-
- def downgrade(self, *args, **kwargs):
- self._migrate(*args, downgrade=True, **kwargs)
-
- def _migrate(self, version, repository=None, downgrade=False):
- repository = repository or self.repo_path
- err = ''
- version = versioning_api._migrate_version(self.schema,
- version,
- not downgrade,
- err)
- changeset = self.schema.changeset(version)
- for ver, change in changeset:
- self.schema.runchange(ver, change, changeset.step)
- self.assertEqual(self.schema.version, version)
-
def _mysql_check_all_tables_innodb(self):
database = self.engine.url.database
diff --git a/tests/test_ssl.py b/keystone/tests/test_ssl.py
index 8de5cc19..cb6b5fdc 100644
--- a/tests/test_ssl.py
+++ b/keystone/tests/test_ssl.py
@@ -18,7 +18,7 @@
import os
import ssl
-from keystone import test
+from keystone.tests import core as test
from keystone.common import environment
from keystone import config
diff --git a/tests/test_token_bind.py b/keystone/tests/test_token_bind.py
index 20488a91..ae398ea1 100644
--- a/tests/test_token_bind.py
+++ b/keystone/tests/test_token_bind.py
@@ -15,7 +15,7 @@
from keystone.common import wsgi
from keystone import config
from keystone import exception
-from keystone import test
+from keystone.tests import core as test
CONF = config.CONF
diff --git a/tests/test_token_provider.py b/keystone/tests/test_token_provider.py
index a7e92717..08fab35d 100644
--- a/tests/test_token_provider.py
+++ b/keystone/tests/test_token_provider.py
@@ -17,7 +17,7 @@
import uuid
from keystone import exception
-from keystone import test
+from keystone.tests import core as test
from keystone import token
diff --git a/tests/test_url_middleware.py b/keystone/tests/test_url_middleware.py
index 2a36e8c2..436eb8d4 100644
--- a/tests/test_url_middleware.py
+++ b/keystone/tests/test_url_middleware.py
@@ -16,7 +16,7 @@
import webob
-from keystone import test
+from keystone.tests import core as test
from keystone import middleware
diff --git a/tests/test_utils.py b/keystone/tests/test_utils.py
index 4a65bea1..19535a7b 100644
--- a/tests/test_utils.py
+++ b/keystone/tests/test_utils.py
@@ -29,7 +29,7 @@
# License for the specific language governing permissions and limitations
# under the License.
-from keystone import test
+from keystone.tests import core as test
from keystone.common import utils
diff --git a/tests/test_uuid_token_provider.conf b/keystone/tests/test_uuid_token_provider.conf
index d127ea3b..d127ea3b 100644
--- a/tests/test_uuid_token_provider.conf
+++ b/keystone/tests/test_uuid_token_provider.conf
diff --git a/tests/test_v3.py b/keystone/tests/test_v3.py
index 4f00de7d..7db14c84 100644
--- a/tests/test_v3.py
+++ b/keystone/tests/test_v3.py
@@ -4,13 +4,12 @@ import uuid
from lxml import etree
import webtest
-from keystone import test
-
from keystone import auth
from keystone.common import serializer
from keystone import config
from keystone.openstack.common import timeutils
from keystone.policy.backends import rules
+from keystone.tests import core as test
import test_content_types
diff --git a/tests/test_v3_auth.py b/keystone/tests/test_v3_auth.py
index 11d66700..1f4425ce 100644
--- a/tests/test_v3_auth.py
+++ b/keystone/tests/test_v3_auth.py
@@ -19,7 +19,7 @@ from keystone import auth
from keystone.common import cms
from keystone import config
from keystone import exception
-from keystone import test
+from keystone.tests import core as test
import test_v3
@@ -545,6 +545,67 @@ class TestTokenRevoking(test_v3.RestfulTestCase):
headers={'X-Subject-Token': token},
expected_status=204)
+ def test_disabling_project_revokes_token(self):
+ resp = self.post(
+ '/auth/tokens',
+ body=self.build_authentication_request(
+ user_id=self.user3['id'],
+ password=self.user3['password'],
+ project_id=self.projectA['id']))
+ token = resp.headers.get('X-Subject-Token')
+
+ # confirm token is valid
+ self.head('/auth/tokens',
+ headers={'X-Subject-Token': token},
+ expected_status=204)
+
+ # disable the project, which should invalidate the token
+ self.patch(
+ '/projects/%(project_id)s' % {'project_id': self.projectA['id']},
+ body={'project': {'enabled': False}})
+
+ # user should no longer have access to the project
+ self.head('/auth/tokens',
+ headers={'X-Subject-Token': token},
+ expected_status=401)
+ resp = self.post(
+ '/auth/tokens',
+ body=self.build_authentication_request(
+ user_id=self.user3['id'],
+ password=self.user3['password'],
+ project_id=self.projectA['id']),
+ expected_status=401)
+
+ def test_deleting_project_revokes_token(self):
+ resp = self.post(
+ '/auth/tokens',
+ body=self.build_authentication_request(
+ user_id=self.user3['id'],
+ password=self.user3['password'],
+ project_id=self.projectA['id']))
+ token = resp.headers.get('X-Subject-Token')
+
+ # confirm token is valid
+ self.head('/auth/tokens',
+ headers={'X-Subject-Token': token},
+ expected_status=204)
+
+ # delete the project, which should invalidate the token
+ self.delete(
+ '/projects/%(project_id)s' % {'project_id': self.projectA['id']})
+
+ # user should no longer have access to the project
+ self.head('/auth/tokens',
+ headers={'X-Subject-Token': token},
+ expected_status=401)
+ resp = self.post(
+ '/auth/tokens',
+ body=self.build_authentication_request(
+ user_id=self.user3['id'],
+ password=self.user3['password'],
+ project_id=self.projectA['id']),
+ expected_status=401)
+
def test_deleting_group_grant_revokes_tokens(self):
"""Test deleting a group grant revokes tokens.
diff --git a/tests/test_v3_catalog.py b/keystone/tests/test_v3_catalog.py
index 408670ec..408670ec 100644
--- a/tests/test_v3_catalog.py
+++ b/keystone/tests/test_v3_catalog.py
diff --git a/tests/test_v3_credential.py b/keystone/tests/test_v3_credential.py
index 6040cca3..6040cca3 100644
--- a/tests/test_v3_credential.py
+++ b/keystone/tests/test_v3_credential.py
diff --git a/tests/test_v3_identity.py b/keystone/tests/test_v3_identity.py
index f1e19c42..f1e19c42 100644
--- a/tests/test_v3_identity.py
+++ b/keystone/tests/test_v3_identity.py
diff --git a/tests/test_v3_policy.py b/keystone/tests/test_v3_policy.py
index d988efd2..d988efd2 100644
--- a/tests/test_v3_policy.py
+++ b/keystone/tests/test_v3_policy.py
diff --git a/tests/test_v3_protection.py b/keystone/tests/test_v3_protection.py
index 38e32813..38e32813 100644
--- a/tests/test_v3_protection.py
+++ b/keystone/tests/test_v3_protection.py
diff --git a/tests/test_versions.py b/keystone/tests/test_versions.py
index c5864c37..933fb246 100644
--- a/tests/test_versions.py
+++ b/keystone/tests/test_versions.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from keystone import test
+from keystone.tests import core as test
from keystone import config
from keystone import controllers
diff --git a/tests/test_wsgi.py b/keystone/tests/test_wsgi.py
index 003f7571..781159e2 100644
--- a/tests/test_wsgi.py
+++ b/keystone/tests/test_wsgi.py
@@ -14,11 +14,10 @@
# License for the specific language governing permissions and limitations
# under the License.
-from keystone import test
-
from keystone.common import wsgi
from keystone import exception
from keystone.openstack.common import jsonutils
+from keystone.tests import core as test
class FakeApp(wsgi.Application):
@@ -37,37 +36,6 @@ class BaseWSGITest(test.TestCase):
req.environ['wsgiorg.routing_args'] = [None, args]
return req
- def test_mask_password(self):
- message = ("test = 'password': 'aaaaaa', 'param1': 'value1', "
- "\"new_password\": 'bbbbbb'")
- self.assertEqual(wsgi.mask_password(message, True),
- u"test = 'password': '***', 'param1': 'value1', "
- "\"new_password\": '***'")
-
- message = "test = 'password' : 'aaaaaa'"
- self.assertEqual(wsgi.mask_password(message, False, '111'),
- "test = 'password' : '111'")
-
- message = u"test = u'password' : u'aaaaaa'"
- self.assertEqual(wsgi.mask_password(message, True),
- u"test = u'password' : u'***'")
-
- message = 'test = "password" : "aaaaaaaaa"'
- self.assertEqual(wsgi.mask_password(message),
- 'test = "password" : "***"')
-
- message = 'test = "original_password" : "aaaaaaaaa"'
- self.assertEqual(wsgi.mask_password(message),
- 'test = "original_password" : "***"')
-
- message = 'test = "original_password" : ""'
- self.assertEqual(wsgi.mask_password(message),
- 'test = "original_password" : "***"')
-
- message = 'test = "param1" : "value"'
- self.assertEqual(wsgi.mask_password(message),
- 'test = "param1" : "value"')
-
class ApplicationTest(BaseWSGITest):
def test_response_content_type(self):
@@ -210,3 +178,36 @@ class MiddlewareTest(BaseWSGITest):
app = factory(self.app)
self.assertIn("testkey", app.kwargs)
self.assertEquals("test", app.kwargs["testkey"])
+
+
+class WSGIFunctionTest(test.TestCase):
+ def test_mask_password(self):
+ message = ("test = 'password': 'aaaaaa', 'param1': 'value1', "
+ "\"new_password\": 'bbbbbb'")
+ self.assertEqual(wsgi.mask_password(message, True),
+ u"test = 'password': '***', 'param1': 'value1', "
+ "\"new_password\": '***'")
+
+ message = "test = 'password' : 'aaaaaa'"
+ self.assertEqual(wsgi.mask_password(message, False, '111'),
+ "test = 'password' : '111'")
+
+ message = u"test = u'password' : u'aaaaaa'"
+ self.assertEqual(wsgi.mask_password(message, True),
+ u"test = u'password' : u'***'")
+
+ message = 'test = "password" : "aaaaaaaaa"'
+ self.assertEqual(wsgi.mask_password(message),
+ 'test = "password" : "***"')
+
+ message = 'test = "original_password" : "aaaaaaaaa"'
+ self.assertEqual(wsgi.mask_password(message),
+ 'test = "original_password" : "***"')
+
+ message = 'test = "original_password" : ""'
+ self.assertEqual(wsgi.mask_password(message),
+ 'test = "original_password" : "***"')
+
+ message = 'test = "param1" : "value"'
+ self.assertEqual(wsgi.mask_password(message),
+ 'test = "param1" : "value"')
diff --git a/keystone/tests/tmp/.gitkeep b/keystone/tests/tmp/.gitkeep
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/keystone/tests/tmp/.gitkeep
diff --git a/keystone/token/backends/kvs.py b/keystone/token/backends/kvs.py
index c3c3e769..171d77df 100644
--- a/keystone/token/backends/kvs.py
+++ b/keystone/token/backends/kvs.py
@@ -17,8 +17,8 @@
import copy
from keystone.common import kvs
-from keystone.common import logging
from keystone import exception
+from keystone.openstack.common import log as logging
from keystone.openstack.common import timeutils
from keystone import token
diff --git a/keystone/token/backends/memcache.py b/keystone/token/backends/memcache.py
index 06e89d60..a07a516b 100644
--- a/keystone/token/backends/memcache.py
+++ b/keystone/token/backends/memcache.py
@@ -19,11 +19,11 @@ import copy
import memcache
-from keystone.common import logging
from keystone.common import utils
from keystone import config
from keystone import exception
from keystone.openstack.common import jsonutils
+from keystone.openstack.common import log as logging
from keystone.openstack.common import timeutils
from keystone import token
diff --git a/keystone/token/backends/sql.py b/keystone/token/backends/sql.py
index 0e8a916d..82eab651 100644
--- a/keystone/token/backends/sql.py
+++ b/keystone/token/backends/sql.py
@@ -17,7 +17,6 @@
import copy
import datetime
-
from keystone.common import sql
from keystone import exception
from keystone.openstack.common import timeutils
@@ -30,9 +29,13 @@ class TokenModel(sql.ModelBase, sql.DictBase):
id = sql.Column(sql.String(64), primary_key=True)
expires = sql.Column(sql.DateTime(), default=None)
extra = sql.Column(sql.JsonBlob())
- valid = sql.Column(sql.Boolean(), default=True)
+ valid = sql.Column(sql.Boolean(), default=True, nullable=False)
user_id = sql.Column(sql.String(64))
- trust_id = sql.Column(sql.String(64), nullable=True)
+ trust_id = sql.Column(sql.String(64))
+ __table_args__ = (
+ sql.Index('ix_token_expires', 'expires'),
+ sql.Index('ix_token_valid', 'valid')
+ )
class Token(sql.Base, token.Driver):
diff --git a/keystone/token/controllers.py b/keystone/token/controllers.py
index 9ebc29fe..954ff8e8 100644
--- a/keystone/token/controllers.py
+++ b/keystone/token/controllers.py
@@ -3,11 +3,10 @@ import json
from keystone.common import cms
from keystone.common import controller
from keystone.common import dependency
-from keystone.common import logging
-from keystone.common import utils
from keystone.common import wsgi
from keystone import config
from keystone import exception
+from keystone.openstack.common import log as logging
from keystone.openstack.common import timeutils
from keystone.token import core
from keystone.token import provider as token_provider
@@ -215,10 +214,9 @@ class Auth(controller.V2Controller):
attribute='password', target='passwordCredentials')
password = auth['passwordCredentials']['password']
- max_pw_size = utils.MAX_PASSWORD_LENGTH
- if password and len(password) > max_pw_size:
- raise exception.ValidationSizeError(attribute='password',
- size=max_pw_size)
+ if password and len(password) > CONF.identity.max_password_length:
+ raise exception.ValidationSizeError(
+ attribute='password', size=CONF.identity.max_password_length)
if ("userId" not in auth['passwordCredentials'] and
"username" not in auth['passwordCredentials']):
diff --git a/keystone/token/core.py b/keystone/token/core.py
index bc27b80d..3959586b 100644
--- a/keystone/token/core.py
+++ b/keystone/token/core.py
@@ -21,10 +21,10 @@ import datetime
from keystone.common import cms
from keystone.common import dependency
-from keystone.common import logging
from keystone.common import manager
from keystone import config
from keystone import exception
+from keystone.openstack.common import log as logging
from keystone.openstack.common import timeutils
diff --git a/keystone/token/provider.py b/keystone/token/provider.py
index 2864be6f..f2acb0e1 100644
--- a/keystone/token/provider.py
+++ b/keystone/token/provider.py
@@ -18,10 +18,10 @@
from keystone.common import dependency
-from keystone.common import logging
from keystone.common import manager
from keystone import config
from keystone import exception
+from keystone.openstack.common import log as logging
CONF = config.CONF
diff --git a/keystone/token/providers/pki.py b/keystone/token/providers/pki.py
index 81abe5d4..64dde473 100644
--- a/keystone/token/providers/pki.py
+++ b/keystone/token/providers/pki.py
@@ -20,9 +20,9 @@ import json
from keystone.common import cms
from keystone.common import environment
-from keystone.common import logging
from keystone import config
from keystone import exception
+from keystone.openstack.common import log as logging
from keystone.token.providers import uuid
diff --git a/keystone/trust/backends/sql.py b/keystone/trust/backends/sql.py
index daa8e3f7..9e92ad71 100644
--- a/keystone/trust/backends/sql.py
+++ b/keystone/trust/backends/sql.py
@@ -26,11 +26,11 @@ class TrustModel(sql.ModelBase, sql.DictBase):
'project_id', 'impersonation', 'expires_at']
id = sql.Column(sql.String(64), primary_key=True)
#user id Of owner
- trustor_user_id = sql.Column(sql.String(64), unique=False, nullable=False,)
+ trustor_user_id = sql.Column(sql.String(64), nullable=False,)
#user_id of user allowed to consume this preauth
- trustee_user_id = sql.Column(sql.String(64), unique=False, nullable=False)
- project_id = sql.Column(sql.String(64), unique=False, nullable=True)
- impersonation = sql.Column(sql.Boolean)
+ trustee_user_id = sql.Column(sql.String(64), nullable=False)
+ project_id = sql.Column(sql.String(64))
+ impersonation = sql.Column(sql.Boolean, nullable=False)
deleted_at = sql.Column(sql.DateTime)
expires_at = sql.Column(sql.DateTime)
extra = sql.Column(sql.JsonBlob())
diff --git a/keystone/trust/controllers.py b/keystone/trust/controllers.py
index 7a94fe29..3d8df459 100644
--- a/keystone/trust/controllers.py
+++ b/keystone/trust/controllers.py
@@ -2,10 +2,10 @@ import uuid
from keystone.common import controller
from keystone.common import dependency
-from keystone.common import logging
from keystone import config
from keystone import exception
from keystone import identity
+from keystone.openstack.common import log as logging
from keystone.openstack.common import timeutils
diff --git a/keystone/trust/core.py b/keystone/trust/core.py
index 5c4fc90f..e4ff74de 100644
--- a/keystone/trust/core.py
+++ b/keystone/trust/core.py
@@ -17,10 +17,10 @@
"""Main entry point into the Identity service."""
from keystone.common import dependency
-from keystone.common import logging
from keystone.common import manager
from keystone import config
from keystone import exception
+from keystone.openstack.common import log as logging
CONF = config.CONF
diff --git a/requirements.txt b/requirements.txt
index e54bb6a0..b57a91fe 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,6 +5,7 @@ pam>=0.1.4
WebOb>=1.0.8
eventlet
greenlet
+netaddr
PasteDeploy
paste
routes
@@ -15,3 +16,4 @@ lxml
iso8601>=0.1.4
python-keystoneclient>=0.3.0
oslo.config>=1.1.0
+Babel>=0.9.6
diff --git a/run_tests.sh b/run_tests.sh
index f9b8b7c4..7916b4d3 100755
--- a/run_tests.sh
+++ b/run_tests.sh
@@ -27,8 +27,9 @@ function usage {
echo " -x, --stop Stop running tests after the first error or failure."
echo " -f, --force Force a clean re-build of the virtual environment. Useful when dependencies have been added."
echo " -u, --update Update the virtual environment with any newer package versions"
- echo " -p, --pep8 Just run pep8"
- echo " -P, --no-pep8 Don't run pep8"
+ echo " -p, --pep8 Just run flake8"
+ echo " -8, --8 Just run flake8, don't show PEP8 text for each error"
+ echo " -P, --no-pep8 Don't run flake8"
echo " -c, --coverage Generate coverage report"
echo " -h, --help Print this usage message"
echo " -xintegration Ignore all keystoneclient test cases (integration tests)"
@@ -95,10 +96,10 @@ fi
function cleanup_test_db {
# Default test settings will leave around some test*.db files
- # TODO(termie): this could probably be moved into tests/__init__.py
+ # TODO(termie): this could probably be moved into keystone/tests/__init__.py
# but there have been some issues with creating that
# file for some users
- rm -f tests/test*.db
+ rm -f keystone/tests/*.db
}
function run_tests {
diff --git a/setup.cfg b/setup.cfg
index 83d43963..8bce3b3d 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -64,4 +64,4 @@ detailed-errors=1
cover-package = keystone
cover-html = true
cover-erase = true
-where=tests
+where=keystone/tests
diff --git a/test-requirements.txt b/test-requirements.txt
index 48ac1280..223c4456 100644
--- a/test-requirements.txt
+++ b/test-requirements.txt
@@ -40,8 +40,5 @@ keyring
netifaces
-# For translations processing
-Babel
-
# For documentation
oslo.sphinx