diff options
33 files changed, 1431 insertions, 260 deletions
diff --git a/doc/source/developing.rst b/doc/source/developing.rst index c14ef7ab..2cf4b98e 100644 --- a/doc/source/developing.rst +++ b/doc/source/developing.rst @@ -71,6 +71,36 @@ place:: .. _`python-keystoneclient`: https://github.com/openstack/python-keystoneclient +Database Schema Migrations +-------------------------- + +Keystone uses SQLAlchemy-migrate +_`SQLAlchemy-migrate`:http://code.google.com/p/sqlalchemy-migrate/ to migrate the SQL database +between revisions. For core components, the migrations are kept in a central +repository under keystone/common/sql/migrate_repo. + +Extensions should be created as directories under `keystone/contrib`. An +extension that requires sql migrations should not change the common repository, +but should instead have its own repository. This repository must be in the +extension's directory in `keystone/contrib/<extension>/migrate_repo.` In +addition it needs a subdirectory named `versions`. For example, if the +extension name is `my_extension` then the directory structure would be +`keystone/contrib/my_extension/migrate_repo/versions/`. For the migration +o work, both the migrate_repo and versions subdirectories must have empty +__init__.py files. SQLAlchemy-migrate will look for a configuration file in +the migrate_repo named migrate.cfg. This conforms to a Key/value ini file +format. A sample config file with the minimal set of values is:: + + [db_settings] + repository_id=my_extension + version_table=migrate_version + required_dbs=[] + +The directory `keystone/contrib/example` contains a sample extension migration. + +Migrations for extension must be explicitly run. To run a migration for a specific +extension, run `keystone-manage --extension <name> db_sync`. + Initial Sample Data ------------------- diff --git a/etc/keystone.conf.sample b/etc/keystone.conf.sample index a49a9a5e..90efe5f6 100644 --- a/etc/keystone.conf.sample +++ b/etc/keystone.conf.sample @@ -100,6 +100,9 @@ # exist to order to maintain support for your v2 clients. # default_domain_id = default +# Maximum supported length for user passwords; decrease to improve performance. +# max_password_length = 4096 + [credential] # driver = keystone.credential.backends.sql.Credential diff --git a/keystone/assignment/backends/ldap.py b/keystone/assignment/backends/ldap.py index 9b273e40..718d38c3 100644 --- a/keystone/assignment/backends/ldap.py +++ b/keystone/assignment/backends/ldap.py @@ -263,28 +263,19 @@ class ProjectApi(common_ldap.EnabledEmuMixIn, common_ldap.BaseLdap): DEFAULT_OBJECTCLASS = 'groupOfNames' DEFAULT_ID_ATTR = 'cn' DEFAULT_MEMBER_ATTRIBUTE = 'member' - DEFAULT_ATTRIBUTE_IGNORE = [] NotFound = exception.ProjectNotFound notfound_arg = 'project_id' # NOTE(yorik-sar): while options_name = tenant options_name = 'tenant' - attribute_mapping = {'name': 'ou', - 'description': 'description', - 'tenantId': 'cn', - 'enabled': 'enabled', - 'domain_id': 'domain_id'} + attribute_options_names = {'name': 'name', + 'description': 'desc', + 'enabled': 'enabled', + 'domain_id': 'domain_id'} model = models.Project def __init__(self, conf): super(ProjectApi, self).__init__(conf) - self.attribute_mapping['name'] = conf.ldap.tenant_name_attribute - self.attribute_mapping['description'] = conf.ldap.tenant_desc_attribute - self.attribute_mapping['enabled'] = conf.ldap.tenant_enabled_attribute - self.attribute_mapping['domain_id'] = ( - conf.ldap.tenant_domain_id_attribute) self.member_attribute = (getattr(conf.ldap, 'tenant_member_attribute') or self.DEFAULT_MEMBER_ATTRIBUTE) - self.attribute_ignore = (getattr(conf.ldap, 'tenant_attribute_ignore') - or self.DEFAULT_ATTRIBUTE_IGNORE) def create(self, values): self.affirm_unique(values) @@ -381,21 +372,15 @@ class RoleApi(common_ldap.BaseLdap): DEFAULT_STRUCTURAL_CLASSES = [] DEFAULT_OBJECTCLASS = 'organizationalRole' DEFAULT_MEMBER_ATTRIBUTE = 'roleOccupant' - DEFAULT_ATTRIBUTE_IGNORE = [] NotFound = exception.RoleNotFound options_name = 'role' - attribute_mapping = {'name': 'ou', - #'serviceId': 'service_id', - } + attribute_options_names = {'name': 'name'} model = models.Role def __init__(self, conf): super(RoleApi, self).__init__(conf) - self.attribute_mapping['name'] = conf.ldap.role_name_attribute self.member_attribute = (getattr(conf.ldap, 'role_member_attribute') or self.DEFAULT_MEMBER_ATTRIBUTE) - self.attribute_ignore = (getattr(conf.ldap, 'role_attribute_ignore') - or self.DEFAULT_ATTRIBUTE_IGNORE) def get(self, id, filter=None): model = super(RoleApi, self).get(id, filter) diff --git a/keystone/assignment/core.py b/keystone/assignment/core.py index 64edb3fa..0a2ee681 100644 --- a/keystone/assignment/core.py +++ b/keystone/assignment/core.py @@ -178,9 +178,23 @@ class Manager(manager.Manager): keystone.exception.UserNotFound """ - self.driver.add_role_to_user_and_project(user_id, - tenant_id, - config.CONF.member_role_id) + try: + self.driver.add_role_to_user_and_project( + user_id, + tenant_id, + config.CONF.member_role_id) + except exception.RoleNotFound: + LOG.info(_("Creating the default role %s " + "because it does not exist.") % + config.CONF.member_role_id) + role = {'id': CONF.member_role_id, + 'name': CONF.member_role_name} + self.driver.create_role(config.CONF.member_role_id, role) + #now that default role exists, the add should succeed + self.driver.add_role_to_user_and_project( + user_id, + tenant_id, + config.CONF.member_role_id) def remove_user_from_project(self, tenant_id, user_id): """Remove user from a tenant diff --git a/keystone/catalog/backends/sql.py b/keystone/catalog/backends/sql.py index 1dad5a80..d7b2123a 100644 --- a/keystone/catalog/backends/sql.py +++ b/keystone/catalog/backends/sql.py @@ -32,6 +32,7 @@ class Service(sql.ModelBase, sql.DictBase): id = sql.Column(sql.String(64), primary_key=True) type = sql.Column(sql.String(255)) extra = sql.Column(sql.JsonBlob()) + endpoints = sql.relationship("Endpoint", backref="service") class Endpoint(sql.ModelBase, sql.DictBase): @@ -40,12 +41,12 @@ class Endpoint(sql.ModelBase, sql.DictBase): 'legacy_endpoint_id'] id = sql.Column(sql.String(64), primary_key=True) legacy_endpoint_id = sql.Column(sql.String(64)) - interface = sql.Column(sql.String(8), primary_key=True) - region = sql.Column('region', sql.String(255)) + interface = sql.Column(sql.String(8), nullable=False) + region = sql.Column(sql.String(255)) service_id = sql.Column(sql.String(64), sql.ForeignKey('service.id'), nullable=False) - url = sql.Column(sql.Text()) + url = sql.Column(sql.Text(), nullable=False) extra = sql.Column(sql.JsonBlob()) @@ -150,28 +151,26 @@ class Catalog(sql.Base, catalog.Driver): d.update({'tenant_id': tenant_id, 'user_id': user_id}) + session = self.get_session() + endpoints = (session.query(Endpoint). + options(sql.joinedload(Endpoint.service)). + all()) + catalog = {} - services = {} - for endpoint in self.list_endpoints(): - # look up the service - services.setdefault( - endpoint['service_id'], - self.get_service(endpoint['service_id'])) - service = services[endpoint['service_id']] - - # add the endpoint to the catalog if it's not already there - catalog.setdefault(endpoint['region'], {}) - catalog[endpoint['region']].setdefault( - service['type'], { - 'id': endpoint['id'], - 'name': service['name'], - 'publicURL': '', # this may be overridden, but must exist - }) - - # add the interface's url - url = core.format_url(endpoint.get('url'), d) + + for endpoint in endpoints: + region = endpoint['region'] + service_type = endpoint.service['type'] + default_service = { + 'id': endpoint['id'], + 'name': endpoint.service['name'], + 'publicURL': '' + } + catalog.setdefault(region, {}) + catalog[region].setdefault(service_type, default_service) + url = core.format_url(endpoint['url'], d) interface_url = '%sURL' % endpoint['interface'] - catalog[endpoint['region']][service['type']][interface_url] = url + catalog[region][service_type][interface_url] = url return catalog @@ -180,27 +179,19 @@ class Catalog(sql.Base, catalog.Driver): d.update({'tenant_id': tenant_id, 'user_id': user_id}) - services = {} - for endpoint in self.list_endpoints(): - # look up the service - service_id = endpoint['service_id'] - services.setdefault( - service_id, - self.get_service(service_id)) - service = services[service_id] + session = self.get_session() + services = (session.query(Service). + options(sql.joinedload(Service.endpoints)). + all()) + + def make_v3_endpoint(endpoint): del endpoint['service_id'] endpoint['url'] = core.format_url(endpoint['url'], d) - if 'endpoints' in services[service_id]: - services[service_id]['endpoints'].append(endpoint) - else: - services[service_id]['endpoints'] = [endpoint] - - catalog = [] - for service_id, service in services.iteritems(): - formatted_service = {} - formatted_service['id'] = service['id'] - formatted_service['type'] = service['type'] - formatted_service['endpoints'] = service['endpoints'] - catalog.append(formatted_service) + return endpoint + + catalog = [{'endpoints': [make_v3_endpoint(ep.to_dict()) + for ep in svc.endpoints], + 'id': svc.id, + 'type': svc.type} for svc in services] return catalog diff --git a/keystone/cli.py b/keystone/cli.py index 21d2ad40..18c095ce 100644 --- a/keystone/cli.py +++ b/keystone/cli.py @@ -20,12 +20,15 @@ import grp import os import pwd +from migrate import exceptions + from oslo.config import cfg import pbr.version from keystone.common import openssl from keystone.common.sql import migration from keystone import config +from keystone import contrib from keystone.openstack.common import importutils from keystone.openstack.common import jsonutils from keystone import token @@ -57,14 +60,35 @@ class DbSync(BaseApp): 'version. If not provided, db_sync will ' 'migrate the database to the latest known ' 'version.')) + parser.add_argument('--extension', default=None, + help=('Migrate the database for the specified ' + 'extension. If not provided, db_sync will ' + 'migrate the common repository.')) + return parser @staticmethod def main(): - for k in ['identity', 'catalog', 'policy', 'token', 'credential']: - driver = importutils.import_object(getattr(CONF, k).driver) - if hasattr(driver, 'db_sync'): - driver.db_sync(CONF.command.version) + version = CONF.command.version + extension = CONF.command.extension + if not extension: + migration.db_sync(version=version) + else: + package_name = "%s.%s.migrate_repo" % (contrib.__name__, extension) + try: + package = importutils.import_module(package_name) + repo_path = os.path.abspath(os.path.dirname(package.__file__)) + except ImportError: + print _("This extension does not provide migrations.") + exit(0) + try: + # Register the repo with the version control API + # If it already knows about the repo, it will throw + # an exception that we can safely ignore + migration.db_version_control(version=None, repo_path=repo_path) + except exceptions.DatabaseAlreadyControlledError: + pass + migration.db_sync(version=None, repo_path=repo_path) class DbVersion(BaseApp): @@ -72,9 +96,29 @@ class DbVersion(BaseApp): name = 'db_version' + @classmethod + def add_argument_parser(cls, subparsers): + parser = super(DbVersion, cls).add_argument_parser(subparsers) + parser.add_argument('--extension', default=None, + help=('Migrate the database for the specified ' + 'extension. If not provided, db_sync will ' + 'migrate the common repository.')) + @staticmethod def main(): - print(migration.db_version()) + extension = CONF.command.extension + if extension: + try: + package_name = ("%s.%s.migrate_repo" % + (contrib.__name__, extension)) + package = importutils.import_module(package_name) + repo_path = os.path.abspath(os.path.dirname(package.__file__)) + print(migration.db_version(repo_path)) + except ImportError: + print _("This extension does not provide migrations.") + exit(1) + else: + print(migration.db_version()) class BaseCertificateSetup(BaseApp): diff --git a/keystone/common/config.py b/keystone/common/config.py index 10c47a35..cd525369 100644 --- a/keystone/common/config.py +++ b/keystone/common/config.py @@ -210,6 +210,7 @@ def configure(): # identity register_str('default_domain_id', group='identity', default='default') + register_int('max_password_length', group='identity', default=4096) # trust register_bool('enabled', group='trust', default=True) diff --git a/keystone/common/ldap/core.py b/keystone/common/ldap/core.py index 7a2dfee7..39ea78de 100644 --- a/keystone/common/ldap/core.py +++ b/keystone/common/ldap/core.py @@ -114,7 +114,7 @@ class BaseLdap(object): notfound_arg = None options_name = None model = None - attribute_mapping = {} + attribute_options_names = {} attribute_ignore = [] tree_dn = None @@ -129,6 +129,7 @@ class BaseLdap(object): self.tls_cacertfile = conf.ldap.tls_cacertfile self.tls_cacertdir = conf.ldap.tls_cacertdir self.tls_req_cert = parse_tls_cert(conf.ldap.tls_req_cert) + self.attribute_mapping = {} if self.options_name is not None: self.suffix = conf.ldap.suffix @@ -145,6 +146,10 @@ class BaseLdap(object): self.object_class = (getattr(conf.ldap, objclass) or self.DEFAULT_OBJECTCLASS) + for k, v in self.attribute_options_names.iteritems(): + v = '%s_%s_attribute' % (self.options_name, v) + self.attribute_mapping[k] = getattr(conf.ldap, v) + attr_mapping_opt = ('%s_additional_attribute_mapping' % self.options_name) attr_mapping = (getattr(conf.ldap, attr_mapping_opt) @@ -167,6 +172,10 @@ class BaseLdap(object): if self.notfound_arg is None: self.notfound_arg = self.options_name + '_id' + + attribute_ignore = '%s_attribute_ignore' % self.options_name + self.attribute_ignore = getattr(conf.ldap, attribute_ignore) + self.use_dumb_member = getattr(conf.ldap, 'use_dumb_member') self.dumb_member = (getattr(conf.ldap, 'dumb_member') or self.DUMB_MEMBER_DN) diff --git a/keystone/common/sql/core.py b/keystone/common/sql/core.py index 2d3114f2..67863588 100644 --- a/keystone/common/sql/core.py +++ b/keystone/common/sql/core.py @@ -45,6 +45,7 @@ ModelBase = declarative.declarative_base() # For exporting to other modules Column = sql.Column +Index = sql.Index String = sql.String ForeignKey = sql.ForeignKey DateTime = sql.DateTime @@ -54,6 +55,8 @@ NotFound = sql.orm.exc.NoResultFound Boolean = sql.Boolean Text = sql.Text UniqueConstraint = sql.UniqueConstraint +relationship = sql.orm.relationship +joinedload = sql.orm.joinedload def initialize_decorator(init): @@ -179,6 +182,8 @@ class DictBase(object): setattr(self, key, value) def __getitem__(self, key): + if key in self.extra: + return self.extra[key] return getattr(self, key) def get(self, key, default=None): diff --git a/keystone/common/sql/migrate_repo/versions/031_drop_credential_indexes.py b/keystone/common/sql/migrate_repo/versions/031_drop_credential_indexes.py new file mode 100644 index 00000000..89ca04f0 --- /dev/null +++ b/keystone/common/sql/migrate_repo/versions/031_drop_credential_indexes.py @@ -0,0 +1,40 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 OpenStack Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import sqlalchemy + + +def upgrade(migrate_engine): + #This migration is relevant only for mysql because for all other + #migrate engines these indexes were successfully dropped. + if migrate_engine.name != 'mysql': + return + meta = sqlalchemy.MetaData(bind=migrate_engine) + table = sqlalchemy.Table('credential', meta, autoload=True) + for index in table.indexes: + index.drop() + + +def downgrade(migrate_engine): + if migrate_engine.name != 'mysql': + return + meta = sqlalchemy.MetaData(bind=migrate_engine) + table = sqlalchemy.Table('credential', meta, autoload=True) + index = sqlalchemy.Index('user_id', table.c['user_id']) + index.create() + index = sqlalchemy.Index('credential_project_id_fkey', + table.c['project_id']) + index.create() diff --git a/keystone/common/sql/migration.py b/keystone/common/sql/migration.py index 86e0254c..3cb9cd63 100644 --- a/keystone/common/sql/migration.py +++ b/keystone/common/sql/migration.py @@ -39,39 +39,51 @@ except ImportError: sys.exit('python-migrate is not installed. Exiting.') -def db_sync(version=None): +def migrate_repository(version, current_version, repo_path): + if version is None or version > current_version: + result = versioning_api.upgrade(CONF.sql.connection, + repo_path, version) + else: + result = versioning_api.downgrade( + CONF.sql.connection, repo_path, version) + return result + + +def db_sync(version=None, repo_path=None): if version is not None: try: version = int(version) except ValueError: raise Exception(_('version should be an integer')) + if repo_path is None: + repo_path = find_migrate_repo() + current_version = db_version(repo_path=repo_path) + return migrate_repository(version, current_version, repo_path) - current_version = db_version() - repo_path = _find_migrate_repo() - if version is None or version > current_version: - return versioning_api.upgrade(CONF.sql.connection, repo_path, version) - else: - return versioning_api.downgrade( - CONF.sql.connection, repo_path, version) - -def db_version(): - repo_path = _find_migrate_repo() +def db_version(repo_path=None): + if repo_path is None: + repo_path = find_migrate_repo() try: return versioning_api.db_version(CONF.sql.connection, repo_path) except versioning_exceptions.DatabaseNotControlledError: return db_version_control(0) -def db_version_control(version=None): - repo_path = _find_migrate_repo() +def db_version_control(version=None, repo_path=None): + if repo_path is None: + repo_path = find_migrate_repo() versioning_api.version_control(CONF.sql.connection, repo_path, version) return version -def _find_migrate_repo(): +def find_migrate_repo(package=None): """Get the path for the migrate repository.""" - path = os.path.join(os.path.abspath(os.path.dirname(__file__)), + if package is None: + file = __file__ + else: + file = package.__file__ + path = os.path.join(os.path.abspath(os.path.dirname(file)), 'migrate_repo') assert os.path.exists(path) return path diff --git a/keystone/common/utils.py b/keystone/common/utils.py index fd2d7567..9966ee67 100644 --- a/keystone/common/utils.py +++ b/keystone/common/utils.py @@ -36,8 +36,6 @@ config.register_int('crypt_strength', default=40000) LOG = logging.getLogger(__name__) -MAX_PASSWORD_LENGTH = 4096 - def read_cached_file(filename, cache_info, reload_func=None): """Read from a file if it has been modified. @@ -68,12 +66,13 @@ class SmarterEncoder(json.JSONEncoder): def trunc_password(password): - """Truncate passwords to the MAX_PASSWORD_LENGTH.""" + """Truncate passwords to the max_length.""" + max_length = CONF.identity.max_password_length try: - if len(password) > MAX_PASSWORD_LENGTH: - return password[:MAX_PASSWORD_LENGTH] - else: - return password + if len(password) > max_length: + LOG.warning( + _('Truncating user password to %s characters.') % max_length) + return password[:max_length] except TypeError: raise exception.ValidationError(attribute='string', target='password') diff --git a/keystone/contrib/example/__init__.py b/keystone/contrib/example/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/keystone/contrib/example/__init__.py diff --git a/keystone/contrib/example/migrate_repo/__init__.py b/keystone/contrib/example/migrate_repo/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/keystone/contrib/example/migrate_repo/__init__.py diff --git a/keystone/contrib/example/migrate_repo/migrate.cfg b/keystone/contrib/example/migrate_repo/migrate.cfg new file mode 100644 index 00000000..5b1b1c0a --- /dev/null +++ b/keystone/contrib/example/migrate_repo/migrate.cfg @@ -0,0 +1,25 @@ +[db_settings] +# Used to identify which repository this database is versioned under. +# You can use the name of your project. +repository_id=example + +# The name of the database table used to track the schema version. +# This name shouldn't already be used by your project. +# If this is changed once a database is under version control, you'll need to +# change the table name in each database too. +version_table=migrate_version + +# When committing a change script, Migrate will attempt to generate the +# sql for all supported databases; normally, if one of them fails - probably +# because you don't have that database installed - it is ignored and the +# commit continues, perhaps ending successfully. +# Databases in this list MUST compile successfully during a commit, or the +# entire commit will fail. List the databases your application will actually +# be using to ensure your updates to that database work properly. +# This must be a list; example: ['postgres','sqlite'] +required_dbs=[] + +# When creating new change scripts, Migrate will stamp the new script with +# a version number. By default this is latest_version + 1. You can set this +# to 'true' to tell Migrate to use the UTC timestamp instead. +use_timestamp_numbering=False diff --git a/keystone/contrib/example/migrate_repo/versions/001_example_table.py b/keystone/contrib/example/migrate_repo/versions/001_example_table.py new file mode 100644 index 00000000..bb2203d3 --- /dev/null +++ b/keystone/contrib/example/migrate_repo/versions/001_example_table.py @@ -0,0 +1,45 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 OpenStack LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import sqlalchemy as sql + + +def upgrade(migrate_engine): + # Upgrade operations go here. Don't create your own engine; bind + # migrate_engine to your metadata + meta = sql.MetaData() + meta.bind = migrate_engine + + # catalog + + service_table = sql.Table( + 'example', + meta, + sql.Column('id', sql.String(64), primary_key=True), + sql.Column('type', sql.String(255)), + sql.Column('extra', sql.Text())) + service_table.create(migrate_engine, checkfirst=True) + + +def downgrade(migrate_engine): + # Operations to reverse the above upgrade go here. + meta = sql.MetaData() + meta.bind = migrate_engine + + tables = ['example'] + for t in tables: + table = sql.Table(t, meta, autoload=True) + table.drop(migrate_engine, checkfirst=True) diff --git a/keystone/contrib/example/migrate_repo/versions/__init__.py b/keystone/contrib/example/migrate_repo/versions/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/keystone/contrib/example/migrate_repo/versions/__init__.py diff --git a/keystone/identity/backends/ldap.py b/keystone/identity/backends/ldap.py index 91ea1e41..a359c63f 100644 --- a/keystone/identity/backends/ldap.py +++ b/keystone/identity/backends/ldap.py @@ -210,29 +210,20 @@ class UserApi(common_ldap.EnabledEmuMixIn, common_ldap.BaseLdap): DEFAULT_STRUCTURAL_CLASSES = ['person'] DEFAULT_ID_ATTR = 'cn' DEFAULT_OBJECTCLASS = 'inetOrgPerson' - DEFAULT_ATTRIBUTE_IGNORE = ['tenant_id', 'tenants'] NotFound = exception.UserNotFound options_name = 'user' - attribute_mapping = {'password': 'userPassword', - 'email': 'mail', - 'name': 'sn', - 'enabled': 'enabled', - 'domain_id': 'domain_id'} + attribute_options_names = {'password': 'pass', + 'email': 'mail', + 'name': 'name', + 'enabled': 'enabled', + 'domain_id': 'domain_id'} model = models.User def __init__(self, conf): super(UserApi, self).__init__(conf) - self.attribute_mapping['name'] = conf.ldap.user_name_attribute - self.attribute_mapping['email'] = conf.ldap.user_mail_attribute - self.attribute_mapping['password'] = conf.ldap.user_pass_attribute - self.attribute_mapping['enabled'] = conf.ldap.user_enabled_attribute - self.attribute_mapping['domain_id'] = ( - conf.ldap.user_domain_id_attribute) self.enabled_mask = conf.ldap.user_enabled_mask self.enabled_default = conf.ldap.user_enabled_default - self.attribute_ignore = (getattr(conf.ldap, 'user_attribute_ignore') - or self.DEFAULT_ATTRIBUTE_IGNORE) def _ldap_res_to_model(self, res): obj = super(UserApi, self)._ldap_res_to_model(res) @@ -277,25 +268,17 @@ class GroupApi(common_ldap.BaseLdap): DEFAULT_OBJECTCLASS = 'groupOfNames' DEFAULT_ID_ATTR = 'cn' DEFAULT_MEMBER_ATTRIBUTE = 'member' - DEFAULT_ATTRIBUTE_IGNORE = [] NotFound = exception.GroupNotFound options_name = 'group' - attribute_mapping = {'name': 'ou', - 'description': 'description', - 'groupId': 'cn', - 'domain_id': 'domain_id'} + attribute_options_names = {'description': 'desc', + 'name': 'name', + 'domain_id': 'domain_id'} model = models.Group def __init__(self, conf): super(GroupApi, self).__init__(conf) - self.attribute_mapping['name'] = conf.ldap.group_name_attribute - self.attribute_mapping['description'] = conf.ldap.group_desc_attribute - self.attribute_mapping['domain_id'] = ( - conf.ldap.group_domain_id_attribute) self.member_attribute = (getattr(conf.ldap, 'group_member_attribute') or self.DEFAULT_MEMBER_ATTRIBUTE) - self.attribute_ignore = (getattr(conf.ldap, 'group_attribute_ignore') - or self.DEFAULT_ATTRIBUTE_IGNORE) def create(self, values): self.affirm_unique(values) diff --git a/keystone/openstack/common/gettextutils.py b/keystone/openstack/common/gettextutils.py index 55ba3387..ed085370 100644 --- a/keystone/openstack/common/gettextutils.py +++ b/keystone/openstack/common/gettextutils.py @@ -1,6 +1,7 @@ # vim: tabstop=4 shiftwidth=4 softtabstop=4 # Copyright 2012 Red Hat, Inc. +# Copyright 2013 IBM Corp. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -23,18 +24,27 @@ Usual usage in an openstack.common module: from keystone.openstack.common.gettextutils import _ """ +import copy import gettext +import logging.handlers import os +import re +import UserString + +from babel import localedata +import six _localedir = os.environ.get('keystone'.upper() + '_LOCALEDIR') _t = gettext.translation('keystone', localedir=_localedir, fallback=True) +_AVAILABLE_LANGUAGES = [] + def _(msg): return _t.ugettext(msg) -def install(domain): +def install(domain, lazy=False): """Install a _() function using the given translation domain. Given a translation domain, install a _() function using gettext's @@ -44,7 +54,252 @@ def install(domain): overriding the default localedir (e.g. /usr/share/locale) using a translation-domain-specific environment variable (e.g. NOVA_LOCALEDIR). + + :param domain: the translation domain + :param lazy: indicates whether or not to install the lazy _() function. + The lazy _() introduces a way to do deferred translation + of messages by installing a _ that builds Message objects, + instead of strings, which can then be lazily translated into + any available locale. + """ + if lazy: + # NOTE(mrodden): Lazy gettext functionality. + # + # The following introduces a deferred way to do translations on + # messages in OpenStack. We override the standard _() function + # and % (format string) operation to build Message objects that can + # later be translated when we have more information. + # + # Also included below is an example LocaleHandler that translates + # Messages to an associated locale, effectively allowing many logs, + # each with their own locale. + + def _lazy_gettext(msg): + """Create and return a Message object. + + Lazy gettext function for a given domain, it is a factory method + for a project/module to get a lazy gettext function for its own + translation domain (i.e. nova, glance, cinder, etc.) + + Message encapsulates a string so that we can translate + it later when needed. + """ + return Message(msg, domain) + + import __builtin__ + __builtin__.__dict__['_'] = _lazy_gettext + else: + localedir = '%s_LOCALEDIR' % domain.upper() + gettext.install(domain, + localedir=os.environ.get(localedir), + unicode=True) + + +class Message(UserString.UserString, object): + """Class used to encapsulate translatable messages.""" + def __init__(self, msg, domain): + # _msg is the gettext msgid and should never change + self._msg = msg + self._left_extra_msg = '' + self._right_extra_msg = '' + self.params = None + self.locale = None + self.domain = domain + + @property + def data(self): + # NOTE(mrodden): this should always resolve to a unicode string + # that best represents the state of the message currently + + localedir = os.environ.get(self.domain.upper() + '_LOCALEDIR') + if self.locale: + lang = gettext.translation(self.domain, + localedir=localedir, + languages=[self.locale], + fallback=True) + else: + # use system locale for translations + lang = gettext.translation(self.domain, + localedir=localedir, + fallback=True) + + full_msg = (self._left_extra_msg + + lang.ugettext(self._msg) + + self._right_extra_msg) + + if self.params is not None: + full_msg = full_msg % self.params + + return six.text_type(full_msg) + + def _save_dictionary_parameter(self, dict_param): + full_msg = self.data + # look for %(blah) fields in string; + # ignore %% and deal with the + # case where % is first character on the line + keys = re.findall('(?:[^%]|^)%\((\w*)\)[a-z]', full_msg) + + # if we don't find any %(blah) blocks but have a %s + if not keys and re.findall('(?:[^%]|^)%[a-z]', full_msg): + # apparently the full dictionary is the parameter + params = copy.deepcopy(dict_param) + else: + params = {} + for key in keys: + try: + params[key] = copy.deepcopy(dict_param[key]) + except TypeError: + # cast uncopyable thing to unicode string + params[key] = unicode(dict_param[key]) + + return params + + def _save_parameters(self, other): + # we check for None later to see if + # we actually have parameters to inject, + # so encapsulate if our parameter is actually None + if other is None: + self.params = (other, ) + elif isinstance(other, dict): + self.params = self._save_dictionary_parameter(other) + else: + # fallback to casting to unicode, + # this will handle the problematic python code-like + # objects that cannot be deep-copied + try: + self.params = copy.deepcopy(other) + except TypeError: + self.params = unicode(other) + + return self + + # overrides to be more string-like + def __unicode__(self): + return self.data + + def __str__(self): + return self.data.encode('utf-8') + + def __getstate__(self): + to_copy = ['_msg', '_right_extra_msg', '_left_extra_msg', + 'domain', 'params', 'locale'] + new_dict = self.__dict__.fromkeys(to_copy) + for attr in to_copy: + new_dict[attr] = copy.deepcopy(self.__dict__[attr]) + + return new_dict + + def __setstate__(self, state): + for (k, v) in state.items(): + setattr(self, k, v) + + # operator overloads + def __add__(self, other): + copied = copy.deepcopy(self) + copied._right_extra_msg += other.__str__() + return copied + + def __radd__(self, other): + copied = copy.deepcopy(self) + copied._left_extra_msg += other.__str__() + return copied + + def __mod__(self, other): + # do a format string to catch and raise + # any possible KeyErrors from missing parameters + self.data % other + copied = copy.deepcopy(self) + return copied._save_parameters(other) + + def __mul__(self, other): + return self.data * other + + def __rmul__(self, other): + return other * self.data + + def __getitem__(self, key): + return self.data[key] + + def __getslice__(self, start, end): + return self.data.__getslice__(start, end) + + def __getattribute__(self, name): + # NOTE(mrodden): handle lossy operations that we can't deal with yet + # These override the UserString implementation, since UserString + # uses our __class__ attribute to try and build a new message + # after running the inner data string through the operation. + # At that point, we have lost the gettext message id and can just + # safely resolve to a string instead. + ops = ['capitalize', 'center', 'decode', 'encode', + 'expandtabs', 'ljust', 'lstrip', 'replace', 'rjust', 'rstrip', + 'strip', 'swapcase', 'title', 'translate', 'upper', 'zfill'] + if name in ops: + return getattr(self.data, name) + else: + return UserString.UserString.__getattribute__(self, name) + + +def get_available_languages(domain): + """Lists the available languages for the given translation domain. + + :param domain: the domain to get languages for """ - gettext.install(domain, - localedir=os.environ.get(domain.upper() + '_LOCALEDIR'), - unicode=True) + if _AVAILABLE_LANGUAGES: + return _AVAILABLE_LANGUAGES + + localedir = '%s_LOCALEDIR' % domain.upper() + find = lambda x: gettext.find(domain, + localedir=os.environ.get(localedir), + languages=[x]) + + # NOTE(mrodden): en_US should always be available (and first in case + # order matters) since our in-line message strings are en_US + _AVAILABLE_LANGUAGES.append('en_US') + # NOTE(luisg): Babel <1.0 used a function called list(), which was + # renamed to locale_identifiers() in >=1.0, the requirements master list + # requires >=0.9.6, uncapped, so defensively work with both. We can remove + # this check when the master list updates to >=1.0, and all projects udpate + list_identifiers = (getattr(localedata, 'list', None) or + getattr(localedata, 'locale_identifiers')) + locale_identifiers = list_identifiers() + for i in locale_identifiers: + if find(i) is not None: + _AVAILABLE_LANGUAGES.append(i) + return _AVAILABLE_LANGUAGES + + +def get_localized_message(message, user_locale): + """Gets a localized version of the given message in the given locale.""" + if (isinstance(message, Message)): + if user_locale: + message.locale = user_locale + return unicode(message) + else: + return message + + +class LocaleHandler(logging.Handler): + """Handler that can have a locale associated to translate Messages. + + A quick example of how to utilize the Message class above. + LocaleHandler takes a locale and a target logging.Handler object + to forward LogRecord objects to after translating the internal Message. + """ + + def __init__(self, locale, target): + """Initialize a LocaleHandler + + :param locale: locale to use for translating messages + :param target: logging.Handler object to forward + LogRecord objects to after translation + """ + logging.Handler.__init__(self) + self.locale = locale + self.target = target + + def emit(self, record): + if isinstance(record.msg, Message): + # set the locale and resolve to a string + record.msg.locale = self.locale + + self.target.emit(record) diff --git a/keystone/openstack/common/importutils.py b/keystone/openstack/common/importutils.py index 3bd277f4..7a303f93 100644 --- a/keystone/openstack/common/importutils.py +++ b/keystone/openstack/common/importutils.py @@ -24,7 +24,7 @@ import traceback def import_class(import_str): - """Returns a class from a string including module and class""" + """Returns a class from a string including module and class.""" mod_str, _sep, class_str = import_str.rpartition('.') try: __import__(mod_str) @@ -41,8 +41,9 @@ def import_object(import_str, *args, **kwargs): def import_object_ns(name_space, import_str, *args, **kwargs): - """ - Import a class and return an instance of it, first by trying + """Tries to import object from default namespace. + + Imports a class and return an instance of it, first by trying to find the class in a default namespace, then failing back to a full path if not found in the default namespace. """ diff --git a/keystone/openstack/common/jsonutils.py b/keystone/openstack/common/jsonutils.py index d73e4c26..ecea09bb 100644 --- a/keystone/openstack/common/jsonutils.py +++ b/keystone/openstack/common/jsonutils.py @@ -38,11 +38,24 @@ import functools import inspect import itertools import json +import types import xmlrpclib +import netaddr +import six + from keystone.openstack.common import timeutils +_nasty_type_tests = [inspect.ismodule, inspect.isclass, inspect.ismethod, + inspect.isfunction, inspect.isgeneratorfunction, + inspect.isgenerator, inspect.istraceback, inspect.isframe, + inspect.iscode, inspect.isbuiltin, inspect.isroutine, + inspect.isabstract] + +_simple_types = (types.NoneType, int, basestring, bool, float, long) + + def to_primitive(value, convert_instances=False, convert_datetime=True, level=0, max_depth=3): """Convert a complex object into primitives. @@ -58,19 +71,32 @@ def to_primitive(value, convert_instances=False, convert_datetime=True, Therefore, convert_instances=True is lossy ... be aware. """ - nasty = [inspect.ismodule, inspect.isclass, inspect.ismethod, - inspect.isfunction, inspect.isgeneratorfunction, - inspect.isgenerator, inspect.istraceback, inspect.isframe, - inspect.iscode, inspect.isbuiltin, inspect.isroutine, - inspect.isabstract] - for test in nasty: - if test(value): - return unicode(value) - - # value of itertools.count doesn't get caught by inspects - # above and results in infinite loop when list(value) is called. + # handle obvious types first - order of basic types determined by running + # full tests on nova project, resulting in the following counts: + # 572754 <type 'NoneType'> + # 460353 <type 'int'> + # 379632 <type 'unicode'> + # 274610 <type 'str'> + # 199918 <type 'dict'> + # 114200 <type 'datetime.datetime'> + # 51817 <type 'bool'> + # 26164 <type 'list'> + # 6491 <type 'float'> + # 283 <type 'tuple'> + # 19 <type 'long'> + if isinstance(value, _simple_types): + return value + + if isinstance(value, datetime.datetime): + if convert_datetime: + return timeutils.strtime(value) + else: + return value + + # value of itertools.count doesn't get caught by nasty_type_tests + # and results in infinite loop when list(value) is called. if type(value) == itertools.count: - return unicode(value) + return six.text_type(value) # FIXME(vish): Workaround for LP bug 852095. Without this workaround, # tests that raise an exception in a mocked method that @@ -91,17 +117,18 @@ def to_primitive(value, convert_instances=False, convert_datetime=True, convert_datetime=convert_datetime, level=level, max_depth=max_depth) + if isinstance(value, dict): + return dict((k, recursive(v)) for k, v in value.iteritems()) + elif isinstance(value, (list, tuple)): + return [recursive(lv) for lv in value] + # It's not clear why xmlrpclib created their own DateTime type, but # for our purposes, make it a datetime type which is explicitly # handled if isinstance(value, xmlrpclib.DateTime): value = datetime.datetime(*tuple(value.timetuple())[:6]) - if isinstance(value, (list, tuple)): - return [recursive(v) for v in value] - elif isinstance(value, dict): - return dict((k, recursive(v)) for k, v in value.iteritems()) - elif convert_datetime and isinstance(value, datetime.datetime): + if convert_datetime and isinstance(value, datetime.datetime): return timeutils.strtime(value) elif hasattr(value, 'iteritems'): return recursive(dict(value.iteritems()), level=level + 1) @@ -111,12 +138,16 @@ def to_primitive(value, convert_instances=False, convert_datetime=True, # Likely an instance of something. Watch for cycles. # Ignore class member vars. return recursive(value.__dict__, level=level + 1) + elif isinstance(value, netaddr.IPAddress): + return six.text_type(value) else: + if any(test(value) for test in _nasty_type_tests): + return six.text_type(value) return value except TypeError: # Class objects are tricky since they may define something like # __iter__ defined but it isn't callable as list(). - return unicode(value) + return six.text_type(value) def dumps(value, default=to_primitive, **kwargs): diff --git a/keystone/openstack/common/local.py b/keystone/openstack/common/local.py new file mode 100644 index 00000000..e82f17d0 --- /dev/null +++ b/keystone/openstack/common/local.py @@ -0,0 +1,47 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Local storage of variables using weak references""" + +import threading +import weakref + + +class WeakLocal(threading.local): + def __getattribute__(self, attr): + rval = super(WeakLocal, self).__getattribute__(attr) + if rval: + # NOTE(mikal): this bit is confusing. What is stored is a weak + # reference, not the value itself. We therefore need to lookup + # the weak reference and return the inner value here. + rval = rval() + return rval + + def __setattr__(self, attr, value): + value = weakref.ref(value) + return super(WeakLocal, self).__setattr__(attr, value) + + +# NOTE(mikal): the name "store" should be deprecated in the future +store = WeakLocal() + +# A "weak" store uses weak references and allows an object to fall out of scope +# when it falls out of scope in the code that uses the thread local storage. A +# "strong" store will hold a reference to the object so that it never falls out +# of scope. +weak_store = WeakLocal() +strong_store = threading.local() diff --git a/keystone/openstack/common/log.py b/keystone/openstack/common/log.py new file mode 100644 index 00000000..5a43c326 --- /dev/null +++ b/keystone/openstack/common/log.py @@ -0,0 +1,559 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation. +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Openstack logging handler. + +This module adds to logging functionality by adding the option to specify +a context object when calling the various log methods. If the context object +is not specified, default formatting is used. Additionally, an instance uuid +may be passed as part of the log message, which is intended to make it easier +for admins to find messages related to a specific instance. + +It also allows setting of formatting information through conf. + +""" + +import inspect +import itertools +import logging +import logging.config +import logging.handlers +import os +import sys +import traceback + +from oslo.config import cfg +from six import moves + +from keystone.openstack.common.gettextutils import _ # noqa +from keystone.openstack.common import importutils +from keystone.openstack.common import jsonutils +from keystone.openstack.common import local + + +_DEFAULT_LOG_DATE_FORMAT = "%Y-%m-%d %H:%M:%S" + +common_cli_opts = [ + cfg.BoolOpt('debug', + short='d', + default=False, + help='Print debugging output (set logging level to ' + 'DEBUG instead of default WARNING level).'), + cfg.BoolOpt('verbose', + short='v', + default=False, + help='Print more verbose output (set logging level to ' + 'INFO instead of default WARNING level).'), +] + +logging_cli_opts = [ + cfg.StrOpt('log-config', + metavar='PATH', + help='If this option is specified, the logging configuration ' + 'file specified is used and overrides any other logging ' + 'options specified. Please see the Python logging module ' + 'documentation for details on logging configuration ' + 'files.'), + cfg.StrOpt('log-format', + default=None, + metavar='FORMAT', + help='DEPRECATED. ' + 'A logging.Formatter log message format string which may ' + 'use any of the available logging.LogRecord attributes. ' + 'This option is deprecated. Please use ' + 'logging_context_format_string and ' + 'logging_default_format_string instead.'), + cfg.StrOpt('log-date-format', + default=_DEFAULT_LOG_DATE_FORMAT, + metavar='DATE_FORMAT', + help='Format string for %%(asctime)s in log records. ' + 'Default: %(default)s'), + cfg.StrOpt('log-file', + metavar='PATH', + deprecated_name='logfile', + help='(Optional) Name of log file to output to. ' + 'If no default is set, logging will go to stdout.'), + cfg.StrOpt('log-dir', + deprecated_name='logdir', + help='(Optional) The base directory used for relative ' + '--log-file paths'), + cfg.BoolOpt('use-syslog', + default=False, + help='Use syslog for logging.'), + cfg.StrOpt('syslog-log-facility', + default='LOG_USER', + help='syslog facility to receive log lines') +] + +generic_log_opts = [ + cfg.BoolOpt('use_stderr', + default=True, + help='Log output to standard error') +] + +log_opts = [ + cfg.StrOpt('logging_context_format_string', + default='%(asctime)s.%(msecs)03d %(process)d %(levelname)s ' + '%(name)s [%(request_id)s %(user)s %(tenant)s] ' + '%(instance)s%(message)s', + help='format string to use for log messages with context'), + cfg.StrOpt('logging_default_format_string', + default='%(asctime)s.%(msecs)03d %(process)d %(levelname)s ' + '%(name)s [-] %(instance)s%(message)s', + help='format string to use for log messages without context'), + cfg.StrOpt('logging_debug_format_suffix', + default='%(funcName)s %(pathname)s:%(lineno)d', + help='data to append to log format when level is DEBUG'), + cfg.StrOpt('logging_exception_prefix', + default='%(asctime)s.%(msecs)03d %(process)d TRACE %(name)s ' + '%(instance)s', + help='prefix each line of exception output with this format'), + cfg.ListOpt('default_log_levels', + default=[ + 'amqplib=WARN', + 'sqlalchemy=WARN', + 'boto=WARN', + 'suds=INFO', + 'keystone=INFO', + 'eventlet.wsgi.server=WARN' + ], + help='list of logger=LEVEL pairs'), + cfg.BoolOpt('publish_errors', + default=False, + help='publish error events'), + cfg.BoolOpt('fatal_deprecations', + default=False, + help='make deprecations fatal'), + + # NOTE(mikal): there are two options here because sometimes we are handed + # a full instance (and could include more information), and other times we + # are just handed a UUID for the instance. + cfg.StrOpt('instance_format', + default='[instance: %(uuid)s] ', + help='If an instance is passed with the log message, format ' + 'it like this'), + cfg.StrOpt('instance_uuid_format', + default='[instance: %(uuid)s] ', + help='If an instance UUID is passed with the log message, ' + 'format it like this'), +] + +CONF = cfg.CONF +CONF.register_cli_opts(common_cli_opts) +CONF.register_cli_opts(logging_cli_opts) +CONF.register_opts(generic_log_opts) +CONF.register_opts(log_opts) + +# our new audit level +# NOTE(jkoelker) Since we synthesized an audit level, make the logging +# module aware of it so it acts like other levels. +logging.AUDIT = logging.INFO + 1 +logging.addLevelName(logging.AUDIT, 'AUDIT') + + +try: + NullHandler = logging.NullHandler +except AttributeError: # NOTE(jkoelker) NullHandler added in Python 2.7 + class NullHandler(logging.Handler): + def handle(self, record): + pass + + def emit(self, record): + pass + + def createLock(self): + self.lock = None + + +def _dictify_context(context): + if context is None: + return None + if not isinstance(context, dict) and getattr(context, 'to_dict', None): + context = context.to_dict() + return context + + +def _get_binary_name(): + return os.path.basename(inspect.stack()[-1][1]) + + +def _get_log_file_path(binary=None): + logfile = CONF.log_file + logdir = CONF.log_dir + + if logfile and not logdir: + return logfile + + if logfile and logdir: + return os.path.join(logdir, logfile) + + if logdir: + binary = binary or _get_binary_name() + return '%s.log' % (os.path.join(logdir, binary),) + + +class BaseLoggerAdapter(logging.LoggerAdapter): + + def audit(self, msg, *args, **kwargs): + self.log(logging.AUDIT, msg, *args, **kwargs) + + +class LazyAdapter(BaseLoggerAdapter): + def __init__(self, name='unknown', version='unknown'): + self._logger = None + self.extra = {} + self.name = name + self.version = version + + @property + def logger(self): + if not self._logger: + self._logger = getLogger(self.name, self.version) + return self._logger + + +class ContextAdapter(BaseLoggerAdapter): + warn = logging.LoggerAdapter.warning + + def __init__(self, logger, project_name, version_string): + self.logger = logger + self.project = project_name + self.version = version_string + + @property + def handlers(self): + return self.logger.handlers + + def deprecated(self, msg, *args, **kwargs): + stdmsg = _("Deprecated: %s") % msg + if CONF.fatal_deprecations: + self.critical(stdmsg, *args, **kwargs) + raise DeprecatedConfig(msg=stdmsg) + else: + self.warn(stdmsg, *args, **kwargs) + + def process(self, msg, kwargs): + if 'extra' not in kwargs: + kwargs['extra'] = {} + extra = kwargs['extra'] + + context = kwargs.pop('context', None) + if not context: + context = getattr(local.store, 'context', None) + if context: + extra.update(_dictify_context(context)) + + instance = kwargs.pop('instance', None) + instance_extra = '' + if instance: + instance_extra = CONF.instance_format % instance + else: + instance_uuid = kwargs.pop('instance_uuid', None) + if instance_uuid: + instance_extra = (CONF.instance_uuid_format + % {'uuid': instance_uuid}) + extra.update({'instance': instance_extra}) + + extra.update({"project": self.project}) + extra.update({"version": self.version}) + extra['extra'] = extra.copy() + return msg, kwargs + + +class JSONFormatter(logging.Formatter): + def __init__(self, fmt=None, datefmt=None): + # NOTE(jkoelker) we ignore the fmt argument, but its still there + # since logging.config.fileConfig passes it. + self.datefmt = datefmt + + def formatException(self, ei, strip_newlines=True): + lines = traceback.format_exception(*ei) + if strip_newlines: + lines = [itertools.ifilter( + lambda x: x, + line.rstrip().splitlines()) for line in lines] + lines = list(itertools.chain(*lines)) + return lines + + def format(self, record): + message = {'message': record.getMessage(), + 'asctime': self.formatTime(record, self.datefmt), + 'name': record.name, + 'msg': record.msg, + 'args': record.args, + 'levelname': record.levelname, + 'levelno': record.levelno, + 'pathname': record.pathname, + 'filename': record.filename, + 'module': record.module, + 'lineno': record.lineno, + 'funcname': record.funcName, + 'created': record.created, + 'msecs': record.msecs, + 'relative_created': record.relativeCreated, + 'thread': record.thread, + 'thread_name': record.threadName, + 'process_name': record.processName, + 'process': record.process, + 'traceback': None} + + if hasattr(record, 'extra'): + message['extra'] = record.extra + + if record.exc_info: + message['traceback'] = self.formatException(record.exc_info) + + return jsonutils.dumps(message) + + +def _create_logging_excepthook(product_name): + def logging_excepthook(type, value, tb): + extra = {} + if CONF.verbose: + extra['exc_info'] = (type, value, tb) + getLogger(product_name).critical(str(value), **extra) + return logging_excepthook + + +class LogConfigError(Exception): + + message = _('Error loading logging config %(log_config)s: %(err_msg)s') + + def __init__(self, log_config, err_msg): + self.log_config = log_config + self.err_msg = err_msg + + def __str__(self): + return self.message % dict(log_config=self.log_config, + err_msg=self.err_msg) + + +def _load_log_config(log_config): + try: + logging.config.fileConfig(log_config) + except moves.configparser.Error as exc: + raise LogConfigError(log_config, str(exc)) + + +def setup(product_name): + """Setup logging.""" + if CONF.log_config: + _load_log_config(CONF.log_config) + else: + _setup_logging_from_conf() + sys.excepthook = _create_logging_excepthook(product_name) + + +def set_defaults(logging_context_format_string): + cfg.set_defaults(log_opts, + logging_context_format_string= + logging_context_format_string) + + +def _find_facility_from_conf(): + facility_names = logging.handlers.SysLogHandler.facility_names + facility = getattr(logging.handlers.SysLogHandler, + CONF.syslog_log_facility, + None) + + if facility is None and CONF.syslog_log_facility in facility_names: + facility = facility_names.get(CONF.syslog_log_facility) + + if facility is None: + valid_facilities = facility_names.keys() + consts = ['LOG_AUTH', 'LOG_AUTHPRIV', 'LOG_CRON', 'LOG_DAEMON', + 'LOG_FTP', 'LOG_KERN', 'LOG_LPR', 'LOG_MAIL', 'LOG_NEWS', + 'LOG_AUTH', 'LOG_SYSLOG', 'LOG_USER', 'LOG_UUCP', + 'LOG_LOCAL0', 'LOG_LOCAL1', 'LOG_LOCAL2', 'LOG_LOCAL3', + 'LOG_LOCAL4', 'LOG_LOCAL5', 'LOG_LOCAL6', 'LOG_LOCAL7'] + valid_facilities.extend(consts) + raise TypeError(_('syslog facility must be one of: %s') % + ', '.join("'%s'" % fac + for fac in valid_facilities)) + + return facility + + +def _setup_logging_from_conf(): + log_root = getLogger(None).logger + for handler in log_root.handlers: + log_root.removeHandler(handler) + + if CONF.use_syslog: + facility = _find_facility_from_conf() + syslog = logging.handlers.SysLogHandler(address='/dev/log', + facility=facility) + log_root.addHandler(syslog) + + logpath = _get_log_file_path() + if logpath: + filelog = logging.handlers.WatchedFileHandler(logpath) + log_root.addHandler(filelog) + + if CONF.use_stderr: + streamlog = ColorHandler() + log_root.addHandler(streamlog) + + elif not CONF.log_file: + # pass sys.stdout as a positional argument + # python2.6 calls the argument strm, in 2.7 it's stream + streamlog = logging.StreamHandler(sys.stdout) + log_root.addHandler(streamlog) + + if CONF.publish_errors: + handler = importutils.import_object( + "keystone.openstack.common.log_handler.PublishErrorsHandler", + logging.ERROR) + log_root.addHandler(handler) + + datefmt = CONF.log_date_format + for handler in log_root.handlers: + # NOTE(alaski): CONF.log_format overrides everything currently. This + # should be deprecated in favor of context aware formatting. + if CONF.log_format: + handler.setFormatter(logging.Formatter(fmt=CONF.log_format, + datefmt=datefmt)) + log_root.info('Deprecated: log_format is now deprecated and will ' + 'be removed in the next release') + else: + handler.setFormatter(ContextFormatter(datefmt=datefmt)) + + if CONF.debug: + log_root.setLevel(logging.DEBUG) + elif CONF.verbose: + log_root.setLevel(logging.INFO) + else: + log_root.setLevel(logging.WARNING) + + for pair in CONF.default_log_levels: + mod, _sep, level_name = pair.partition('=') + level = logging.getLevelName(level_name) + logger = logging.getLogger(mod) + logger.setLevel(level) + +_loggers = {} + + +def getLogger(name='unknown', version='unknown'): + if name not in _loggers: + _loggers[name] = ContextAdapter(logging.getLogger(name), + name, + version) + return _loggers[name] + + +def getLazyLogger(name='unknown', version='unknown'): + """Returns lazy logger. + + Creates a pass-through logger that does not create the real logger + until it is really needed and delegates all calls to the real logger + once it is created. + """ + return LazyAdapter(name, version) + + +class WritableLogger(object): + """A thin wrapper that responds to `write` and logs.""" + + def __init__(self, logger, level=logging.INFO): + self.logger = logger + self.level = level + + def write(self, msg): + self.logger.log(self.level, msg) + + +class ContextFormatter(logging.Formatter): + """A context.RequestContext aware formatter configured through flags. + + The flags used to set format strings are: logging_context_format_string + and logging_default_format_string. You can also specify + logging_debug_format_suffix to append extra formatting if the log level is + debug. + + For information about what variables are available for the formatter see: + http://docs.python.org/library/logging.html#formatter + + """ + + def format(self, record): + """Uses contextstring if request_id is set, otherwise default.""" + # NOTE(sdague): default the fancier formating params + # to an empty string so we don't throw an exception if + # they get used + for key in ('instance', 'color'): + if key not in record.__dict__: + record.__dict__[key] = '' + + if record.__dict__.get('request_id', None): + self._fmt = CONF.logging_context_format_string + else: + self._fmt = CONF.logging_default_format_string + + if (record.levelno == logging.DEBUG and + CONF.logging_debug_format_suffix): + self._fmt += " " + CONF.logging_debug_format_suffix + + # Cache this on the record, Logger will respect our formated copy + if record.exc_info: + record.exc_text = self.formatException(record.exc_info, record) + return logging.Formatter.format(self, record) + + def formatException(self, exc_info, record=None): + """Format exception output with CONF.logging_exception_prefix.""" + if not record: + return logging.Formatter.formatException(self, exc_info) + + stringbuffer = moves.StringIO() + traceback.print_exception(exc_info[0], exc_info[1], exc_info[2], + None, stringbuffer) + lines = stringbuffer.getvalue().split('\n') + stringbuffer.close() + + if CONF.logging_exception_prefix.find('%(asctime)') != -1: + record.asctime = self.formatTime(record, self.datefmt) + + formatted_lines = [] + for line in lines: + pl = CONF.logging_exception_prefix % record.__dict__ + fl = '%s%s' % (pl, line) + formatted_lines.append(fl) + return '\n'.join(formatted_lines) + + +class ColorHandler(logging.StreamHandler): + LEVEL_COLORS = { + logging.DEBUG: '\033[00;32m', # GREEN + logging.INFO: '\033[00;36m', # CYAN + logging.AUDIT: '\033[01;36m', # BOLD CYAN + logging.WARN: '\033[01;33m', # BOLD YELLOW + logging.ERROR: '\033[01;31m', # BOLD RED + logging.CRITICAL: '\033[01;31m', # BOLD RED + } + + def format(self, record): + record.color = self.LEVEL_COLORS[record.levelno] + return logging.StreamHandler.format(self, record) + + +class DeprecatedConfig(Exception): + message = _("Fatal call to deprecated config: %(msg)s") + + def __init__(self, msg): + super(Exception, self).__init__(self.message % dict(msg=msg)) diff --git a/keystone/openstack/common/timeutils.py b/keystone/openstack/common/timeutils.py index 60943659..aa9f7080 100644 --- a/keystone/openstack/common/timeutils.py +++ b/keystone/openstack/common/timeutils.py @@ -23,6 +23,7 @@ import calendar import datetime import iso8601 +import six # ISO 8601 extended time format with microseconds @@ -32,7 +33,7 @@ PERFECT_TIME_FORMAT = _ISO8601_TIME_FORMAT_SUBSECOND def isotime(at=None, subsecond=False): - """Stringify time in ISO 8601 format""" + """Stringify time in ISO 8601 format.""" if not at: at = utcnow() st = at.strftime(_ISO8601_TIME_FORMAT @@ -44,13 +45,13 @@ def isotime(at=None, subsecond=False): def parse_isotime(timestr): - """Parse time from ISO 8601 format""" + """Parse time from ISO 8601 format.""" try: return iso8601.parse_date(timestr) except iso8601.ParseError as e: - raise ValueError(e.message) + raise ValueError(unicode(e)) except TypeError as e: - raise ValueError(e.message) + raise ValueError(unicode(e)) def strtime(at=None, fmt=PERFECT_TIME_FORMAT): @@ -66,7 +67,7 @@ def parse_strtime(timestr, fmt=PERFECT_TIME_FORMAT): def normalize_time(timestamp): - """Normalize time in arbitrary timezone to UTC naive object""" + """Normalize time in arbitrary timezone to UTC naive object.""" offset = timestamp.utcoffset() if offset is None: return timestamp @@ -75,14 +76,14 @@ def normalize_time(timestamp): def is_older_than(before, seconds): """Return True if before is older than seconds.""" - if isinstance(before, basestring): + if isinstance(before, six.string_types): before = parse_strtime(before).replace(tzinfo=None) return utcnow() - before > datetime.timedelta(seconds=seconds) def is_newer_than(after, seconds): """Return True if after is newer than seconds.""" - if isinstance(after, basestring): + if isinstance(after, six.string_types): after = parse_strtime(after).replace(tzinfo=None) return after - utcnow() > datetime.timedelta(seconds=seconds) @@ -103,7 +104,7 @@ def utcnow(): def iso8601_from_timestamp(timestamp): - """Returns a iso8601 formated date from timestamp""" + """Returns a iso8601 formated date from timestamp.""" return isotime(datetime.datetime.utcfromtimestamp(timestamp)) @@ -111,9 +112,9 @@ utcnow.override_time = None def set_time_override(override_time=datetime.datetime.utcnow()): - """ - Override utils.utcnow to return a constant time or a list thereof, - one at a time. + """Overrides utils.utcnow. + + Make it return a constant time or a list thereof, one at a time. """ utcnow.override_time = override_time @@ -141,7 +142,8 @@ def clear_time_override(): def marshall_now(now=None): """Make an rpc-safe datetime with microseconds. - Note: tzinfo is stripped, but not required for relative times.""" + Note: tzinfo is stripped, but not required for relative times. + """ if not now: now = utcnow() return dict(day=now.day, month=now.month, year=now.year, hour=now.hour, @@ -161,7 +163,8 @@ def unmarshall_time(tyme): def delta_seconds(before, after): - """ + """Return the difference between two timing objects. + Compute the difference in seconds between two date, time, or datetime objects (as a float, to microsecond resolution). """ @@ -174,8 +177,7 @@ def delta_seconds(before, after): def is_soon(dt, window): - """ - Determines if time is going to happen in the next window seconds. + """Determines if time is going to happen in the next window seconds. :params dt: the time :params window: minimum seconds to remain to consider the time not soon diff --git a/keystone/token/backends/sql.py b/keystone/token/backends/sql.py index 0e8a916d..82eab651 100644 --- a/keystone/token/backends/sql.py +++ b/keystone/token/backends/sql.py @@ -17,7 +17,6 @@ import copy import datetime - from keystone.common import sql from keystone import exception from keystone.openstack.common import timeutils @@ -30,9 +29,13 @@ class TokenModel(sql.ModelBase, sql.DictBase): id = sql.Column(sql.String(64), primary_key=True) expires = sql.Column(sql.DateTime(), default=None) extra = sql.Column(sql.JsonBlob()) - valid = sql.Column(sql.Boolean(), default=True) + valid = sql.Column(sql.Boolean(), default=True, nullable=False) user_id = sql.Column(sql.String(64)) - trust_id = sql.Column(sql.String(64), nullable=True) + trust_id = sql.Column(sql.String(64)) + __table_args__ = ( + sql.Index('ix_token_expires', 'expires'), + sql.Index('ix_token_valid', 'valid') + ) class Token(sql.Base, token.Driver): diff --git a/keystone/token/controllers.py b/keystone/token/controllers.py index 9ebc29fe..91514493 100644 --- a/keystone/token/controllers.py +++ b/keystone/token/controllers.py @@ -4,7 +4,6 @@ from keystone.common import cms from keystone.common import controller from keystone.common import dependency from keystone.common import logging -from keystone.common import utils from keystone.common import wsgi from keystone import config from keystone import exception @@ -215,10 +214,9 @@ class Auth(controller.V2Controller): attribute='password', target='passwordCredentials') password = auth['passwordCredentials']['password'] - max_pw_size = utils.MAX_PASSWORD_LENGTH - if password and len(password) > max_pw_size: - raise exception.ValidationSizeError(attribute='password', - size=max_pw_size) + if password and len(password) > CONF.identity.max_password_length: + raise exception.ValidationSizeError( + attribute='password', size=CONF.identity.max_password_length) if ("userId" not in auth['passwordCredentials'] and "username" not in auth['passwordCredentials']): diff --git a/keystone/trust/backends/sql.py b/keystone/trust/backends/sql.py index daa8e3f7..9e92ad71 100644 --- a/keystone/trust/backends/sql.py +++ b/keystone/trust/backends/sql.py @@ -26,11 +26,11 @@ class TrustModel(sql.ModelBase, sql.DictBase): 'project_id', 'impersonation', 'expires_at'] id = sql.Column(sql.String(64), primary_key=True) #user id Of owner - trustor_user_id = sql.Column(sql.String(64), unique=False, nullable=False,) + trustor_user_id = sql.Column(sql.String(64), nullable=False,) #user_id of user allowed to consume this preauth - trustee_user_id = sql.Column(sql.String(64), unique=False, nullable=False) - project_id = sql.Column(sql.String(64), unique=False, nullable=True) - impersonation = sql.Column(sql.Boolean) + trustee_user_id = sql.Column(sql.String(64), nullable=False) + project_id = sql.Column(sql.String(64)) + impersonation = sql.Column(sql.Boolean, nullable=False) deleted_at = sql.Column(sql.DateTime) expires_at = sql.Column(sql.DateTime) extra = sql.Column(sql.JsonBlob()) diff --git a/requirements.txt b/requirements.txt index e54bb6a0..f7161d2c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ pam>=0.1.4 WebOb>=1.0.8 eventlet greenlet +netaddr PasteDeploy paste routes diff --git a/tests/test_auth.py b/tests/test_auth.py index db5314be..e8e6c7a9 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -179,7 +179,8 @@ class AuthBadRequests(AuthTest): def test_authenticate_password_too_large(self): """Verify sending large 'password' raises the right exception.""" - body_dict = _build_user_auth(username='FOO', password='0' * 8193) + length = CONF.identity.max_password_length + 1 + body_dict = _build_user_auth(username='FOO', password='0' * length) self.assertRaises(exception.ValidationSizeError, self.controller.authenticate, {}, body_dict) diff --git a/tests/test_backend.py b/tests/test_backend.py index 7e4d820e..75a94773 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1453,6 +1453,18 @@ class IdentityTests(object): tenants = self.identity_api.get_projects_for_user(self.user_foo['id']) self.assertIn(self.tenant_baz['id'], tenants) + def test_add_user_to_project_missing_default_role(self): + self.assignment_api.delete_role(CONF.member_role_id) + self.assertRaises(exception.RoleNotFound, + self.assignment_api.get_role, + CONF.member_role_id) + self.identity_api.add_user_to_project(self.tenant_baz['id'], + self.user_foo['id']) + tenants = self.identity_api.get_projects_for_user(self.user_foo['id']) + self.assertIn(self.tenant_baz['id'], tenants) + default_role = self.assignment_api.get_role(CONF.member_role_id) + self.assertIsNotNone(default_role) + def test_add_user_to_project_404(self): self.assertRaises(exception.ProjectNotFound, self.identity_api.add_user_to_project, diff --git a/tests/test_sql_migrate_extensions.py b/tests/test_sql_migrate_extensions.py new file mode 100644 index 00000000..4a529559 --- /dev/null +++ b/tests/test_sql_migrate_extensions.py @@ -0,0 +1,47 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 OpenStack LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +To run these tests against a live database: +1. Modify the file `tests/backend_sql.conf` to use the connection for your + live database +2. Set up a blank, live database. +3. run the tests using + ./run_tests.sh -N test_sql_upgrade + WARNING:: + Your database will be wiped. + Do not do this against a Database with valuable data as + all data will be lost. +""" + +from keystone.contrib import example + +import test_sql_upgrade + + +class SqlUpgradeExampleExtension(test_sql_upgrade.SqlMigrateBase): + def repo_package(self): + return example + + def test_upgrade(self): + self.assertTableDoesNotExist('example') + self.upgrade(1, repository=self.repo_path) + self.assertTableColumns('example', ['id', 'type', 'extra']) + + def test_downgrade(self): + self.upgrade(1, repository=self.repo_path) + self.assertTableColumns('example', ['id', 'type', 'extra']) + self.downgrade(0, repository=self.repo_path) + self.assertTableDoesNotExist('example') diff --git a/tests/test_sql_upgrade.py b/tests/test_sql_upgrade.py index cf82b814..9540c4cd 100644 --- a/tests/test_sql_upgrade.py +++ b/tests/test_sql_upgrade.py @@ -45,8 +45,7 @@ CONF = config.CONF DEFAULT_DOMAIN_ID = CONF.identity.default_domain_id -class SqlUpgradeTests(test.TestCase): - +class SqlMigrateBase(test.TestCase): def initialize_sql(self): self.metadata = sqlalchemy.MetaData() self.metadata.bind = self.engine @@ -55,12 +54,15 @@ class SqlUpgradeTests(test.TestCase): test.testsdir('test_overrides.conf'), test.testsdir('backend_sql.conf')] - #override this to sepcify the complete list of configuration files + #override this to specify the complete list of configuration files def config_files(self): return self._config_file_list + def repo_package(self): + return None + def setUp(self): - super(SqlUpgradeTests, self).setUp() + super(SqlMigrateBase, self).setUp() self.config(self.config_files()) self.base = sql.Base() @@ -71,7 +73,7 @@ class SqlUpgradeTests(test.TestCase): autocommit=False) self.initialize_sql() - self.repo_path = migration._find_migrate_repo() + self.repo_path = migration.find_migrate_repo(self.repo_package()) self.schema = versioning_api.ControlledSchema.create( self.engine, self.repo_path, 0) @@ -85,7 +87,64 @@ class SqlUpgradeTests(test.TestCase): autoload=True) self.downgrade(0) table.drop(self.engine, checkfirst=True) - super(SqlUpgradeTests, self).tearDown() + super(SqlMigrateBase, self).tearDown() + + def select_table(self, name): + table = sqlalchemy.Table(name, + self.metadata, + autoload=True) + s = sqlalchemy.select([table]) + return s + + def assertTableExists(self, table_name): + try: + self.select_table(table_name) + except sqlalchemy.exc.NoSuchTableError: + raise AssertionError('Table "%s" does not exist' % table_name) + + def assertTableDoesNotExist(self, table_name): + """Asserts that a given table exists cannot be selected by name.""" + # Switch to a different metadata otherwise you might still + # detect renamed or dropped tables + try: + temp_metadata = sqlalchemy.MetaData() + temp_metadata.bind = self.engine + sqlalchemy.Table(table_name, temp_metadata, autoload=True) + except sqlalchemy.exc.NoSuchTableError: + pass + else: + raise AssertionError('Table "%s" already exists' % table_name) + + def upgrade(self, *args, **kwargs): + self._migrate(*args, **kwargs) + + def downgrade(self, *args, **kwargs): + self._migrate(*args, downgrade=True, **kwargs) + + def _migrate(self, version, repository=None, downgrade=False, + current_schema=None): + repository = repository or self.repo_path + err = '' + version = versioning_api._migrate_version(self.schema, + version, + not downgrade, + err) + if not current_schema: + current_schema = self.schema + changeset = current_schema.changeset(version) + for ver, change in changeset: + self.schema.runchange(ver, change, changeset.step) + self.assertEqual(self.schema.version, version) + + def assertTableColumns(self, table_name, expected_cols): + """Asserts that the table contains the expected set of columns.""" + self.initialize_sql() + table = self.select_table(table_name) + actual_cols = [col.name for col in table.columns] + self.assertEqual(expected_cols, actual_cols, '%s table' % table_name) + + +class SqlUpgradeTests(SqlMigrateBase): def test_blank_db_to_start(self): self.assertTableDoesNotExist('user') @@ -108,13 +167,6 @@ class SqlUpgradeTests(test.TestCase): self.downgrade(x - 1) self.upgrade(x) - def assertTableColumns(self, table_name, expected_cols): - """Asserts that the table contains the expected set of columns.""" - self.initialize_sql() - table = self.select_table(table_name) - actual_cols = [col.name for col in table.columns] - self.assertEqual(expected_cols, actual_cols, '%s table' % table_name) - def test_upgrade_add_initial_tables(self): self.upgrade(1) self.assertTableColumns("user", ["id", "name", "extra"]) @@ -1186,6 +1238,24 @@ class SqlUpgradeTests(test.TestCase): self.assertEqual(cred.user_id, ec2_credential['user_id']) + def test_drop_credential_indexes(self): + self.upgrade(31) + table = sqlalchemy.Table('credential', self.metadata, autoload=True) + self.assertEqual(len(table.indexes), 0) + + def test_downgrade_30(self): + self.upgrade(31) + self.downgrade(30) + table = sqlalchemy.Table('credential', self.metadata, autoload=True) + index_data = [(idx.name, idx.columns.keys()) + for idx in table.indexes] + if self.engine.name == 'mysql': + self.assertIn(('user_id', ['user_id']), index_data) + self.assertIn(('credential_project_id_fkey', ['project_id']), + index_data) + else: + self.assertEqual(len(index_data), 0) + def populate_user_table(self, with_pass_enab=False, with_pass_enab_domain=False): # Populate the appropriate fields in the user @@ -1284,50 +1354,6 @@ class SqlUpgradeTests(test.TestCase): 'extra': json.dumps(extra)}) self.engine.execute(ins) - def select_table(self, name): - table = sqlalchemy.Table(name, - self.metadata, - autoload=True) - s = sqlalchemy.select([table]) - return s - - def assertTableExists(self, table_name): - try: - self.select_table(table_name) - except sqlalchemy.exc.NoSuchTableError: - raise AssertionError('Table "%s" does not exist' % table_name) - - def assertTableDoesNotExist(self, table_name): - """Asserts that a given table exists cannot be selected by name.""" - # Switch to a different metadata otherwise you might still - # detect renamed or dropped tables - try: - temp_metadata = sqlalchemy.MetaData() - temp_metadata.bind = self.engine - sqlalchemy.Table(table_name, temp_metadata, autoload=True) - except sqlalchemy.exc.NoSuchTableError: - pass - else: - raise AssertionError('Table "%s" already exists' % table_name) - - def upgrade(self, *args, **kwargs): - self._migrate(*args, **kwargs) - - def downgrade(self, *args, **kwargs): - self._migrate(*args, downgrade=True, **kwargs) - - def _migrate(self, version, repository=None, downgrade=False): - repository = repository or self.repo_path - err = '' - version = versioning_api._migrate_version(self.schema, - version, - not downgrade, - err) - changeset = self.schema.changeset(version) - for ver, change in changeset: - self.schema.runchange(ver, change, changeset.step) - self.assertEqual(self.schema.version, version) - def _mysql_check_all_tables_innodb(self): database = self.engine.url.database diff --git a/tests/test_wsgi.py b/tests/test_wsgi.py index 003f7571..362df922 100644 --- a/tests/test_wsgi.py +++ b/tests/test_wsgi.py @@ -37,37 +37,6 @@ class BaseWSGITest(test.TestCase): req.environ['wsgiorg.routing_args'] = [None, args] return req - def test_mask_password(self): - message = ("test = 'password': 'aaaaaa', 'param1': 'value1', " - "\"new_password\": 'bbbbbb'") - self.assertEqual(wsgi.mask_password(message, True), - u"test = 'password': '***', 'param1': 'value1', " - "\"new_password\": '***'") - - message = "test = 'password' : 'aaaaaa'" - self.assertEqual(wsgi.mask_password(message, False, '111'), - "test = 'password' : '111'") - - message = u"test = u'password' : u'aaaaaa'" - self.assertEqual(wsgi.mask_password(message, True), - u"test = u'password' : u'***'") - - message = 'test = "password" : "aaaaaaaaa"' - self.assertEqual(wsgi.mask_password(message), - 'test = "password" : "***"') - - message = 'test = "original_password" : "aaaaaaaaa"' - self.assertEqual(wsgi.mask_password(message), - 'test = "original_password" : "***"') - - message = 'test = "original_password" : ""' - self.assertEqual(wsgi.mask_password(message), - 'test = "original_password" : "***"') - - message = 'test = "param1" : "value"' - self.assertEqual(wsgi.mask_password(message), - 'test = "param1" : "value"') - class ApplicationTest(BaseWSGITest): def test_response_content_type(self): @@ -210,3 +179,36 @@ class MiddlewareTest(BaseWSGITest): app = factory(self.app) self.assertIn("testkey", app.kwargs) self.assertEquals("test", app.kwargs["testkey"]) + + +class WSGIFunctionTest(test.TestCase): + def test_mask_password(self): + message = ("test = 'password': 'aaaaaa', 'param1': 'value1', " + "\"new_password\": 'bbbbbb'") + self.assertEqual(wsgi.mask_password(message, True), + u"test = 'password': '***', 'param1': 'value1', " + "\"new_password\": '***'") + + message = "test = 'password' : 'aaaaaa'" + self.assertEqual(wsgi.mask_password(message, False, '111'), + "test = 'password' : '111'") + + message = u"test = u'password' : u'aaaaaa'" + self.assertEqual(wsgi.mask_password(message, True), + u"test = u'password' : u'***'") + + message = 'test = "password" : "aaaaaaaaa"' + self.assertEqual(wsgi.mask_password(message), + 'test = "password" : "***"') + + message = 'test = "original_password" : "aaaaaaaaa"' + self.assertEqual(wsgi.mask_password(message), + 'test = "original_password" : "***"') + + message = 'test = "original_password" : ""' + self.assertEqual(wsgi.mask_password(message), + 'test = "original_password" : "***"') + + message = 'test = "param1" : "value"' + self.assertEqual(wsgi.mask_password(message), + 'test = "param1" : "value"') |