From fe139bbdee60aadd720cb7a83d0846f2824c078f Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 00:49:04 -0700 Subject: Began wiring up context authorization --- nova/db/sqlalchemy/api.py | 50 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 9c3caf9af..b5847d299 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -19,6 +19,7 @@ Implementation of SQLAlchemy backend """ +import logging import sys from nova import db @@ -48,6 +49,24 @@ def _deleted(context): return context.get('deleted', False) +def is_admin_context(context): + if not context: + logging.warning('Use of empty request context is deprecated') + return True + if not context.user: + return True + return context.user.is_admin() + + +def is_user_context(context): + if not context: + logging.warning('Use of empty request context is deprecated') + return False + if not context.user or not context.project: + return False + return True + + ################### @@ -869,14 +888,41 @@ def volume_detached(_context, volume_id): def volume_get(context, volume_id): - return models.Volume.find(volume_id, deleted=_deleted(context)) + session = get_session() + + if is_admin_context(context): + volume_ref = session.query(models.Volume + ).filter_by(id=volume_id + ).filter_by(deleted=_deleted(context) + ).first() + if not volume_ref: + raise exception.NotFound('No volume for id %s' % volume_id) + + if is_user_context(context): + volume_ref = session.query(models.Volume + ).filter_by(project_id=project_id + ).filter_by(id=volume_id + ).filter_by(deleted=False + ).first() + if not volume_ref: + raise exception.NotFound('No volume for id %s' % volume_id) + + raise exception.NotAuthorized() def volume_get_all(context): - return models.Volume.all(deleted=_deleted(context)) + if is_admin_context(context): + return models.Volume.all(deleted=_deleted(context)) + raise exception.NotAuthorized() def volume_get_all_by_project(context, project_id): + if is_user_context(context): + if context.project.id != project_id: + raise exception.NotAuthorized() + elif not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.Volume ).filter_by(project_id=project_id -- cgit From e258998923b7e8fa92656aa409f875b640df930c Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 13:26:14 -0700 Subject: Progress on volumes Fixed foreign keys to respect deleted flag --- nova/db/sqlalchemy/api.py | 130 ++++++++++++++++++++++++++++++------------- nova/db/sqlalchemy/models.py | 35 +++++++++--- 2 files changed, 118 insertions(+), 47 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index b5847d299..28b937233 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -30,7 +30,7 @@ from nova.db.sqlalchemy import models from nova.db.sqlalchemy.session import get_session from sqlalchemy import or_ from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import joinedload_all +from sqlalchemy.orm import joinedload, joinedload_all from sqlalchemy.sql import exists, func FLAGS = flags.FLAGS @@ -811,6 +811,7 @@ def quota_destroy(_context, project_id): def volume_allocate_shelf_and_blade(_context, volume_id): + # TODO(devcamcar): Make admin only session = get_session() with session.begin(): export_device = session.query(models.ExportDevice @@ -839,7 +840,7 @@ def volume_attached(_context, volume_id, instance_id, mountpoint): volume_ref.save(session=session) -def volume_create(_context, values): +def volume_create(context, values): volume_ref = models.Volume() for (key, value) in values.iteritems(): volume_ref[key] = value @@ -848,7 +849,7 @@ def volume_create(_context, values): with session.begin(): while volume_ref.ec2_id == None: ec2_id = utils.generate_uid(volume_ref.__prefix__) - if not volume_ec2_id_exists(_context, ec2_id, session=session): + if not volume_ec2_id_exists(context, ec2_id, session=session): volume_ref.ec2_id = ec2_id volume_ref.save(session=session) return volume_ref @@ -876,10 +877,10 @@ def volume_destroy(_context, volume_id): {'id': volume_id}) -def volume_detached(_context, volume_id): +def volume_detached(context, volume_id): session = get_session() with session.begin(): - volume_ref = models.Volume.find(volume_id, session=session) + volume_ref = volume_get(context, volume_id, session=session) volume_ref['status'] = 'available' volume_ref['mountpoint'] = None volume_ref['attach_status'] = 'detached' @@ -887,27 +888,29 @@ def volume_detached(_context, volume_id): volume_ref.save(session=session) -def volume_get(context, volume_id): - session = get_session() +def volume_get(context, volume_id, session=None): + if not session: + session = get_session() + result = None if is_admin_context(context): - volume_ref = session.query(models.Volume - ).filter_by(id=volume_id - ).filter_by(deleted=_deleted(context) - ).first() - if not volume_ref: - raise exception.NotFound('No volume for id %s' % volume_id) + result = session.query(models.Volume + ).filter_by(id=volume_id + ).filter_by(deleted=_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Volume + ).filter_by(project_id=context.project.project_id + ).filter_by(id=volume_id + ).filter_by(deleted=False + ).first() + else: + raise exception.NotAuthorized() - if is_user_context(context): - volume_ref = session.query(models.Volume - ).filter_by(project_id=project_id - ).filter_by(id=volume_id - ).filter_by(deleted=False - ).first() - if not volume_ref: - raise exception.NotFound('No volume for id %s' % volume_id) + if not result: + raise exception.NotFound('No volume for id %s' % volume_id) - raise exception.NotAuthorized() + return result def volume_get_all(context): @@ -916,6 +919,7 @@ def volume_get_all(context): raise exception.NotAuthorized() + def volume_get_all_by_project(context, project_id): if is_user_context(context): if context.project.id != project_id: @@ -932,42 +936,92 @@ def volume_get_all_by_project(context, project_id): def volume_get_by_ec2_id(context, ec2_id): session = get_session() - volume_ref = session.query(models.Volume + result = None + + if is_admin_context(context): + result = session.query(models.Volume ).filter_by(ec2_id=ec2_id ).filter_by(deleted=_deleted(context) ).first() - if not volume_ref: - raise exception.NotFound('Volume %s not found' % (ec2_id)) + elif is_user_context(context): + result = session.query(models.Volume + ).filter_by(project_id=context.project.id + ).filter_by(ec2_id=ec2_id + ).filter_by(deleted=False + ).first() + else: + raise exception.NotAuthorized() - return volume_ref + if not result: + raise exception.NotFound('Volume %s not found' % ec2_id) + + return result def volume_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() - return session.query(exists().where(models.Volume.id==ec2_id)).one()[0] + + if is_admin_context(context) or is_user_context(context): + return session.query(exists( + ).where(models.Volume.id==ec2_id) + ).one()[0] + else: + raise exception.NotAuthorized() -def volume_get_instance(_context, volume_id): +def volume_get_instance(context, volume_id): session = get_session() - with session.begin(): - return models.Volume.find(volume_id, session=session).instance + result = None + + if is_admin_context(context): + result = session.query(models.Volume + ).filter_by(id=volume_id + ).filter_by(deleted=_deleted(context) + ).options(joinedload('instance') + ).first() + elif is_user_context(context): + result = session.query(models.Volume + ).filter_by(project_id=context.project.id + ).filter_by(deleted=False + ).options(joinedload('instance') + ).first() + else: + raise exception.NotAuthorized() + + if not result: + raise exception.NotFound('Volume %s not found' % ec2_id) + + return result.instance -def volume_get_shelf_and_blade(_context, volume_id): +def volume_get_shelf_and_blade(context, volume_id): session = get_session() - export_device = session.query(models.ExportDevice - ).filter_by(volume_id=volume_id - ).first() - if not export_device: + result = None + + if is_admin_context(context): + result = session.query(models.ExportDevice + ).filter_by(volume_id=volume_id + ).first() + elif is_user_context(context): + result = session.query(models.ExportDevice + ).join(models.Volume + ).filter(models.Volume.project_id==context.project.id + ).filter_by(volume_id=volume_id + ).first() + else: + raise exception.NotAuthorized() + + if not result: raise exception.NotFound() - return (export_device.shelf_id, export_device.blade_id) + return (result.shelf_id, result.blade_id) -def volume_update(_context, volume_id, values): + +def volume_update(context, volume_id, values): session = get_session() with session.begin(): - volume_ref = models.Volume.find(volume_id, session=session) + volume_ref = volume_get(context, volume_id, session=session) for (key, value) in values.iteritems(): volume_ref[key] = value volume_ref.save(session=session) diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 01e58b05e..1b9edf475 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -282,7 +282,11 @@ class Volume(BASE, NovaBase): size = Column(Integer) availability_zone = Column(String(255)) # TODO(vish): foreign key? instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True) - instance = relationship(Instance, backref=backref('volumes')) + instance = relationship(Instance, + backref=backref('volumes'), + foreign_keys=instance_id, + primaryjoin='and_(Volume.instance_id==Instance.id,' + 'Volume.deleted==False)') mountpoint = Column(String(255)) attach_time = Column(String(255)) # TODO(vish): datetime status = Column(String(255)) # TODO(vish): enum? @@ -333,8 +337,11 @@ class ExportDevice(BASE, NovaBase): shelf_id = Column(Integer) blade_id = Column(Integer) volume_id = Column(Integer, ForeignKey('volumes.id'), nullable=True) - volume = relationship(Volume, backref=backref('export_device', - uselist=False)) + volume = relationship(Volume, + backref=backref('export_device', uselist=False), + foreign_keys=volume_id, + primaryjoin='and_(ExportDevice.volume_id==Volume.id,' + 'ExportDevice.deleted==False)') class KeyPair(BASE, NovaBase): @@ -407,8 +414,12 @@ class NetworkIndex(BASE, NovaBase): id = Column(Integer, primary_key=True) index = Column(Integer, unique=True) network_id = Column(Integer, ForeignKey('networks.id'), nullable=True) - network = relationship(Network, backref=backref('network_index', - uselist=False)) + network = relationship(Network, + backref=backref('network_index', uselist=False), + foreign_keys=network_id, + primaryjoin='and_(NetworkIndex.network_id==Network.id,' + 'NetworkIndex.deleted==False)') + class AuthToken(BASE, NovaBase): """Represents an authorization token for all API transactions. Fields @@ -432,8 +443,11 @@ class FixedIp(BASE, NovaBase): network_id = Column(Integer, ForeignKey('networks.id'), nullable=True) network = relationship(Network, backref=backref('fixed_ips')) instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True) - instance = relationship(Instance, backref=backref('fixed_ip', - uselist=False)) + instance = relationship(Instance, + backref=backref('fixed_ip', uselist=False), + foreign_keys=instance_id, + primaryjoin='and_(FixedIp.instance_id==Instance.id,' + 'FixedIp.deleted==False)') allocated = Column(Boolean, default=False) leased = Column(Boolean, default=False) reserved = Column(Boolean, default=False) @@ -462,8 +476,11 @@ class FloatingIp(BASE, NovaBase): id = Column(Integer, primary_key=True) address = Column(String(255)) fixed_ip_id = Column(Integer, ForeignKey('fixed_ips.id'), nullable=True) - fixed_ip = relationship(FixedIp, backref=backref('floating_ips')) - + fixed_ip = relationship(FixedIp, + backref=backref('floating_ips'), + foreign_keys=fixed_ip_id, + primaryjoin='and_(FloatingIp.fixed_ip_id==FixedIp.id,' + 'FloatingIp.deleted==False)') project_id = Column(String(255)) host = Column(String(255)) # , ForeignKey('hosts.id')) -- cgit From f4cf49ec3761bdd38dd1a6cb064875b90e65ad4e Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 14:27:31 -0700 Subject: Wired up context auth for services --- nova/db/sqlalchemy/api.py | 111 ++++++++++++++++++++++++++++++++++--------- nova/db/sqlalchemy/models.py | 15 ------ 2 files changed, 89 insertions(+), 37 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 28b937233..01a5af38b 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -71,16 +71,37 @@ def is_user_context(context): def service_destroy(context, service_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - service_ref = models.Service.find(service_id, session=session) + service_ref = service_get(context, service_id, session=session) service_ref.delete(session=session) -def service_get(_context, service_id): - return models.Service.find(service_id) + +def service_get(context, service_id, session=None): + if not is_admin_context(context): + raise exception.NotAuthorized() + + if not session: + session = get_session() + + result = session.query(models.Service + ).filter_by(id=service_id + ).filter_by(deleted=_deleted(context) + ).first() + + if not result: + raise exception.NotFound('No service for id %s' % service_id) + + return result def service_get_all_by_topic(context, topic): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.Service ).filter_by(deleted=False @@ -89,7 +110,10 @@ def service_get_all_by_topic(context, topic): ).all() -def _service_get_all_topic_subquery(_context, session, topic, subq, label): +def _service_get_all_topic_subquery(context, session, topic, subq, label): + if not is_admin_context(context): + raise exception.NotAuthorized() + sort_value = getattr(subq.c, label) return session.query(models.Service, func.coalesce(sort_value, 0) ).filter_by(topic=topic @@ -101,6 +125,9 @@ def _service_get_all_topic_subquery(_context, session, topic, subq, label): def service_get_all_compute_sorted(context): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): # NOTE(vish): The intended query is below @@ -125,6 +152,9 @@ def service_get_all_compute_sorted(context): def service_get_all_network_sorted(context): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): topic = 'network' @@ -142,6 +172,9 @@ def service_get_all_network_sorted(context): def service_get_all_volume_sorted(context): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): topic = 'volume' @@ -158,11 +191,27 @@ def service_get_all_volume_sorted(context): label) -def service_get_by_args(_context, host, binary): - return models.Service.find_by_args(host, binary) +def service_get_by_args(context, host, binary): + if not is_admin_context(context): + raise exception.NotAuthorized() + + session = get_session() + result = session.query(models.Service + ).filter_by(host=host + ).filter_by(binary=binary + ).filter_by(deleted=_deleted(context) + ).first() + + if not result: + raise exception.NotFound('No service for %s, %s' % (host, binary)) + + return result + +def service_create(context, values): + if not is_admin_context(context): + return exception.NotAuthorized() -def service_create(_context, values): service_ref = models.Service() for (key, value) in values.iteritems(): service_ref[key] = value @@ -170,10 +219,13 @@ def service_create(_context, values): return service_ref -def service_update(_context, service_id, values): +def service_update(context, service_id, values): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - service_ref = models.Service.find(service_id, session=session) + service_ref = session_get(context, service_id, session=session) for (key, value) in values.iteritems(): service_ref[key] = value service_ref.save(session=session) @@ -428,8 +480,8 @@ def instance_destroy(_context, instance_id): instance_ref.delete(session=session) -def instance_get(context, instance_id): - return models.Instance.find(instance_id, deleted=_deleted(context)) +def instance_get(context, instance_id, session=None): + return models.Instance.find(instance_id, session=session, deleted=_deleted(context)) def instance_get_all(context): @@ -810,8 +862,10 @@ def quota_destroy(_context, project_id): ################### -def volume_allocate_shelf_and_blade(_context, volume_id): - # TODO(devcamcar): Make admin only +def volume_allocate_shelf_and_blade(context, volume_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): export_device = session.query(models.ExportDevice @@ -828,15 +882,17 @@ def volume_allocate_shelf_and_blade(_context, volume_id): return (export_device.shelf_id, export_device.blade_id) -def volume_attached(_context, volume_id, instance_id, mountpoint): +def volume_attached(context, volume_id, instance_id, mountpoint): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - volume_ref = models.Volume.find(volume_id, session=session) + volume_ref = volume_get(context, volume_id, session=session) volume_ref['status'] = 'in-use' volume_ref['mountpoint'] = mountpoint volume_ref['attach_status'] = 'attached' - volume_ref.instance = models.Instance.find(instance_id, - session=session) + volume_ref.instance = instance_get(context, instance_id, session=session) volume_ref.save(session=session) @@ -855,7 +911,10 @@ def volume_create(context, values): return volume_ref -def volume_data_get_for_project(_context, project_id): +def volume_data_get_for_project(context, project_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() result = session.query(func.count(models.Volume.id), func.sum(models.Volume.size) @@ -866,7 +925,10 @@ def volume_data_get_for_project(_context, project_id): return (result[0] or 0, result[1] or 0) -def volume_destroy(_context, volume_id): +def volume_destroy(context, volume_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -878,6 +940,9 @@ def volume_destroy(_context, volume_id): def volume_detached(context, volume_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): volume_ref = volume_get(context, volume_id, session=session) @@ -914,10 +979,12 @@ def volume_get(context, volume_id, session=None): def volume_get_all(context): - if is_admin_context(context): - return models.Volume.all(deleted=_deleted(context)) + if not is_admin_context(context): + raise exception.NotAuthorized() - raise exception.NotAuthorized() + return session.query(models.Volume + ).filter_by(deleted=_deleted(context) + ).all() def volume_get_all_by_project(context, project_id): diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 1b9edf475..b9bb8e4f2 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -176,21 +176,6 @@ class Service(BASE, NovaBase): report_count = Column(Integer, nullable=False, default=0) disabled = Column(Boolean, default=False) - @classmethod - def find_by_args(cls, host, binary, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(host=host - ).filter_by(binary=binary - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for %s, %s" % (host, - binary)) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - class Instance(BASE, NovaBase): """Represents a guest vm""" -- cgit From 734df1fbad8195e7cd7072d0d0aeb5b94841f121 Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 19:09:00 -0700 Subject: Made network tests pass again --- nova/db/api.py | 1 - nova/db/sqlalchemy/api.py | 233 +++++++++++++++++++++++++++++------------ nova/db/sqlalchemy/models.py | 26 ----- nova/network/manager.py | 3 +- nova/tests/network_unittest.py | 1 + 5 files changed, 170 insertions(+), 94 deletions(-) diff --git a/nova/db/api.py b/nova/db/api.py index b68a0fe8f..4cfdd788c 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -175,7 +175,6 @@ def floating_ip_get_by_address(context, address): return IMPL.floating_ip_get_by_address(context, address) -def floating_ip_get_instance(context, address): """Get an instance for a floating ip by address.""" return IMPL.floating_ip_get_instance(context, address) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 01a5af38b..d129df2be 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -234,7 +234,13 @@ def service_update(context, service_id, values): ################### -def floating_ip_allocate_address(_context, host, project_id): +def floating_ip_allocate_address(context, host, project_id): + if is_user_context(context): + if context.project.id != project_id: + raise exception.NotAuthorized() + elif not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): floating_ip_ref = session.query(models.FloatingIp @@ -253,7 +259,10 @@ def floating_ip_allocate_address(_context, host, project_id): return floating_ip_ref['address'] -def floating_ip_create(_context, values): +def floating_ip_create(context, values): + if not is_user_context(context) and not is_admin_context(context): + raise exception.NotAuthorized() + floating_ip_ref = models.FloatingIp() for (key, value) in values.iteritems(): floating_ip_ref[key] = value @@ -261,7 +270,13 @@ def floating_ip_create(_context, values): return floating_ip_ref['address'] -def floating_ip_count_by_project(_context, project_id): +def floating_ip_count_by_project(context, project_id): + if is_user_context(context): + if context.project.id != project_id: + raise exception.NotAuthorized() + elif not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.FloatingIp ).filter_by(project_id=project_id @@ -269,39 +284,63 @@ def floating_ip_count_by_project(_context, project_id): ).count() -def floating_ip_fixed_ip_associate(_context, floating_address, fixed_address): +#@require_context +def floating_ip_fixed_ip_associate(context, floating_address, fixed_address): + if not is_user_context(context) and not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(floating_address, - session=session) - fixed_ip_ref = models.FixedIp.find_by_str(fixed_address, - session=session) + # TODO(devcamcar): How to ensure floating_id belongs to user? + floating_ip_ref = floating_ip_get_by_address(context, + floating_address, + session=session) + fixed_ip_ref = fixed_ip_get_by_address(context, + fixed_address, + session=session) floating_ip_ref.fixed_ip = fixed_ip_ref floating_ip_ref.save(session=session) -def floating_ip_deallocate(_context, address): +#@require_context +def floating_ip_deallocate(context, address): + if not is_user_context(context) and not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) + # TODO(devcamcar): How to ensure floating id belongs to user? + floating_ip_ref = floating_ip_get_by_address(context, + address, + session=session) floating_ip_ref['project_id'] = None floating_ip_ref.save(session=session) +#@require_context +def floating_ip_destroy(context, address): + if not is_user_context(context) and not is_admin_context(context): + raise exception.NotAuthorized() -def floating_ip_destroy(_context, address): session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) + # TODO(devcamcar): Ensure address belongs to user. + floating_ip_ref = get_floating_ip_by_address(context, + address, + session=session) floating_ip_ref.delete(session=session) -def floating_ip_disassociate(_context, address): +def floating_ip_disassociate(context, address): + if not is_user_context(context) and is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) + # TODO(devcamcar): Ensure address belongs to user. + # Does get_floating_ip_by_address handle this? + floating_ip_ref = floating_ip_get_by_address(context, + address, + session=session) fixed_ip_ref = floating_ip_ref.fixed_ip if fixed_ip_ref: fixed_ip_address = fixed_ip_ref['address'] @@ -311,16 +350,22 @@ def floating_ip_disassociate(_context, address): floating_ip_ref.save(session=session) return fixed_ip_address +#@require_admin_context +def floating_ip_get_all(context): + if not is_admin_context(context): + raise exception.NotAuthorized() -def floating_ip_get_all(_context): session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') ).filter_by(deleted=False ).all() +#@require_admin_context +def floating_ip_get_all_by_host(context, host): + if not is_admin_context(context): + raise exception.NotAuthorized() -def floating_ip_get_all_by_host(_context, host): session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -328,7 +373,15 @@ def floating_ip_get_all_by_host(_context, host): ).filter_by(deleted=False ).all() -def floating_ip_get_all_by_project(_context, project_id): +#@require_context +def floating_ip_get_all_by_project(context, project_id): + # TODO(devcamcar): Change to decorate and check project_id separately. + if is_user_context(context): + if context.project.id != project_id: + raise exception.NotAuthorized() + elif not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -336,22 +389,38 @@ def floating_ip_get_all_by_project(_context, project_id): ).filter_by(deleted=False ).all() -def floating_ip_get_by_address(_context, address): - return models.FloatingIp.find_by_str(address) +#@require_context +def floating_ip_get_by_address(context, address, session=None): + # TODO(devcamcar): Ensure the address belongs to user. + if not is_user_context(context) and not is_admin_context(context): + raise exception.NotAuthorized() + + if not session: + session = get_session() + + result = session.query(models.FloatingIp + ).filter_by(address=address + ).filter_by(deleted=_deleted(context) + ).first() + if not result: + raise exception.NotFound('No fixed ip for address %s' % address) + return result -def floating_ip_get_instance(_context, address): - session = get_session() - with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) - return floating_ip_ref.fixed_ip.instance + + # floating_ip_ref = get_floating_ip_by_address(context, + # address, + # session=session) + # return floating_ip_ref.fixed_ip.instance ################### +#@require_context +def fixed_ip_associate(context, address, instance_id): + if not is_user_context(context) and not is_admin_context(context): + raise exception.NotAuthorized() -def fixed_ip_associate(_context, address, instance_id): session = get_session() with session.begin(): fixed_ip_ref = session.query(models.FixedIp @@ -364,12 +433,17 @@ def fixed_ip_associate(_context, address, instance_id): # then this has concurrency issues if not fixed_ip_ref: raise db.NoMoreAddresses() - fixed_ip_ref.instance = models.Instance.find(instance_id, - session=session) + fixed_ip_ref.instance = instance_get(context, + instance_id, + session=session) session.add(fixed_ip_ref) -def fixed_ip_associate_pool(_context, network_id, instance_id): +#@require_admin_context +def fixed_ip_associate_pool(context, network_id, instance_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): network_or_none = or_(models.FixedIp.network_id == network_id, @@ -386,14 +460,16 @@ def fixed_ip_associate_pool(_context, network_id, instance_id): if not fixed_ip_ref: raise db.NoMoreAddresses() if not fixed_ip_ref.network: - fixed_ip_ref.network = models.Network.find(network_id, - session=session) - fixed_ip_ref.instance = models.Instance.find(instance_id, - session=session) + fixed_ip_ref.network = network_get(context, + network_id, + session=session) + fixed_ip_ref.instance = instance_get(context, + instance_id, + session=session) session.add(fixed_ip_ref) return fixed_ip_ref['address'] - +#@require_context def fixed_ip_create(_context, values): fixed_ip_ref = models.FixedIp() for (key, value) in values.iteritems(): @@ -401,45 +477,56 @@ def fixed_ip_create(_context, values): fixed_ip_ref.save() return fixed_ip_ref['address'] - -def fixed_ip_disassociate(_context, address): +#@require_context +def fixed_ip_disassociate(context, address): session = get_session() with session.begin(): - fixed_ip_ref = models.FixedIp.find_by_str(address, session=session) + fixed_ip_ref = fixed_ip_get_by_address(context, + address, + session=session) fixed_ip_ref.instance = None fixed_ip_ref.save(session=session) -def fixed_ip_get_by_address(_context, address): - session = get_session() - with session.begin(): - try: - return session.query(models.FixedIp - ).options(joinedload_all('instance') - ).filter_by(address=address - ).filter_by(deleted=False - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for address %s" % address) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - - -def fixed_ip_get_instance(_context, address): - session = get_session() - with session.begin(): - return models.FixedIp.find_by_str(address, session=session).instance +#@require_context +def fixed_ip_get_by_address(context, address, session=None): + # TODO(devcamcar): Ensure floating ip belongs to user. + # Only possible if it is associated with an instance. + # May have to use system context for this always. + if not session: + session = get_session() + result = session.query(models.FixedIp + ).filter_by(address=address + ).filter_by(deleted=_deleted(context) + ).options(joinedload('network') + ).options(joinedload('instance') + ).first() + if not result: + raise exception.NotFound('No floating ip for address %s' % address) + + return result + + +#@require_context +def fixed_ip_get_instance(context, address): + fixed_ip_ref = fixed_ip_get_by_address(context, address) + return fixed_ip_ref.instance -def fixed_ip_get_network(_context, address): - session = get_session() - with session.begin(): - return models.FixedIp.find_by_str(address, session=session).network +#@require_admin_context +def fixed_ip_get_network(context, address): + fixed_ip_ref = fixed_ip_get_by_address(context, address) + return fixed_ip_ref.network -def fixed_ip_update(_context, address, values): + +#@require_context +def fixed_ip_update(context, address, values): session = get_session() with session.begin(): - fixed_ip_ref = models.FixedIp.find_by_str(address, session=session) + fixed_ip_ref = fixed_ip_get_by_address(context, + address, + session=session) for (key, value) in values.iteritems(): fixed_ip_ref[key] = value fixed_ip_ref.save(session=session) @@ -462,7 +549,9 @@ def instance_create(_context, values): instance_ref.save(session=session) return instance_ref + def instance_data_get_for_project(_context, project_id): + # TODO(devmcar): Admin only session = get_session() result = session.query(func.count(models.Instance.id), func.sum(models.Instance.vcpus) @@ -474,6 +563,7 @@ def instance_data_get_for_project(_context, project_id): def instance_destroy(_context, instance_id): + # TODO(devcamcar): Support user context session = get_session() with session.begin(): instance_ref = models.Instance.find(instance_id, session=session) @@ -481,17 +571,21 @@ def instance_destroy(_context, instance_id): def instance_get(context, instance_id, session=None): + # TODO(devcamcar): Support user context return models.Instance.find(instance_id, session=session, deleted=_deleted(context)) def instance_get_all(context): + # TODO(devcamcar): Admin only session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).filter_by(deleted=_deleted(context) ).all() + def instance_get_all_by_user(context, user_id): + # TODO(devcamcar): Admin only session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -499,7 +593,9 @@ def instance_get_all_by_user(context, user_id): ).filter_by(user_id=user_id ).all() + def instance_get_all_by_project(context, project_id): + # TODO(devcamcar): Support user context session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -509,6 +605,7 @@ def instance_get_all_by_project(context, project_id): def instance_get_all_by_reservation(_context, reservation_id): + # TODO(devcamcar): Support user context session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -518,6 +615,7 @@ def instance_get_all_by_reservation(_context, reservation_id): def instance_get_by_ec2_id(context, ec2_id): + # TODO(devcamcar): Support user context session = get_session() instance_ref = session.query(models.Instance ).filter_by(ec2_id=ec2_id @@ -536,6 +634,7 @@ def instance_ec2_id_exists(context, ec2_id, session=None): def instance_get_fixed_address(_context, instance_id): + # TODO(devcamcar): Support user context session = get_session() with session.begin(): instance_ref = models.Instance.find(instance_id, session=session) @@ -545,6 +644,7 @@ def instance_get_fixed_address(_context, instance_id): def instance_get_floating_address(_context, instance_id): + # TODO(devcamcar): Support user context session = get_session() with session.begin(): instance_ref = models.Instance.find(instance_id, session=session) @@ -557,6 +657,7 @@ def instance_get_floating_address(_context, instance_id): def instance_is_vpn(context, instance_id): + # TODO(devcamcar): Admin only # TODO(vish): Move this into image code somewhere instance_ref = instance_get(context, instance_id) return instance_ref['image_id'] == FLAGS.vpn_image_id @@ -683,8 +784,8 @@ def network_destroy(_context, network_id): {'id': network_id}) -def network_get(_context, network_id): - return models.Network.find(network_id) +def network_get(_context, network_id, session=None): + return models.Network.find(network_id, session=session) # NOTE(vish): pylint complains because of the long method name, but diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index b9bb8e4f2..7a085c4df 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -441,19 +441,6 @@ class FixedIp(BASE, NovaBase): def str_id(self): return self.address - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(address=str_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for address %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - class FloatingIp(BASE, NovaBase): """Represents a floating ip that dynamically forwards to a fixed ip""" @@ -469,19 +456,6 @@ class FloatingIp(BASE, NovaBase): project_id = Column(String(255)) host = Column(String(255)) # , ForeignKey('hosts.id')) - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(address=str_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for address %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - def register_models(): """Register Models and create metadata""" diff --git a/nova/network/manager.py b/nova/network/manager.py index a7126ea4f..d125d28d8 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -232,7 +232,8 @@ class VlanManager(NetworkManager): address = network_ref['vpn_private_address'] self.db.fixed_ip_associate(context, address, instance_id) else: - address = self.db.fixed_ip_associate_pool(context, + # TODO(devcamcar) Pass system context here. + address = self.db.fixed_ip_associate_pool(None, network_ref['id'], instance_id) self.db.fixed_ip_update(context, address, {'allocated': True}) diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index da65b50a2..110e8430c 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -84,6 +84,7 @@ class NetworkTestCase(test.TrialTestCase): def test_public_network_association(self): """Makes sure that we can allocaate a public ip""" # TODO(vish): better way of adding floating ips + self.context.project = self.projects[0] pubnet = IPy.IP(flags.FLAGS.public_range) address = str(pubnet[0]) try: -- cgit From d32d95e08d67084ea04ccd1565ce6faffb1766ce Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 20:29:55 -0700 Subject: Finished instance context auth --- nova/db/sqlalchemy/api.py | 185 ++++++++++++++++++++++++++++++----------- nova/tests/compute_unittest.py | 2 + nova/tests/network_unittest.py | 4 +- 3 files changed, 141 insertions(+), 50 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index d129df2be..9ab53b89b 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -41,9 +41,10 @@ FLAGS = flags.FLAGS # pylint: disable-msg=C0111 def _deleted(context): """Calculates whether to include deleted objects based on context. - - Currently just looks for a flag called deleted in the context dict. + Currently just looks for a flag called deleted in the context dict. """ + if is_user_context(context): + return False if not hasattr(context, 'get'): return False return context.get('deleted', False) @@ -69,7 +70,7 @@ def is_user_context(context): ################### - +#@require_admin_context def service_destroy(context, service_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -80,6 +81,7 @@ def service_destroy(context, service_id): service_ref.delete(session=session) +#@require_admin_context def service_get(context, service_id, session=None): if not is_admin_context(context): raise exception.NotAuthorized() @@ -98,6 +100,7 @@ def service_get(context, service_id, session=None): return result +#@require_admin_context def service_get_all_by_topic(context, topic): if not is_admin_context(context): raise exception.NotAuthorized() @@ -110,6 +113,7 @@ def service_get_all_by_topic(context, topic): ).all() +#@require_admin_context def _service_get_all_topic_subquery(context, session, topic, subq, label): if not is_admin_context(context): raise exception.NotAuthorized() @@ -124,6 +128,7 @@ def _service_get_all_topic_subquery(context, session, topic, subq, label): ).all() +#@require_admin_context def service_get_all_compute_sorted(context): if not is_admin_context(context): raise exception.NotAuthorized() @@ -151,6 +156,7 @@ def service_get_all_compute_sorted(context): label) +#@require_admin_context def service_get_all_network_sorted(context): if not is_admin_context(context): raise exception.NotAuthorized() @@ -171,6 +177,7 @@ def service_get_all_network_sorted(context): label) +#@require_admin_context def service_get_all_volume_sorted(context): if not is_admin_context(context): raise exception.NotAuthorized() @@ -191,6 +198,7 @@ def service_get_all_volume_sorted(context): label) +#@require_admin_context def service_get_by_args(context, host, binary): if not is_admin_context(context): raise exception.NotAuthorized() @@ -208,6 +216,7 @@ def service_get_by_args(context, host, binary): return result +#@require_admin_context def service_create(context, values): if not is_admin_context(context): return exception.NotAuthorized() @@ -219,6 +228,7 @@ def service_create(context, values): return service_ref +#@require_admin_context def service_update(context, service_id, values): if not is_admin_context(context): raise exception.NotAuthorized() @@ -234,12 +244,11 @@ def service_update(context, service_id, values): ################### +#@require_context def floating_ip_allocate_address(context, host, project_id): if is_user_context(context): if context.project.id != project_id: raise exception.NotAuthorized() - elif not is_admin_context(context): - raise exception.NotAuthorized() session = get_session() with session.begin(): @@ -259,6 +268,7 @@ def floating_ip_allocate_address(context, host, project_id): return floating_ip_ref['address'] +#@require_context def floating_ip_create(context, values): if not is_user_context(context) and not is_admin_context(context): raise exception.NotAuthorized() @@ -270,12 +280,11 @@ def floating_ip_create(context, values): return floating_ip_ref['address'] +#@require_context def floating_ip_count_by_project(context, project_id): if is_user_context(context): if context.project.id != project_id: raise exception.NotAuthorized() - elif not is_admin_context(context): - raise exception.NotAuthorized() session = get_session() return session.query(models.FloatingIp @@ -316,6 +325,7 @@ def floating_ip_deallocate(context, address): floating_ip_ref['project_id'] = None floating_ip_ref.save(session=session) + #@require_context def floating_ip_destroy(context, address): if not is_user_context(context) and not is_admin_context(context): @@ -330,6 +340,7 @@ def floating_ip_destroy(context, address): floating_ip_ref.delete(session=session) +#@require_context def floating_ip_disassociate(context, address): if not is_user_context(context) and is_admin_context(context): raise exception.NotAuthorized() @@ -350,6 +361,7 @@ def floating_ip_disassociate(context, address): floating_ip_ref.save(session=session) return fixed_ip_address + #@require_admin_context def floating_ip_get_all(context): if not is_admin_context(context): @@ -361,6 +373,7 @@ def floating_ip_get_all(context): ).filter_by(deleted=False ).all() + #@require_admin_context def floating_ip_get_all_by_host(context, host): if not is_admin_context(context): @@ -373,6 +386,7 @@ def floating_ip_get_all_by_host(context, host): ).filter_by(deleted=False ).all() + #@require_context def floating_ip_get_all_by_project(context, project_id): # TODO(devcamcar): Change to decorate and check project_id separately. @@ -389,6 +403,7 @@ def floating_ip_get_all_by_project(context, project_id): ).filter_by(deleted=False ).all() + #@require_context def floating_ip_get_by_address(context, address, session=None): # TODO(devcamcar): Ensure the address belongs to user. @@ -408,14 +423,9 @@ def floating_ip_get_by_address(context, address, session=None): return result - # floating_ip_ref = get_floating_ip_by_address(context, - # address, - # session=session) - # return floating_ip_ref.fixed_ip.instance - - ################### + #@require_context def fixed_ip_associate(context, address, instance_id): if not is_user_context(context) and not is_admin_context(context): @@ -469,6 +479,7 @@ def fixed_ip_associate_pool(context, network_id, instance_id): session.add(fixed_ip_ref) return fixed_ip_ref['address'] + #@require_context def fixed_ip_create(_context, values): fixed_ip_ref = models.FixedIp() @@ -477,6 +488,7 @@ def fixed_ip_create(_context, values): fixed_ip_ref.save() return fixed_ip_ref['address'] + #@require_context def fixed_ip_disassociate(context, address): session = get_session() @@ -535,7 +547,8 @@ def fixed_ip_update(context, address, values): ################### -def instance_create(_context, values): +#@require_context +def instance_create(context, values): instance_ref = models.Instance() for (key, value) in values.iteritems(): instance_ref[key] = value @@ -544,14 +557,14 @@ def instance_create(_context, values): with session.begin(): while instance_ref.ec2_id == None: ec2_id = utils.generate_uid(instance_ref.__prefix__) - if not instance_ec2_id_exists(_context, ec2_id, session=session): + if not instance_ec2_id_exists(context, ec2_id, session=session): instance_ref.ec2_id = ec2_id instance_ref.save(session=session) return instance_ref -def instance_data_get_for_project(_context, project_id): - # TODO(devmcar): Admin only +#@require_admin_context +def instance_data_get_for_project(context, project_id): session = get_session() result = session.query(func.count(models.Instance.id), func.sum(models.Instance.vcpus) @@ -562,21 +575,42 @@ def instance_data_get_for_project(_context, project_id): return (result[0] or 0, result[1] or 0) -def instance_destroy(_context, instance_id): - # TODO(devcamcar): Support user context +#@require_context +def instance_destroy(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) instance_ref.delete(session=session) +#@require_context def instance_get(context, instance_id, session=None): - # TODO(devcamcar): Support user context - return models.Instance.find(instance_id, session=session, deleted=_deleted(context)) + if not session: + session = get_session() + result = None + + if is_admin_context(context): + result = session.query(models.Instance + ).filter_by(id=instance_id + ).filter_by(deleted=_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Instance + ).filter_by(project_id=context.project.id + ).filter_by(id=instance_id + ).filter_by(deleted=False + ).first() + if not result: + raise exception.NotFound('No instance for id %s' % instance_id) + + return result +#@require_admin_context def instance_get_all(context): - # TODO(devcamcar): Admin only + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -584,8 +618,11 @@ def instance_get_all(context): ).all() +#@require_admin_context def instance_get_all_by_user(context, user_id): - # TODO(devcamcar): Admin only + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -594,8 +631,12 @@ def instance_get_all_by_user(context, user_id): ).all() +#@require_context def instance_get_all_by_project(context, project_id): - # TODO(devcamcar): Support user context + if is_user_context(context): + if context.project.id != project_id: + raise exception.NotAuthorized() + session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -604,50 +645,68 @@ def instance_get_all_by_project(context, project_id): ).all() -def instance_get_all_by_reservation(_context, reservation_id): - # TODO(devcamcar): Support user context +#@require_context +def instance_get_all_by_reservation(context, reservation_id): session = get_session() - return session.query(models.Instance - ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(reservation_id=reservation_id - ).filter_by(deleted=False - ).all() + + if is_admin_context(context): + return session.query(models.Instance + ).options(joinedload_all('fixed_ip.floating_ips') + ).filter_by(reservation_id=reservation_id + ).filter_by(deleted=_deleted(context) + ).all() + elif is_user_context(context): + return session.query(models.Instance + ).options(joinedload_all('fixed_ip.floating_ips') + ).filter_by(project_id=context.project.id + ).filter_by(reservation_id=reservation_id + ).filter_by(deleted=False + ).all() +#@require_context def instance_get_by_ec2_id(context, ec2_id): - # TODO(devcamcar): Support user context session = get_session() - instance_ref = session.query(models.Instance + + if is_admin_context(context): + result = session.query(models.Instance ).filter_by(ec2_id=ec2_id ).filter_by(deleted=_deleted(context) ).first() - if not instance_ref: + elif is_user_context(context): + result = session.query(models.Instance + ).filter_by(project_id=context.project.id + ).filter_by(ec2_id=ec2_id + ).filter_by(deleted=False + ).first() + if not result: raise exception.NotFound('Instance %s not found' % (ec2_id)) - return instance_ref + return result +#@require_context def instance_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() return session.query(exists().where(models.Instance.id==ec2_id)).one()[0] -def instance_get_fixed_address(_context, instance_id): - # TODO(devcamcar): Support user context +#@require_context +def instance_get_fixed_address(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) if not instance_ref.fixed_ip: return None return instance_ref.fixed_ip['address'] -def instance_get_floating_address(_context, instance_id): - # TODO(devcamcar): Support user context +#@require_context +def instance_get_floating_address(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) if not instance_ref.fixed_ip: return None if not instance_ref.fixed_ip.floating_ips: @@ -656,14 +715,20 @@ def instance_get_floating_address(_context, instance_id): return instance_ref.fixed_ip.floating_ips[0]['address'] +#@require_admin_context def instance_is_vpn(context, instance_id): - # TODO(devcamcar): Admin only + if not is_admin_context(context): + raise exception.NotAuthorized() # TODO(vish): Move this into image code somewhere instance_ref = instance_get(context, instance_id) return instance_ref['image_id'] == FLAGS.vpn_image_id +#@require_admin_context def instance_set_state(context, instance_id, state, description=None): + if not is_admin_context(context): + raise exception.NotAuthorized() + # TODO(devcamcar): Move this out of models and into driver from nova.compute import power_state if not description: @@ -674,10 +739,11 @@ def instance_set_state(context, instance_id, state, description=None): 'state_description': description}) -def instance_update(_context, instance_id, values): +#@require_context +def instance_update(context, instance_id, values): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) for (key, value) in values.iteritems(): instance_ref[key] = value instance_ref.save(session=session) @@ -686,6 +752,7 @@ def instance_update(_context, instance_id, values): ################### +#@require_context def key_pair_create(_context, values): key_pair_ref = models.KeyPair() for (key, value) in values.iteritems(): @@ -694,7 +761,8 @@ def key_pair_create(_context, values): return key_pair_ref -def key_pair_destroy(_context, user_id, name): +#@require_context +def key_pair_destroy(context, user_id, name): session = get_session() with session.begin(): key_pair_ref = models.KeyPair.find_by_args(user_id, @@ -784,8 +852,27 @@ def network_destroy(_context, network_id): {'id': network_id}) -def network_get(_context, network_id, session=None): - return models.Network.find(network_id, session=session) +#@require_context +def network_get(context, network_id, session=None): + if not session: + session = get_session() + result = None + + if is_admin_context(context): + result = session.query(models.Network + ).filter_by(id=network_id + ).filter_by(deleted=_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Network + ).filter_by(project_id=context.project.id + ).filter_by(id=network_id + ).filter_by(deleted=False + ).first() + if not result: + raise exception.NotFound('No network for id %s' % network_id) + + return result # NOTE(vish): pylint complains because of the long method name, but @@ -1066,7 +1153,7 @@ def volume_get(context, volume_id, session=None): ).first() elif is_user_context(context): result = session.query(models.Volume - ).filter_by(project_id=context.project.project_id + ).filter_by(project_id=context.project.id ).filter_by(id=volume_id ).filter_by(deleted=False ).first() diff --git a/nova/tests/compute_unittest.py b/nova/tests/compute_unittest.py index f5c0f1c09..e705c2552 100644 --- a/nova/tests/compute_unittest.py +++ b/nova/tests/compute_unittest.py @@ -96,6 +96,8 @@ class ComputeTestCase(test.TrialTestCase): self.assertEqual(instance_ref['deleted_at'], None) terminate = datetime.datetime.utcnow() yield self.compute.terminate_instance(self.context, instance_id) + # TODO(devcamcar): Pass deleted in using system context. + # context.read_deleted ? instance_ref = db.instance_get({'deleted': True}, instance_id) self.assert_(instance_ref['launched_at'] < terminate) self.assert_(instance_ref['deleted_at'] > terminate) diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index 110e8430c..ca6a4bbc2 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -56,7 +56,9 @@ class NetworkTestCase(test.TrialTestCase): 'netuser', name)) # create the necessary network data for the project - self.network.set_network_host(self.context, self.projects[i].id) + user_context = context.APIRequestContext(project=self.projects[i], + user=self.user) + self.network.set_network_host(user_context, self.projects[i].id) instance_ref = db.instance_create(None, {'mac_address': utils.generate_mac()}) self.instance_id = instance_ref['id'] -- cgit From ea5dcda819f2656589df177331f693f945d98f4a Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 20:35:24 -0700 Subject: Finished instance context auth --- nova/db/sqlalchemy/api.py | 32 +++++++++++++++++++++++++++++--- nova/tests/network_unittest.py | 1 + 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 9ab53b89b..2d553d98d 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -794,11 +794,21 @@ def key_pair_get_all_by_user(_context, user_id): ################### -def network_count(_context): - return models.Network.count() +#@require_admin_context +def network_count(context): + if not is_admin_context(context): + raise exception.NotAuthorized() + return session.query(models.Network + ).filter_by(deleted=deleted + ).count() + +#@require_admin_context def network_count_allocated_ips(_context, network_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -807,7 +817,11 @@ def network_count_allocated_ips(_context, network_id): ).count() +#@require_admin_context def network_count_available_ips(_context, network_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -817,7 +831,11 @@ def network_count_available_ips(_context, network_id): ).count() +#@require_admin_context def network_count_reserved_ips(_context, network_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -826,7 +844,11 @@ def network_count_reserved_ips(_context, network_id): ).count() +#@require_admin_context def network_create(_context, values): + if not is_admin_context(context): + raise exception.NotAuthorized() + network_ref = models.Network() for (key, value) in values.iteritems(): network_ref[key] = value @@ -834,7 +856,11 @@ def network_create(_context, values): return network_ref -def network_destroy(_context, network_id): +#@require_admin_context +def network_destroy(context, network_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index ca6a4bbc2..e01d7cff9 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -49,6 +49,7 @@ class NetworkTestCase(test.TrialTestCase): self.user = self.manager.create_user('netuser', 'netuser', 'netuser') self.projects = [] self.network = utils.import_object(FLAGS.network_manager) + # TODO(devcamcar): Passing project=None is Bad(tm). self.context = context.APIRequestContext(project=None, user=self.user) for i in range(5): name = 'project%s' % i -- cgit From e716990fd58521f8c0166330ec9bc62c7cd91b7e Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 20:54:15 -0700 Subject: Finished context auth for network --- nova/db/sqlalchemy/api.py | 103 ++++++++++++++++++++++++++++++++-------------- nova/network/manager.py | 3 +- 2 files changed, 73 insertions(+), 33 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 2d553d98d..23589b7d8 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -799,13 +799,14 @@ def network_count(context): if not is_admin_context(context): raise exception.NotAuthorized() + session = get_session() return session.query(models.Network - ).filter_by(deleted=deleted + ).filter_by(deleted=_deleted(context) ).count() #@require_admin_context -def network_count_allocated_ips(_context, network_id): +def network_count_allocated_ips(context, network_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -818,7 +819,7 @@ def network_count_allocated_ips(_context, network_id): #@require_admin_context -def network_count_available_ips(_context, network_id): +def network_count_available_ips(context, network_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -832,7 +833,7 @@ def network_count_available_ips(_context, network_id): #@require_admin_context -def network_count_reserved_ips(_context, network_id): +def network_count_reserved_ips(context, network_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -845,7 +846,7 @@ def network_count_reserved_ips(_context, network_id): #@require_admin_context -def network_create(_context, values): +def network_create(context, values): if not is_admin_context(context): raise exception.NotAuthorized() @@ -904,7 +905,11 @@ def network_get(context, network_id, session=None): # NOTE(vish): pylint complains because of the long method name, but # it fits with the names of the rest of the methods # pylint: disable-msg=C0103 -def network_get_associated_fixed_ips(_context, network_id): +#@require_admin_context +def network_get_associated_fixed_ips(context, network_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.FixedIp ).options(joinedload_all('instance') @@ -914,18 +919,28 @@ def network_get_associated_fixed_ips(_context, network_id): ).all() -def network_get_by_bridge(_context, bridge): +#@require_admin_context +def network_get_by_bridge(context, bridge): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() - rv = session.query(models.Network + result = session.query(models.Network ).filter_by(bridge=bridge ).filter_by(deleted=False ).first() - if not rv: + + if not result: raise exception.NotFound('No network for bridge %s' % bridge) - return rv + + return result -def network_get_index(_context, network_id): +#@require_admin_context +def network_get_index(context, network_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): network_index = session.query(models.NetworkIndex @@ -933,19 +948,34 @@ def network_get_index(_context, network_id): ).filter_by(deleted=False ).with_lockmode('update' ).first() + if not network_index: raise db.NoMoreNetworks() - network_index['network'] = models.Network.find(network_id, - session=session) + + network_index['network'] = network_get(context, + network_id, + session=session) session.add(network_index) + return network_index['index'] -def network_index_count(_context): - return models.NetworkIndex.count() +#@require_admin_context +def network_index_count(context): + if not is_admin_context(context): + raise exception.NotAuthorized() + + session = get_session() + return session.query(models.NetworkIndex + ).filter_by(deleted=_deleted(context) + ).count() + +#@require_admin_context +def network_index_create_safe(context, values): + if not is_admin_context(context): + raise exception.NotAuthorized() -def network_index_create_safe(_context, values): network_index_ref = models.NetworkIndex() for (key, value) in values.iteritems(): network_index_ref[key] = value @@ -955,29 +985,35 @@ def network_index_create_safe(_context, values): pass -def network_set_host(_context, network_id, host_id): +#@require_admin_context +def network_set_host(context, network_id, host_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - network = session.query(models.Network - ).filter_by(id=network_id - ).filter_by(deleted=False - ).with_lockmode('update' - ).first() - if not network: - raise exception.NotFound("Couldn't find network with %s" % - network_id) + network_ref = session.query(models.Network + ).filter_by(id=network_id + ).filter_by(deleted=False + ).with_lockmode('update' + ).first() + if not network_ref: + raise exception.NotFound('No network for id %s' % network_id) + # NOTE(vish): if with_lockmode isn't supported, as in sqlite, # then this has concurrency issues - if not network['host']: - network['host'] = host_id - session.add(network) - return network['host'] + if not network_ref['host']: + network_ref['host'] = host_id + session.add(network_ref) + + return network_ref['host'] -def network_update(_context, network_id, values): +#@require_context +def network_update(context, network_id, values): session = get_session() with session.begin(): - network_ref = models.Network.find(network_id, session=session) + network_ref = network_get(context, network_id, session=session) for (key, value) in values.iteritems(): network_ref[key] = value network_ref.save(session=session) @@ -985,7 +1021,10 @@ def network_update(_context, network_id, values): ################### - +# YOU ARE HERE. +# random idea for system user: +# ctx = context.system_user(on_behalf_of=user, read_deleted=False) +# TODO(devcamcar): Rename to network_get_all_by_project def project_get_network(_context, project_id): session = get_session() rv = session.query(models.Network diff --git a/nova/network/manager.py b/nova/network/manager.py index d125d28d8..ecf2fa2c2 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -88,7 +88,8 @@ class NetworkManager(manager.Manager): # TODO(vish): can we minimize db access by just getting the # id here instead of the ref? network_id = network_ref['id'] - host = self.db.network_set_host(context, + # TODO(devcamcar): Replace with system context + host = self.db.network_set_host(None, network_id, self.host) self._on_set_network_host(context, network_id) -- cgit From 98cac90592658773791eb15b19ed60adf0a57d96 Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Thu, 30 Sep 2010 00:36:10 -0700 Subject: Completed quota context auth --- nova/db/sqlalchemy/api.py | 103 +++++++++++++++++++++++++++++++------------ nova/db/sqlalchemy/models.py | 12 ----- 2 files changed, 75 insertions(+), 40 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 23589b7d8..b225a6a88 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -1021,19 +1021,22 @@ def network_update(context, network_id, values): ################### -# YOU ARE HERE. -# random idea for system user: -# ctx = context.system_user(on_behalf_of=user, read_deleted=False) -# TODO(devcamcar): Rename to network_get_all_by_project -def project_get_network(_context, project_id): + +#@require_context +def project_get_network(context, project_id): + if not is_admin_context(context) and not is_user_context(context): + raise error.NotAuthorized() + session = get_session() - rv = session.query(models.Network + result= session.query(models.Network ).filter_by(project_id=project_id ).filter_by(deleted=False ).first() - if not rv: + + if not result: raise exception.NotFound('No network for project: %s' % project_id) - return rv + + return result ################### @@ -1043,14 +1046,26 @@ def queue_get_for(_context, topic, physical_node_id): # FIXME(ja): this should be servername? return "%s.%s" % (topic, physical_node_id) + ################### -def export_device_count(_context): - return models.ExportDevice.count() +#@require_admin_context +def export_device_count(context): + if not is_admin_context(context): + raise exception.notauthorized() + + session = get_session() + return session.query(models.ExportDevice + ).filter_by(deleted=_deleted(context) + ).count() + +#@require_admin_context +def export_device_create(context, values): + if not is_admin_context(context): + raise exception.notauthorized() -def export_device_create(_context, values): export_device_ref = models.ExportDevice() for (key, value) in values.iteritems(): export_device_ref[key] = value @@ -1084,7 +1099,29 @@ def auth_create_token(_context, token): ################### +#@require_admin_context +def quota_get(context, project_id, session=None): + if not is_admin_context(context): + raise exception.NotAuthorized() + + if not session: + session = get_session() + + result = session.query(models.Quota + ).filter_by(project_id=project_id + ).filter_by(deleted=_deleted(context) + ).first() + if not result: + raise exception.NotFound('No quota for project_id %s' % project_id) + + return result + + +#@require_admin_context def quota_create(_context, values): + if not is_admin_context(context): + raise exception.NotAuthorized() + quota_ref = models.Quota() for (key, value) in values.iteritems(): quota_ref[key] = value @@ -1092,29 +1129,34 @@ def quota_create(_context, values): return quota_ref -def quota_get(_context, project_id): - return models.Quota.find_by_str(project_id) - +#@require_admin_context +def quota_update(context, project_id, values): + if not is_admin_context(context): + raise exception.NotAuthorized() -def quota_update(_context, project_id, values): session = get_session() with session.begin(): - quota_ref = models.Quota.find_by_str(project_id, session=session) + quota_ref = quota_get(context, project_id, session=session) for (key, value) in values.iteritems(): quota_ref[key] = value quota_ref.save(session=session) -def quota_destroy(_context, project_id): +#@require_admin_context +def quota_destroy(context, project_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - quota_ref = models.Quota.find_by_str(project_id, session=session) + quota_ref = quota_get(context, project_id, session=session) quota_ref.delete(session=session) ################### +#@require_admin_context def volume_allocate_shelf_and_blade(context, volume_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -1135,6 +1177,7 @@ def volume_allocate_shelf_and_blade(context, volume_id): return (export_device.shelf_id, export_device.blade_id) +#@require_admin_context def volume_attached(context, volume_id, instance_id, mountpoint): if not is_admin_context(context): raise exception.NotAuthorized() @@ -1149,6 +1192,7 @@ def volume_attached(context, volume_id, instance_id, mountpoint): volume_ref.save(session=session) +#@require_context def volume_create(context, values): volume_ref = models.Volume() for (key, value) in values.iteritems(): @@ -1164,6 +1208,7 @@ def volume_create(context, values): return volume_ref +#@require_admin_context def volume_data_get_for_project(context, project_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -1178,6 +1223,7 @@ def volume_data_get_for_project(context, project_id): return (result[0] or 0, result[1] or 0) +#@require_admin_context def volume_destroy(context, volume_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -1192,6 +1238,7 @@ def volume_destroy(context, volume_id): {'id': volume_id}) +#@require_admin_context def volume_detached(context, volume_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -1206,6 +1253,7 @@ def volume_detached(context, volume_id): volume_ref.save(session=session) +#@require_context def volume_get(context, volume_id, session=None): if not session: session = get_session() @@ -1222,15 +1270,13 @@ def volume_get(context, volume_id, session=None): ).filter_by(id=volume_id ).filter_by(deleted=False ).first() - else: - raise exception.NotAuthorized() - if not result: raise exception.NotFound('No volume for id %s' % volume_id) return result +#@require_admin_context def volume_get_all(context): if not is_admin_context(context): raise exception.NotAuthorized() @@ -1239,7 +1285,7 @@ def volume_get_all(context): ).filter_by(deleted=_deleted(context) ).all() - +#@require_context def volume_get_all_by_project(context, project_id): if is_user_context(context): if context.project.id != project_id: @@ -1254,6 +1300,7 @@ def volume_get_all_by_project(context, project_id): ).all() +#@require_context def volume_get_by_ec2_id(context, ec2_id): session = get_session() result = None @@ -1278,6 +1325,7 @@ def volume_get_by_ec2_id(context, ec2_id): return result +#@require_context def volume_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() @@ -1286,10 +1334,9 @@ def volume_ec2_id_exists(context, ec2_id, session=None): return session.query(exists( ).where(models.Volume.id==ec2_id) ).one()[0] - else: - raise exception.NotAuthorized() +#@require_context def volume_get_instance(context, volume_id): session = get_session() result = None @@ -1315,6 +1362,7 @@ def volume_get_instance(context, volume_id): return result.instance +#@require_context def volume_get_shelf_and_blade(context, volume_id): session = get_session() result = None @@ -1329,15 +1377,14 @@ def volume_get_shelf_and_blade(context, volume_id): ).filter(models.Volume.project_id==context.project.id ).filter_by(volume_id=volume_id ).first() - else: - raise exception.NotAuthorized() - if not result: - raise exception.NotFound() + raise exception.NotFound('No export device found for volume %s' % + volume_id) return (result.shelf_id, result.blade_id) +#@require_context def volume_update(context, volume_id, values): session = get_session() with session.begin(): diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 7a085c4df..76444127f 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -302,18 +302,6 @@ class Quota(BASE, NovaBase): def str_id(self): return self.project_id - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(project_id=str_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for project_id %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] class ExportDevice(BASE, NovaBase): """Represates a shelf and blade that a volume can be exported on""" -- cgit From 30541d48b17ab4626791d969388871b3a1b7758f Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Thu, 30 Sep 2010 01:07:05 -0700 Subject: Wired up context auth for keypairs --- nova/db/sqlalchemy/api.py | 46 +++++++++++++++++++++++++++++++++++--------- nova/db/sqlalchemy/models.py | 20 ------------------- 2 files changed, 37 insertions(+), 29 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index b225a6a88..302322979 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -753,7 +753,7 @@ def instance_update(context, instance_id, values): #@require_context -def key_pair_create(_context, values): +def key_pair_create(context, values): key_pair_ref = models.KeyPair() for (key, value) in values.iteritems(): key_pair_ref[key] = value @@ -763,15 +763,22 @@ def key_pair_create(_context, values): #@require_context def key_pair_destroy(context, user_id, name): + if is_user_context(context): + if context.user.id != user_id: + raise exception.NotAuthorized() + session = get_session() with session.begin(): - key_pair_ref = models.KeyPair.find_by_args(user_id, - name, - session=session) + key_pair_ref = key_pair_get(context, user_id, name, session=session) key_pair_ref.delete(session=session) -def key_pair_destroy_all_by_user(_context, user_id): +#@require_context +def key_pair_destroy_all_by_user(context, user_id): + if is_user_context(context): + if context.user.id != user_id: + raise exception.NotAuthorized() + session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -779,11 +786,32 @@ def key_pair_destroy_all_by_user(_context, user_id): {'id': user_id}) -def key_pair_get(_context, user_id, name): - return models.KeyPair.find_by_args(user_id, name) +#@require_context +def key_pair_get(context, user_id, name, session=None): + if is_user_context(context): + if context.user.id != user_id: + raise exception.NotAuthorized() + + if not session: + session = get_session() + + result = session.query(models.KeyPair + ).filter_by(user_id=user_id + ).filter_by(name=name + ).filter_by(deleted=_deleted(context) + ).first() + if not result: + raise exception.NotFound('no keypair for user %s, name %s' % + (user_id, name)) + return result + +#@require_context +def key_pair_get_all_by_user(context, user_id): + if is_user_context(context): + if context.user.id != user_id: + raise exception.NotAuthorized() -def key_pair_get_all_by_user(_context, user_id): session = get_session() return session.query(models.KeyPair ).filter_by(user_id=user_id @@ -1118,7 +1146,7 @@ def quota_get(context, project_id, session=None): #@require_admin_context -def quota_create(_context, values): +def quota_create(context, values): if not is_admin_context(context): raise exception.NotAuthorized() diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 76444127f..1f5bdf9f5 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -332,26 +332,6 @@ class KeyPair(BASE, NovaBase): def str_id(self): return '%s.%s' % (self.user_id, self.name) - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - user_id, _sep, name = str_id.partition('.') - return cls.find_by_str(user_id, name, session, deleted) - - @classmethod - def find_by_args(cls, user_id, name, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(user_id=user_id - ).filter_by(name=name - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for user %s, name %s" % - (user_id, name)) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - class Network(BASE, NovaBase): """Represents a network""" -- cgit From 336523b36ceb8f5302acd443b7f1171b67575f73 Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Thu, 30 Sep 2010 01:11:16 -0700 Subject: Removed deprecated bits from NovaBase --- nova/db/sqlalchemy/models.py | 38 -------------------------------------- 1 file changed, 38 deletions(-) diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 1f5bdf9f5..a29090c60 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -50,44 +50,6 @@ class NovaBase(object): deleted_at = Column(DateTime) deleted = Column(Boolean, default=False) - @classmethod - def all(cls, session=None, deleted=False): - """Get all objects of this type""" - if not session: - session = get_session() - return session.query(cls - ).filter_by(deleted=deleted - ).all() - - @classmethod - def count(cls, session=None, deleted=False): - """Count objects of this type""" - if not session: - session = get_session() - return session.query(cls - ).filter_by(deleted=deleted - ).count() - - @classmethod - def find(cls, obj_id, session=None, deleted=False): - """Find object by id""" - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(id=obj_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for id %s" % obj_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - """Find object by str_id""" - int_id = int(str_id.rpartition('-')[2]) - return cls.find(int_id, session=session, deleted=deleted) - @property def str_id(self): """Get string id of object (generally prefix + '-' + id)""" -- cgit From 8bd81f3ec811e19f6e7faf7a4fe271f85fbc7fc7 Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Thu, 30 Sep 2010 02:02:14 -0700 Subject: Simplified authorization with decorators" " --- nova/db/sqlalchemy/api.py | 408 ++++++++++++++++------------------------------ 1 file changed, 142 insertions(+), 266 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 302322979..0e7d2e664 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -51,6 +51,7 @@ def _deleted(context): def is_admin_context(context): + """Indicates if the request context is an administrator.""" if not context: logging.warning('Use of empty request context is deprecated') return True @@ -60,6 +61,7 @@ def is_admin_context(context): def is_user_context(context): + """Indicates if the request context is a normal user.""" if not context: logging.warning('Use of empty request context is deprecated') return False @@ -68,24 +70,62 @@ def is_user_context(context): return True +def authorize_project_context(context, project_id): + """Ensures that the request context has permission to access the + given project. + """ + if is_user_context(context): + if not context.project: + raise exception.NotAuthorized() + elif context.project.id != project_id: + raise exception.NotAuthorized() + + +def authorize_user_context(context, user_id): + """Ensures that the request context has permission to access the + given user. + """ + if is_user_context(context): + if not context.user: + raise exception.NotAuthorized() + elif context.user.id != user_id: + raise exception.NotAuthorized() + + +def require_admin_context(f): + """Decorator used to indicate that the method requires an + administrator context. + """ + def wrapper(*args, **kwargs): + if not is_admin_context(args[0]): + raise exception.NotAuthorized() + return f(*args, **kwargs) + return wrapper + + +def require_context(f): + """Decorator used to indicate that the method requires either + an administrator or normal user context. + """ + def wrapper(*args, **kwargs): + if not is_admin_context(args[0]) and not is_user_context(args[0]): + raise exception.NotAuthorized() + return f(*args, **kwargs) + return wrapper + + ################### -#@require_admin_context +@require_admin_context def service_destroy(context, service_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): service_ref = service_get(context, service_id, session=session) service_ref.delete(session=session) -#@require_admin_context +@require_admin_context def service_get(context, service_id, session=None): - if not is_admin_context(context): - raise exception.NotAuthorized() - if not session: session = get_session() @@ -100,11 +140,8 @@ def service_get(context, service_id, session=None): return result -#@require_admin_context +@require_admin_context def service_get_all_by_topic(context, topic): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.Service ).filter_by(deleted=False @@ -113,11 +150,8 @@ def service_get_all_by_topic(context, topic): ).all() -#@require_admin_context +@require_admin_context def _service_get_all_topic_subquery(context, session, topic, subq, label): - if not is_admin_context(context): - raise exception.NotAuthorized() - sort_value = getattr(subq.c, label) return session.query(models.Service, func.coalesce(sort_value, 0) ).filter_by(topic=topic @@ -128,11 +162,8 @@ def _service_get_all_topic_subquery(context, session, topic, subq, label): ).all() -#@require_admin_context +@require_admin_context def service_get_all_compute_sorted(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # NOTE(vish): The intended query is below @@ -156,11 +187,8 @@ def service_get_all_compute_sorted(context): label) -#@require_admin_context +@require_admin_context def service_get_all_network_sorted(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): topic = 'network' @@ -177,11 +205,8 @@ def service_get_all_network_sorted(context): label) -#@require_admin_context +@require_admin_context def service_get_all_volume_sorted(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): topic = 'volume' @@ -198,11 +223,8 @@ def service_get_all_volume_sorted(context): label) -#@require_admin_context +@require_admin_context def service_get_by_args(context, host, binary): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() result = session.query(models.Service ).filter_by(host=host @@ -216,11 +238,8 @@ def service_get_by_args(context, host, binary): return result -#@require_admin_context +@require_admin_context def service_create(context, values): - if not is_admin_context(context): - return exception.NotAuthorized() - service_ref = models.Service() for (key, value) in values.iteritems(): service_ref[key] = value @@ -228,11 +247,8 @@ def service_create(context, values): return service_ref -#@require_admin_context +@require_admin_context def service_update(context, service_id, values): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): service_ref = session_get(context, service_id, session=session) @@ -244,11 +260,9 @@ def service_update(context, service_id, values): ################### -#@require_context +@require_context def floating_ip_allocate_address(context, host, project_id): - if is_user_context(context): - if context.project.id != project_id: - raise exception.NotAuthorized() + authorize_project_context(context, project_id) session = get_session() with session.begin(): @@ -268,11 +282,8 @@ def floating_ip_allocate_address(context, host, project_id): return floating_ip_ref['address'] -#@require_context +@require_context def floating_ip_create(context, values): - if not is_user_context(context) and not is_admin_context(context): - raise exception.NotAuthorized() - floating_ip_ref = models.FloatingIp() for (key, value) in values.iteritems(): floating_ip_ref[key] = value @@ -280,11 +291,9 @@ def floating_ip_create(context, values): return floating_ip_ref['address'] -#@require_context +@require_context def floating_ip_count_by_project(context, project_id): - if is_user_context(context): - if context.project.id != project_id: - raise exception.NotAuthorized() + authorize_project_context(context, project_id) session = get_session() return session.query(models.FloatingIp @@ -293,11 +302,8 @@ def floating_ip_count_by_project(context, project_id): ).count() -#@require_context +@require_context def floating_ip_fixed_ip_associate(context, floating_address, fixed_address): - if not is_user_context(context) and not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # TODO(devcamcar): How to ensure floating_id belongs to user? @@ -311,11 +317,8 @@ def floating_ip_fixed_ip_associate(context, floating_address, fixed_address): floating_ip_ref.save(session=session) -#@require_context +@require_context def floating_ip_deallocate(context, address): - if not is_user_context(context) and not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # TODO(devcamcar): How to ensure floating id belongs to user? @@ -326,11 +329,8 @@ def floating_ip_deallocate(context, address): floating_ip_ref.save(session=session) -#@require_context +@require_context def floating_ip_destroy(context, address): - if not is_user_context(context) and not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # TODO(devcamcar): Ensure address belongs to user. @@ -340,11 +340,8 @@ def floating_ip_destroy(context, address): floating_ip_ref.delete(session=session) -#@require_context +@require_context def floating_ip_disassociate(context, address): - if not is_user_context(context) and is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # TODO(devcamcar): Ensure address belongs to user. @@ -362,11 +359,8 @@ def floating_ip_disassociate(context, address): return fixed_ip_address -#@require_admin_context +@require_admin_context def floating_ip_get_all(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -374,11 +368,8 @@ def floating_ip_get_all(context): ).all() -#@require_admin_context +@require_admin_context def floating_ip_get_all_by_host(context, host): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -387,14 +378,9 @@ def floating_ip_get_all_by_host(context, host): ).all() -#@require_context +@require_context def floating_ip_get_all_by_project(context, project_id): - # TODO(devcamcar): Change to decorate and check project_id separately. - if is_user_context(context): - if context.project.id != project_id: - raise exception.NotAuthorized() - elif not is_admin_context(context): - raise exception.NotAuthorized() + authorize_project_context(context, project_id) session = get_session() return session.query(models.FloatingIp @@ -404,12 +390,9 @@ def floating_ip_get_all_by_project(context, project_id): ).all() -#@require_context +@require_context def floating_ip_get_by_address(context, address, session=None): # TODO(devcamcar): Ensure the address belongs to user. - if not is_user_context(context) and not is_admin_context(context): - raise exception.NotAuthorized() - if not session: session = get_session() @@ -426,11 +409,8 @@ def floating_ip_get_by_address(context, address, session=None): ################### -#@require_context +@require_context def fixed_ip_associate(context, address, instance_id): - if not is_user_context(context) and not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): fixed_ip_ref = session.query(models.FixedIp @@ -449,11 +429,8 @@ def fixed_ip_associate(context, address, instance_id): session.add(fixed_ip_ref) -#@require_admin_context +@require_admin_context def fixed_ip_associate_pool(context, network_id, instance_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): network_or_none = or_(models.FixedIp.network_id == network_id, @@ -480,7 +457,7 @@ def fixed_ip_associate_pool(context, network_id, instance_id): return fixed_ip_ref['address'] -#@require_context +@require_context def fixed_ip_create(_context, values): fixed_ip_ref = models.FixedIp() for (key, value) in values.iteritems(): @@ -489,7 +466,7 @@ def fixed_ip_create(_context, values): return fixed_ip_ref['address'] -#@require_context +@require_context def fixed_ip_disassociate(context, address): session = get_session() with session.begin(): @@ -500,7 +477,7 @@ def fixed_ip_disassociate(context, address): fixed_ip_ref.save(session=session) -#@require_context +@require_context def fixed_ip_get_by_address(context, address, session=None): # TODO(devcamcar): Ensure floating ip belongs to user. # Only possible if it is associated with an instance. @@ -520,19 +497,19 @@ def fixed_ip_get_by_address(context, address, session=None): return result -#@require_context +@require_context def fixed_ip_get_instance(context, address): fixed_ip_ref = fixed_ip_get_by_address(context, address) return fixed_ip_ref.instance -#@require_admin_context +@require_admin_context def fixed_ip_get_network(context, address): fixed_ip_ref = fixed_ip_get_by_address(context, address) return fixed_ip_ref.network -#@require_context +@require_context def fixed_ip_update(context, address, values): session = get_session() with session.begin(): @@ -547,7 +524,7 @@ def fixed_ip_update(context, address, values): ################### -#@require_context +@require_context def instance_create(context, values): instance_ref = models.Instance() for (key, value) in values.iteritems(): @@ -563,7 +540,7 @@ def instance_create(context, values): return instance_ref -#@require_admin_context +@require_admin_context def instance_data_get_for_project(context, project_id): session = get_session() result = session.query(func.count(models.Instance.id), @@ -575,7 +552,7 @@ def instance_data_get_for_project(context, project_id): return (result[0] or 0, result[1] or 0) -#@require_context +@require_context def instance_destroy(context, instance_id): session = get_session() with session.begin(): @@ -583,7 +560,7 @@ def instance_destroy(context, instance_id): instance_ref.delete(session=session) -#@require_context +@require_context def instance_get(context, instance_id, session=None): if not session: session = get_session() @@ -606,11 +583,8 @@ def instance_get(context, instance_id, session=None): return result -#@require_admin_context +@require_admin_context def instance_get_all(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -618,11 +592,8 @@ def instance_get_all(context): ).all() -#@require_admin_context +@require_admin_context def instance_get_all_by_user(context, user_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -631,11 +602,9 @@ def instance_get_all_by_user(context, user_id): ).all() -#@require_context +@require_context def instance_get_all_by_project(context, project_id): - if is_user_context(context): - if context.project.id != project_id: - raise exception.NotAuthorized() + authorize_project_context(context, project_id) session = get_session() return session.query(models.Instance @@ -645,7 +614,7 @@ def instance_get_all_by_project(context, project_id): ).all() -#@require_context +@require_context def instance_get_all_by_reservation(context, reservation_id): session = get_session() @@ -664,7 +633,7 @@ def instance_get_all_by_reservation(context, reservation_id): ).all() -#@require_context +@require_context def instance_get_by_ec2_id(context, ec2_id): session = get_session() @@ -685,14 +654,14 @@ def instance_get_by_ec2_id(context, ec2_id): return result -#@require_context +@require_context def instance_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() return session.query(exists().where(models.Instance.id==ec2_id)).one()[0] -#@require_context +@require_context def instance_get_fixed_address(context, instance_id): session = get_session() with session.begin(): @@ -702,7 +671,7 @@ def instance_get_fixed_address(context, instance_id): return instance_ref.fixed_ip['address'] -#@require_context +@require_context def instance_get_floating_address(context, instance_id): session = get_session() with session.begin(): @@ -715,20 +684,15 @@ def instance_get_floating_address(context, instance_id): return instance_ref.fixed_ip.floating_ips[0]['address'] -#@require_admin_context +@require_admin_context def instance_is_vpn(context, instance_id): - if not is_admin_context(context): - raise exception.NotAuthorized() # TODO(vish): Move this into image code somewhere instance_ref = instance_get(context, instance_id) return instance_ref['image_id'] == FLAGS.vpn_image_id -#@require_admin_context +@require_admin_context def instance_set_state(context, instance_id, state, description=None): - if not is_admin_context(context): - raise exception.NotAuthorized() - # TODO(devcamcar): Move this out of models and into driver from nova.compute import power_state if not description: @@ -739,7 +703,7 @@ def instance_set_state(context, instance_id, state, description=None): 'state_description': description}) -#@require_context +@require_context def instance_update(context, instance_id, values): session = get_session() with session.begin(): @@ -752,7 +716,7 @@ def instance_update(context, instance_id, values): ################### -#@require_context +@require_context def key_pair_create(context, values): key_pair_ref = models.KeyPair() for (key, value) in values.iteritems(): @@ -761,11 +725,9 @@ def key_pair_create(context, values): return key_pair_ref -#@require_context +@require_context def key_pair_destroy(context, user_id, name): - if is_user_context(context): - if context.user.id != user_id: - raise exception.NotAuthorized() + authorize_user_context(context, user_id) session = get_session() with session.begin(): @@ -773,11 +735,9 @@ def key_pair_destroy(context, user_id, name): key_pair_ref.delete(session=session) -#@require_context +@require_context def key_pair_destroy_all_by_user(context, user_id): - if is_user_context(context): - if context.user.id != user_id: - raise exception.NotAuthorized() + authorize_user_context(context, user_id) session = get_session() with session.begin(): @@ -786,11 +746,9 @@ def key_pair_destroy_all_by_user(context, user_id): {'id': user_id}) -#@require_context +@require_context def key_pair_get(context, user_id, name, session=None): - if is_user_context(context): - if context.user.id != user_id: - raise exception.NotAuthorized() + authorize_user_context(context, user_id) if not session: session = get_session() @@ -806,11 +764,9 @@ def key_pair_get(context, user_id, name, session=None): return result -#@require_context +@require_context def key_pair_get_all_by_user(context, user_id): - if is_user_context(context): - if context.user.id != user_id: - raise exception.NotAuthorized() + authorize_user_context(context, user_id) session = get_session() return session.query(models.KeyPair @@ -822,22 +778,16 @@ def key_pair_get_all_by_user(context, user_id): ################### -#@require_admin_context +@require_admin_context def network_count(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.Network ).filter_by(deleted=_deleted(context) ).count() -#@require_admin_context +@require_admin_context def network_count_allocated_ips(context, network_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -846,11 +796,8 @@ def network_count_allocated_ips(context, network_id): ).count() -#@require_admin_context +@require_admin_context def network_count_available_ips(context, network_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -860,11 +807,8 @@ def network_count_available_ips(context, network_id): ).count() -#@require_admin_context +@require_admin_context def network_count_reserved_ips(context, network_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -873,11 +817,8 @@ def network_count_reserved_ips(context, network_id): ).count() -#@require_admin_context +@require_admin_context def network_create(context, values): - if not is_admin_context(context): - raise exception.NotAuthorized() - network_ref = models.Network() for (key, value) in values.iteritems(): network_ref[key] = value @@ -885,11 +826,8 @@ def network_create(context, values): return network_ref -#@require_admin_context +@require_admin_context def network_destroy(context, network_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -907,7 +845,7 @@ def network_destroy(context, network_id): {'id': network_id}) -#@require_context +@require_context def network_get(context, network_id, session=None): if not session: session = get_session() @@ -933,11 +871,8 @@ def network_get(context, network_id, session=None): # NOTE(vish): pylint complains because of the long method name, but # it fits with the names of the rest of the methods # pylint: disable-msg=C0103 -#@require_admin_context +@require_admin_context def network_get_associated_fixed_ips(context, network_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.FixedIp ).options(joinedload_all('instance') @@ -947,11 +882,8 @@ def network_get_associated_fixed_ips(context, network_id): ).all() -#@require_admin_context +@require_admin_context def network_get_by_bridge(context, bridge): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() result = session.query(models.Network ).filter_by(bridge=bridge @@ -964,11 +896,8 @@ def network_get_by_bridge(context, bridge): return result -#@require_admin_context +@require_admin_context def network_get_index(context, network_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): network_index = session.query(models.NetworkIndex @@ -988,22 +917,16 @@ def network_get_index(context, network_id): return network_index['index'] -#@require_admin_context +@require_admin_context def network_index_count(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.NetworkIndex ).filter_by(deleted=_deleted(context) ).count() -#@require_admin_context +@require_admin_context def network_index_create_safe(context, values): - if not is_admin_context(context): - raise exception.NotAuthorized() - network_index_ref = models.NetworkIndex() for (key, value) in values.iteritems(): network_index_ref[key] = value @@ -1013,11 +936,8 @@ def network_index_create_safe(context, values): pass -#@require_admin_context +@require_admin_context def network_set_host(context, network_id, host_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): network_ref = session.query(models.Network @@ -1037,7 +957,7 @@ def network_set_host(context, network_id, host_id): return network_ref['host'] -#@require_context +@require_context def network_update(context, network_id, values): session = get_session() with session.begin(): @@ -1050,11 +970,8 @@ def network_update(context, network_id, values): ################### -#@require_context +@require_context def project_get_network(context, project_id): - if not is_admin_context(context) and not is_user_context(context): - raise error.NotAuthorized() - session = get_session() result= session.query(models.Network ).filter_by(project_id=project_id @@ -1078,22 +995,16 @@ def queue_get_for(_context, topic, physical_node_id): ################### -#@require_admin_context +@require_admin_context def export_device_count(context): - if not is_admin_context(context): - raise exception.notauthorized() - session = get_session() return session.query(models.ExportDevice ).filter_by(deleted=_deleted(context) ).count() -#@require_admin_context +@require_admin_context def export_device_create(context, values): - if not is_admin_context(context): - raise exception.notauthorized() - export_device_ref = models.ExportDevice() for (key, value) in values.iteritems(): export_device_ref[key] = value @@ -1127,11 +1038,8 @@ def auth_create_token(_context, token): ################### -#@require_admin_context +@require_admin_context def quota_get(context, project_id, session=None): - if not is_admin_context(context): - raise exception.NotAuthorized() - if not session: session = get_session() @@ -1145,11 +1053,8 @@ def quota_get(context, project_id, session=None): return result -#@require_admin_context +@require_admin_context def quota_create(context, values): - if not is_admin_context(context): - raise exception.NotAuthorized() - quota_ref = models.Quota() for (key, value) in values.iteritems(): quota_ref[key] = value @@ -1157,11 +1062,8 @@ def quota_create(context, values): return quota_ref -#@require_admin_context +@require_admin_context def quota_update(context, project_id, values): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): quota_ref = quota_get(context, project_id, session=session) @@ -1170,11 +1072,8 @@ def quota_update(context, project_id, values): quota_ref.save(session=session) -#@require_admin_context +@require_admin_context def quota_destroy(context, project_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): quota_ref = quota_get(context, project_id, session=session) @@ -1184,11 +1083,8 @@ def quota_destroy(context, project_id): ################### -#@require_admin_context +@require_admin_context def volume_allocate_shelf_and_blade(context, volume_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): export_device = session.query(models.ExportDevice @@ -1205,11 +1101,8 @@ def volume_allocate_shelf_and_blade(context, volume_id): return (export_device.shelf_id, export_device.blade_id) -#@require_admin_context +@require_admin_context def volume_attached(context, volume_id, instance_id, mountpoint): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): volume_ref = volume_get(context, volume_id, session=session) @@ -1220,7 +1113,7 @@ def volume_attached(context, volume_id, instance_id, mountpoint): volume_ref.save(session=session) -#@require_context +@require_context def volume_create(context, values): volume_ref = models.Volume() for (key, value) in values.iteritems(): @@ -1236,11 +1129,8 @@ def volume_create(context, values): return volume_ref -#@require_admin_context +@require_admin_context def volume_data_get_for_project(context, project_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() result = session.query(func.count(models.Volume.id), func.sum(models.Volume.size) @@ -1251,11 +1141,8 @@ def volume_data_get_for_project(context, project_id): return (result[0] or 0, result[1] or 0) -#@require_admin_context +@require_admin_context def volume_destroy(context, volume_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -1266,11 +1153,8 @@ def volume_destroy(context, volume_id): {'id': volume_id}) -#@require_admin_context +@require_admin_context def volume_detached(context, volume_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): volume_ref = volume_get(context, volume_id, session=session) @@ -1281,7 +1165,7 @@ def volume_detached(context, volume_id): volume_ref.save(session=session) -#@require_context +@require_context def volume_get(context, volume_id, session=None): if not session: session = get_session() @@ -1304,22 +1188,15 @@ def volume_get(context, volume_id, session=None): return result -#@require_admin_context +@require_admin_context def volume_get_all(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - return session.query(models.Volume ).filter_by(deleted=_deleted(context) ).all() -#@require_context +@require_context def volume_get_all_by_project(context, project_id): - if is_user_context(context): - if context.project.id != project_id: - raise exception.NotAuthorized() - elif not is_admin_context(context): - raise exception.NotAuthorized() + authorize_project_context(context, project_id) session = get_session() return session.query(models.Volume @@ -1328,7 +1205,7 @@ def volume_get_all_by_project(context, project_id): ).all() -#@require_context +@require_context def volume_get_by_ec2_id(context, ec2_id): session = get_session() result = None @@ -1353,18 +1230,17 @@ def volume_get_by_ec2_id(context, ec2_id): return result -#@require_context +@require_context def volume_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() - if is_admin_context(context) or is_user_context(context): - return session.query(exists( - ).where(models.Volume.id==ec2_id) - ).one()[0] + return session.query(exists( + ).where(models.Volume.id==ec2_id) + ).one()[0] -#@require_context +@require_context def volume_get_instance(context, volume_id): session = get_session() result = None @@ -1390,7 +1266,7 @@ def volume_get_instance(context, volume_id): return result.instance -#@require_context +@require_context def volume_get_shelf_and_blade(context, volume_id): session = get_session() result = None @@ -1412,7 +1288,7 @@ def volume_get_shelf_and_blade(context, volume_id): return (result.shelf_id, result.blade_id) -#@require_context +@require_context def volume_update(context, volume_id, values): session = get_session() with session.begin(): -- cgit From cf456bdb2a767644d95599aa1c8f580279959a4e Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Thu, 30 Sep 2010 02:47:05 -0700 Subject: Refactored APIRequestContext --- nova/api/context.py | 46 +++++++++++++++++++++++++++ nova/api/ec2/__init__.py | 8 ++--- nova/api/ec2/context.py | 33 -------------------- nova/db/sqlalchemy/api.py | 71 +++++++++++++++++++----------------------- nova/network/manager.py | 2 -- nova/tests/compute_unittest.py | 8 ++--- 6 files changed, 86 insertions(+), 82 deletions(-) create mode 100644 nova/api/context.py delete mode 100644 nova/api/ec2/context.py diff --git a/nova/api/context.py b/nova/api/context.py new file mode 100644 index 000000000..b66cfe468 --- /dev/null +++ b/nova/api/context.py @@ -0,0 +1,46 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +APIRequestContext +""" + +import random + + +class APIRequestContext(object): + def __init__(self, user, project): + self.user = user + self.project = project + self.request_id = ''.join( + [random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-') + for x in xrange(20)] + ) + if user: + self.is_admin = user.is_admin() + else: + self.is_admin = False + self.read_deleted = False + + +def get_admin_context(user=None, read_deleted=False): + context_ref = APIRequestContext(user=user, project=None) + context_ref.is_admin = True + context_ref.read_deleted = read_deleted + return context_ref + diff --git a/nova/api/ec2/__init__.py b/nova/api/ec2/__init__.py index 7a958f841..6b538a7f1 100644 --- a/nova/api/ec2/__init__.py +++ b/nova/api/ec2/__init__.py @@ -27,8 +27,8 @@ import webob.exc from nova import exception from nova import flags from nova import wsgi +from nova.api import context from nova.api.ec2 import apirequest -from nova.api.ec2 import context from nova.api.ec2 import admin from nova.api.ec2 import cloud from nova.auth import manager @@ -193,15 +193,15 @@ class Authorizer(wsgi.Middleware): return True if 'none' in roles: return False - return any(context.project.has_role(context.user.id, role) + return any(context.project.has_role(context.user.id, role) for role in roles) - + class Executor(wsgi.Application): """Execute an EC2 API request. - Executes 'ec2.action' upon 'ec2.controller', passing 'ec2.context' and + Executes 'ec2.action' upon 'ec2.controller', passing 'ec2.context' and 'ec2.action_args' (all variables in WSGI environ.) Returns an XML response, or a 400 upon failure. """ diff --git a/nova/api/ec2/context.py b/nova/api/ec2/context.py deleted file mode 100644 index c53ba98d9..000000000 --- a/nova/api/ec2/context.py +++ /dev/null @@ -1,33 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 United States Government as represented by the -# Administrator of the National Aeronautics and Space Administration. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -""" -APIRequestContext -""" - -import random - - -class APIRequestContext(object): - def __init__(self, user, project): - self.user = user - self.project = project - self.request_id = ''.join( - [random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-') - for x in xrange(20)] - ) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 0e7d2e664..fc5ee2235 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -21,6 +21,7 @@ Implementation of SQLAlchemy backend import logging import sys +import warnings from nova import db from nova import exception @@ -36,28 +37,13 @@ from sqlalchemy.sql import exists, func FLAGS = flags.FLAGS -# NOTE(vish): disabling docstring pylint because the docstrings are -# in the interface definition -# pylint: disable-msg=C0111 -def _deleted(context): - """Calculates whether to include deleted objects based on context. - Currently just looks for a flag called deleted in the context dict. - """ - if is_user_context(context): - return False - if not hasattr(context, 'get'): - return False - return context.get('deleted', False) - - def is_admin_context(context): """Indicates if the request context is an administrator.""" if not context: - logging.warning('Use of empty request context is deprecated') - return True - if not context.user: + warnings.warn('Use of empty request context is deprecated', + DeprecationWarning) return True - return context.user.is_admin() + return context.is_admin def is_user_context(context): @@ -92,6 +78,13 @@ def authorize_user_context(context, user_id): raise exception.NotAuthorized() +def use_deleted(context): + """Indicates if the context has access to deleted objects.""" + if not context: + return False + return context.read_deleted + + def require_admin_context(f): """Decorator used to indicate that the method requires an administrator context. @@ -131,7 +124,7 @@ def service_get(context, service_id, session=None): result = session.query(models.Service ).filter_by(id=service_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() if not result: @@ -229,7 +222,7 @@ def service_get_by_args(context, host, binary): result = session.query(models.Service ).filter_by(host=host ).filter_by(binary=binary - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() if not result: @@ -398,7 +391,7 @@ def floating_ip_get_by_address(context, address, session=None): result = session.query(models.FloatingIp ).filter_by(address=address - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() if not result: raise exception.NotFound('No fixed ip for address %s' % address) @@ -487,7 +480,7 @@ def fixed_ip_get_by_address(context, address, session=None): result = session.query(models.FixedIp ).filter_by(address=address - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).options(joinedload('network') ).options(joinedload('instance') ).first() @@ -569,7 +562,7 @@ def instance_get(context, instance_id, session=None): if is_admin_context(context): result = session.query(models.Instance ).filter_by(id=instance_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Instance @@ -588,7 +581,7 @@ def instance_get_all(context): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).all() @@ -597,7 +590,7 @@ def instance_get_all_by_user(context, user_id): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).filter_by(user_id=user_id ).all() @@ -610,7 +603,7 @@ def instance_get_all_by_project(context, project_id): return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).filter_by(project_id=project_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).all() @@ -622,7 +615,7 @@ def instance_get_all_by_reservation(context, reservation_id): return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).filter_by(reservation_id=reservation_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).all() elif is_user_context(context): return session.query(models.Instance @@ -640,7 +633,7 @@ def instance_get_by_ec2_id(context, ec2_id): if is_admin_context(context): result = session.query(models.Instance ).filter_by(ec2_id=ec2_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Instance @@ -756,7 +749,7 @@ def key_pair_get(context, user_id, name, session=None): result = session.query(models.KeyPair ).filter_by(user_id=user_id ).filter_by(name=name - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() if not result: raise exception.NotFound('no keypair for user %s, name %s' % @@ -782,7 +775,7 @@ def key_pair_get_all_by_user(context, user_id): def network_count(context): session = get_session() return session.query(models.Network - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).count() @@ -854,7 +847,7 @@ def network_get(context, network_id, session=None): if is_admin_context(context): result = session.query(models.Network ).filter_by(id=network_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Network @@ -921,7 +914,7 @@ def network_get_index(context, network_id): def network_index_count(context): session = get_session() return session.query(models.NetworkIndex - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).count() @@ -999,7 +992,7 @@ def queue_get_for(_context, topic, physical_node_id): def export_device_count(context): session = get_session() return session.query(models.ExportDevice - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).count() @@ -1045,7 +1038,7 @@ def quota_get(context, project_id, session=None): result = session.query(models.Quota ).filter_by(project_id=project_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() if not result: raise exception.NotFound('No quota for project_id %s' % project_id) @@ -1174,7 +1167,7 @@ def volume_get(context, volume_id, session=None): if is_admin_context(context): result = session.query(models.Volume ).filter_by(id=volume_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Volume @@ -1191,7 +1184,7 @@ def volume_get(context, volume_id, session=None): @require_admin_context def volume_get_all(context): return session.query(models.Volume - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).all() @require_context @@ -1201,7 +1194,7 @@ def volume_get_all_by_project(context, project_id): session = get_session() return session.query(models.Volume ).filter_by(project_id=project_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).all() @@ -1213,7 +1206,7 @@ def volume_get_by_ec2_id(context, ec2_id): if is_admin_context(context): result = session.query(models.Volume ).filter_by(ec2_id=ec2_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Volume @@ -1248,7 +1241,7 @@ def volume_get_instance(context, volume_id): if is_admin_context(context): result = session.query(models.Volume ).filter_by(id=volume_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).options(joinedload('instance') ).first() elif is_user_context(context): diff --git a/nova/network/manager.py b/nova/network/manager.py index ecf2fa2c2..265c0d742 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -88,7 +88,6 @@ class NetworkManager(manager.Manager): # TODO(vish): can we minimize db access by just getting the # id here instead of the ref? network_id = network_ref['id'] - # TODO(devcamcar): Replace with system context host = self.db.network_set_host(None, network_id, self.host) @@ -233,7 +232,6 @@ class VlanManager(NetworkManager): address = network_ref['vpn_private_address'] self.db.fixed_ip_associate(context, address, instance_id) else: - # TODO(devcamcar) Pass system context here. address = self.db.fixed_ip_associate_pool(None, network_ref['id'], instance_id) diff --git a/nova/tests/compute_unittest.py b/nova/tests/compute_unittest.py index e705c2552..1e2bb113b 100644 --- a/nova/tests/compute_unittest.py +++ b/nova/tests/compute_unittest.py @@ -30,7 +30,7 @@ from nova import flags from nova import test from nova import utils from nova.auth import manager - +from nova.api import context FLAGS = flags.FLAGS @@ -96,9 +96,9 @@ class ComputeTestCase(test.TrialTestCase): self.assertEqual(instance_ref['deleted_at'], None) terminate = datetime.datetime.utcnow() yield self.compute.terminate_instance(self.context, instance_id) - # TODO(devcamcar): Pass deleted in using system context. - # context.read_deleted ? - instance_ref = db.instance_get({'deleted': True}, instance_id) + self.context = context.get_admin_context(user=self.user, + read_deleted=True) + instance_ref = db.instance_get(self.context, instance_id) self.assert_(instance_ref['launched_at'] < terminate) self.assert_(instance_ref['deleted_at'] > terminate) -- cgit From ab948224a5c6ea976def30927ac7668dd765dbca Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Thu, 30 Sep 2010 03:13:47 -0700 Subject: Cleaned up db/api.py --- nova/db/api.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/nova/db/api.py b/nova/db/api.py index 4cfdd788c..290b460a6 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -175,10 +175,6 @@ def floating_ip_get_by_address(context, address): return IMPL.floating_ip_get_by_address(context, address) - """Get an instance for a floating ip by address.""" - return IMPL.floating_ip_get_instance(context, address) - - #################### -- cgit From c9e14d6257f0b488bd892c09d284091c0f612dd7 Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Fri, 1 Oct 2010 01:44:17 -0700 Subject: Locked down fixed ips and improved network tests --- nova/db/sqlalchemy/api.py | 98 ++++++++++++++++-------------------------- nova/tests/network_unittest.py | 44 ++++++++++--------- 2 files changed, 60 insertions(+), 82 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index fc5ee2235..860723516 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -78,7 +78,7 @@ def authorize_user_context(context, user_id): raise exception.NotAuthorized() -def use_deleted(context): +def can_read_deleted(context): """Indicates if the context has access to deleted objects.""" if not context: return False @@ -124,7 +124,7 @@ def service_get(context, service_id, session=None): result = session.query(models.Service ).filter_by(id=service_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() if not result: @@ -222,9 +222,8 @@ def service_get_by_args(context, host, binary): result = session.query(models.Service ).filter_by(host=host ).filter_by(binary=binary - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() - if not result: raise exception.NotFound('No service for %s, %s' % (host, binary)) @@ -256,7 +255,6 @@ def service_update(context, service_id, values): @require_context def floating_ip_allocate_address(context, host, project_id): authorize_project_context(context, project_id) - session = get_session() with session.begin(): floating_ip_ref = session.query(models.FloatingIp @@ -287,7 +285,6 @@ def floating_ip_create(context, values): @require_context def floating_ip_count_by_project(context, project_id): authorize_project_context(context, project_id) - session = get_session() return session.query(models.FloatingIp ).filter_by(project_id=project_id @@ -374,7 +371,6 @@ def floating_ip_get_all_by_host(context, host): @require_context def floating_ip_get_all_by_project(context, project_id): authorize_project_context(context, project_id) - session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -391,7 +387,7 @@ def floating_ip_get_by_address(context, address, session=None): result = session.query(models.FloatingIp ).filter_by(address=address - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() if not result: raise exception.NotFound('No fixed ip for address %s' % address) @@ -406,6 +402,7 @@ def floating_ip_get_by_address(context, address, session=None): def fixed_ip_associate(context, address, instance_id): session = get_session() with session.begin(): + instance = instance_get(context, instance_id, session=session) fixed_ip_ref = session.query(models.FixedIp ).filter_by(address=address ).filter_by(deleted=False @@ -416,9 +413,7 @@ def fixed_ip_associate(context, address, instance_id): # then this has concurrency issues if not fixed_ip_ref: raise db.NoMoreAddresses() - fixed_ip_ref.instance = instance_get(context, - instance_id, - session=session) + fixed_ip_ref.instance = instance session.add(fixed_ip_ref) @@ -472,21 +467,21 @@ def fixed_ip_disassociate(context, address): @require_context def fixed_ip_get_by_address(context, address, session=None): - # TODO(devcamcar): Ensure floating ip belongs to user. - # Only possible if it is associated with an instance. - # May have to use system context for this always. if not session: session = get_session() result = session.query(models.FixedIp ).filter_by(address=address - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).options(joinedload('network') ).options(joinedload('instance') ).first() if not result: raise exception.NotFound('No floating ip for address %s' % address) + if is_user_context(context): + authorize_project_context(context, result.instance.project_id) + return result @@ -562,7 +557,7 @@ def instance_get(context, instance_id, session=None): if is_admin_context(context): result = session.query(models.Instance ).filter_by(id=instance_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Instance @@ -581,7 +576,7 @@ def instance_get_all(context): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() @@ -590,7 +585,7 @@ def instance_get_all_by_user(context, user_id): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).filter_by(user_id=user_id ).all() @@ -603,7 +598,7 @@ def instance_get_all_by_project(context, project_id): return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).filter_by(project_id=project_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() @@ -615,7 +610,7 @@ def instance_get_all_by_reservation(context, reservation_id): return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).filter_by(reservation_id=reservation_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() elif is_user_context(context): return session.query(models.Instance @@ -633,7 +628,7 @@ def instance_get_by_ec2_id(context, ec2_id): if is_admin_context(context): result = session.query(models.Instance ).filter_by(ec2_id=ec2_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Instance @@ -749,7 +744,7 @@ def key_pair_get(context, user_id, name, session=None): result = session.query(models.KeyPair ).filter_by(user_id=user_id ).filter_by(name=name - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() if not result: raise exception.NotFound('no keypair for user %s, name %s' % @@ -775,7 +770,7 @@ def key_pair_get_all_by_user(context, user_id): def network_count(context): session = get_session() return session.query(models.Network - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).count() @@ -847,7 +842,7 @@ def network_get(context, network_id, session=None): if is_admin_context(context): result = session.query(models.Network ).filter_by(id=network_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Network @@ -914,7 +909,7 @@ def network_get_index(context, network_id): def network_index_count(context): session = get_session() return session.query(models.NetworkIndex - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).count() @@ -992,7 +987,7 @@ def queue_get_for(_context, topic, physical_node_id): def export_device_count(context): session = get_session() return session.query(models.ExportDevice - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).count() @@ -1038,7 +1033,7 @@ def quota_get(context, project_id, session=None): result = session.query(models.Quota ).filter_by(project_id=project_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() if not result: raise exception.NotFound('No quota for project_id %s' % project_id) @@ -1167,7 +1162,7 @@ def volume_get(context, volume_id, session=None): if is_admin_context(context): result = session.query(models.Volume ).filter_by(id=volume_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Volume @@ -1184,7 +1179,7 @@ def volume_get(context, volume_id, session=None): @require_admin_context def volume_get_all(context): return session.query(models.Volume - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() @require_context @@ -1194,7 +1189,7 @@ def volume_get_all_by_project(context, project_id): session = get_session() return session.query(models.Volume ).filter_by(project_id=project_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() @@ -1206,7 +1201,7 @@ def volume_get_by_ec2_id(context, ec2_id): if is_admin_context(context): result = session.query(models.Volume ).filter_by(ec2_id=ec2_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Volume @@ -1233,47 +1228,26 @@ def volume_ec2_id_exists(context, ec2_id, session=None): ).one()[0] -@require_context +@require_admin_context def volume_get_instance(context, volume_id): session = get_session() - result = None - - if is_admin_context(context): - result = session.query(models.Volume - ).filter_by(id=volume_id - ).filter_by(deleted=use_deleted(context) - ).options(joinedload('instance') - ).first() - elif is_user_context(context): - result = session.query(models.Volume - ).filter_by(project_id=context.project.id - ).filter_by(deleted=False - ).options(joinedload('instance') - ).first() - else: - raise exception.NotAuthorized() - + result = session.query(models.Volume + ).filter_by(id=volume_id + ).filter_by(deleted=can_read_deleted(context) + ).options(joinedload('instance') + ).first() if not result: raise exception.NotFound('Volume %s not found' % ec2_id) return result.instance -@require_context +@require_admin_context def volume_get_shelf_and_blade(context, volume_id): session = get_session() - result = None - - if is_admin_context(context): - result = session.query(models.ExportDevice - ).filter_by(volume_id=volume_id - ).first() - elif is_user_context(context): - result = session.query(models.ExportDevice - ).join(models.Volume - ).filter(models.Volume.project_id==context.project.id - ).filter_by(volume_id=volume_id - ).first() + result = session.query(models.ExportDevice + ).filter_by(volume_id=volume_id + ).first() if not result: raise exception.NotFound('No export device found for volume %s' % volume_id) diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index e01d7cff9..e601c480c 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -49,7 +49,6 @@ class NetworkTestCase(test.TrialTestCase): self.user = self.manager.create_user('netuser', 'netuser', 'netuser') self.projects = [] self.network = utils.import_object(FLAGS.network_manager) - # TODO(devcamcar): Passing project=None is Bad(tm). self.context = context.APIRequestContext(project=None, user=self.user) for i in range(5): name = 'project%s' % i @@ -60,11 +59,9 @@ class NetworkTestCase(test.TrialTestCase): user_context = context.APIRequestContext(project=self.projects[i], user=self.user) self.network.set_network_host(user_context, self.projects[i].id) - instance_ref = db.instance_create(None, - {'mac_address': utils.generate_mac()}) + instance_ref = self._create_instance(0) self.instance_id = instance_ref['id'] - instance_ref = db.instance_create(None, - {'mac_address': utils.generate_mac()}) + instance_ref = self._create_instance(1) self.instance2_id = instance_ref['id'] def tearDown(self): # pylint: disable-msg=C0103 @@ -77,6 +74,15 @@ class NetworkTestCase(test.TrialTestCase): self.manager.delete_project(project) self.manager.delete_user(self.user) + def _create_instance(self, project_num, mac=None): + if not mac: + mac = utils.generate_mac() + project = self.projects[project_num] + self.context.project = project + return db.instance_create(self.context, + {'project_id': project.id, + 'mac_address': mac}) + def _create_address(self, project_num, instance_id=None): """Create an address in given project num""" if instance_id is None: @@ -84,6 +90,11 @@ class NetworkTestCase(test.TrialTestCase): self.context.project = self.projects[project_num] return self.network.allocate_fixed_ip(self.context, instance_id) + def _deallocate_address(self, project_num, address): + self.context.project = self.projects[project_num] + self.network.deallocate_fixed_ip(self.context, address) + + def test_public_network_association(self): """Makes sure that we can allocaate a public ip""" # TODO(vish): better way of adding floating ips @@ -134,14 +145,14 @@ class NetworkTestCase(test.TrialTestCase): lease_ip(address) lease_ip(address2) - self.network.deallocate_fixed_ip(self.context, address) + self._deallocate_address(0, address) release_ip(address) self.assertFalse(is_allocated_in_project(address, self.projects[0].id)) # First address release shouldn't affect the second self.assertTrue(is_allocated_in_project(address2, self.projects[1].id)) - self.network.deallocate_fixed_ip(self.context, address2) + self._deallocate_address(1, address2) release_ip(address2) self.assertFalse(is_allocated_in_project(address2, self.projects[1].id)) @@ -152,24 +163,19 @@ class NetworkTestCase(test.TrialTestCase): lease_ip(first) instance_ids = [] for i in range(1, 5): - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address = self._create_address(i, instance_ref['id']) - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address2 = self._create_address(i, instance_ref['id']) - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address3 = self._create_address(i, instance_ref['id']) lease_ip(address) lease_ip(address2) lease_ip(address3) + self.context.project = self.projects[i] self.assertFalse(is_allocated_in_project(address, self.projects[0].id)) self.assertFalse(is_allocated_in_project(address2, @@ -185,7 +191,7 @@ class NetworkTestCase(test.TrialTestCase): for instance_id in instance_ids: db.instance_destroy(None, instance_id) release_ip(first) - self.network.deallocate_fixed_ip(self.context, first) + self._deallocate_address(0, first) def test_vpn_ip_and_port_looks_valid(self): """Ensure the vpn ip and port are reasonable""" @@ -246,9 +252,7 @@ class NetworkTestCase(test.TrialTestCase): addresses = [] instance_ids = [] for i in range(num_available_ips): - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(0) instance_ids.append(instance_ref['id']) address = self._create_address(0, instance_ref['id']) addresses.append(address) -- cgit