diff options
| author | Jenkins <jenkins@review.openstack.org> | 2012-08-13 16:17:08 +0000 |
|---|---|---|
| committer | Gerrit Code Review <review@openstack.org> | 2012-08-13 16:17:08 +0000 |
| commit | dafa79ff7880af907b5836d832e6c5dbfd7e0529 (patch) | |
| tree | 2ed64a97450761726b0fc722f4e9c06b7aa40fb0 /nova | |
| parent | 2db23f6ba4db8fe95d05874dc3c6bef017a412ae (diff) | |
| parent | 56a6fa2bb68a8738f8d02f41f92983ee3115a19d (diff) | |
Merge "Move results filtering to db."
Diffstat (limited to 'nova')
| -rw-r--r-- | nova/compute/api.py | 1 | ||||
| -rw-r--r-- | nova/db/sqlalchemy/api.py | 91 | ||||
| -rw-r--r-- | nova/db/sqlalchemy/session.py | 12 | ||||
| -rw-r--r-- | nova/tests/compute/test_compute.py | 35 | ||||
| -rw-r--r-- | nova/tests/test_db_api.py | 28 |
5 files changed, 87 insertions, 80 deletions
diff --git a/nova/compute/api.py b/nova/compute/api.py index 1e1c82a8f..cd3ffdf56 100644 --- a/nova/compute/api.py +++ b/nova/compute/api.py @@ -1082,7 +1082,6 @@ class API(base.Base): filter_mapping = { 'image': 'image_ref', 'name': 'display_name', - 'instance_name': 'name', 'tenant_id': 'project_id', 'flavor': _remap_flavor_filter, 'fixed_ip': _remap_fixed_ip_filter} diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index ba34a08d8..c54388f07 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -23,7 +23,6 @@ import collections import copy import datetime import functools -import re import warnings from nova import block_device @@ -44,7 +43,6 @@ from sqlalchemy.orm import joinedload_all from sqlalchemy.sql.expression import asc from sqlalchemy.sql.expression import desc from sqlalchemy.sql.expression import literal_column -from sqlalchemy.sql.expression import or_ from sqlalchemy.sql import func FLAGS = flags.FLAGS @@ -247,7 +245,19 @@ def exact_filter(query, model, filters, legal_keys): # OK, filtering on this key; what value do we search for? value = filters.pop(key) - if isinstance(value, (list, tuple, set, frozenset)): + if key == 'metadata': + column_attr = getattr(model, key) + if isinstance(value, list): + for item in value: + for k, v in item.iteritems(): + query = query.filter(column_attr.any(key=k)) + query = query.filter(column_attr.any(value=v)) + + else: + for k, v in value.iteritems(): + query = query.filter(column_attr.any(key=k)) + query = query.filter(column_attr.any(value=v)) + elif isinstance(value, (list, tuple, set, frozenset)): # Looking for values in a list; apply to query directly column_attr = getattr(model, key) query = query.filter(column_attr.in_(value)) @@ -1517,28 +1527,6 @@ def instance_get_all_by_filters(context, filters, sort_key, sort_dir): will be returned by default, unless there's a filter that says otherwise""" - def _regexp_filter_by_metadata(instance, meta): - inst_metadata = [{node['key']: node['value']} - for node in instance['metadata']] - if isinstance(meta, list): - for node in meta: - if node not in inst_metadata: - return False - elif isinstance(meta, dict): - for k, v in meta.iteritems(): - if {k: v} not in inst_metadata: - return False - return True - - def _regexp_filter_by_column(instance, filter_name, filter_re): - try: - v = getattr(instance, filter_name) - except AttributeError: - return True - if v and filter_re.match(unicode(v)): - return True - return False - sort_fn = {'desc': desc, 'asc': asc} session = get_session() @@ -1580,37 +1568,46 @@ def instance_get_all_by_filters(context, filters, sort_key, sort_dir): # Filters for exact matches that we can do along with the SQL query... # For other filters that don't match this, we will do regexp matching exact_match_filter_names = ['project_id', 'user_id', 'image_ref', - 'vm_state', 'instance_type_id', 'uuid'] + 'vm_state', 'instance_type_id', 'uuid', + 'metadata'] # Filter the query query_prefix = exact_filter(query_prefix, models.Instance, filters, exact_match_filter_names) + query_prefix = regex_filter(query_prefix, models.Instance, filters) instances = query_prefix.all() - if not instances: - return [] + return instances - # Now filter on everything else for regexp matching.. - # For filters not in the list, we'll attempt to use the filter_name - # as a column name in Instance.. - regexp_filter_funcs = {} - for filter_name in filters.iterkeys(): - filter_func = regexp_filter_funcs.get(filter_name, None) - filter_re = re.compile(str(filters[filter_name])) - if filter_func: - filter_l = lambda instance: filter_func(instance, filter_re) - elif filter_name == 'metadata': - filter_l = lambda instance: _regexp_filter_by_metadata(instance, - filters[filter_name]) - else: - filter_l = lambda instance: _regexp_filter_by_column(instance, - filter_name, filter_re) - instances = filter(filter_l, instances) - if not instances: - break +def regex_filter(query, model, filters): + """Applies regular expression filtering to a query. - return instances + Returns the updated query. + + :param query: query to apply filters to + :param model: model object the query applies to + :param filters: dictionary of filters with regex values + """ + + regexp_op_map = { + 'postgresql': '~', + 'mysql': 'REGEXP', + 'oracle': 'REGEXP_LIKE', + 'sqlite': 'REGEXP' + } + db_string = FLAGS.sql_connection.split(':')[0].split('+')[0] + db_regexp_op = regexp_op_map.get(db_string, 'LIKE') + for filter_name in filters.iterkeys(): + try: + column_attr = getattr(model, filter_name) + except AttributeError: + continue + if 'property' == type(column_attr).__name__: + continue + query = query.filter(column_attr.op(db_regexp_op)( + str(filters[filter_name]))) + return query @require_context diff --git a/nova/db/sqlalchemy/session.py b/nova/db/sqlalchemy/session.py index 6aa5050e4..cada9d79a 100644 --- a/nova/db/sqlalchemy/session.py +++ b/nova/db/sqlalchemy/session.py @@ -18,6 +18,7 @@ """Session Handling for SQLAlchemy backend.""" +import re import time from sqlalchemy.exc import DisconnectionError, OperationalError @@ -85,6 +86,16 @@ def is_db_connection_error(args): return False +def regexp(expr, item): + reg = re.compile(expr) + return reg.search(unicode(item)) is not None + + +class AddRegexFactory(sqlalchemy.interfaces.PoolListener): + def connect(delf, dbapi_con, con_record): + dbapi_con.create_function('REGEXP', 2, regexp) + + def get_engine(): """Return a SQLAlchemy engine.""" global _ENGINE @@ -109,6 +120,7 @@ def get_engine(): if FLAGS.sql_connection == "sqlite://": engine_args["poolclass"] = StaticPool engine_args["connect_args"] = {'check_same_thread': False} + engine_args['listeners'] = [AddRegexFactory()] _ENGINE = sqlalchemy.create_engine(FLAGS.sql_connection, **engine_args) diff --git a/nova/tests/compute/test_compute.py b/nova/tests/compute/test_compute.py index da3f8c1e1..e54fd44cf 100644 --- a/nova/tests/compute/test_compute.py +++ b/nova/tests/compute/test_compute.py @@ -3206,14 +3206,14 @@ class ComputeAPITestCase(BaseTestCase): 'display_name': 'not-woot'}) instances = self.compute_api.get_all(c, - search_opts={'name': 'woo.*'}) + search_opts={'name': '^woo.*'}) self.assertEqual(len(instances), 2) instance_uuids = [instance['uuid'] for instance in instances] self.assertTrue(instance1['uuid'] in instance_uuids) self.assertTrue(instance2['uuid'] in instance_uuids) instances = self.compute_api.get_all(c, - search_opts={'name': 'woot.*'}) + search_opts={'name': '^woot.*'}) instance_uuids = [instance['uuid'] for instance in instances] self.assertEqual(len(instances), 1) self.assertTrue(instance1['uuid'] in instance_uuids) @@ -3226,7 +3226,7 @@ class ComputeAPITestCase(BaseTestCase): self.assertTrue(instance3['uuid'] in instance_uuids) instances = self.compute_api.get_all(c, - search_opts={'name': 'n.*'}) + search_opts={'name': '^n.*'}) self.assertEqual(len(instances), 1) instance_uuids = [instance['uuid'] for instance in instances] self.assertTrue(instance3['uuid'] in instance_uuids) @@ -3239,35 +3239,6 @@ class ComputeAPITestCase(BaseTestCase): db.instance_destroy(c, instance2['uuid']) db.instance_destroy(c, instance3['uuid']) - def test_get_all_by_instance_name_regexp(self): - """Test searching instances by name""" - self.flags(instance_name_template='instance-%d') - - c = context.get_admin_context() - instance1 = self._create_fake_instance() - instance2 = self._create_fake_instance({'id': 2}) - instance3 = self._create_fake_instance({'id': 10}) - - instances = self.compute_api.get_all(c, - search_opts={'instance_name': 'instance.*'}) - self.assertEqual(len(instances), 3) - - instances = self.compute_api.get_all(c, - search_opts={'instance_name': '.*\-\d$'}) - self.assertEqual(len(instances), 2) - instance_uuids = [instance['uuid'] for instance in instances] - self.assertTrue(instance1['uuid'] in instance_uuids) - self.assertTrue(instance2['uuid'] in instance_uuids) - - instances = self.compute_api.get_all(c, - search_opts={'instance_name': 'i.*2'}) - self.assertEqual(len(instances), 1) - self.assertEqual(instances[0]['uuid'], instance2['uuid']) - - db.instance_destroy(c, instance1['uuid']) - db.instance_destroy(c, instance2['uuid']) - db.instance_destroy(c, instance3['uuid']) - def test_get_all_by_multiple_options_at_once(self): """Test searching by multiple options at once""" c = context.get_admin_context() diff --git a/nova/tests/test_db_api.py b/nova/tests/test_db_api.py index b664f54e4..9acc964cf 100644 --- a/nova/tests/test_db_api.py +++ b/nova/tests/test_db_api.py @@ -66,6 +66,34 @@ class DbApiTestCase(test.TestCase): result = db.instance_get_all_by_filters(self.context, {}) self.assertEqual(2, len(result)) + def test_instance_get_all_by_filters_regex(self): + self.create_instances_with_args(display_name='test1') + self.create_instances_with_args(display_name='teeeest2') + self.create_instances_with_args(display_name='diff') + result = db.instance_get_all_by_filters(self.context, + {'display_name': 't.*st.'}) + self.assertEqual(2, len(result)) + + def test_instance_get_all_by_filters_regex_unsupported_db(self): + """Ensure that the 'LIKE' operator is used for unsupported dbs.""" + self.flags(sql_connection="notdb://") + self.create_instances_with_args(display_name='test1') + self.create_instances_with_args(display_name='test.*') + self.create_instances_with_args(display_name='diff') + result = db.instance_get_all_by_filters(self.context, + {'display_name': 'test.*'}) + self.assertEqual(1, len(result)) + result = db.instance_get_all_by_filters(self.context, + {'display_name': '%test%'}) + self.assertEqual(2, len(result)) + + def test_instance_get_all_by_filters_metadata(self): + self.create_instances_with_args(metadata={'foo': 'bar'}) + self.create_instances_with_args() + result = db.instance_get_all_by_filters(self.context, + {'metadata': {'foo': 'bar'}}) + self.assertEqual(1, len(result)) + def test_instance_get_all_by_filters_unicode_value(self): self.create_instances_with_args(display_name=u'test♥') result = db.instance_get_all_by_filters(self.context, |
