diff options
| author | Soren Hansen <soren.hansen@rackspace.com> | 2010-09-10 15:02:07 +0200 |
|---|---|---|
| committer | Soren Hansen <soren.hansen@rackspace.com> | 2010-09-10 15:02:07 +0200 |
| commit | d64adee4656a3044258c7dbfff93f5201c39560c (patch) | |
| tree | 26b1419d5fd8c8f98a55de70461dbe342b133fab /nova/db | |
| parent | c3dd0aa79d982d8f34172e6023d4b632ea23f2b9 (diff) | |
| parent | 33d832ee798bc9530be577e3234ff8bcdac4939e (diff) | |
| download | nova-d64adee4656a3044258c7dbfff93f5201c39560c.tar.gz nova-d64adee4656a3044258c7dbfff93f5201c39560c.tar.xz nova-d64adee4656a3044258c7dbfff93f5201c39560c.zip | |
Merge with orm_deux (fixing up style changes in my stuff at the same time).
Diffstat (limited to 'nova/db')
| -rw-r--r-- | nova/db/api.py | 13 | ||||
| -rw-r--r-- | nova/db/sqlalchemy/api.py | 472 | ||||
| -rw-r--r-- | nova/db/sqlalchemy/models.py | 139 | ||||
| -rw-r--r-- | nova/db/sqlalchemy/session.py | 44 |
4 files changed, 330 insertions, 338 deletions
diff --git a/nova/db/api.py b/nova/db/api.py index c7a6da183..2bcf0bd2b 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -30,10 +30,9 @@ flags.DEFINE_string('db_backend', 'sqlalchemy', IMPL = utils.LazyPluggable(FLAGS['db_backend'], - sqlalchemy='nova.db.sqlalchemy.api') + sqlalchemy='nova.db.sqlalchemy.api') -# TODO(vish): where should these exceptions go? class NoMoreAddresses(exception.Error): """No more available addresses""" pass @@ -87,9 +86,9 @@ def floating_ip_allocate_address(context, host, project_id): return IMPL.floating_ip_allocate_address(context, host, project_id) -def floating_ip_create(context, address, host): - """Create a floating ip for a given address on the specified host.""" - return IMPL.floating_ip_create(context, address, host) +def floating_ip_create(context, values): + """Create a floating ip from the values dictionary.""" + return IMPL.floating_ip_create(context, values) def floating_ip_disassociate(context, address): @@ -231,9 +230,9 @@ def instance_is_vpn(context, instance_id): return IMPL.instance_is_vpn(context, instance_id) -def instance_state(context, instance_id, state, description=None): +def instance_set_state(context, instance_id, state, description=None): """Set the state of an instance.""" - return IMPL.instance_state(context, instance_id, state, description) + return IMPL.instance_set_state(context, instance_id, state, description) def instance_update(context, instance_id, values): diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 622e76cd7..1c95efd83 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -23,7 +23,7 @@ from nova import db from nova import exception from nova import flags from nova.db.sqlalchemy import models -from nova.db.sqlalchemy.session import managed_session +from nova.db.sqlalchemy.session import get_session from sqlalchemy import or_ from sqlalchemy.orm import eagerload @@ -52,55 +52,58 @@ def service_create(_context, values): return service_ref.id -def service_update(context, service_id, values): - service_ref = service_get(context, service_id) - for (key, value) in values.iteritems(): - service_ref[key] = value - service_ref.save() +def service_update(_context, service_id, values): + session = get_session() + with session.begin(): + service_ref = models.Service.find(service_id, session=session) + for (key, value) in values.iteritems(): + service_ref[key] = value + service_ref.save(session=session) ################### def floating_ip_allocate_address(_context, host, project_id): - with managed_session(autocommit=False) as session: - floating_ip_ref = session.query(models.FloatingIp) \ - .filter_by(host=host) \ - .filter_by(fixed_ip_id=None) \ - .filter_by(deleted=False) \ - .with_lockmode('update') \ - .first() + session = get_session() + with session.begin(): + floating_ip_ref = session.query(models.FloatingIp + ).filter_by(host=host + ).filter_by(fixed_ip_id=None + ).filter_by(deleted=False + ).with_lockmode('update' + ).first() # NOTE(vish): if with_lockmode isn't supported, as in sqlite, # then this has concurrency issues if not floating_ip_ref: raise db.NoMoreAddresses() floating_ip_ref['project_id'] = project_id session.add(floating_ip_ref) - session.commit() - return floating_ip_ref['address'] + return floating_ip_ref['address'] -def floating_ip_create(_context, address, host): +def floating_ip_create(_context, values): floating_ip_ref = models.FloatingIp() - floating_ip_ref['address'] = address - floating_ip_ref['host'] = host + for (key, value) in values.iteritems(): + floating_ip_ref[key] = value floating_ip_ref.save() - return floating_ip_ref + return floating_ip_ref['address'] def floating_ip_fixed_ip_associate(_context, floating_address, fixed_address): - with managed_session(autocommit=False) as session: + session = get_session() + with session.begin(): floating_ip_ref = models.FloatingIp.find_by_str(floating_address, session=session) fixed_ip_ref = models.FixedIp.find_by_str(fixed_address, session=session) floating_ip_ref.fixed_ip = fixed_ip_ref floating_ip_ref.save(session=session) - session.commit() def floating_ip_disassociate(_context, address): - with managed_session(autocommit=False) as session: + session = get_session() + with session.begin(): floating_ip_ref = models.FloatingIp.find_by_str(address, session=session) fixed_ip_ref = floating_ip_ref.fixed_ip @@ -110,12 +113,12 @@ def floating_ip_disassociate(_context, address): fixed_ip_address = None floating_ip_ref.fixed_ip = None floating_ip_ref.save(session=session) - session.commit() - return fixed_ip_address + return fixed_ip_address def floating_ip_deallocate(_context, address): - with managed_session(autocommit=False) as session: + session = get_session() + with session.begin(): floating_ip_ref = models.FloatingIp.find_by_str(address, session=session) floating_ip_ref['project_id'] = None @@ -127,7 +130,8 @@ def floating_ip_get_by_address(_context, address): def floating_ip_get_instance(_context, address): - with managed_session() as session: + session = get_session() + with session.begin(): floating_ip_ref = models.FloatingIp.find_by_str(address, session=session) return floating_ip_ref.fixed_ip.instance @@ -137,27 +141,28 @@ def floating_ip_get_instance(_context, address): def fixed_ip_allocate(_context, network_id): - with managed_session(autocommit=False) as session: + session = get_session() + with session.begin(): network_or_none = or_(models.FixedIp.network_id == network_id, models.FixedIp.network_id == None) - fixed_ip_ref = session.query(models.FixedIp) \ - .filter(network_or_none) \ - .filter_by(reserved=False) \ - .filter_by(allocated=False) \ - .filter_by(leased=False) \ - .filter_by(deleted=False) \ - .with_lockmode('update') \ - .first() + fixed_ip_ref = session.query(models.FixedIp + ).filter(network_or_none + ).filter_by(reserved=False + ).filter_by(allocated=False + ).filter_by(leased=False + ).filter_by(deleted=False + ).with_lockmode('update' + ).first() # NOTE(vish): if with_lockmode isn't supported, as in sqlite, # then this has concurrency issues if not fixed_ip_ref: raise db.NoMoreAddresses() if not fixed_ip_ref.network: - fixed_ip_ref.network = models.Network.find(network_id) + fixed_ip_ref.network = models.Network.find(network_id, + session=session) fixed_ip_ref['allocated'] = True session.add(fixed_ip_ref) - session.commit() - return fixed_ip_ref['address'] + return fixed_ip_ref['address'] def fixed_ip_create(_context, values): @@ -173,43 +178,45 @@ def fixed_ip_get_by_address(_context, address): def fixed_ip_get_instance(_context, address): - with managed_session() as session: + session = get_session() + with session.begin(): return models.FixedIp.find_by_str(address, session=session).instance def fixed_ip_get_network(_context, address): - with managed_session() as session: + session = get_session() + with session.begin(): return models.FixedIp.find_by_str(address, session=session).network def fixed_ip_deallocate(context, address): - fixed_ip_ref = fixed_ip_get_by_address(context, address) - fixed_ip_ref['allocated'] = False - fixed_ip_ref.save() + db.fixed_ip_update(context, address, {'allocated': False}) def fixed_ip_instance_associate(_context, address, instance_id): - with managed_session(autocommit=False) as session: + session = get_session() + with session.begin(): fixed_ip_ref = models.FixedIp.find_by_str(address, session=session) instance_ref = models.Instance.find(instance_id, session=session) fixed_ip_ref.instance = instance_ref fixed_ip_ref.save(session=session) - session.commit() def fixed_ip_instance_disassociate(_context, address): - with managed_session(autocommit=False) as session: + session = get_session() + with session.begin(): fixed_ip_ref = models.FixedIp.find_by_str(address, session=session) fixed_ip_ref.instance = None fixed_ip_ref.save(session=session) - session.commit() -def fixed_ip_update(context, address, values): - fixed_ip_ref = fixed_ip_get_by_address(context, address) - for (key, value) in values.iteritems(): - fixed_ip_ref[key] = value - fixed_ip_ref.save() +def fixed_ip_update(_context, address, values): + session = get_session() + with session.begin(): + fixed_ip_ref = models.FixedIp.find_by_str(address, session=session) + for (key, value) in values.iteritems(): + fixed_ip_ref[key] = value + fixed_ip_ref.save(session=session) ################### @@ -223,9 +230,11 @@ def instance_create(_context, values): return instance_ref.id -def instance_destroy(context, instance_id): - instance_ref = instance_get(context, instance_id) - instance_ref.delete() +def instance_destroy(_context, instance_id): + session = get_session() + with session.begin(): + instance_ref = models.Instance.find(instance_id, session=session) + instance_ref.delete(session=session) def instance_get(_context, instance_id): @@ -237,19 +246,19 @@ def instance_get_all(_context): def instance_get_by_project(_context, project_id): - with managed_session() as session: - return session.query(models.Instance) \ - .filter_by(project_id=project_id) \ - .filter_by(deleted=False) \ - .all() + session = get_session() + return session.query(models.Instance + ).filter_by(project_id=project_id + ).filter_by(deleted=False + ).all() def instance_get_by_reservation(_context, reservation_id): - with managed_session() as session: - return session.query(models.Instance) \ - .filter_by(reservation_id=reservation_id) \ - .filter_by(deleted=False) \ - .all() + session = get_session() + return session.query(models.Instance + ).filter_by(reservation_id=reservation_id + ).filter_by(deleted=False + ).all() def instance_get_by_str(_context, str_id): @@ -257,7 +266,8 @@ def instance_get_by_str(_context, str_id): def instance_get_fixed_address(_context, instance_id): - with managed_session() as session: + session = get_session() + with session.begin(): instance_ref = models.Instance.find(instance_id, session=session) if not instance_ref.fixed_ip: return None @@ -265,7 +275,8 @@ def instance_get_fixed_address(_context, instance_id): def instance_get_floating_address(_context, instance_id): - with managed_session() as session: + session = get_session() + with session.begin(): instance_ref = models.Instance.find(instance_id, session=session) if not instance_ref.fixed_ip: return None @@ -281,20 +292,29 @@ def instance_get_host(context, instance_id): def instance_is_vpn(context, instance_id): + # TODO(vish): Move this into image code somewhere instance_ref = instance_get(context, instance_id) return instance_ref['image_id'] == FLAGS.vpn_image_id -def instance_state(context, instance_id, state, description=None): - instance_ref = instance_get(context, instance_id) - instance_ref.set_state(state, description) +def instance_set_state(context, instance_id, state, description=None): + # TODO(devcamcar): Move this out of models and into driver + from nova.compute import power_state + if not description: + description = power_state.name(state) + db.instance_update(context, + instance_id, + {'state': state, + 'state_description': description}) -def instance_update(context, instance_id, values): - instance_ref = instance_get(context, instance_id) - for (key, value) in values.iteritems(): - instance_ref[key] = value - instance_ref.save() +def instance_update(_context, instance_id, values): + session = get_session() + with session.begin(): + instance_ref = models.Instance.find(instance_id, session=session) + for (key, value) in values.iteritems(): + instance_ref[key] = value + instance_ref.save(session=session) ################### @@ -305,31 +325,31 @@ def network_count(_context): def network_count_allocated_ips(_context, network_id): - with managed_session() as session: - return session.query(models.FixedIp) \ - .filter_by(network_id=network_id) \ - .filter_by(allocated=True) \ - .filter_by(deleted=False) \ - .count() + session = get_session() + return session.query(models.FixedIp + ).filter_by(network_id=network_id + ).filter_by(allocated=True + ).filter_by(deleted=False + ).count() def network_count_available_ips(_context, network_id): - with managed_session() as session: - return session.query(models.FixedIp) \ - .filter_by(network_id=network_id) \ - .filter_by(allocated=False) \ - .filter_by(reserved=False) \ - .filter_by(deleted=False) \ - .count() + session = get_session() + return session.query(models.FixedIp + ).filter_by(network_id=network_id + ).filter_by(allocated=False + ).filter_by(reserved=False + ).filter_by(deleted=False + ).count() def network_count_reserved_ips(_context, network_id): - with managed_session() as session: - return session.query(models.FixedIp) \ - .filter_by(network_id=network_id) \ - .filter_by(reserved=True) \ - .filter_by(deleted=False) \ - .count() + session = get_session() + return session.query(models.FixedIp + ).filter_by(network_id=network_id + ).filter_by(reserved=True + ).filter_by(deleted=False + ).count() def network_create(_context, values): @@ -341,7 +361,8 @@ def network_create(_context, values): def network_destroy(_context, network_id): - with managed_session(autocommit=False) as session: + session = get_session() + with session.begin(): # TODO(vish): do we have to use sql here? session.execute('update networks set deleted=1 where id=:id', {'id': network_id}) @@ -355,32 +376,33 @@ def network_destroy(_context, network_id): session.execute('update network_indexes set network_id=NULL ' 'where network_id=:id', {'id': network_id}) - session.commit() def network_get(_context, network_id): return models.Network.find(network_id) +# NOTE(vish): pylint complains because of the long method name, but +# it fits with the names of the rest of the methods # pylint: disable-msg=C0103 def network_get_associated_fixed_ips(_context, network_id): - with managed_session() as session: - return session.query(models.FixedIp) \ - .filter_by(network_id=network_id) \ - .filter(models.FixedIp.instance_id != None) \ - .filter_by(deleted=False) \ - .all() + session = get_session() + return session.query(models.FixedIp + ).filter_by(network_id=network_id + ).filter(models.FixedIp.instance_id != None + ).filter_by(deleted=False + ).all() def network_get_by_bridge(_context, bridge): - with managed_session() as session: - rv = session.query(models.Network) \ - .filter_by(bridge=bridge) \ - .filter_by(deleted=False) \ - .first() - if not rv: - raise exception.NotFound('No network for bridge %s' % bridge) - return rv + session = get_session() + rv = session.query(models.Network + ).filter_by(bridge=bridge + ).filter_by(deleted=False + ).first() + if not rv: + raise exception.NotFound('No network for bridge %s' % bridge) + return rv def network_get_host(context, network_id): @@ -389,19 +411,19 @@ def network_get_host(context, network_id): def network_get_index(_context, network_id): - with managed_session(autocommit=False) as session: - network_index = session.query(models.NetworkIndex) \ - .filter_by(network_id=None) \ - .filter_by(deleted=False) \ - .with_lockmode('update') \ - .first() + session = get_session() + with session.begin(): + network_index = session.query(models.NetworkIndex + ).filter_by(network_id=None + ).filter_by(deleted=False + ).with_lockmode('update' + ).first() if not network_index: raise db.NoMoreNetworks() network_index['network'] = models.Network.find(network_id, session=session) session.add(network_index) - session.commit() - return network_index['index'] + return network_index['index'] def network_index_count(_context): @@ -416,45 +438,45 @@ def network_index_create(_context, values): def network_set_host(_context, network_id, host_id): - with managed_session(autocommit=False) as session: - network = session.query(models.Network) \ - .filter_by(id=network_id) \ - .filter_by(deleted=False) \ - .with_lockmode('update') \ - .first() + session = get_session() + with session.begin(): + network = session.query(models.Network + ).filter_by(id=network_id + ).filter_by(deleted=False + ).with_lockmode('update' + ).first() if not network: raise exception.NotFound("Couldn't find network with %s" % network_id) # NOTE(vish): if with_lockmode isn't supported, as in sqlite, # then this has concurrency issues - if network.host: - session.commit() - return network['host'] - network['host'] = host_id - session.add(network) - session.commit() - return network['host'] + if not network['host']: + network['host'] = host_id + session.add(network) + return network['host'] -def network_update(context, network_id, values): - network_ref = network_get(context, network_id) - for (key, value) in values.iteritems(): - network_ref[key] = value - network_ref.save() +def network_update(_context, network_id, values): + session = get_session() + with session.begin(): + network_ref = models.Network.find(network_id, session=session) + for (key, value) in values.iteritems(): + network_ref[key] = value + network_ref.save(session=session) ################### def project_get_network(_context, project_id): - with managed_session() as session: - rv = session.query(models.Network) \ - .filter_by(project_id=project_id) \ - .filter_by(deleted=False) \ - .first() - if not rv: - raise exception.NotFound('No network for project: %s' % project_id) - return rv + session = get_session() + rv = session.query(models.Network + ).filter_by(project_id=project_id + ).filter_by(deleted=False + ).first() + if not rv: + raise exception.NotFound('No network for project: %s' % project_id) + return rv ################### @@ -483,29 +505,32 @@ def export_device_create(_context, values): def volume_allocate_shelf_and_blade(_context, volume_id): - with managed_session(autocommit=False) as session: - export_device = session.query(models.ExportDevice) \ - .filter_by(volume=None) \ - .filter_by(deleted=False) \ - .with_lockmode('update') \ - .first() + session = get_session() + with session.begin(): + export_device = session.query(models.ExportDevice + ).filter_by(volume=None + ).filter_by(deleted=False + ).with_lockmode('update' + ).first() # NOTE(vish): if with_lockmode isn't supported, as in sqlite, # then this has concurrency issues if not export_device: raise db.NoMoreBlades() export_device.volume_id = volume_id session.add(export_device) - session.commit() - return (export_device.shelf_id, export_device.blade_id) + return (export_device.shelf_id, export_device.blade_id) -def volume_attached(context, volume_id, instance_id, mountpoint): - volume_ref = volume_get(context, volume_id) - volume_ref.instance_id = instance_id - volume_ref['status'] = 'in-use' - volume_ref['mountpoint'] = mountpoint - volume_ref['attach_status'] = 'attached' - volume_ref.save() +def volume_attached(_context, volume_id, instance_id, mountpoint): + session = get_session() + with session.begin(): + volume_ref = models.Volume.find(volume_id, session=session) + volume_ref['status'] = 'in-use' + volume_ref['mountpoint'] = mountpoint + volume_ref['attach_status'] = 'attached' + volume_ref.instance = models.Instance.find(instance_id, + session=session) + volume_ref.save(session=session) def volume_create(_context, values): @@ -517,23 +542,25 @@ def volume_create(_context, values): def volume_destroy(_context, volume_id): - with managed_session(autocommit=False) as session: + session = get_session() + with session.begin(): # TODO(vish): do we have to use sql here? session.execute('update volumes set deleted=1 where id=:id', {'id': volume_id}) session.execute('update export_devices set volume_id=NULL ' 'where volume_id=:id', {'id': volume_id}) - session.commit() -def volume_detached(context, volume_id): - volume_ref = volume_get(context, volume_id) - volume_ref['instance_id'] = None - volume_ref['mountpoint'] = None - volume_ref['status'] = 'available' - volume_ref['attach_status'] = 'detached' - volume_ref.save() +def volume_detached(_context, volume_id): + session = get_session() + with session.begin(): + volume_ref = models.Volume.find(volume_id, session=session) + volume_ref['status'] = 'available' + volume_ref['mountpoint'] = None + volume_ref['attach_status'] = 'detached' + volume_ref.instance = None + volume_ref.save(session=session) def volume_get(_context, volume_id): @@ -545,11 +572,11 @@ def volume_get_all(_context): def volume_get_by_project(_context, project_id): - with managed_session() as session: - return session.query(models.Volume) \ - .filter_by(project_id=project_id) \ - .filter_by(deleted=False) \ - .all() + session = get_session() + return session.query(models.Volume + ).filter_by(project_id=project_id + ).filter_by(deleted=False + ).all() def volume_get_by_str(_context, str_id): @@ -561,27 +588,29 @@ def volume_get_host(context, volume_id): return volume_ref['host'] -def volume_get_instance(context, volume_id): - volume_ref = db.volume_get(context, volume_id) - instance_ref = db.instance_get(context, volume_ref['instance_id']) - return instance_ref +def volume_get_instance(_context, volume_id): + session = get_session() + with session.begin(): + return models.Volume.find(volume_id, session=session).instance def volume_get_shelf_and_blade(_context, volume_id): - with managed_session() as session: - export_device = session.query(models.ExportDevice) \ - .filter_by(volume_id=volume_id) \ - .first() - if not export_device: - raise exception.NotFound() - return (export_device.shelf_id, export_device.blade_id) + session = get_session() + export_device = session.query(models.ExportDevice + ).filter_by(volume_id=volume_id + ).first() + if not export_device: + raise exception.NotFound() + return (export_device.shelf_id, export_device.blade_id) -def volume_update(context, volume_id, values): - volume_ref = volume_get(context, volume_id) - for (key, value) in values.iteritems(): - volume_ref[key] = value - volume_ref.save() +def volume_update(_context, volume_id, values): + session = get_session() + with session.begin(): + volume_ref = models.Volume.find(volume_id, session=session) + for (key, value) in values.iteritems(): + volume_ref[key] = value + volume_ref.save(session=session) ################### @@ -596,48 +625,48 @@ def security_group_create(_context, values): def security_group_get_by_id(_context, security_group_id): - with managed_session() as session: - return session.query(models.SecurityGroup) \ - .options(eagerload('rules')) \ - .get(security_group_id) + session = get_session() + with session.begin(): + return session.query(models.SecurityGroup + ).options(eagerload('rules') + ).get(security_group_id) def security_group_get_by_instance(_context, instance_id): - with managed_session() as session: - return session.query(models.Instance) \ - .get(instance_id) \ - .security_groups \ + session = get_session() + with session.begin(): + return session.query(models.Instance + ).get(instance_id + ).security_groups \ .all() def security_group_get_by_user(_context, user_id): - with managed_session() as session: - return session.query(models.SecurityGroup) \ - .filter_by(user_id=user_id) \ - .filter_by(deleted=False) \ - .options(eagerload('rules')) \ - .all() + session = get_session() + with session.begin(): + return session.query(models.SecurityGroup + ).filter_by(user_id=user_id + ).filter_by(deleted=False + ).options(eagerload('rules') + ).all() def security_group_get_by_user_and_name(_context, user_id, name): - with managed_session() as session: - return session.query(models.SecurityGroup) \ - .filter_by(user_id=user_id) \ - .filter_by(name=name) \ - .filter_by(deleted=False) \ - .options(eagerload('rules')) \ - .one() + session = get_session() + with session.begin(): + return session.query(models.SecurityGroup + ).filter_by(user_id=user_id + ).filter_by(name=name + ).filter_by(deleted=False + ).options(eagerload('rules') + ).one() def security_group_destroy(_context, security_group_id): - with managed_session() as session: - security_group = session.query(models.SecurityGroup) \ - .get(security_group_id) + session = get_session() + with session.begin(): + security_group = session.query(models.SecurityGroup + ).get(security_group_id) security_group.delete(session=session) -def security_group_get_all(_context): - return models.SecurityGroup.all() - - - ################### @@ -650,7 +679,8 @@ def security_group_rule_create(_context, values): return security_group_rule_ref def security_group_rule_destroy(_context, security_group_rule_id): - with managed_session() as session: - security_group_rule = session.query(models.SecurityGroupIngressRule) \ - .get(security_group_rule_id) + session = get_session() + with session.begin(): + security_group_rule = session.query(models.SecurityGroupIngressRule + ).get(security_group_rule_id) security_group_rule.delete(session=session) diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 27c8e4d4c..f27520aa8 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -30,7 +30,7 @@ from sqlalchemy import Column, Integer, String, Table from sqlalchemy import ForeignKey, DateTime, Boolean, Text from sqlalchemy.ext.declarative import declarative_base -from nova.db.sqlalchemy.session import managed_session +from nova.db.sqlalchemy.session import get_session from nova import auth from nova import exception @@ -53,40 +53,34 @@ class NovaBase(object): @classmethod def all(cls, session=None): """Get all objects of this type""" - if session: - return session.query(cls) \ - .filter_by(deleted=False) \ - .all() - else: - with managed_session() as sess: - return cls.all(session=sess) + if not session: + session = get_session() + return session.query(cls + ).filter_by(deleted=False + ).all() @classmethod def count(cls, session=None): """Count objects of this type""" - if session: - return session.query(cls) \ - .filter_by(deleted=False) \ - .count() - else: - with managed_session() as sess: - return cls.count(session=sess) + if not session: + session = get_session() + return session.query(cls + ).filter_by(deleted=False + ).count() @classmethod def find(cls, obj_id, session=None): """Find object by id""" - if session: - try: - return session.query(cls) \ - .filter_by(id=obj_id) \ - .filter_by(deleted=False) \ - .one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for id %s" % obj_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - else: - with managed_session() as sess: - return cls.find(obj_id, session=sess) + if not session: + session = get_session() + try: + return session.query(cls + ).filter_by(id=obj_id + ).filter_by(deleted=False + ).one() + except exc.NoResultFound: + new_exc = exception.NotFound("No model for id %s" % obj_id) + raise new_exc.__class__, new_exc, sys.exc_info()[2] @classmethod def find_by_str(cls, str_id, session=None): @@ -101,12 +95,10 @@ class NovaBase(object): def save(self, session=None): """Save this object""" - if session: - session.add(self) - session.flush() - else: - with managed_session() as sess: - self.save(session=sess) + if not session: + session = get_session() + session.add(self) + session.flush() def delete(self, session=None): """Delete this object""" @@ -175,20 +167,18 @@ class Service(BASE, NovaBase): @classmethod def find_by_args(cls, host, binary, session=None): - if session: - try: - return session.query(cls) \ - .filter_by(host=host) \ - .filter_by(binary=binary) \ - .filter_by(deleted=False) \ - .one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for %s, %s" % (host, - binary)) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - else: - with managed_session() as sess: - return cls.find_by_args(host, binary, session=sess) + if not session: + session = get_session() + try: + return session.query(cls + ).filter_by(host=host + ).filter_by(binary=binary + ).filter_by(deleted=False + ).one() + except exc.NoResultFound: + new_exc = exception.NotFound("No model for %s, %s" % (host, + binary)) + raise new_exc.__class__, new_exc, sys.exc_info()[2] class Instance(BASE, NovaBase): @@ -240,16 +230,6 @@ class Instance(BASE, NovaBase): reservation_id = Column(String(255)) mac_address = Column(String(255)) - def set_state(self, state_code, state_description=None): - """Set the code and description of an instance""" - # TODO(devcamcar): Move this out of models and into driver - from nova.compute import power_state - self.state = state_code - if not state_description: - state_description = power_state.name(state_code) - self.state_description = state_description - self.save() - # TODO(vish): see Ewan's email about state improvements, probably # should be in a driver base class or some such # vmstate_state = running, halted, suspended, paused @@ -275,6 +255,7 @@ class Volume(BASE, NovaBase): size = Column(Integer) availability_zone = Column(String(255)) # TODO(vish): foreign key? instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True) + instance = relationship(Instance, backref=backref('volumes')) mountpoint = Column(String(255)) attach_time = Column(String(255)) # TODO(vish): datetime status = Column(String(255)) # TODO(vish): enum? @@ -405,18 +386,16 @@ class FixedIp(BASE, NovaBase): @classmethod def find_by_str(cls, str_id, session=None): - if session: - try: - return session.query(cls) \ - .filter_by(address=str_id) \ - .filter_by(deleted=False) \ - .one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for address %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - else: - with managed_session() as sess: - return cls.find_by_str(str_id, session=sess) + if not session: + session = get_session() + try: + return session.query(cls + ).filter_by(address=str_id + ).filter_by(deleted=False + ).one() + except exc.NoResultFound: + new_exc = exception.NotFound("No model for address %s" % str_id) + raise new_exc.__class__, new_exc, sys.exc_info()[2] class FloatingIp(BASE, NovaBase): @@ -436,18 +415,16 @@ class FloatingIp(BASE, NovaBase): @classmethod def find_by_str(cls, str_id, session=None): - if session: - try: - return session.query(cls) \ - .filter_by(address=str_id) \ - .filter_by(deleted=False) \ - .one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for address %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - else: - with managed_session() as sess: - return cls.find_by_str(str_id, session=sess) + if not session: + session = get_session() + try: + return session.query(cls + ).filter_by(address=str_id + ).filter_by(deleted=False + ).one() + except exc.NoResultFound: + new_exc = exception.NotFound("No model for address %s" % str_id) + raise new_exc.__class__, new_exc, sys.exc_info()[2] def register_models(): diff --git a/nova/db/sqlalchemy/session.py b/nova/db/sqlalchemy/session.py index adcc42293..69a205378 100644 --- a/nova/db/sqlalchemy/session.py +++ b/nova/db/sqlalchemy/session.py @@ -19,38 +19,24 @@ Session Handling for SQLAlchemy backend """ -import logging - from sqlalchemy import create_engine -from sqlalchemy.orm import create_session +from sqlalchemy.orm import sessionmaker from nova import flags FLAGS = flags.FLAGS - -def managed_session(autocommit=True): - """Helper method to grab session manager""" - return SessionExecutionManager(autocommit=autocommit) - - -class SessionExecutionManager: - """Session manager supporting with .. as syntax""" - _engine = None - _session = None - - def __init__(self, autocommit): - if not self._engine: - self._engine = create_engine(FLAGS.sql_connection, echo=False) - self._session = create_session(bind=self._engine, - autocommit=autocommit) - - def __enter__(self): - return self._session - - def __exit__(self, exc_type, exc_value, traceback): - if exc_type: - logging.exception("Rolling back due to failed transaction: %s", - exc_type) - self._session.rollback() - self._session.close() +_ENGINE = None +_MAKER = None + +def get_session(autocommit=True, expire_on_commit=False): + """Helper method to grab session""" + global _ENGINE + global _MAKER + if not _MAKER: + if not _ENGINE: + _ENGINE = create_engine(FLAGS.sql_connection, echo=False) + _MAKER = sessionmaker(bind=_ENGINE, + autocommit=autocommit, + expire_on_commit=expire_on_commit) + return _MAKER() |
