diff options
| -rw-r--r-- | nova/tests/integrated/test_api_samples.py | 121 |
1 files changed, 80 insertions, 41 deletions
diff --git a/nova/tests/integrated/test_api_samples.py b/nova/tests/integrated/test_api_samples.py index a15573252..89c16c6fb 100644 --- a/nova/tests/integrated/test_api_samples.py +++ b/nova/tests/integrated/test_api_samples.py @@ -21,6 +21,7 @@ from lxml import etree from nova import flags from nova.openstack.common import jsonutils from nova.openstack.common.log import logging +from nova import test from nova.tests import fake_network from nova.tests.image import fake from nova.tests.integrated import integrated_helpers @@ -29,6 +30,10 @@ FLAGS = flags.FLAGS LOG = logging.getLogger(__name__) +class NoMatch(test.TestingException): + pass + + class ApiSampleTestBase(integrated_helpers._IntegratedTestBase): ctype = 'json' all_extensions = False @@ -44,23 +49,37 @@ class ApiSampleTestBase(integrated_helpers._IntegratedTestBase): fake_network.stub_compute_with_ips(self.stubs) self.generate_samples = os.getenv('GENERATE_SAMPLES') is not None - def _pretty_data(self, data, strip_text=True): + def _pretty_data(self, data): if self.ctype == 'json': data = jsonutils.dumps(jsonutils.loads(data), sort_keys=True, indent=4) else: xml = etree.XML(data) - # NOTE(vish): strip newlines from text blobs for matching - if strip_text: - for text in xml.xpath('//text()'): - parent = text.getparent() - parent.text = parent.text.replace('\n', '') - data = etree.tostring(xml, encoding="UTF-8", xml_declaration=True, pretty_print=True) return '\n'.join(line.rstrip() for line in data.split('\n')).strip() + def _objectify(self, data): + if self.ctype == 'json': + return jsonutils.loads(data) + else: + def to_dict(node): + ret = {} + if node.items(): + ret.update(dict(node.items())) + if node.text: + ret['__content__'] = node.text + if node.tag: + ret['__tag__'] = node.tag + if node.nsmap: + ret['__nsmap__'] = node.nsmap + for element in node: + ret.setdefault(node.tag, []) + ret[node.tag].append(to_dict(element)) + return ret + return to_dict(etree.fromstring(data)) + @classmethod def _get_sample(cls, name, suffix=''): parts = [os.path.dirname(os.path.abspath(__file__))] @@ -72,49 +91,69 @@ class ApiSampleTestBase(integrated_helpers._IntegratedTestBase): parts.append(name + "." + cls.ctype + suffix) return os.path.join(*parts) - @classmethod - def _get_template(cls, name): - return cls._get_sample(name, suffix='.tpl') - def _read_template(self, name): - with open(self._get_template(name)) as inf: + template = self._get_sample(name, suffix='.tpl') + if self.generate_samples and not os.path.exists(template): + with open(template, 'w') as outf: + pass + with open(template) as inf: return inf.read().strip() def _write_sample(self, name, data): with open(self._get_sample(name), 'w') as outf: - outf.write(self._pretty_data(data, False)) - - def _verify_response(self, name, subs, response): - expected = self._read_template(name) - - # NOTE(vish): escape stuff for regex - for char in ['[', ']', '<', '>', '?']: - expected = expected.replace(char, '\%s' % char) - - expected = expected % subs - data = response.read() - result = self._pretty_data(data) - if self.generate_samples: - self._write_sample(name, data) - result_lines = result.split('\n') - expected_lines = expected.split('\n') - if len(result_lines) != len(expected_lines): - LOG.info(expected) - LOG.info(result) - self.fail('Number of lines (%s) incorrect' % (len(expected_lines))) - result = None - for line, result_line in zip(expected_lines, result_lines): + outf.write(data) + + def _compare_result(self, subs, expected, result): + matched_value = None + if isinstance(expected, dict): + if not isinstance(result, dict): + raise NoMatch( + _('Result: %(result)s is not a dict.') % locals()) + ex_keys = sorted(expected.keys()) + res_keys = sorted(result.keys()) + if ex_keys != res_keys: + raise NoMatch(_('Key mismatch:\n' + '%(ex_keys)s\n%(res_keys)s') % locals()) + for key in ex_keys: + res = self._compare_result(subs, expected[key], result[key]) + matched_value = res or matched_value + elif isinstance(expected, list): + if not isinstance(result, list): + raise NoMatch( + _('Result: %(result)s is not a list.') % locals()) + for ex_obj, res_obj in zip(sorted(expected), sorted(result)): + res = self._compare_result(subs, ex_obj, res_obj) + matched_value = res or matched_value + + elif isinstance(expected, basestring) and '%' in expected: try: - match = re.match(line, result_line) + # NOTE(vish): escape stuff for regex + for char in ['[', ']', '<', '>', '?']: + expected = expected.replace(char, '\%s' % char) + expected = expected % subs + match = re.match(expected, result) except Exception as exc: - self.fail(_('Response error on line:\n' - '%(line)s\n%(result_line)s') % locals()) + raise NoMatch(_('Values do not match:\n' + '%(expected)s\n%(result)s') % locals()) if not match: - self.fail(_('Response error on line:\n' - '%(line)s\n%(result_line)s') % locals()) + raise NoMatch(_('Values do not match:\n' + '%(expected)s\n%(result)s') % locals()) if match.groups(): - result = match.groups()[0] - return result + matched_value = match.groups()[0] + else: + if expected != result: + raise NoMatch(_('Values do not match:\n' + '%(expected)s\n%(result)s') % locals()) + return matched_value + + def _verify_response(self, name, subs, response): + expected = self._read_template(name) + expected = self._objectify(expected) + result = self._pretty_data(response.read()) + if self.generate_samples: + self._write_sample(name, result) + result = self._objectify(result) + return self._compare_result(subs, expected, result) def _get_host(self): return 'http://openstack.example.com' |
