From de5b1ce17a44e824f1f29ead19dac45db4e0086c Mon Sep 17 00:00:00 2001 From: Vishvananda Ishaya Date: Mon, 30 Aug 2010 15:11:46 -0700 Subject: all tests pass again --- nova/db/api.py | 19 ++++--- nova/db/sqlalchemy/api.py | 121 ++++++++++++++++++++++++++--------------- nova/db/sqlalchemy/models.py | 40 ++++++++------ nova/db/sqlalchemy/session.py | 9 ++- nova/endpoint/cloud.py | 4 +- nova/tests/compute_unittest.py | 30 +++++----- nova/tests/network_unittest.py | 7 +-- nova/tests/volume_unittest.py | 7 ++- 8 files changed, 144 insertions(+), 93 deletions(-) diff --git a/nova/db/api.py b/nova/db/api.py index 91d7b8415..9b8c48934 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -108,10 +108,15 @@ def floating_ip_fixed_ip_associate(context, floating_address, fixed_address): def floating_ip_get_by_address(context, address): - """Get a floating ip by address.""" + """Get a floating ip by address or raise if it doesn't exist.""" return _impl.floating_ip_get_by_address(context, address) +def floating_ip_get_instance(context, address): + """Get an instance for a floating ip by address.""" + return _impl.floating_ip_get_instance(context, address) + + #################### @@ -134,10 +139,15 @@ def fixed_ip_deallocate(context, address): def fixed_ip_get_by_address(context, address): - """Get a fixed ip by address.""" + """Get a fixed ip by address or raise if it does not exist.""" return _impl.fixed_ip_get_by_address(context, address) +def fixed_ip_get_instance(context, address): + """Get an instance for a fixed ip by address.""" + return _impl.fixed_ip_get_instance(context, address) + + def fixed_ip_get_network(context, address): """Get a network for a fixed ip by address.""" return _impl.fixed_ip_get_network(context, address) @@ -181,11 +191,6 @@ def instance_get_all(context): return _impl.instance_get_all(context) -def instance_get_by_address(context, address): - """Gets an instance by fixed ip address or raise if it does not exist.""" - return _impl.instance_get_by_address(context, address) - - def instance_get_by_project(context, project_id): """Get all instance belonging to a project.""" return _impl.instance_get_by_project(context, project_id) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index cef77cc50..a4b0ba545 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -79,30 +79,50 @@ def floating_ip_create(context, address, host): def floating_ip_fixed_ip_associate(context, floating_address, fixed_address): - floating_ip_ref = db.floating_ip_get_by_address(context, floating_address) - fixed_ip_ref = models.FixedIp.find_by_str(fixed_address) - floating_ip_ref.fixed_ip = fixed_ip_ref - floating_ip_ref.save() + with managed_session(autocommit=False) as session: + floating_ip_ref = models.FloatingIp.find_by_str(floating_address, + session=session) + fixed_ip_ref = models.FixedIp.find_by_str(fixed_address, + session=session) + floating_ip_ref.fixed_ip = fixed_ip_ref + floating_ip_ref.save(session=session) + session.commit() def floating_ip_disassociate(context, address): - floating_ip_ref = db.floating_ip_get_by_address(context, address) - fixed_ip_address = floating_ip_ref.fixed_ip['str_id'] - floating_ip_ref['fixed_ip'] = None - floating_ip_ref.save() - return fixed_ip_address + with managed_session(autocommit=False) as session: + floating_ip_ref = models.FloatingIp.find_by_str(address, + session=session) + fixed_ip_ref = floating_ip_ref.fixed_ip + if fixed_ip_ref: + fixed_ip_address = fixed_ip_ref['str_id'] + else: + fixed_ip_address = None + floating_ip_ref.fixed_ip = None + floating_ip_ref.save(session=session) + session.commit() + return fixed_ip_address def floating_ip_deallocate(context, address): - floating_ip_ref = db.floating_ip_get_by_address(context, address) - floating_ip_ref['project_id'] = None - floating_ip_ref.save() + with managed_session(autocommit=False) as session: + floating_ip_ref = models.FloatingIp.find_by_str(address, + session=session) + floating_ip_ref['project_id'] = None + floating_ip_ref.save(session=session) def floating_ip_get_by_address(context, address): return models.FloatingIp.find_by_str(address) +def floating_ip_get_instance(context, address): + with managed_session() as session: + floating_ip_ref = models.FloatingIp.find_by_str(address, + session=session) + return floating_ip_ref.fixed_ip.instance + + ################### @@ -139,8 +159,14 @@ def fixed_ip_get_by_address(context, address): return models.FixedIp.find_by_str(address) +def fixed_ip_get_instance(context, address): + with managed_session() as session: + return models.FixedIp.find_by_str(address, session=session).instance + + def fixed_ip_get_network(context, address): - return models.FixedIp.find_by_str(address).network + with managed_session() as session: + return models.FixedIp.find_by_str(address, session=session).network def fixed_ip_deallocate(context, address): @@ -150,15 +176,20 @@ def fixed_ip_deallocate(context, address): def fixed_ip_instance_associate(context, address, instance_id): - fixed_ip_ref = fixed_ip_get_by_address(context, address) - fixed_ip_ref.instance = instance_get(context, instance_id) - fixed_ip_ref.save() + with managed_session(autocommit=False) as session: + fixed_ip_ref = models.FixedIp.find_by_str(address, session=session) + instance_ref = models.Instance.find(instance_id, session=session) + fixed_ip_ref.instance = instance_ref + fixed_ip_ref.save(session=session) + session.commit() def fixed_ip_instance_disassociate(context, address): - fixed_ip_ref = fixed_ip_get_by_address(context, address) - fixed_ip_ref.instance = None - fixed_ip_ref.save() + with managed_session(autocommit=False) as session: + fixed_ip_ref = models.FixedIp.find_by_str(address, session=session) + fixed_ip_ref.instance = None + fixed_ip_ref.save(session=session) + session.commit() def fixed_ip_update(context, address, values): @@ -192,13 +223,6 @@ def instance_get_all(context): return models.Instance.all() -def instance_get_by_address(context, address): - fixed_ip_ref = db.fixed_ip_get_by_address(address) - if not fixed_ip_ref.instance: - raise exception.NotFound("No instance found for address %s" % address) - return fixed_ip_ref.instance - - def instance_get_by_project(context, project_id): with managed_session() as session: return session.query(models.Instance) \ @@ -220,20 +244,22 @@ def instance_get_by_str(context, str_id): def instance_get_fixed_address(context, instance_id): - instance_ref = instance_get(context, instance_id) - if not instance_ref.fixed_ip: - return None - return instance_ref.fixed_ip['str_id'] + with managed_session() as session: + instance_ref = models.Instance.find(instance_id, session=session) + if not instance_ref.fixed_ip: + return None + return instance_ref.fixed_ip['str_id'] def instance_get_floating_address(context, instance_id): - instance_ref = instance_get(context, instance_id) - if not instance_ref.fixed_ip: - return None - if not instance_ref.fixed_ip.floating_ips: - return None - # NOTE(vish): this just returns the first floating ip - return instance_ref.fixed_ip.floating_ips[0]['str_id'] + with managed_session() as session: + instance_ref = models.Instance.find(instance_id, session=session) + if not instance_ref.fixed_ip: + return None + if not instance_ref.fixed_ip.floating_ips: + return None + # NOTE(vish): this just returns the first floating ip + return instance_ref.fixed_ip.floating_ips[0]['str_id'] def instance_get_host(context, instance_id): @@ -307,6 +333,13 @@ def network_destroy(context, network_id): # TODO(vish): do we have to use sql here? session.execute('update networks set deleted=1 where id=:id', {'id': network_id}) + session.execute('update fixed_ips set deleted=1 where network_id=:id', + {'id': network_id}) + session.execute('update floating_ips set deleted=1 ' + 'where fixed_ip_id in ' + '(select id from fixed_ips ' + 'where network_id=:id)', + {'id': network_id}) session.execute('update network_indexes set network_id=NULL where network_id=:id', {'id': network_id}) session.commit() @@ -472,7 +505,7 @@ def volume_destroy(context, volume_id): # TODO(vish): do we have to use sql here? session.execute('update volumes set deleted=1 where id=:id', {'id': volume_id}) - session.execute('update export_devices set volume_id=NULL where network_id=:id', + session.execute('update export_devices set volume_id=NULL where volume_id=:id', {'id': volume_id}) session.commit() @@ -512,11 +545,13 @@ def volume_get_host(context, volume_id): def volume_get_shelf_and_blade(context, volume_id): - volume_ref = volume_get(context, volume_id) - export_device = volume_ref.export_device - if not export_device: - raise exception.NotFound() - return (export_device.shelf_id, export_device.blade_id) + with managed_session() as session: + export_device = session.query(models.ExportDevice) \ + .filter_by(volume_id=volume_id) \ + .first() + if not export_device: + raise exception.NotFound() + return (export_device.shelf_id, export_device.blade_id) def volume_update(context, volume_id, values): diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index b7031eec0..b6077a583 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -274,14 +274,18 @@ class FixedIp(Base, NovaBase): return self.ip_str @classmethod - def find_by_str(cls, session, str_id): - try: - return session.query(cls) \ - .filter_by(ip_str=str_id) \ - .filter_by(deleted=False) \ - .one() - except exc.NoResultFound: - raise exception.NotFound("No model for ip str %s" % str_id) + def find_by_str(cls, str_id, session=None): + if session: + try: + return session.query(cls) \ + .filter_by(ip_str=str_id) \ + .filter_by(deleted=False) \ + .one() + except exc.NoResultFound: + raise exception.NotFound("No model for ip_str %s" % str_id) + else: + with managed_session() as s: + return cls.find_by_str(str_id, session=s) class FloatingIp(Base, NovaBase): @@ -299,14 +303,18 @@ class FloatingIp(Base, NovaBase): return self.ip_str @classmethod - def find_by_str(cls, session, str_id): - try: - return session.query(cls) \ - .filter_by(ip_str=str_id) \ - .filter_by(deleted=False) \ - .one() - except exc.NoResultFound: - raise exception.NotFound("No model for ip str %s" % str_id) + def find_by_str(cls, str_id, session=None): + if session: + try: + return session.query(cls) \ + .filter_by(ip_str=str_id) \ + .filter_by(deleted=False) \ + .one() + except exc.NoResultFound: + raise exception.NotFound("No model for ip_str %s" % str_id) + else: + with managed_session() as s: + return cls.find_by_str(str_id, session=s) class Network(Base, NovaBase): diff --git a/nova/db/sqlalchemy/session.py b/nova/db/sqlalchemy/session.py index 2b088170b..99270433a 100644 --- a/nova/db/sqlalchemy/session.py +++ b/nova/db/sqlalchemy/session.py @@ -44,9 +44,8 @@ class SessionExecutionManager: def __enter__(self): return self._session - def __exit__(self, type, value, traceback): - if type: - logging.exception("Error in database transaction") + def __exit__(self, exc_type, exc_value, traceback): + if exc_type: + logging.exception("Rolling back due to failed transaction") self._session.rollback() - if self._session: - self._session.close() + self._session.close() diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index 0f3ecb3b0..4f7f1c605 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -94,7 +94,7 @@ class CloudController(object): return result def get_metadata(self, ipaddress): - i = db.instance_get_by_address(ipaddress) + i = db.fixed_ip_get_instance(ipaddress) if i is None: return None mpi = self._get_mpi_data(i['project_id']) @@ -421,7 +421,7 @@ class CloudController(object): context.project.id) for floating_ip_ref in iterator: address = floating_ip_ref['id_str'] - instance_ref = db.instance_get_by_address(address) + instance_ref = db.floating_ip_get_instance(address) address_rv = { 'public_ip': address, 'instance_id': instance_ref['id_str'] diff --git a/nova/tests/compute_unittest.py b/nova/tests/compute_unittest.py index 28e51f387..a8d644c84 100644 --- a/nova/tests/compute_unittest.py +++ b/nova/tests/compute_unittest.py @@ -40,7 +40,7 @@ class InstanceXmlTestCase(test.TrialTestCase): # instance_id = 'foo' # first_node = node.Node() - # inst = yield first_node.run_instance(instance_id) + # inst = yield first_node.run_instance(self.context, instance_id) # # # force the state so that we can verify that it changes # inst._s['state'] = node.Instance.NOSTATE @@ -50,7 +50,7 @@ class InstanceXmlTestCase(test.TrialTestCase): # second_node = node.Node() # new_inst = node.Instance.fromXml(second_node._conn, pool=second_node._pool, xml=xml) # self.assertEqual(new_inst.state, node.Instance.RUNNING) - # rv = yield first_node.terminate_instance(instance_id) + # rv = yield first_node.terminate_instance(self.context, instance_id) class ComputeConnectionTestCase(test.TrialTestCase): @@ -63,6 +63,7 @@ class ComputeConnectionTestCase(test.TrialTestCase): self.manager = manager.AuthManager() user = self.manager.create_user('fake', 'fake', 'fake') project = self.manager.create_project('fake', 'fake', 'fake') + self.context = None def tearDown(self): self.manager.delete_user('fake') @@ -84,13 +85,13 @@ class ComputeConnectionTestCase(test.TrialTestCase): def test_run_describe_terminate(self): instance_id = self._create_instance() - yield self.compute.run_instance(instance_id) + yield self.compute.run_instance(self.context, instance_id) instances = db.instance_get_all(None) logging.info("Running instances: %s", instances) self.assertEqual(len(instances), 1) - yield self.compute.terminate_instance(instance_id) + yield self.compute.terminate_instance(self.context, instance_id) instances = db.instance_get_all(None) logging.info("After terminating instances: %s", instances) @@ -99,22 +100,25 @@ class ComputeConnectionTestCase(test.TrialTestCase): @defer.inlineCallbacks def test_reboot(self): instance_id = self._create_instance() - yield self.compute.run_instance(instance_id) - yield self.compute.reboot_instance(instance_id) - yield self.compute.terminate_instance(instance_id) + yield self.compute.run_instance(self.context, instance_id) + yield self.compute.reboot_instance(self.context, instance_id) + yield self.compute.terminate_instance(self.context, instance_id) @defer.inlineCallbacks def test_console_output(self): instance_id = self._create_instance() - rv = yield self.compute.run_instance(instance_id) + rv = yield self.compute.run_instance(self.context, instance_id) - console = yield self.compute.get_console_output(instance_id) + console = yield self.compute.get_console_output(self.context, + instance_id) self.assert_(console) - rv = yield self.compute.terminate_instance(instance_id) + rv = yield self.compute.terminate_instance(self.context, instance_id) @defer.inlineCallbacks def test_run_instance_existing(self): instance_id = self._create_instance() - yield self.compute.run_instance(instance_id) - self.assertFailure(self.compute.run_instance(instance_id), exception.Error) - yield self.compute.terminate_instance(instance_id) + yield self.compute.run_instance(self.context, instance_id) + self.assertFailure(self.compute.run_instance(self.context, + instance_id), + exception.Error) + yield self.compute.terminate_instance(self.context, instance_id) diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index e3fe01fa2..b479f2fa4 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -253,12 +253,11 @@ class NetworkTestCase(test.TrialTestCase): def is_allocated_in_project(address, project_id): """Returns true if address is in specified project""" - fixed_ip = db.fixed_ip_get_by_address(None, address) project_net = db.project_get_network(None, project_id) + network = db.fixed_ip_get_network(None, address) + instance = db.fixed_ip_get_instance(None, address) # instance exists until release - logging.debug('fixed_ip.instance: %s', fixed_ip.instance) - logging.debug('project_net: %s', project_net) - return fixed_ip.instance is not None and fixed_ip.network == project_net + return instance is not None and network['id'] == project_net['id'] def binpath(script): diff --git a/nova/tests/volume_unittest.py b/nova/tests/volume_unittest.py index 4504276e2..6573e9876 100644 --- a/nova/tests/volume_unittest.py +++ b/nova/tests/volume_unittest.py @@ -117,6 +117,7 @@ class VolumeTestCase(test.TrialTestCase): else: rv = yield self.compute.detach_volume(instance_id, volume_id) + vol = db.volume_get(None, volume_id) self.assertEqual(vol['status'], "available") rv = self.volume.delete_volume(self.context, volume_id) @@ -134,9 +135,9 @@ class VolumeTestCase(test.TrialTestCase): volume_ids = [] def _check(volume_id): volume_ids.append(volume_id) - vol = db.volume_get(None, volume_id) - shelf_blade = '%s.%s' % (vol.export_device.shelf_id, - vol.export_device.blade_id) + (shelf_id, blade_id) = db.volume_get_shelf_and_blade(None, + volume_id) + shelf_blade = '%s.%s' % (shelf_id, blade_id) self.assert_(shelf_blade not in shelf_blades) shelf_blades.append(shelf_blade) logging.debug("got %s" % shelf_blade) -- cgit