summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--doc/source/developing.rst30
-rw-r--r--etc/keystone.conf.sample3
-rw-r--r--keystone/assignment/backends/ldap.py25
-rw-r--r--keystone/assignment/core.py20
-rw-r--r--keystone/catalog/backends/sql.py77
-rw-r--r--keystone/cli.py54
-rw-r--r--keystone/common/config.py1
-rw-r--r--keystone/common/ldap/core.py11
-rw-r--r--keystone/common/sql/core.py5
-rw-r--r--keystone/common/sql/migrate_repo/versions/031_drop_credential_indexes.py40
-rw-r--r--keystone/common/sql/migration.py42
-rw-r--r--keystone/common/utils.py13
-rw-r--r--keystone/contrib/example/__init__.py0
-rw-r--r--keystone/contrib/example/migrate_repo/__init__.py0
-rw-r--r--keystone/contrib/example/migrate_repo/migrate.cfg25
-rw-r--r--keystone/contrib/example/migrate_repo/versions/001_example_table.py45
-rw-r--r--keystone/contrib/example/migrate_repo/versions/__init__.py0
-rw-r--r--keystone/identity/backends/ldap.py33
-rw-r--r--keystone/openstack/common/gettextutils.py263
-rw-r--r--keystone/openstack/common/importutils.py7
-rw-r--r--keystone/openstack/common/jsonutils.py67
-rw-r--r--keystone/openstack/common/local.py47
-rw-r--r--keystone/openstack/common/log.py559
-rw-r--r--keystone/openstack/common/timeutils.py32
-rw-r--r--keystone/token/backends/sql.py9
-rw-r--r--keystone/token/controllers.py8
-rw-r--r--keystone/trust/backends/sql.py8
-rw-r--r--requirements.txt1
-rw-r--r--tests/test_auth.py3
-rw-r--r--tests/test_backend.py12
-rw-r--r--tests/test_sql_migrate_extensions.py47
-rw-r--r--tests/test_sql_upgrade.py140
-rw-r--r--tests/test_wsgi.py64
33 files changed, 1431 insertions, 260 deletions
diff --git a/doc/source/developing.rst b/doc/source/developing.rst
index c14ef7ab..2cf4b98e 100644
--- a/doc/source/developing.rst
+++ b/doc/source/developing.rst
@@ -71,6 +71,36 @@ place::
.. _`python-keystoneclient`: https://github.com/openstack/python-keystoneclient
+Database Schema Migrations
+--------------------------
+
+Keystone uses SQLAlchemy-migrate
+_`SQLAlchemy-migrate`:http://code.google.com/p/sqlalchemy-migrate/ to migrate the SQL database
+between revisions. For core components, the migrations are kept in a central
+repository under keystone/common/sql/migrate_repo.
+
+Extensions should be created as directories under `keystone/contrib`. An
+extension that requires sql migrations should not change the common repository,
+but should instead have its own repository. This repository must be in the
+extension's directory in `keystone/contrib/<extension>/migrate_repo.` In
+addition it needs a subdirectory named `versions`. For example, if the
+extension name is `my_extension` then the directory structure would be
+`keystone/contrib/my_extension/migrate_repo/versions/`. For the migration
+o work, both the migrate_repo and versions subdirectories must have empty
+__init__.py files. SQLAlchemy-migrate will look for a configuration file in
+the migrate_repo named migrate.cfg. This conforms to a Key/value ini file
+format. A sample config file with the minimal set of values is::
+
+ [db_settings]
+ repository_id=my_extension
+ version_table=migrate_version
+ required_dbs=[]
+
+The directory `keystone/contrib/example` contains a sample extension migration.
+
+Migrations for extension must be explicitly run. To run a migration for a specific
+extension, run `keystone-manage --extension <name> db_sync`.
+
Initial Sample Data
-------------------
diff --git a/etc/keystone.conf.sample b/etc/keystone.conf.sample
index a49a9a5e..90efe5f6 100644
--- a/etc/keystone.conf.sample
+++ b/etc/keystone.conf.sample
@@ -100,6 +100,9 @@
# exist to order to maintain support for your v2 clients.
# default_domain_id = default
+# Maximum supported length for user passwords; decrease to improve performance.
+# max_password_length = 4096
+
[credential]
# driver = keystone.credential.backends.sql.Credential
diff --git a/keystone/assignment/backends/ldap.py b/keystone/assignment/backends/ldap.py
index 9b273e40..718d38c3 100644
--- a/keystone/assignment/backends/ldap.py
+++ b/keystone/assignment/backends/ldap.py
@@ -263,28 +263,19 @@ class ProjectApi(common_ldap.EnabledEmuMixIn, common_ldap.BaseLdap):
DEFAULT_OBJECTCLASS = 'groupOfNames'
DEFAULT_ID_ATTR = 'cn'
DEFAULT_MEMBER_ATTRIBUTE = 'member'
- DEFAULT_ATTRIBUTE_IGNORE = []
NotFound = exception.ProjectNotFound
notfound_arg = 'project_id' # NOTE(yorik-sar): while options_name = tenant
options_name = 'tenant'
- attribute_mapping = {'name': 'ou',
- 'description': 'description',
- 'tenantId': 'cn',
- 'enabled': 'enabled',
- 'domain_id': 'domain_id'}
+ attribute_options_names = {'name': 'name',
+ 'description': 'desc',
+ 'enabled': 'enabled',
+ 'domain_id': 'domain_id'}
model = models.Project
def __init__(self, conf):
super(ProjectApi, self).__init__(conf)
- self.attribute_mapping['name'] = conf.ldap.tenant_name_attribute
- self.attribute_mapping['description'] = conf.ldap.tenant_desc_attribute
- self.attribute_mapping['enabled'] = conf.ldap.tenant_enabled_attribute
- self.attribute_mapping['domain_id'] = (
- conf.ldap.tenant_domain_id_attribute)
self.member_attribute = (getattr(conf.ldap, 'tenant_member_attribute')
or self.DEFAULT_MEMBER_ATTRIBUTE)
- self.attribute_ignore = (getattr(conf.ldap, 'tenant_attribute_ignore')
- or self.DEFAULT_ATTRIBUTE_IGNORE)
def create(self, values):
self.affirm_unique(values)
@@ -381,21 +372,15 @@ class RoleApi(common_ldap.BaseLdap):
DEFAULT_STRUCTURAL_CLASSES = []
DEFAULT_OBJECTCLASS = 'organizationalRole'
DEFAULT_MEMBER_ATTRIBUTE = 'roleOccupant'
- DEFAULT_ATTRIBUTE_IGNORE = []
NotFound = exception.RoleNotFound
options_name = 'role'
- attribute_mapping = {'name': 'ou',
- #'serviceId': 'service_id',
- }
+ attribute_options_names = {'name': 'name'}
model = models.Role
def __init__(self, conf):
super(RoleApi, self).__init__(conf)
- self.attribute_mapping['name'] = conf.ldap.role_name_attribute
self.member_attribute = (getattr(conf.ldap, 'role_member_attribute')
or self.DEFAULT_MEMBER_ATTRIBUTE)
- self.attribute_ignore = (getattr(conf.ldap, 'role_attribute_ignore')
- or self.DEFAULT_ATTRIBUTE_IGNORE)
def get(self, id, filter=None):
model = super(RoleApi, self).get(id, filter)
diff --git a/keystone/assignment/core.py b/keystone/assignment/core.py
index 64edb3fa..0a2ee681 100644
--- a/keystone/assignment/core.py
+++ b/keystone/assignment/core.py
@@ -178,9 +178,23 @@ class Manager(manager.Manager):
keystone.exception.UserNotFound
"""
- self.driver.add_role_to_user_and_project(user_id,
- tenant_id,
- config.CONF.member_role_id)
+ try:
+ self.driver.add_role_to_user_and_project(
+ user_id,
+ tenant_id,
+ config.CONF.member_role_id)
+ except exception.RoleNotFound:
+ LOG.info(_("Creating the default role %s "
+ "because it does not exist.") %
+ config.CONF.member_role_id)
+ role = {'id': CONF.member_role_id,
+ 'name': CONF.member_role_name}
+ self.driver.create_role(config.CONF.member_role_id, role)
+ #now that default role exists, the add should succeed
+ self.driver.add_role_to_user_and_project(
+ user_id,
+ tenant_id,
+ config.CONF.member_role_id)
def remove_user_from_project(self, tenant_id, user_id):
"""Remove user from a tenant
diff --git a/keystone/catalog/backends/sql.py b/keystone/catalog/backends/sql.py
index 1dad5a80..d7b2123a 100644
--- a/keystone/catalog/backends/sql.py
+++ b/keystone/catalog/backends/sql.py
@@ -32,6 +32,7 @@ class Service(sql.ModelBase, sql.DictBase):
id = sql.Column(sql.String(64), primary_key=True)
type = sql.Column(sql.String(255))
extra = sql.Column(sql.JsonBlob())
+ endpoints = sql.relationship("Endpoint", backref="service")
class Endpoint(sql.ModelBase, sql.DictBase):
@@ -40,12 +41,12 @@ class Endpoint(sql.ModelBase, sql.DictBase):
'legacy_endpoint_id']
id = sql.Column(sql.String(64), primary_key=True)
legacy_endpoint_id = sql.Column(sql.String(64))
- interface = sql.Column(sql.String(8), primary_key=True)
- region = sql.Column('region', sql.String(255))
+ interface = sql.Column(sql.String(8), nullable=False)
+ region = sql.Column(sql.String(255))
service_id = sql.Column(sql.String(64),
sql.ForeignKey('service.id'),
nullable=False)
- url = sql.Column(sql.Text())
+ url = sql.Column(sql.Text(), nullable=False)
extra = sql.Column(sql.JsonBlob())
@@ -150,28 +151,26 @@ class Catalog(sql.Base, catalog.Driver):
d.update({'tenant_id': tenant_id,
'user_id': user_id})
+ session = self.get_session()
+ endpoints = (session.query(Endpoint).
+ options(sql.joinedload(Endpoint.service)).
+ all())
+
catalog = {}
- services = {}
- for endpoint in self.list_endpoints():
- # look up the service
- services.setdefault(
- endpoint['service_id'],
- self.get_service(endpoint['service_id']))
- service = services[endpoint['service_id']]
-
- # add the endpoint to the catalog if it's not already there
- catalog.setdefault(endpoint['region'], {})
- catalog[endpoint['region']].setdefault(
- service['type'], {
- 'id': endpoint['id'],
- 'name': service['name'],
- 'publicURL': '', # this may be overridden, but must exist
- })
-
- # add the interface's url
- url = core.format_url(endpoint.get('url'), d)
+
+ for endpoint in endpoints:
+ region = endpoint['region']
+ service_type = endpoint.service['type']
+ default_service = {
+ 'id': endpoint['id'],
+ 'name': endpoint.service['name'],
+ 'publicURL': ''
+ }
+ catalog.setdefault(region, {})
+ catalog[region].setdefault(service_type, default_service)
+ url = core.format_url(endpoint['url'], d)
interface_url = '%sURL' % endpoint['interface']
- catalog[endpoint['region']][service['type']][interface_url] = url
+ catalog[region][service_type][interface_url] = url
return catalog
@@ -180,27 +179,19 @@ class Catalog(sql.Base, catalog.Driver):
d.update({'tenant_id': tenant_id,
'user_id': user_id})
- services = {}
- for endpoint in self.list_endpoints():
- # look up the service
- service_id = endpoint['service_id']
- services.setdefault(
- service_id,
- self.get_service(service_id))
- service = services[service_id]
+ session = self.get_session()
+ services = (session.query(Service).
+ options(sql.joinedload(Service.endpoints)).
+ all())
+
+ def make_v3_endpoint(endpoint):
del endpoint['service_id']
endpoint['url'] = core.format_url(endpoint['url'], d)
- if 'endpoints' in services[service_id]:
- services[service_id]['endpoints'].append(endpoint)
- else:
- services[service_id]['endpoints'] = [endpoint]
-
- catalog = []
- for service_id, service in services.iteritems():
- formatted_service = {}
- formatted_service['id'] = service['id']
- formatted_service['type'] = service['type']
- formatted_service['endpoints'] = service['endpoints']
- catalog.append(formatted_service)
+ return endpoint
+
+ catalog = [{'endpoints': [make_v3_endpoint(ep.to_dict())
+ for ep in svc.endpoints],
+ 'id': svc.id,
+ 'type': svc.type} for svc in services]
return catalog
diff --git a/keystone/cli.py b/keystone/cli.py
index 21d2ad40..18c095ce 100644
--- a/keystone/cli.py
+++ b/keystone/cli.py
@@ -20,12 +20,15 @@ import grp
import os
import pwd
+from migrate import exceptions
+
from oslo.config import cfg
import pbr.version
from keystone.common import openssl
from keystone.common.sql import migration
from keystone import config
+from keystone import contrib
from keystone.openstack.common import importutils
from keystone.openstack.common import jsonutils
from keystone import token
@@ -57,14 +60,35 @@ class DbSync(BaseApp):
'version. If not provided, db_sync will '
'migrate the database to the latest known '
'version.'))
+ parser.add_argument('--extension', default=None,
+ help=('Migrate the database for the specified '
+ 'extension. If not provided, db_sync will '
+ 'migrate the common repository.'))
+
return parser
@staticmethod
def main():
- for k in ['identity', 'catalog', 'policy', 'token', 'credential']:
- driver = importutils.import_object(getattr(CONF, k).driver)
- if hasattr(driver, 'db_sync'):
- driver.db_sync(CONF.command.version)
+ version = CONF.command.version
+ extension = CONF.command.extension
+ if not extension:
+ migration.db_sync(version=version)
+ else:
+ package_name = "%s.%s.migrate_repo" % (contrib.__name__, extension)
+ try:
+ package = importutils.import_module(package_name)
+ repo_path = os.path.abspath(os.path.dirname(package.__file__))
+ except ImportError:
+ print _("This extension does not provide migrations.")
+ exit(0)
+ try:
+ # Register the repo with the version control API
+ # If it already knows about the repo, it will throw
+ # an exception that we can safely ignore
+ migration.db_version_control(version=None, repo_path=repo_path)
+ except exceptions.DatabaseAlreadyControlledError:
+ pass
+ migration.db_sync(version=None, repo_path=repo_path)
class DbVersion(BaseApp):
@@ -72,9 +96,29 @@ class DbVersion(BaseApp):
name = 'db_version'
+ @classmethod
+ def add_argument_parser(cls, subparsers):
+ parser = super(DbVersion, cls).add_argument_parser(subparsers)
+ parser.add_argument('--extension', default=None,
+ help=('Migrate the database for the specified '
+ 'extension. If not provided, db_sync will '
+ 'migrate the common repository.'))
+
@staticmethod
def main():
- print(migration.db_version())
+ extension = CONF.command.extension
+ if extension:
+ try:
+ package_name = ("%s.%s.migrate_repo" %
+ (contrib.__name__, extension))
+ package = importutils.import_module(package_name)
+ repo_path = os.path.abspath(os.path.dirname(package.__file__))
+ print(migration.db_version(repo_path))
+ except ImportError:
+ print _("This extension does not provide migrations.")
+ exit(1)
+ else:
+ print(migration.db_version())
class BaseCertificateSetup(BaseApp):
diff --git a/keystone/common/config.py b/keystone/common/config.py
index 10c47a35..cd525369 100644
--- a/keystone/common/config.py
+++ b/keystone/common/config.py
@@ -210,6 +210,7 @@ def configure():
# identity
register_str('default_domain_id', group='identity', default='default')
+ register_int('max_password_length', group='identity', default=4096)
# trust
register_bool('enabled', group='trust', default=True)
diff --git a/keystone/common/ldap/core.py b/keystone/common/ldap/core.py
index 7a2dfee7..39ea78de 100644
--- a/keystone/common/ldap/core.py
+++ b/keystone/common/ldap/core.py
@@ -114,7 +114,7 @@ class BaseLdap(object):
notfound_arg = None
options_name = None
model = None
- attribute_mapping = {}
+ attribute_options_names = {}
attribute_ignore = []
tree_dn = None
@@ -129,6 +129,7 @@ class BaseLdap(object):
self.tls_cacertfile = conf.ldap.tls_cacertfile
self.tls_cacertdir = conf.ldap.tls_cacertdir
self.tls_req_cert = parse_tls_cert(conf.ldap.tls_req_cert)
+ self.attribute_mapping = {}
if self.options_name is not None:
self.suffix = conf.ldap.suffix
@@ -145,6 +146,10 @@ class BaseLdap(object):
self.object_class = (getattr(conf.ldap, objclass)
or self.DEFAULT_OBJECTCLASS)
+ for k, v in self.attribute_options_names.iteritems():
+ v = '%s_%s_attribute' % (self.options_name, v)
+ self.attribute_mapping[k] = getattr(conf.ldap, v)
+
attr_mapping_opt = ('%s_additional_attribute_mapping' %
self.options_name)
attr_mapping = (getattr(conf.ldap, attr_mapping_opt)
@@ -167,6 +172,10 @@ class BaseLdap(object):
if self.notfound_arg is None:
self.notfound_arg = self.options_name + '_id'
+
+ attribute_ignore = '%s_attribute_ignore' % self.options_name
+ self.attribute_ignore = getattr(conf.ldap, attribute_ignore)
+
self.use_dumb_member = getattr(conf.ldap, 'use_dumb_member')
self.dumb_member = (getattr(conf.ldap, 'dumb_member') or
self.DUMB_MEMBER_DN)
diff --git a/keystone/common/sql/core.py b/keystone/common/sql/core.py
index 2d3114f2..67863588 100644
--- a/keystone/common/sql/core.py
+++ b/keystone/common/sql/core.py
@@ -45,6 +45,7 @@ ModelBase = declarative.declarative_base()
# For exporting to other modules
Column = sql.Column
+Index = sql.Index
String = sql.String
ForeignKey = sql.ForeignKey
DateTime = sql.DateTime
@@ -54,6 +55,8 @@ NotFound = sql.orm.exc.NoResultFound
Boolean = sql.Boolean
Text = sql.Text
UniqueConstraint = sql.UniqueConstraint
+relationship = sql.orm.relationship
+joinedload = sql.orm.joinedload
def initialize_decorator(init):
@@ -179,6 +182,8 @@ class DictBase(object):
setattr(self, key, value)
def __getitem__(self, key):
+ if key in self.extra:
+ return self.extra[key]
return getattr(self, key)
def get(self, key, default=None):
diff --git a/keystone/common/sql/migrate_repo/versions/031_drop_credential_indexes.py b/keystone/common/sql/migrate_repo/versions/031_drop_credential_indexes.py
new file mode 100644
index 00000000..89ca04f0
--- /dev/null
+++ b/keystone/common/sql/migrate_repo/versions/031_drop_credential_indexes.py
@@ -0,0 +1,40 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2013 OpenStack Foundation
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import sqlalchemy
+
+
+def upgrade(migrate_engine):
+ #This migration is relevant only for mysql because for all other
+ #migrate engines these indexes were successfully dropped.
+ if migrate_engine.name != 'mysql':
+ return
+ meta = sqlalchemy.MetaData(bind=migrate_engine)
+ table = sqlalchemy.Table('credential', meta, autoload=True)
+ for index in table.indexes:
+ index.drop()
+
+
+def downgrade(migrate_engine):
+ if migrate_engine.name != 'mysql':
+ return
+ meta = sqlalchemy.MetaData(bind=migrate_engine)
+ table = sqlalchemy.Table('credential', meta, autoload=True)
+ index = sqlalchemy.Index('user_id', table.c['user_id'])
+ index.create()
+ index = sqlalchemy.Index('credential_project_id_fkey',
+ table.c['project_id'])
+ index.create()
diff --git a/keystone/common/sql/migration.py b/keystone/common/sql/migration.py
index 86e0254c..3cb9cd63 100644
--- a/keystone/common/sql/migration.py
+++ b/keystone/common/sql/migration.py
@@ -39,39 +39,51 @@ except ImportError:
sys.exit('python-migrate is not installed. Exiting.')
-def db_sync(version=None):
+def migrate_repository(version, current_version, repo_path):
+ if version is None or version > current_version:
+ result = versioning_api.upgrade(CONF.sql.connection,
+ repo_path, version)
+ else:
+ result = versioning_api.downgrade(
+ CONF.sql.connection, repo_path, version)
+ return result
+
+
+def db_sync(version=None, repo_path=None):
if version is not None:
try:
version = int(version)
except ValueError:
raise Exception(_('version should be an integer'))
+ if repo_path is None:
+ repo_path = find_migrate_repo()
+ current_version = db_version(repo_path=repo_path)
+ return migrate_repository(version, current_version, repo_path)
- current_version = db_version()
- repo_path = _find_migrate_repo()
- if version is None or version > current_version:
- return versioning_api.upgrade(CONF.sql.connection, repo_path, version)
- else:
- return versioning_api.downgrade(
- CONF.sql.connection, repo_path, version)
-
-def db_version():
- repo_path = _find_migrate_repo()
+def db_version(repo_path=None):
+ if repo_path is None:
+ repo_path = find_migrate_repo()
try:
return versioning_api.db_version(CONF.sql.connection, repo_path)
except versioning_exceptions.DatabaseNotControlledError:
return db_version_control(0)
-def db_version_control(version=None):
- repo_path = _find_migrate_repo()
+def db_version_control(version=None, repo_path=None):
+ if repo_path is None:
+ repo_path = find_migrate_repo()
versioning_api.version_control(CONF.sql.connection, repo_path, version)
return version
-def _find_migrate_repo():
+def find_migrate_repo(package=None):
"""Get the path for the migrate repository."""
- path = os.path.join(os.path.abspath(os.path.dirname(__file__)),
+ if package is None:
+ file = __file__
+ else:
+ file = package.__file__
+ path = os.path.join(os.path.abspath(os.path.dirname(file)),
'migrate_repo')
assert os.path.exists(path)
return path
diff --git a/keystone/common/utils.py b/keystone/common/utils.py
index fd2d7567..9966ee67 100644
--- a/keystone/common/utils.py
+++ b/keystone/common/utils.py
@@ -36,8 +36,6 @@ config.register_int('crypt_strength', default=40000)
LOG = logging.getLogger(__name__)
-MAX_PASSWORD_LENGTH = 4096
-
def read_cached_file(filename, cache_info, reload_func=None):
"""Read from a file if it has been modified.
@@ -68,12 +66,13 @@ class SmarterEncoder(json.JSONEncoder):
def trunc_password(password):
- """Truncate passwords to the MAX_PASSWORD_LENGTH."""
+ """Truncate passwords to the max_length."""
+ max_length = CONF.identity.max_password_length
try:
- if len(password) > MAX_PASSWORD_LENGTH:
- return password[:MAX_PASSWORD_LENGTH]
- else:
- return password
+ if len(password) > max_length:
+ LOG.warning(
+ _('Truncating user password to %s characters.') % max_length)
+ return password[:max_length]
except TypeError:
raise exception.ValidationError(attribute='string', target='password')
diff --git a/keystone/contrib/example/__init__.py b/keystone/contrib/example/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/keystone/contrib/example/__init__.py
diff --git a/keystone/contrib/example/migrate_repo/__init__.py b/keystone/contrib/example/migrate_repo/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/keystone/contrib/example/migrate_repo/__init__.py
diff --git a/keystone/contrib/example/migrate_repo/migrate.cfg b/keystone/contrib/example/migrate_repo/migrate.cfg
new file mode 100644
index 00000000..5b1b1c0a
--- /dev/null
+++ b/keystone/contrib/example/migrate_repo/migrate.cfg
@@ -0,0 +1,25 @@
+[db_settings]
+# Used to identify which repository this database is versioned under.
+# You can use the name of your project.
+repository_id=example
+
+# The name of the database table used to track the schema version.
+# This name shouldn't already be used by your project.
+# If this is changed once a database is under version control, you'll need to
+# change the table name in each database too.
+version_table=migrate_version
+
+# When committing a change script, Migrate will attempt to generate the
+# sql for all supported databases; normally, if one of them fails - probably
+# because you don't have that database installed - it is ignored and the
+# commit continues, perhaps ending successfully.
+# Databases in this list MUST compile successfully during a commit, or the
+# entire commit will fail. List the databases your application will actually
+# be using to ensure your updates to that database work properly.
+# This must be a list; example: ['postgres','sqlite']
+required_dbs=[]
+
+# When creating new change scripts, Migrate will stamp the new script with
+# a version number. By default this is latest_version + 1. You can set this
+# to 'true' to tell Migrate to use the UTC timestamp instead.
+use_timestamp_numbering=False
diff --git a/keystone/contrib/example/migrate_repo/versions/001_example_table.py b/keystone/contrib/example/migrate_repo/versions/001_example_table.py
new file mode 100644
index 00000000..bb2203d3
--- /dev/null
+++ b/keystone/contrib/example/migrate_repo/versions/001_example_table.py
@@ -0,0 +1,45 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012 OpenStack LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import sqlalchemy as sql
+
+
+def upgrade(migrate_engine):
+ # Upgrade operations go here. Don't create your own engine; bind
+ # migrate_engine to your metadata
+ meta = sql.MetaData()
+ meta.bind = migrate_engine
+
+ # catalog
+
+ service_table = sql.Table(
+ 'example',
+ meta,
+ sql.Column('id', sql.String(64), primary_key=True),
+ sql.Column('type', sql.String(255)),
+ sql.Column('extra', sql.Text()))
+ service_table.create(migrate_engine, checkfirst=True)
+
+
+def downgrade(migrate_engine):
+ # Operations to reverse the above upgrade go here.
+ meta = sql.MetaData()
+ meta.bind = migrate_engine
+
+ tables = ['example']
+ for t in tables:
+ table = sql.Table(t, meta, autoload=True)
+ table.drop(migrate_engine, checkfirst=True)
diff --git a/keystone/contrib/example/migrate_repo/versions/__init__.py b/keystone/contrib/example/migrate_repo/versions/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/keystone/contrib/example/migrate_repo/versions/__init__.py
diff --git a/keystone/identity/backends/ldap.py b/keystone/identity/backends/ldap.py
index 91ea1e41..a359c63f 100644
--- a/keystone/identity/backends/ldap.py
+++ b/keystone/identity/backends/ldap.py
@@ -210,29 +210,20 @@ class UserApi(common_ldap.EnabledEmuMixIn, common_ldap.BaseLdap):
DEFAULT_STRUCTURAL_CLASSES = ['person']
DEFAULT_ID_ATTR = 'cn'
DEFAULT_OBJECTCLASS = 'inetOrgPerson'
- DEFAULT_ATTRIBUTE_IGNORE = ['tenant_id', 'tenants']
NotFound = exception.UserNotFound
options_name = 'user'
- attribute_mapping = {'password': 'userPassword',
- 'email': 'mail',
- 'name': 'sn',
- 'enabled': 'enabled',
- 'domain_id': 'domain_id'}
+ attribute_options_names = {'password': 'pass',
+ 'email': 'mail',
+ 'name': 'name',
+ 'enabled': 'enabled',
+ 'domain_id': 'domain_id'}
model = models.User
def __init__(self, conf):
super(UserApi, self).__init__(conf)
- self.attribute_mapping['name'] = conf.ldap.user_name_attribute
- self.attribute_mapping['email'] = conf.ldap.user_mail_attribute
- self.attribute_mapping['password'] = conf.ldap.user_pass_attribute
- self.attribute_mapping['enabled'] = conf.ldap.user_enabled_attribute
- self.attribute_mapping['domain_id'] = (
- conf.ldap.user_domain_id_attribute)
self.enabled_mask = conf.ldap.user_enabled_mask
self.enabled_default = conf.ldap.user_enabled_default
- self.attribute_ignore = (getattr(conf.ldap, 'user_attribute_ignore')
- or self.DEFAULT_ATTRIBUTE_IGNORE)
def _ldap_res_to_model(self, res):
obj = super(UserApi, self)._ldap_res_to_model(res)
@@ -277,25 +268,17 @@ class GroupApi(common_ldap.BaseLdap):
DEFAULT_OBJECTCLASS = 'groupOfNames'
DEFAULT_ID_ATTR = 'cn'
DEFAULT_MEMBER_ATTRIBUTE = 'member'
- DEFAULT_ATTRIBUTE_IGNORE = []
NotFound = exception.GroupNotFound
options_name = 'group'
- attribute_mapping = {'name': 'ou',
- 'description': 'description',
- 'groupId': 'cn',
- 'domain_id': 'domain_id'}
+ attribute_options_names = {'description': 'desc',
+ 'name': 'name',
+ 'domain_id': 'domain_id'}
model = models.Group
def __init__(self, conf):
super(GroupApi, self).__init__(conf)
- self.attribute_mapping['name'] = conf.ldap.group_name_attribute
- self.attribute_mapping['description'] = conf.ldap.group_desc_attribute
- self.attribute_mapping['domain_id'] = (
- conf.ldap.group_domain_id_attribute)
self.member_attribute = (getattr(conf.ldap, 'group_member_attribute')
or self.DEFAULT_MEMBER_ATTRIBUTE)
- self.attribute_ignore = (getattr(conf.ldap, 'group_attribute_ignore')
- or self.DEFAULT_ATTRIBUTE_IGNORE)
def create(self, values):
self.affirm_unique(values)
diff --git a/keystone/openstack/common/gettextutils.py b/keystone/openstack/common/gettextutils.py
index 55ba3387..ed085370 100644
--- a/keystone/openstack/common/gettextutils.py
+++ b/keystone/openstack/common/gettextutils.py
@@ -1,6 +1,7 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2012 Red Hat, Inc.
+# Copyright 2013 IBM Corp.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
@@ -23,18 +24,27 @@ Usual usage in an openstack.common module:
from keystone.openstack.common.gettextutils import _
"""
+import copy
import gettext
+import logging.handlers
import os
+import re
+import UserString
+
+from babel import localedata
+import six
_localedir = os.environ.get('keystone'.upper() + '_LOCALEDIR')
_t = gettext.translation('keystone', localedir=_localedir, fallback=True)
+_AVAILABLE_LANGUAGES = []
+
def _(msg):
return _t.ugettext(msg)
-def install(domain):
+def install(domain, lazy=False):
"""Install a _() function using the given translation domain.
Given a translation domain, install a _() function using gettext's
@@ -44,7 +54,252 @@ def install(domain):
overriding the default localedir (e.g. /usr/share/locale) using
a translation-domain-specific environment variable (e.g.
NOVA_LOCALEDIR).
+
+ :param domain: the translation domain
+ :param lazy: indicates whether or not to install the lazy _() function.
+ The lazy _() introduces a way to do deferred translation
+ of messages by installing a _ that builds Message objects,
+ instead of strings, which can then be lazily translated into
+ any available locale.
+ """
+ if lazy:
+ # NOTE(mrodden): Lazy gettext functionality.
+ #
+ # The following introduces a deferred way to do translations on
+ # messages in OpenStack. We override the standard _() function
+ # and % (format string) operation to build Message objects that can
+ # later be translated when we have more information.
+ #
+ # Also included below is an example LocaleHandler that translates
+ # Messages to an associated locale, effectively allowing many logs,
+ # each with their own locale.
+
+ def _lazy_gettext(msg):
+ """Create and return a Message object.
+
+ Lazy gettext function for a given domain, it is a factory method
+ for a project/module to get a lazy gettext function for its own
+ translation domain (i.e. nova, glance, cinder, etc.)
+
+ Message encapsulates a string so that we can translate
+ it later when needed.
+ """
+ return Message(msg, domain)
+
+ import __builtin__
+ __builtin__.__dict__['_'] = _lazy_gettext
+ else:
+ localedir = '%s_LOCALEDIR' % domain.upper()
+ gettext.install(domain,
+ localedir=os.environ.get(localedir),
+ unicode=True)
+
+
+class Message(UserString.UserString, object):
+ """Class used to encapsulate translatable messages."""
+ def __init__(self, msg, domain):
+ # _msg is the gettext msgid and should never change
+ self._msg = msg
+ self._left_extra_msg = ''
+ self._right_extra_msg = ''
+ self.params = None
+ self.locale = None
+ self.domain = domain
+
+ @property
+ def data(self):
+ # NOTE(mrodden): this should always resolve to a unicode string
+ # that best represents the state of the message currently
+
+ localedir = os.environ.get(self.domain.upper() + '_LOCALEDIR')
+ if self.locale:
+ lang = gettext.translation(self.domain,
+ localedir=localedir,
+ languages=[self.locale],
+ fallback=True)
+ else:
+ # use system locale for translations
+ lang = gettext.translation(self.domain,
+ localedir=localedir,
+ fallback=True)
+
+ full_msg = (self._left_extra_msg +
+ lang.ugettext(self._msg) +
+ self._right_extra_msg)
+
+ if self.params is not None:
+ full_msg = full_msg % self.params
+
+ return six.text_type(full_msg)
+
+ def _save_dictionary_parameter(self, dict_param):
+ full_msg = self.data
+ # look for %(blah) fields in string;
+ # ignore %% and deal with the
+ # case where % is first character on the line
+ keys = re.findall('(?:[^%]|^)%\((\w*)\)[a-z]', full_msg)
+
+ # if we don't find any %(blah) blocks but have a %s
+ if not keys and re.findall('(?:[^%]|^)%[a-z]', full_msg):
+ # apparently the full dictionary is the parameter
+ params = copy.deepcopy(dict_param)
+ else:
+ params = {}
+ for key in keys:
+ try:
+ params[key] = copy.deepcopy(dict_param[key])
+ except TypeError:
+ # cast uncopyable thing to unicode string
+ params[key] = unicode(dict_param[key])
+
+ return params
+
+ def _save_parameters(self, other):
+ # we check for None later to see if
+ # we actually have parameters to inject,
+ # so encapsulate if our parameter is actually None
+ if other is None:
+ self.params = (other, )
+ elif isinstance(other, dict):
+ self.params = self._save_dictionary_parameter(other)
+ else:
+ # fallback to casting to unicode,
+ # this will handle the problematic python code-like
+ # objects that cannot be deep-copied
+ try:
+ self.params = copy.deepcopy(other)
+ except TypeError:
+ self.params = unicode(other)
+
+ return self
+
+ # overrides to be more string-like
+ def __unicode__(self):
+ return self.data
+
+ def __str__(self):
+ return self.data.encode('utf-8')
+
+ def __getstate__(self):
+ to_copy = ['_msg', '_right_extra_msg', '_left_extra_msg',
+ 'domain', 'params', 'locale']
+ new_dict = self.__dict__.fromkeys(to_copy)
+ for attr in to_copy:
+ new_dict[attr] = copy.deepcopy(self.__dict__[attr])
+
+ return new_dict
+
+ def __setstate__(self, state):
+ for (k, v) in state.items():
+ setattr(self, k, v)
+
+ # operator overloads
+ def __add__(self, other):
+ copied = copy.deepcopy(self)
+ copied._right_extra_msg += other.__str__()
+ return copied
+
+ def __radd__(self, other):
+ copied = copy.deepcopy(self)
+ copied._left_extra_msg += other.__str__()
+ return copied
+
+ def __mod__(self, other):
+ # do a format string to catch and raise
+ # any possible KeyErrors from missing parameters
+ self.data % other
+ copied = copy.deepcopy(self)
+ return copied._save_parameters(other)
+
+ def __mul__(self, other):
+ return self.data * other
+
+ def __rmul__(self, other):
+ return other * self.data
+
+ def __getitem__(self, key):
+ return self.data[key]
+
+ def __getslice__(self, start, end):
+ return self.data.__getslice__(start, end)
+
+ def __getattribute__(self, name):
+ # NOTE(mrodden): handle lossy operations that we can't deal with yet
+ # These override the UserString implementation, since UserString
+ # uses our __class__ attribute to try and build a new message
+ # after running the inner data string through the operation.
+ # At that point, we have lost the gettext message id and can just
+ # safely resolve to a string instead.
+ ops = ['capitalize', 'center', 'decode', 'encode',
+ 'expandtabs', 'ljust', 'lstrip', 'replace', 'rjust', 'rstrip',
+ 'strip', 'swapcase', 'title', 'translate', 'upper', 'zfill']
+ if name in ops:
+ return getattr(self.data, name)
+ else:
+ return UserString.UserString.__getattribute__(self, name)
+
+
+def get_available_languages(domain):
+ """Lists the available languages for the given translation domain.
+
+ :param domain: the domain to get languages for
"""
- gettext.install(domain,
- localedir=os.environ.get(domain.upper() + '_LOCALEDIR'),
- unicode=True)
+ if _AVAILABLE_LANGUAGES:
+ return _AVAILABLE_LANGUAGES
+
+ localedir = '%s_LOCALEDIR' % domain.upper()
+ find = lambda x: gettext.find(domain,
+ localedir=os.environ.get(localedir),
+ languages=[x])
+
+ # NOTE(mrodden): en_US should always be available (and first in case
+ # order matters) since our in-line message strings are en_US
+ _AVAILABLE_LANGUAGES.append('en_US')
+ # NOTE(luisg): Babel <1.0 used a function called list(), which was
+ # renamed to locale_identifiers() in >=1.0, the requirements master list
+ # requires >=0.9.6, uncapped, so defensively work with both. We can remove
+ # this check when the master list updates to >=1.0, and all projects udpate
+ list_identifiers = (getattr(localedata, 'list', None) or
+ getattr(localedata, 'locale_identifiers'))
+ locale_identifiers = list_identifiers()
+ for i in locale_identifiers:
+ if find(i) is not None:
+ _AVAILABLE_LANGUAGES.append(i)
+ return _AVAILABLE_LANGUAGES
+
+
+def get_localized_message(message, user_locale):
+ """Gets a localized version of the given message in the given locale."""
+ if (isinstance(message, Message)):
+ if user_locale:
+ message.locale = user_locale
+ return unicode(message)
+ else:
+ return message
+
+
+class LocaleHandler(logging.Handler):
+ """Handler that can have a locale associated to translate Messages.
+
+ A quick example of how to utilize the Message class above.
+ LocaleHandler takes a locale and a target logging.Handler object
+ to forward LogRecord objects to after translating the internal Message.
+ """
+
+ def __init__(self, locale, target):
+ """Initialize a LocaleHandler
+
+ :param locale: locale to use for translating messages
+ :param target: logging.Handler object to forward
+ LogRecord objects to after translation
+ """
+ logging.Handler.__init__(self)
+ self.locale = locale
+ self.target = target
+
+ def emit(self, record):
+ if isinstance(record.msg, Message):
+ # set the locale and resolve to a string
+ record.msg.locale = self.locale
+
+ self.target.emit(record)
diff --git a/keystone/openstack/common/importutils.py b/keystone/openstack/common/importutils.py
index 3bd277f4..7a303f93 100644
--- a/keystone/openstack/common/importutils.py
+++ b/keystone/openstack/common/importutils.py
@@ -24,7 +24,7 @@ import traceback
def import_class(import_str):
- """Returns a class from a string including module and class"""
+ """Returns a class from a string including module and class."""
mod_str, _sep, class_str = import_str.rpartition('.')
try:
__import__(mod_str)
@@ -41,8 +41,9 @@ def import_object(import_str, *args, **kwargs):
def import_object_ns(name_space, import_str, *args, **kwargs):
- """
- Import a class and return an instance of it, first by trying
+ """Tries to import object from default namespace.
+
+ Imports a class and return an instance of it, first by trying
to find the class in a default namespace, then failing back to
a full path if not found in the default namespace.
"""
diff --git a/keystone/openstack/common/jsonutils.py b/keystone/openstack/common/jsonutils.py
index d73e4c26..ecea09bb 100644
--- a/keystone/openstack/common/jsonutils.py
+++ b/keystone/openstack/common/jsonutils.py
@@ -38,11 +38,24 @@ import functools
import inspect
import itertools
import json
+import types
import xmlrpclib
+import netaddr
+import six
+
from keystone.openstack.common import timeutils
+_nasty_type_tests = [inspect.ismodule, inspect.isclass, inspect.ismethod,
+ inspect.isfunction, inspect.isgeneratorfunction,
+ inspect.isgenerator, inspect.istraceback, inspect.isframe,
+ inspect.iscode, inspect.isbuiltin, inspect.isroutine,
+ inspect.isabstract]
+
+_simple_types = (types.NoneType, int, basestring, bool, float, long)
+
+
def to_primitive(value, convert_instances=False, convert_datetime=True,
level=0, max_depth=3):
"""Convert a complex object into primitives.
@@ -58,19 +71,32 @@ def to_primitive(value, convert_instances=False, convert_datetime=True,
Therefore, convert_instances=True is lossy ... be aware.
"""
- nasty = [inspect.ismodule, inspect.isclass, inspect.ismethod,
- inspect.isfunction, inspect.isgeneratorfunction,
- inspect.isgenerator, inspect.istraceback, inspect.isframe,
- inspect.iscode, inspect.isbuiltin, inspect.isroutine,
- inspect.isabstract]
- for test in nasty:
- if test(value):
- return unicode(value)
-
- # value of itertools.count doesn't get caught by inspects
- # above and results in infinite loop when list(value) is called.
+ # handle obvious types first - order of basic types determined by running
+ # full tests on nova project, resulting in the following counts:
+ # 572754 <type 'NoneType'>
+ # 460353 <type 'int'>
+ # 379632 <type 'unicode'>
+ # 274610 <type 'str'>
+ # 199918 <type 'dict'>
+ # 114200 <type 'datetime.datetime'>
+ # 51817 <type 'bool'>
+ # 26164 <type 'list'>
+ # 6491 <type 'float'>
+ # 283 <type 'tuple'>
+ # 19 <type 'long'>
+ if isinstance(value, _simple_types):
+ return value
+
+ if isinstance(value, datetime.datetime):
+ if convert_datetime:
+ return timeutils.strtime(value)
+ else:
+ return value
+
+ # value of itertools.count doesn't get caught by nasty_type_tests
+ # and results in infinite loop when list(value) is called.
if type(value) == itertools.count:
- return unicode(value)
+ return six.text_type(value)
# FIXME(vish): Workaround for LP bug 852095. Without this workaround,
# tests that raise an exception in a mocked method that
@@ -91,17 +117,18 @@ def to_primitive(value, convert_instances=False, convert_datetime=True,
convert_datetime=convert_datetime,
level=level,
max_depth=max_depth)
+ if isinstance(value, dict):
+ return dict((k, recursive(v)) for k, v in value.iteritems())
+ elif isinstance(value, (list, tuple)):
+ return [recursive(lv) for lv in value]
+
# It's not clear why xmlrpclib created their own DateTime type, but
# for our purposes, make it a datetime type which is explicitly
# handled
if isinstance(value, xmlrpclib.DateTime):
value = datetime.datetime(*tuple(value.timetuple())[:6])
- if isinstance(value, (list, tuple)):
- return [recursive(v) for v in value]
- elif isinstance(value, dict):
- return dict((k, recursive(v)) for k, v in value.iteritems())
- elif convert_datetime and isinstance(value, datetime.datetime):
+ if convert_datetime and isinstance(value, datetime.datetime):
return timeutils.strtime(value)
elif hasattr(value, 'iteritems'):
return recursive(dict(value.iteritems()), level=level + 1)
@@ -111,12 +138,16 @@ def to_primitive(value, convert_instances=False, convert_datetime=True,
# Likely an instance of something. Watch for cycles.
# Ignore class member vars.
return recursive(value.__dict__, level=level + 1)
+ elif isinstance(value, netaddr.IPAddress):
+ return six.text_type(value)
else:
+ if any(test(value) for test in _nasty_type_tests):
+ return six.text_type(value)
return value
except TypeError:
# Class objects are tricky since they may define something like
# __iter__ defined but it isn't callable as list().
- return unicode(value)
+ return six.text_type(value)
def dumps(value, default=to_primitive, **kwargs):
diff --git a/keystone/openstack/common/local.py b/keystone/openstack/common/local.py
new file mode 100644
index 00000000..e82f17d0
--- /dev/null
+++ b/keystone/openstack/common/local.py
@@ -0,0 +1,47 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 OpenStack Foundation.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Local storage of variables using weak references"""
+
+import threading
+import weakref
+
+
+class WeakLocal(threading.local):
+ def __getattribute__(self, attr):
+ rval = super(WeakLocal, self).__getattribute__(attr)
+ if rval:
+ # NOTE(mikal): this bit is confusing. What is stored is a weak
+ # reference, not the value itself. We therefore need to lookup
+ # the weak reference and return the inner value here.
+ rval = rval()
+ return rval
+
+ def __setattr__(self, attr, value):
+ value = weakref.ref(value)
+ return super(WeakLocal, self).__setattr__(attr, value)
+
+
+# NOTE(mikal): the name "store" should be deprecated in the future
+store = WeakLocal()
+
+# A "weak" store uses weak references and allows an object to fall out of scope
+# when it falls out of scope in the code that uses the thread local storage. A
+# "strong" store will hold a reference to the object so that it never falls out
+# of scope.
+weak_store = WeakLocal()
+strong_store = threading.local()
diff --git a/keystone/openstack/common/log.py b/keystone/openstack/common/log.py
new file mode 100644
index 00000000..5a43c326
--- /dev/null
+++ b/keystone/openstack/common/log.py
@@ -0,0 +1,559 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 OpenStack Foundation.
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Openstack logging handler.
+
+This module adds to logging functionality by adding the option to specify
+a context object when calling the various log methods. If the context object
+is not specified, default formatting is used. Additionally, an instance uuid
+may be passed as part of the log message, which is intended to make it easier
+for admins to find messages related to a specific instance.
+
+It also allows setting of formatting information through conf.
+
+"""
+
+import inspect
+import itertools
+import logging
+import logging.config
+import logging.handlers
+import os
+import sys
+import traceback
+
+from oslo.config import cfg
+from six import moves
+
+from keystone.openstack.common.gettextutils import _ # noqa
+from keystone.openstack.common import importutils
+from keystone.openstack.common import jsonutils
+from keystone.openstack.common import local
+
+
+_DEFAULT_LOG_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
+
+common_cli_opts = [
+ cfg.BoolOpt('debug',
+ short='d',
+ default=False,
+ help='Print debugging output (set logging level to '
+ 'DEBUG instead of default WARNING level).'),
+ cfg.BoolOpt('verbose',
+ short='v',
+ default=False,
+ help='Print more verbose output (set logging level to '
+ 'INFO instead of default WARNING level).'),
+]
+
+logging_cli_opts = [
+ cfg.StrOpt('log-config',
+ metavar='PATH',
+ help='If this option is specified, the logging configuration '
+ 'file specified is used and overrides any other logging '
+ 'options specified. Please see the Python logging module '
+ 'documentation for details on logging configuration '
+ 'files.'),
+ cfg.StrOpt('log-format',
+ default=None,
+ metavar='FORMAT',
+ help='DEPRECATED. '
+ 'A logging.Formatter log message format string which may '
+ 'use any of the available logging.LogRecord attributes. '
+ 'This option is deprecated. Please use '
+ 'logging_context_format_string and '
+ 'logging_default_format_string instead.'),
+ cfg.StrOpt('log-date-format',
+ default=_DEFAULT_LOG_DATE_FORMAT,
+ metavar='DATE_FORMAT',
+ help='Format string for %%(asctime)s in log records. '
+ 'Default: %(default)s'),
+ cfg.StrOpt('log-file',
+ metavar='PATH',
+ deprecated_name='logfile',
+ help='(Optional) Name of log file to output to. '
+ 'If no default is set, logging will go to stdout.'),
+ cfg.StrOpt('log-dir',
+ deprecated_name='logdir',
+ help='(Optional) The base directory used for relative '
+ '--log-file paths'),
+ cfg.BoolOpt('use-syslog',
+ default=False,
+ help='Use syslog for logging.'),
+ cfg.StrOpt('syslog-log-facility',
+ default='LOG_USER',
+ help='syslog facility to receive log lines')
+]
+
+generic_log_opts = [
+ cfg.BoolOpt('use_stderr',
+ default=True,
+ help='Log output to standard error')
+]
+
+log_opts = [
+ cfg.StrOpt('logging_context_format_string',
+ default='%(asctime)s.%(msecs)03d %(process)d %(levelname)s '
+ '%(name)s [%(request_id)s %(user)s %(tenant)s] '
+ '%(instance)s%(message)s',
+ help='format string to use for log messages with context'),
+ cfg.StrOpt('logging_default_format_string',
+ default='%(asctime)s.%(msecs)03d %(process)d %(levelname)s '
+ '%(name)s [-] %(instance)s%(message)s',
+ help='format string to use for log messages without context'),
+ cfg.StrOpt('logging_debug_format_suffix',
+ default='%(funcName)s %(pathname)s:%(lineno)d',
+ help='data to append to log format when level is DEBUG'),
+ cfg.StrOpt('logging_exception_prefix',
+ default='%(asctime)s.%(msecs)03d %(process)d TRACE %(name)s '
+ '%(instance)s',
+ help='prefix each line of exception output with this format'),
+ cfg.ListOpt('default_log_levels',
+ default=[
+ 'amqplib=WARN',
+ 'sqlalchemy=WARN',
+ 'boto=WARN',
+ 'suds=INFO',
+ 'keystone=INFO',
+ 'eventlet.wsgi.server=WARN'
+ ],
+ help='list of logger=LEVEL pairs'),
+ cfg.BoolOpt('publish_errors',
+ default=False,
+ help='publish error events'),
+ cfg.BoolOpt('fatal_deprecations',
+ default=False,
+ help='make deprecations fatal'),
+
+ # NOTE(mikal): there are two options here because sometimes we are handed
+ # a full instance (and could include more information), and other times we
+ # are just handed a UUID for the instance.
+ cfg.StrOpt('instance_format',
+ default='[instance: %(uuid)s] ',
+ help='If an instance is passed with the log message, format '
+ 'it like this'),
+ cfg.StrOpt('instance_uuid_format',
+ default='[instance: %(uuid)s] ',
+ help='If an instance UUID is passed with the log message, '
+ 'format it like this'),
+]
+
+CONF = cfg.CONF
+CONF.register_cli_opts(common_cli_opts)
+CONF.register_cli_opts(logging_cli_opts)
+CONF.register_opts(generic_log_opts)
+CONF.register_opts(log_opts)
+
+# our new audit level
+# NOTE(jkoelker) Since we synthesized an audit level, make the logging
+# module aware of it so it acts like other levels.
+logging.AUDIT = logging.INFO + 1
+logging.addLevelName(logging.AUDIT, 'AUDIT')
+
+
+try:
+ NullHandler = logging.NullHandler
+except AttributeError: # NOTE(jkoelker) NullHandler added in Python 2.7
+ class NullHandler(logging.Handler):
+ def handle(self, record):
+ pass
+
+ def emit(self, record):
+ pass
+
+ def createLock(self):
+ self.lock = None
+
+
+def _dictify_context(context):
+ if context is None:
+ return None
+ if not isinstance(context, dict) and getattr(context, 'to_dict', None):
+ context = context.to_dict()
+ return context
+
+
+def _get_binary_name():
+ return os.path.basename(inspect.stack()[-1][1])
+
+
+def _get_log_file_path(binary=None):
+ logfile = CONF.log_file
+ logdir = CONF.log_dir
+
+ if logfile and not logdir:
+ return logfile
+
+ if logfile and logdir:
+ return os.path.join(logdir, logfile)
+
+ if logdir:
+ binary = binary or _get_binary_name()
+ return '%s.log' % (os.path.join(logdir, binary),)
+
+
+class BaseLoggerAdapter(logging.LoggerAdapter):
+
+ def audit(self, msg, *args, **kwargs):
+ self.log(logging.AUDIT, msg, *args, **kwargs)
+
+
+class LazyAdapter(BaseLoggerAdapter):
+ def __init__(self, name='unknown', version='unknown'):
+ self._logger = None
+ self.extra = {}
+ self.name = name
+ self.version = version
+
+ @property
+ def logger(self):
+ if not self._logger:
+ self._logger = getLogger(self.name, self.version)
+ return self._logger
+
+
+class ContextAdapter(BaseLoggerAdapter):
+ warn = logging.LoggerAdapter.warning
+
+ def __init__(self, logger, project_name, version_string):
+ self.logger = logger
+ self.project = project_name
+ self.version = version_string
+
+ @property
+ def handlers(self):
+ return self.logger.handlers
+
+ def deprecated(self, msg, *args, **kwargs):
+ stdmsg = _("Deprecated: %s") % msg
+ if CONF.fatal_deprecations:
+ self.critical(stdmsg, *args, **kwargs)
+ raise DeprecatedConfig(msg=stdmsg)
+ else:
+ self.warn(stdmsg, *args, **kwargs)
+
+ def process(self, msg, kwargs):
+ if 'extra' not in kwargs:
+ kwargs['extra'] = {}
+ extra = kwargs['extra']
+
+ context = kwargs.pop('context', None)
+ if not context:
+ context = getattr(local.store, 'context', None)
+ if context:
+ extra.update(_dictify_context(context))
+
+ instance = kwargs.pop('instance', None)
+ instance_extra = ''
+ if instance:
+ instance_extra = CONF.instance_format % instance
+ else:
+ instance_uuid = kwargs.pop('instance_uuid', None)
+ if instance_uuid:
+ instance_extra = (CONF.instance_uuid_format
+ % {'uuid': instance_uuid})
+ extra.update({'instance': instance_extra})
+
+ extra.update({"project": self.project})
+ extra.update({"version": self.version})
+ extra['extra'] = extra.copy()
+ return msg, kwargs
+
+
+class JSONFormatter(logging.Formatter):
+ def __init__(self, fmt=None, datefmt=None):
+ # NOTE(jkoelker) we ignore the fmt argument, but its still there
+ # since logging.config.fileConfig passes it.
+ self.datefmt = datefmt
+
+ def formatException(self, ei, strip_newlines=True):
+ lines = traceback.format_exception(*ei)
+ if strip_newlines:
+ lines = [itertools.ifilter(
+ lambda x: x,
+ line.rstrip().splitlines()) for line in lines]
+ lines = list(itertools.chain(*lines))
+ return lines
+
+ def format(self, record):
+ message = {'message': record.getMessage(),
+ 'asctime': self.formatTime(record, self.datefmt),
+ 'name': record.name,
+ 'msg': record.msg,
+ 'args': record.args,
+ 'levelname': record.levelname,
+ 'levelno': record.levelno,
+ 'pathname': record.pathname,
+ 'filename': record.filename,
+ 'module': record.module,
+ 'lineno': record.lineno,
+ 'funcname': record.funcName,
+ 'created': record.created,
+ 'msecs': record.msecs,
+ 'relative_created': record.relativeCreated,
+ 'thread': record.thread,
+ 'thread_name': record.threadName,
+ 'process_name': record.processName,
+ 'process': record.process,
+ 'traceback': None}
+
+ if hasattr(record, 'extra'):
+ message['extra'] = record.extra
+
+ if record.exc_info:
+ message['traceback'] = self.formatException(record.exc_info)
+
+ return jsonutils.dumps(message)
+
+
+def _create_logging_excepthook(product_name):
+ def logging_excepthook(type, value, tb):
+ extra = {}
+ if CONF.verbose:
+ extra['exc_info'] = (type, value, tb)
+ getLogger(product_name).critical(str(value), **extra)
+ return logging_excepthook
+
+
+class LogConfigError(Exception):
+
+ message = _('Error loading logging config %(log_config)s: %(err_msg)s')
+
+ def __init__(self, log_config, err_msg):
+ self.log_config = log_config
+ self.err_msg = err_msg
+
+ def __str__(self):
+ return self.message % dict(log_config=self.log_config,
+ err_msg=self.err_msg)
+
+
+def _load_log_config(log_config):
+ try:
+ logging.config.fileConfig(log_config)
+ except moves.configparser.Error as exc:
+ raise LogConfigError(log_config, str(exc))
+
+
+def setup(product_name):
+ """Setup logging."""
+ if CONF.log_config:
+ _load_log_config(CONF.log_config)
+ else:
+ _setup_logging_from_conf()
+ sys.excepthook = _create_logging_excepthook(product_name)
+
+
+def set_defaults(logging_context_format_string):
+ cfg.set_defaults(log_opts,
+ logging_context_format_string=
+ logging_context_format_string)
+
+
+def _find_facility_from_conf():
+ facility_names = logging.handlers.SysLogHandler.facility_names
+ facility = getattr(logging.handlers.SysLogHandler,
+ CONF.syslog_log_facility,
+ None)
+
+ if facility is None and CONF.syslog_log_facility in facility_names:
+ facility = facility_names.get(CONF.syslog_log_facility)
+
+ if facility is None:
+ valid_facilities = facility_names.keys()
+ consts = ['LOG_AUTH', 'LOG_AUTHPRIV', 'LOG_CRON', 'LOG_DAEMON',
+ 'LOG_FTP', 'LOG_KERN', 'LOG_LPR', 'LOG_MAIL', 'LOG_NEWS',
+ 'LOG_AUTH', 'LOG_SYSLOG', 'LOG_USER', 'LOG_UUCP',
+ 'LOG_LOCAL0', 'LOG_LOCAL1', 'LOG_LOCAL2', 'LOG_LOCAL3',
+ 'LOG_LOCAL4', 'LOG_LOCAL5', 'LOG_LOCAL6', 'LOG_LOCAL7']
+ valid_facilities.extend(consts)
+ raise TypeError(_('syslog facility must be one of: %s') %
+ ', '.join("'%s'" % fac
+ for fac in valid_facilities))
+
+ return facility
+
+
+def _setup_logging_from_conf():
+ log_root = getLogger(None).logger
+ for handler in log_root.handlers:
+ log_root.removeHandler(handler)
+
+ if CONF.use_syslog:
+ facility = _find_facility_from_conf()
+ syslog = logging.handlers.SysLogHandler(address='/dev/log',
+ facility=facility)
+ log_root.addHandler(syslog)
+
+ logpath = _get_log_file_path()
+ if logpath:
+ filelog = logging.handlers.WatchedFileHandler(logpath)
+ log_root.addHandler(filelog)
+
+ if CONF.use_stderr:
+ streamlog = ColorHandler()
+ log_root.addHandler(streamlog)
+
+ elif not CONF.log_file:
+ # pass sys.stdout as a positional argument
+ # python2.6 calls the argument strm, in 2.7 it's stream
+ streamlog = logging.StreamHandler(sys.stdout)
+ log_root.addHandler(streamlog)
+
+ if CONF.publish_errors:
+ handler = importutils.import_object(
+ "keystone.openstack.common.log_handler.PublishErrorsHandler",
+ logging.ERROR)
+ log_root.addHandler(handler)
+
+ datefmt = CONF.log_date_format
+ for handler in log_root.handlers:
+ # NOTE(alaski): CONF.log_format overrides everything currently. This
+ # should be deprecated in favor of context aware formatting.
+ if CONF.log_format:
+ handler.setFormatter(logging.Formatter(fmt=CONF.log_format,
+ datefmt=datefmt))
+ log_root.info('Deprecated: log_format is now deprecated and will '
+ 'be removed in the next release')
+ else:
+ handler.setFormatter(ContextFormatter(datefmt=datefmt))
+
+ if CONF.debug:
+ log_root.setLevel(logging.DEBUG)
+ elif CONF.verbose:
+ log_root.setLevel(logging.INFO)
+ else:
+ log_root.setLevel(logging.WARNING)
+
+ for pair in CONF.default_log_levels:
+ mod, _sep, level_name = pair.partition('=')
+ level = logging.getLevelName(level_name)
+ logger = logging.getLogger(mod)
+ logger.setLevel(level)
+
+_loggers = {}
+
+
+def getLogger(name='unknown', version='unknown'):
+ if name not in _loggers:
+ _loggers[name] = ContextAdapter(logging.getLogger(name),
+ name,
+ version)
+ return _loggers[name]
+
+
+def getLazyLogger(name='unknown', version='unknown'):
+ """Returns lazy logger.
+
+ Creates a pass-through logger that does not create the real logger
+ until it is really needed and delegates all calls to the real logger
+ once it is created.
+ """
+ return LazyAdapter(name, version)
+
+
+class WritableLogger(object):
+ """A thin wrapper that responds to `write` and logs."""
+
+ def __init__(self, logger, level=logging.INFO):
+ self.logger = logger
+ self.level = level
+
+ def write(self, msg):
+ self.logger.log(self.level, msg)
+
+
+class ContextFormatter(logging.Formatter):
+ """A context.RequestContext aware formatter configured through flags.
+
+ The flags used to set format strings are: logging_context_format_string
+ and logging_default_format_string. You can also specify
+ logging_debug_format_suffix to append extra formatting if the log level is
+ debug.
+
+ For information about what variables are available for the formatter see:
+ http://docs.python.org/library/logging.html#formatter
+
+ """
+
+ def format(self, record):
+ """Uses contextstring if request_id is set, otherwise default."""
+ # NOTE(sdague): default the fancier formating params
+ # to an empty string so we don't throw an exception if
+ # they get used
+ for key in ('instance', 'color'):
+ if key not in record.__dict__:
+ record.__dict__[key] = ''
+
+ if record.__dict__.get('request_id', None):
+ self._fmt = CONF.logging_context_format_string
+ else:
+ self._fmt = CONF.logging_default_format_string
+
+ if (record.levelno == logging.DEBUG and
+ CONF.logging_debug_format_suffix):
+ self._fmt += " " + CONF.logging_debug_format_suffix
+
+ # Cache this on the record, Logger will respect our formated copy
+ if record.exc_info:
+ record.exc_text = self.formatException(record.exc_info, record)
+ return logging.Formatter.format(self, record)
+
+ def formatException(self, exc_info, record=None):
+ """Format exception output with CONF.logging_exception_prefix."""
+ if not record:
+ return logging.Formatter.formatException(self, exc_info)
+
+ stringbuffer = moves.StringIO()
+ traceback.print_exception(exc_info[0], exc_info[1], exc_info[2],
+ None, stringbuffer)
+ lines = stringbuffer.getvalue().split('\n')
+ stringbuffer.close()
+
+ if CONF.logging_exception_prefix.find('%(asctime)') != -1:
+ record.asctime = self.formatTime(record, self.datefmt)
+
+ formatted_lines = []
+ for line in lines:
+ pl = CONF.logging_exception_prefix % record.__dict__
+ fl = '%s%s' % (pl, line)
+ formatted_lines.append(fl)
+ return '\n'.join(formatted_lines)
+
+
+class ColorHandler(logging.StreamHandler):
+ LEVEL_COLORS = {
+ logging.DEBUG: '\033[00;32m', # GREEN
+ logging.INFO: '\033[00;36m', # CYAN
+ logging.AUDIT: '\033[01;36m', # BOLD CYAN
+ logging.WARN: '\033[01;33m', # BOLD YELLOW
+ logging.ERROR: '\033[01;31m', # BOLD RED
+ logging.CRITICAL: '\033[01;31m', # BOLD RED
+ }
+
+ def format(self, record):
+ record.color = self.LEVEL_COLORS[record.levelno]
+ return logging.StreamHandler.format(self, record)
+
+
+class DeprecatedConfig(Exception):
+ message = _("Fatal call to deprecated config: %(msg)s")
+
+ def __init__(self, msg):
+ super(Exception, self).__init__(self.message % dict(msg=msg))
diff --git a/keystone/openstack/common/timeutils.py b/keystone/openstack/common/timeutils.py
index 60943659..aa9f7080 100644
--- a/keystone/openstack/common/timeutils.py
+++ b/keystone/openstack/common/timeutils.py
@@ -23,6 +23,7 @@ import calendar
import datetime
import iso8601
+import six
# ISO 8601 extended time format with microseconds
@@ -32,7 +33,7 @@ PERFECT_TIME_FORMAT = _ISO8601_TIME_FORMAT_SUBSECOND
def isotime(at=None, subsecond=False):
- """Stringify time in ISO 8601 format"""
+ """Stringify time in ISO 8601 format."""
if not at:
at = utcnow()
st = at.strftime(_ISO8601_TIME_FORMAT
@@ -44,13 +45,13 @@ def isotime(at=None, subsecond=False):
def parse_isotime(timestr):
- """Parse time from ISO 8601 format"""
+ """Parse time from ISO 8601 format."""
try:
return iso8601.parse_date(timestr)
except iso8601.ParseError as e:
- raise ValueError(e.message)
+ raise ValueError(unicode(e))
except TypeError as e:
- raise ValueError(e.message)
+ raise ValueError(unicode(e))
def strtime(at=None, fmt=PERFECT_TIME_FORMAT):
@@ -66,7 +67,7 @@ def parse_strtime(timestr, fmt=PERFECT_TIME_FORMAT):
def normalize_time(timestamp):
- """Normalize time in arbitrary timezone to UTC naive object"""
+ """Normalize time in arbitrary timezone to UTC naive object."""
offset = timestamp.utcoffset()
if offset is None:
return timestamp
@@ -75,14 +76,14 @@ def normalize_time(timestamp):
def is_older_than(before, seconds):
"""Return True if before is older than seconds."""
- if isinstance(before, basestring):
+ if isinstance(before, six.string_types):
before = parse_strtime(before).replace(tzinfo=None)
return utcnow() - before > datetime.timedelta(seconds=seconds)
def is_newer_than(after, seconds):
"""Return True if after is newer than seconds."""
- if isinstance(after, basestring):
+ if isinstance(after, six.string_types):
after = parse_strtime(after).replace(tzinfo=None)
return after - utcnow() > datetime.timedelta(seconds=seconds)
@@ -103,7 +104,7 @@ def utcnow():
def iso8601_from_timestamp(timestamp):
- """Returns a iso8601 formated date from timestamp"""
+ """Returns a iso8601 formated date from timestamp."""
return isotime(datetime.datetime.utcfromtimestamp(timestamp))
@@ -111,9 +112,9 @@ utcnow.override_time = None
def set_time_override(override_time=datetime.datetime.utcnow()):
- """
- Override utils.utcnow to return a constant time or a list thereof,
- one at a time.
+ """Overrides utils.utcnow.
+
+ Make it return a constant time or a list thereof, one at a time.
"""
utcnow.override_time = override_time
@@ -141,7 +142,8 @@ def clear_time_override():
def marshall_now(now=None):
"""Make an rpc-safe datetime with microseconds.
- Note: tzinfo is stripped, but not required for relative times."""
+ Note: tzinfo is stripped, but not required for relative times.
+ """
if not now:
now = utcnow()
return dict(day=now.day, month=now.month, year=now.year, hour=now.hour,
@@ -161,7 +163,8 @@ def unmarshall_time(tyme):
def delta_seconds(before, after):
- """
+ """Return the difference between two timing objects.
+
Compute the difference in seconds between two date, time, or
datetime objects (as a float, to microsecond resolution).
"""
@@ -174,8 +177,7 @@ def delta_seconds(before, after):
def is_soon(dt, window):
- """
- Determines if time is going to happen in the next window seconds.
+ """Determines if time is going to happen in the next window seconds.
:params dt: the time
:params window: minimum seconds to remain to consider the time not soon
diff --git a/keystone/token/backends/sql.py b/keystone/token/backends/sql.py
index 0e8a916d..82eab651 100644
--- a/keystone/token/backends/sql.py
+++ b/keystone/token/backends/sql.py
@@ -17,7 +17,6 @@
import copy
import datetime
-
from keystone.common import sql
from keystone import exception
from keystone.openstack.common import timeutils
@@ -30,9 +29,13 @@ class TokenModel(sql.ModelBase, sql.DictBase):
id = sql.Column(sql.String(64), primary_key=True)
expires = sql.Column(sql.DateTime(), default=None)
extra = sql.Column(sql.JsonBlob())
- valid = sql.Column(sql.Boolean(), default=True)
+ valid = sql.Column(sql.Boolean(), default=True, nullable=False)
user_id = sql.Column(sql.String(64))
- trust_id = sql.Column(sql.String(64), nullable=True)
+ trust_id = sql.Column(sql.String(64))
+ __table_args__ = (
+ sql.Index('ix_token_expires', 'expires'),
+ sql.Index('ix_token_valid', 'valid')
+ )
class Token(sql.Base, token.Driver):
diff --git a/keystone/token/controllers.py b/keystone/token/controllers.py
index 9ebc29fe..91514493 100644
--- a/keystone/token/controllers.py
+++ b/keystone/token/controllers.py
@@ -4,7 +4,6 @@ from keystone.common import cms
from keystone.common import controller
from keystone.common import dependency
from keystone.common import logging
-from keystone.common import utils
from keystone.common import wsgi
from keystone import config
from keystone import exception
@@ -215,10 +214,9 @@ class Auth(controller.V2Controller):
attribute='password', target='passwordCredentials')
password = auth['passwordCredentials']['password']
- max_pw_size = utils.MAX_PASSWORD_LENGTH
- if password and len(password) > max_pw_size:
- raise exception.ValidationSizeError(attribute='password',
- size=max_pw_size)
+ if password and len(password) > CONF.identity.max_password_length:
+ raise exception.ValidationSizeError(
+ attribute='password', size=CONF.identity.max_password_length)
if ("userId" not in auth['passwordCredentials'] and
"username" not in auth['passwordCredentials']):
diff --git a/keystone/trust/backends/sql.py b/keystone/trust/backends/sql.py
index daa8e3f7..9e92ad71 100644
--- a/keystone/trust/backends/sql.py
+++ b/keystone/trust/backends/sql.py
@@ -26,11 +26,11 @@ class TrustModel(sql.ModelBase, sql.DictBase):
'project_id', 'impersonation', 'expires_at']
id = sql.Column(sql.String(64), primary_key=True)
#user id Of owner
- trustor_user_id = sql.Column(sql.String(64), unique=False, nullable=False,)
+ trustor_user_id = sql.Column(sql.String(64), nullable=False,)
#user_id of user allowed to consume this preauth
- trustee_user_id = sql.Column(sql.String(64), unique=False, nullable=False)
- project_id = sql.Column(sql.String(64), unique=False, nullable=True)
- impersonation = sql.Column(sql.Boolean)
+ trustee_user_id = sql.Column(sql.String(64), nullable=False)
+ project_id = sql.Column(sql.String(64))
+ impersonation = sql.Column(sql.Boolean, nullable=False)
deleted_at = sql.Column(sql.DateTime)
expires_at = sql.Column(sql.DateTime)
extra = sql.Column(sql.JsonBlob())
diff --git a/requirements.txt b/requirements.txt
index e54bb6a0..f7161d2c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,6 +5,7 @@ pam>=0.1.4
WebOb>=1.0.8
eventlet
greenlet
+netaddr
PasteDeploy
paste
routes
diff --git a/tests/test_auth.py b/tests/test_auth.py
index db5314be..e8e6c7a9 100644
--- a/tests/test_auth.py
+++ b/tests/test_auth.py
@@ -179,7 +179,8 @@ class AuthBadRequests(AuthTest):
def test_authenticate_password_too_large(self):
"""Verify sending large 'password' raises the right exception."""
- body_dict = _build_user_auth(username='FOO', password='0' * 8193)
+ length = CONF.identity.max_password_length + 1
+ body_dict = _build_user_auth(username='FOO', password='0' * length)
self.assertRaises(exception.ValidationSizeError,
self.controller.authenticate,
{}, body_dict)
diff --git a/tests/test_backend.py b/tests/test_backend.py
index 7e4d820e..75a94773 100644
--- a/tests/test_backend.py
+++ b/tests/test_backend.py
@@ -1453,6 +1453,18 @@ class IdentityTests(object):
tenants = self.identity_api.get_projects_for_user(self.user_foo['id'])
self.assertIn(self.tenant_baz['id'], tenants)
+ def test_add_user_to_project_missing_default_role(self):
+ self.assignment_api.delete_role(CONF.member_role_id)
+ self.assertRaises(exception.RoleNotFound,
+ self.assignment_api.get_role,
+ CONF.member_role_id)
+ self.identity_api.add_user_to_project(self.tenant_baz['id'],
+ self.user_foo['id'])
+ tenants = self.identity_api.get_projects_for_user(self.user_foo['id'])
+ self.assertIn(self.tenant_baz['id'], tenants)
+ default_role = self.assignment_api.get_role(CONF.member_role_id)
+ self.assertIsNotNone(default_role)
+
def test_add_user_to_project_404(self):
self.assertRaises(exception.ProjectNotFound,
self.identity_api.add_user_to_project,
diff --git a/tests/test_sql_migrate_extensions.py b/tests/test_sql_migrate_extensions.py
new file mode 100644
index 00000000..4a529559
--- /dev/null
+++ b/tests/test_sql_migrate_extensions.py
@@ -0,0 +1,47 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2012 OpenStack LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+To run these tests against a live database:
+1. Modify the file `tests/backend_sql.conf` to use the connection for your
+ live database
+2. Set up a blank, live database.
+3. run the tests using
+ ./run_tests.sh -N test_sql_upgrade
+ WARNING::
+ Your database will be wiped.
+ Do not do this against a Database with valuable data as
+ all data will be lost.
+"""
+
+from keystone.contrib import example
+
+import test_sql_upgrade
+
+
+class SqlUpgradeExampleExtension(test_sql_upgrade.SqlMigrateBase):
+ def repo_package(self):
+ return example
+
+ def test_upgrade(self):
+ self.assertTableDoesNotExist('example')
+ self.upgrade(1, repository=self.repo_path)
+ self.assertTableColumns('example', ['id', 'type', 'extra'])
+
+ def test_downgrade(self):
+ self.upgrade(1, repository=self.repo_path)
+ self.assertTableColumns('example', ['id', 'type', 'extra'])
+ self.downgrade(0, repository=self.repo_path)
+ self.assertTableDoesNotExist('example')
diff --git a/tests/test_sql_upgrade.py b/tests/test_sql_upgrade.py
index cf82b814..9540c4cd 100644
--- a/tests/test_sql_upgrade.py
+++ b/tests/test_sql_upgrade.py
@@ -45,8 +45,7 @@ CONF = config.CONF
DEFAULT_DOMAIN_ID = CONF.identity.default_domain_id
-class SqlUpgradeTests(test.TestCase):
-
+class SqlMigrateBase(test.TestCase):
def initialize_sql(self):
self.metadata = sqlalchemy.MetaData()
self.metadata.bind = self.engine
@@ -55,12 +54,15 @@ class SqlUpgradeTests(test.TestCase):
test.testsdir('test_overrides.conf'),
test.testsdir('backend_sql.conf')]
- #override this to sepcify the complete list of configuration files
+ #override this to specify the complete list of configuration files
def config_files(self):
return self._config_file_list
+ def repo_package(self):
+ return None
+
def setUp(self):
- super(SqlUpgradeTests, self).setUp()
+ super(SqlMigrateBase, self).setUp()
self.config(self.config_files())
self.base = sql.Base()
@@ -71,7 +73,7 @@ class SqlUpgradeTests(test.TestCase):
autocommit=False)
self.initialize_sql()
- self.repo_path = migration._find_migrate_repo()
+ self.repo_path = migration.find_migrate_repo(self.repo_package())
self.schema = versioning_api.ControlledSchema.create(
self.engine,
self.repo_path, 0)
@@ -85,7 +87,64 @@ class SqlUpgradeTests(test.TestCase):
autoload=True)
self.downgrade(0)
table.drop(self.engine, checkfirst=True)
- super(SqlUpgradeTests, self).tearDown()
+ super(SqlMigrateBase, self).tearDown()
+
+ def select_table(self, name):
+ table = sqlalchemy.Table(name,
+ self.metadata,
+ autoload=True)
+ s = sqlalchemy.select([table])
+ return s
+
+ def assertTableExists(self, table_name):
+ try:
+ self.select_table(table_name)
+ except sqlalchemy.exc.NoSuchTableError:
+ raise AssertionError('Table "%s" does not exist' % table_name)
+
+ def assertTableDoesNotExist(self, table_name):
+ """Asserts that a given table exists cannot be selected by name."""
+ # Switch to a different metadata otherwise you might still
+ # detect renamed or dropped tables
+ try:
+ temp_metadata = sqlalchemy.MetaData()
+ temp_metadata.bind = self.engine
+ sqlalchemy.Table(table_name, temp_metadata, autoload=True)
+ except sqlalchemy.exc.NoSuchTableError:
+ pass
+ else:
+ raise AssertionError('Table "%s" already exists' % table_name)
+
+ def upgrade(self, *args, **kwargs):
+ self._migrate(*args, **kwargs)
+
+ def downgrade(self, *args, **kwargs):
+ self._migrate(*args, downgrade=True, **kwargs)
+
+ def _migrate(self, version, repository=None, downgrade=False,
+ current_schema=None):
+ repository = repository or self.repo_path
+ err = ''
+ version = versioning_api._migrate_version(self.schema,
+ version,
+ not downgrade,
+ err)
+ if not current_schema:
+ current_schema = self.schema
+ changeset = current_schema.changeset(version)
+ for ver, change in changeset:
+ self.schema.runchange(ver, change, changeset.step)
+ self.assertEqual(self.schema.version, version)
+
+ def assertTableColumns(self, table_name, expected_cols):
+ """Asserts that the table contains the expected set of columns."""
+ self.initialize_sql()
+ table = self.select_table(table_name)
+ actual_cols = [col.name for col in table.columns]
+ self.assertEqual(expected_cols, actual_cols, '%s table' % table_name)
+
+
+class SqlUpgradeTests(SqlMigrateBase):
def test_blank_db_to_start(self):
self.assertTableDoesNotExist('user')
@@ -108,13 +167,6 @@ class SqlUpgradeTests(test.TestCase):
self.downgrade(x - 1)
self.upgrade(x)
- def assertTableColumns(self, table_name, expected_cols):
- """Asserts that the table contains the expected set of columns."""
- self.initialize_sql()
- table = self.select_table(table_name)
- actual_cols = [col.name for col in table.columns]
- self.assertEqual(expected_cols, actual_cols, '%s table' % table_name)
-
def test_upgrade_add_initial_tables(self):
self.upgrade(1)
self.assertTableColumns("user", ["id", "name", "extra"])
@@ -1186,6 +1238,24 @@ class SqlUpgradeTests(test.TestCase):
self.assertEqual(cred.user_id,
ec2_credential['user_id'])
+ def test_drop_credential_indexes(self):
+ self.upgrade(31)
+ table = sqlalchemy.Table('credential', self.metadata, autoload=True)
+ self.assertEqual(len(table.indexes), 0)
+
+ def test_downgrade_30(self):
+ self.upgrade(31)
+ self.downgrade(30)
+ table = sqlalchemy.Table('credential', self.metadata, autoload=True)
+ index_data = [(idx.name, idx.columns.keys())
+ for idx in table.indexes]
+ if self.engine.name == 'mysql':
+ self.assertIn(('user_id', ['user_id']), index_data)
+ self.assertIn(('credential_project_id_fkey', ['project_id']),
+ index_data)
+ else:
+ self.assertEqual(len(index_data), 0)
+
def populate_user_table(self, with_pass_enab=False,
with_pass_enab_domain=False):
# Populate the appropriate fields in the user
@@ -1284,50 +1354,6 @@ class SqlUpgradeTests(test.TestCase):
'extra': json.dumps(extra)})
self.engine.execute(ins)
- def select_table(self, name):
- table = sqlalchemy.Table(name,
- self.metadata,
- autoload=True)
- s = sqlalchemy.select([table])
- return s
-
- def assertTableExists(self, table_name):
- try:
- self.select_table(table_name)
- except sqlalchemy.exc.NoSuchTableError:
- raise AssertionError('Table "%s" does not exist' % table_name)
-
- def assertTableDoesNotExist(self, table_name):
- """Asserts that a given table exists cannot be selected by name."""
- # Switch to a different metadata otherwise you might still
- # detect renamed or dropped tables
- try:
- temp_metadata = sqlalchemy.MetaData()
- temp_metadata.bind = self.engine
- sqlalchemy.Table(table_name, temp_metadata, autoload=True)
- except sqlalchemy.exc.NoSuchTableError:
- pass
- else:
- raise AssertionError('Table "%s" already exists' % table_name)
-
- def upgrade(self, *args, **kwargs):
- self._migrate(*args, **kwargs)
-
- def downgrade(self, *args, **kwargs):
- self._migrate(*args, downgrade=True, **kwargs)
-
- def _migrate(self, version, repository=None, downgrade=False):
- repository = repository or self.repo_path
- err = ''
- version = versioning_api._migrate_version(self.schema,
- version,
- not downgrade,
- err)
- changeset = self.schema.changeset(version)
- for ver, change in changeset:
- self.schema.runchange(ver, change, changeset.step)
- self.assertEqual(self.schema.version, version)
-
def _mysql_check_all_tables_innodb(self):
database = self.engine.url.database
diff --git a/tests/test_wsgi.py b/tests/test_wsgi.py
index 003f7571..362df922 100644
--- a/tests/test_wsgi.py
+++ b/tests/test_wsgi.py
@@ -37,37 +37,6 @@ class BaseWSGITest(test.TestCase):
req.environ['wsgiorg.routing_args'] = [None, args]
return req
- def test_mask_password(self):
- message = ("test = 'password': 'aaaaaa', 'param1': 'value1', "
- "\"new_password\": 'bbbbbb'")
- self.assertEqual(wsgi.mask_password(message, True),
- u"test = 'password': '***', 'param1': 'value1', "
- "\"new_password\": '***'")
-
- message = "test = 'password' : 'aaaaaa'"
- self.assertEqual(wsgi.mask_password(message, False, '111'),
- "test = 'password' : '111'")
-
- message = u"test = u'password' : u'aaaaaa'"
- self.assertEqual(wsgi.mask_password(message, True),
- u"test = u'password' : u'***'")
-
- message = 'test = "password" : "aaaaaaaaa"'
- self.assertEqual(wsgi.mask_password(message),
- 'test = "password" : "***"')
-
- message = 'test = "original_password" : "aaaaaaaaa"'
- self.assertEqual(wsgi.mask_password(message),
- 'test = "original_password" : "***"')
-
- message = 'test = "original_password" : ""'
- self.assertEqual(wsgi.mask_password(message),
- 'test = "original_password" : "***"')
-
- message = 'test = "param1" : "value"'
- self.assertEqual(wsgi.mask_password(message),
- 'test = "param1" : "value"')
-
class ApplicationTest(BaseWSGITest):
def test_response_content_type(self):
@@ -210,3 +179,36 @@ class MiddlewareTest(BaseWSGITest):
app = factory(self.app)
self.assertIn("testkey", app.kwargs)
self.assertEquals("test", app.kwargs["testkey"])
+
+
+class WSGIFunctionTest(test.TestCase):
+ def test_mask_password(self):
+ message = ("test = 'password': 'aaaaaa', 'param1': 'value1', "
+ "\"new_password\": 'bbbbbb'")
+ self.assertEqual(wsgi.mask_password(message, True),
+ u"test = 'password': '***', 'param1': 'value1', "
+ "\"new_password\": '***'")
+
+ message = "test = 'password' : 'aaaaaa'"
+ self.assertEqual(wsgi.mask_password(message, False, '111'),
+ "test = 'password' : '111'")
+
+ message = u"test = u'password' : u'aaaaaa'"
+ self.assertEqual(wsgi.mask_password(message, True),
+ u"test = u'password' : u'***'")
+
+ message = 'test = "password" : "aaaaaaaaa"'
+ self.assertEqual(wsgi.mask_password(message),
+ 'test = "password" : "***"')
+
+ message = 'test = "original_password" : "aaaaaaaaa"'
+ self.assertEqual(wsgi.mask_password(message),
+ 'test = "original_password" : "***"')
+
+ message = 'test = "original_password" : ""'
+ self.assertEqual(wsgi.mask_password(message),
+ 'test = "original_password" : "***"')
+
+ message = 'test = "param1" : "value"'
+ self.assertEqual(wsgi.mask_password(message),
+ 'test = "param1" : "value"')