From ab2bed9ed60c5333a0f9ba3e679df9893781b72f Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 27 Sep 2010 10:39:52 +0200 Subject: Apply IP configuration to bridge regardless of whether it existed before. The fixes a race condition on hosts running both compute and network where, if compute got there first, it would set up the bridge, but not do IP configuration (because that's meant to happen on the network host), and when network came around, it would see the interface already there and not configure it further. --- nova/network/linux_net.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/nova/network/linux_net.py b/nova/network/linux_net.py index 41aeb5da7..9d5bd8495 100644 --- a/nova/network/linux_net.py +++ b/nova/network/linux_net.py @@ -118,15 +118,16 @@ def ensure_bridge(bridge, interface, net_attrs=None): # _execute("sudo brctl setageing %s 10" % bridge) _execute("sudo brctl stp %s off" % bridge) _execute("sudo brctl addif %s %s" % (bridge, interface)) - if net_attrs: - _execute("sudo ifconfig %s %s broadcast %s netmask %s up" % \ - (bridge, - net_attrs['gateway'], - net_attrs['broadcast'], - net_attrs['netmask'])) - _confirm_rule("FORWARD --in-interface %s -j ACCEPT" % bridge) - else: - _execute("sudo ifconfig %s up" % bridge) + + if net_attrs: + _execute("sudo ifconfig %s %s broadcast %s netmask %s up" % \ + (bridge, + net_attrs['gateway'], + net_attrs['broadcast'], + net_attrs['netmask'])) + _confirm_rule("FORWARD --in-interface %s -j ACCEPT" % bridge) + else: + _execute("sudo ifconfig %s up" % bridge) def get_dhcp_hosts(context, network_id): -- cgit From b4dbc4efa576af61ddc26d1c277237ad4bcdfcfa Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 27 Sep 2010 12:07:55 +0200 Subject: Add db api methods for retrieving the networks for which a host is the designated network host. --- nova/db/api.py | 12 ++++++++++++ nova/db/sqlalchemy/api.py | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/nova/db/api.py b/nova/db/api.py index c1cb1953a..4657408db 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -554,3 +554,15 @@ def volume_update(context, volume_id, values): """ return IMPL.volume_update(context, volume_id, values) + + +################### + + +def host_get_networks(context, host): + """Return all networks for which the given host is the designated + network host + """ + return IMPL.host_get_networks(context, host) + + diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 2b0dd6ea6..6e6b0e3fc 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -848,3 +848,15 @@ def volume_update(_context, volume_id, values): for (key, value) in values.iteritems(): volume_ref[key] = value volume_ref.save(session=session) + + +################### + + +def host_get_networks(context, host): + session = get_session() + with session.begin(): + return session.query(models.Network + ).filter_by(deleted=False + ).filter_by(host=host + ).all() -- cgit From e70948dbec0b21664739b2b7cdb1cc3da92bd01b Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 27 Sep 2010 12:08:40 +0200 Subject: Set up network at manager instantiation time to ensure we're ready to handle the networks we're already supposed to handle. --- nova/network/manager.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nova/network/manager.py b/nova/network/manager.py index 191c1d364..c17823f1e 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -80,6 +80,10 @@ class NetworkManager(manager.Manager): network_driver = FLAGS.network_driver self.driver = utils.import_object(network_driver) super(NetworkManager, self).__init__(*args, **kwargs) + # Set up networking for the projects for which we're already + # the designated network host. + for network in self.db.host_get_networks(None, host=kwargs['host']): + self._on_set_network_host(None, network['id']) def set_network_host(self, context, project_id): """Safely sets the host of the projects network""" -- cgit From 47cccfc21dfd4c1acf74b6d84ced8abba8c40e76 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 27 Sep 2010 12:14:20 +0200 Subject: Ensure dnsmasq can read updates to dnsmasq conffile. --- nova/network/linux_net.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/nova/network/linux_net.py b/nova/network/linux_net.py index 9d5bd8495..7d708968c 100644 --- a/nova/network/linux_net.py +++ b/nova/network/linux_net.py @@ -150,9 +150,14 @@ def update_dhcp(context, network_id): signal causing it to reload, otherwise spawn a new instance """ network_ref = db.network_get(context, network_id) - with open(_dhcp_file(network_ref['vlan'], 'conf'), 'w') as f: + + conffile = _dhcp_file(network_ref['vlan'], 'conf') + with open(conffile, 'w') as f: f.write(get_dhcp_hosts(context, network_id)) + # Make sure dnsmasq can actually read it (it setuid()s to "nobody") + os.chmod(conffile, 0644) + pid = _dnsmasq_pid_for(network_ref['vlan']) # if dnsmasq is already running, then tell it to reload -- cgit From 928df580e5973bc1fd3871a0aa31886302bb9268 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 27 Sep 2010 13:03:29 +0200 Subject: Add a flag the specifies where to find nova-dhcpbridge. --- nova/network/linux_net.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/nova/network/linux_net.py b/nova/network/linux_net.py index 7d708968c..bfa73dca0 100644 --- a/nova/network/linux_net.py +++ b/nova/network/linux_net.py @@ -28,6 +28,11 @@ from nova import flags from nova import utils +def _bin_file(script): + """Return the absolute path to scipt in the bin directory""" + return os.path.abspath(os.path.join(__file__, "../../../bin", script)) + + FLAGS = flags.FLAGS flags.DEFINE_string('dhcpbridge_flagfile', '/etc/nova/nova-dhcpbridge.conf', @@ -39,6 +44,8 @@ flags.DEFINE_string('public_interface', 'vlan1', 'Interface for public IP addresses') flags.DEFINE_string('bridge_dev', 'eth0', 'network device for bridges') +flags.DEFINE_string('dhcpbridge', _bin_file('nova-dhcpbridge'), + 'location of nova-dhcpbridge') DEFAULT_PORTS = [("tcp", 80), ("tcp", 22), ("udp", 1194), ("tcp", 443)] @@ -222,7 +229,7 @@ def _dnsmasq_cmd(net): ' --except-interface=lo', ' --dhcp-range=%s,static,120s' % net['dhcp_start'], ' --dhcp-hostsfile=%s' % _dhcp_file(net['vlan'], 'conf'), - ' --dhcp-script=%s' % _bin_file('nova-dhcpbridge'), + ' --dhcp-script=%s' % FLAGS.dhcpbridge, ' --leasefile-ro'] return ''.join(cmd) @@ -244,11 +251,6 @@ def _dhcp_file(vlan, kind): return os.path.abspath("%s/nova-%s.%s" % (FLAGS.networks_path, vlan, kind)) -def _bin_file(script): - """Return the absolute path to scipt in the bin directory""" - return os.path.abspath(os.path.join(__file__, "../../../bin", script)) - - def _dnsmasq_pid_for(vlan): """Returns he pid for prior dnsmasq instance for a vlan -- cgit From 04fa25e63bf37222d2b1cf88837f1c85cf944f54 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 27 Sep 2010 13:23:39 +0200 Subject: Only call _on_set_network_host on nova-network hosts. --- nova/network/manager.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/nova/network/manager.py b/nova/network/manager.py index c17823f1e..2530f04b7 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -80,10 +80,13 @@ class NetworkManager(manager.Manager): network_driver = FLAGS.network_driver self.driver = utils.import_object(network_driver) super(NetworkManager, self).__init__(*args, **kwargs) - # Set up networking for the projects for which we're already - # the designated network host. - for network in self.db.host_get_networks(None, host=kwargs['host']): - self._on_set_network_host(None, network['id']) + # Host only gets passed if being instantiated as part of the network + # service. + if 'host' in kwargs: + # Set up networking for the projects for which we're already + # the designated network host. + for network in self.db.host_get_networks(None, host=kwargs['host']): + self._on_set_network_host(None, network['id']) def set_network_host(self, context, project_id): """Safely sets the host of the projects network""" -- cgit From afc782e0e80a71ac8d1eb2f1d70e67375ba62aca Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 28 Sep 2010 10:59:55 +0200 Subject: Make sure we also start dnsmasq on startup if we're managing networks. --- nova/network/manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nova/network/manager.py b/nova/network/manager.py index 2530f04b7..20d4fe0f7 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -358,6 +358,7 @@ class VlanManager(NetworkManager): self.driver.ensure_vlan_bridge(network_ref['vlan'], network_ref['bridge'], network_ref) + self.driver.update_dhcp(context, network_id) @property def _bottom_reserved_ips(self): -- cgit From 687a90d6a7ad947c4a5851b1766a19209bb5e46f Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 28 Sep 2010 11:09:40 +0200 Subject: Call out to 'sudo kill' instead of using os.kill. dnsmasq runs as root or nobody, nova may or may not be running as root, so os.kill won't work. --- nova/network/linux_net.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nova/network/linux_net.py b/nova/network/linux_net.py index bfa73dca0..50d2831c3 100644 --- a/nova/network/linux_net.py +++ b/nova/network/linux_net.py @@ -172,7 +172,7 @@ def update_dhcp(context, network_id): # TODO(ja): use "/proc/%d/cmdline" % (pid) to determine if pid refers # correct dnsmasq process try: - os.kill(pid, signal.SIGHUP) + _execute('sudo kill -HUP %d' % pid) return except Exception as exc: # pylint: disable-msg=W0703 logging.debug("Hupping dnsmasq threw %s", exc) @@ -240,7 +240,7 @@ def _stop_dnsmasq(network): if pid: try: - os.kill(pid, signal.SIGTERM) + _execute('sudo kill -TERM %d' % pid) except Exception as exc: # pylint: disable-msg=W0703 logging.debug("Killing dnsmasq threw %s", exc) -- cgit From fe139bbdee60aadd720cb7a83d0846f2824c078f Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 00:49:04 -0700 Subject: Began wiring up context authorization --- nova/db/sqlalchemy/api.py | 50 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 9c3caf9af..b5847d299 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -19,6 +19,7 @@ Implementation of SQLAlchemy backend """ +import logging import sys from nova import db @@ -48,6 +49,24 @@ def _deleted(context): return context.get('deleted', False) +def is_admin_context(context): + if not context: + logging.warning('Use of empty request context is deprecated') + return True + if not context.user: + return True + return context.user.is_admin() + + +def is_user_context(context): + if not context: + logging.warning('Use of empty request context is deprecated') + return False + if not context.user or not context.project: + return False + return True + + ################### @@ -869,14 +888,41 @@ def volume_detached(_context, volume_id): def volume_get(context, volume_id): - return models.Volume.find(volume_id, deleted=_deleted(context)) + session = get_session() + + if is_admin_context(context): + volume_ref = session.query(models.Volume + ).filter_by(id=volume_id + ).filter_by(deleted=_deleted(context) + ).first() + if not volume_ref: + raise exception.NotFound('No volume for id %s' % volume_id) + + if is_user_context(context): + volume_ref = session.query(models.Volume + ).filter_by(project_id=project_id + ).filter_by(id=volume_id + ).filter_by(deleted=False + ).first() + if not volume_ref: + raise exception.NotFound('No volume for id %s' % volume_id) + + raise exception.NotAuthorized() def volume_get_all(context): - return models.Volume.all(deleted=_deleted(context)) + if is_admin_context(context): + return models.Volume.all(deleted=_deleted(context)) + raise exception.NotAuthorized() def volume_get_all_by_project(context, project_id): + if is_user_context(context): + if context.project.id != project_id: + raise exception.NotAuthorized() + elif not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.Volume ).filter_by(project_id=project_id -- cgit From e258998923b7e8fa92656aa409f875b640df930c Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 13:26:14 -0700 Subject: Progress on volumes Fixed foreign keys to respect deleted flag --- nova/db/sqlalchemy/api.py | 130 ++++++++++++++++++++++++++++++------------- nova/db/sqlalchemy/models.py | 35 +++++++++--- 2 files changed, 118 insertions(+), 47 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index b5847d299..28b937233 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -30,7 +30,7 @@ from nova.db.sqlalchemy import models from nova.db.sqlalchemy.session import get_session from sqlalchemy import or_ from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import joinedload_all +from sqlalchemy.orm import joinedload, joinedload_all from sqlalchemy.sql import exists, func FLAGS = flags.FLAGS @@ -811,6 +811,7 @@ def quota_destroy(_context, project_id): def volume_allocate_shelf_and_blade(_context, volume_id): + # TODO(devcamcar): Make admin only session = get_session() with session.begin(): export_device = session.query(models.ExportDevice @@ -839,7 +840,7 @@ def volume_attached(_context, volume_id, instance_id, mountpoint): volume_ref.save(session=session) -def volume_create(_context, values): +def volume_create(context, values): volume_ref = models.Volume() for (key, value) in values.iteritems(): volume_ref[key] = value @@ -848,7 +849,7 @@ def volume_create(_context, values): with session.begin(): while volume_ref.ec2_id == None: ec2_id = utils.generate_uid(volume_ref.__prefix__) - if not volume_ec2_id_exists(_context, ec2_id, session=session): + if not volume_ec2_id_exists(context, ec2_id, session=session): volume_ref.ec2_id = ec2_id volume_ref.save(session=session) return volume_ref @@ -876,10 +877,10 @@ def volume_destroy(_context, volume_id): {'id': volume_id}) -def volume_detached(_context, volume_id): +def volume_detached(context, volume_id): session = get_session() with session.begin(): - volume_ref = models.Volume.find(volume_id, session=session) + volume_ref = volume_get(context, volume_id, session=session) volume_ref['status'] = 'available' volume_ref['mountpoint'] = None volume_ref['attach_status'] = 'detached' @@ -887,27 +888,29 @@ def volume_detached(_context, volume_id): volume_ref.save(session=session) -def volume_get(context, volume_id): - session = get_session() +def volume_get(context, volume_id, session=None): + if not session: + session = get_session() + result = None if is_admin_context(context): - volume_ref = session.query(models.Volume - ).filter_by(id=volume_id - ).filter_by(deleted=_deleted(context) - ).first() - if not volume_ref: - raise exception.NotFound('No volume for id %s' % volume_id) + result = session.query(models.Volume + ).filter_by(id=volume_id + ).filter_by(deleted=_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Volume + ).filter_by(project_id=context.project.project_id + ).filter_by(id=volume_id + ).filter_by(deleted=False + ).first() + else: + raise exception.NotAuthorized() - if is_user_context(context): - volume_ref = session.query(models.Volume - ).filter_by(project_id=project_id - ).filter_by(id=volume_id - ).filter_by(deleted=False - ).first() - if not volume_ref: - raise exception.NotFound('No volume for id %s' % volume_id) + if not result: + raise exception.NotFound('No volume for id %s' % volume_id) - raise exception.NotAuthorized() + return result def volume_get_all(context): @@ -916,6 +919,7 @@ def volume_get_all(context): raise exception.NotAuthorized() + def volume_get_all_by_project(context, project_id): if is_user_context(context): if context.project.id != project_id: @@ -932,42 +936,92 @@ def volume_get_all_by_project(context, project_id): def volume_get_by_ec2_id(context, ec2_id): session = get_session() - volume_ref = session.query(models.Volume + result = None + + if is_admin_context(context): + result = session.query(models.Volume ).filter_by(ec2_id=ec2_id ).filter_by(deleted=_deleted(context) ).first() - if not volume_ref: - raise exception.NotFound('Volume %s not found' % (ec2_id)) + elif is_user_context(context): + result = session.query(models.Volume + ).filter_by(project_id=context.project.id + ).filter_by(ec2_id=ec2_id + ).filter_by(deleted=False + ).first() + else: + raise exception.NotAuthorized() - return volume_ref + if not result: + raise exception.NotFound('Volume %s not found' % ec2_id) + + return result def volume_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() - return session.query(exists().where(models.Volume.id==ec2_id)).one()[0] + + if is_admin_context(context) or is_user_context(context): + return session.query(exists( + ).where(models.Volume.id==ec2_id) + ).one()[0] + else: + raise exception.NotAuthorized() -def volume_get_instance(_context, volume_id): +def volume_get_instance(context, volume_id): session = get_session() - with session.begin(): - return models.Volume.find(volume_id, session=session).instance + result = None + + if is_admin_context(context): + result = session.query(models.Volume + ).filter_by(id=volume_id + ).filter_by(deleted=_deleted(context) + ).options(joinedload('instance') + ).first() + elif is_user_context(context): + result = session.query(models.Volume + ).filter_by(project_id=context.project.id + ).filter_by(deleted=False + ).options(joinedload('instance') + ).first() + else: + raise exception.NotAuthorized() + + if not result: + raise exception.NotFound('Volume %s not found' % ec2_id) + + return result.instance -def volume_get_shelf_and_blade(_context, volume_id): +def volume_get_shelf_and_blade(context, volume_id): session = get_session() - export_device = session.query(models.ExportDevice - ).filter_by(volume_id=volume_id - ).first() - if not export_device: + result = None + + if is_admin_context(context): + result = session.query(models.ExportDevice + ).filter_by(volume_id=volume_id + ).first() + elif is_user_context(context): + result = session.query(models.ExportDevice + ).join(models.Volume + ).filter(models.Volume.project_id==context.project.id + ).filter_by(volume_id=volume_id + ).first() + else: + raise exception.NotAuthorized() + + if not result: raise exception.NotFound() - return (export_device.shelf_id, export_device.blade_id) + return (result.shelf_id, result.blade_id) -def volume_update(_context, volume_id, values): + +def volume_update(context, volume_id, values): session = get_session() with session.begin(): - volume_ref = models.Volume.find(volume_id, session=session) + volume_ref = volume_get(context, volume_id, session=session) for (key, value) in values.iteritems(): volume_ref[key] = value volume_ref.save(session=session) diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 01e58b05e..1b9edf475 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -282,7 +282,11 @@ class Volume(BASE, NovaBase): size = Column(Integer) availability_zone = Column(String(255)) # TODO(vish): foreign key? instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True) - instance = relationship(Instance, backref=backref('volumes')) + instance = relationship(Instance, + backref=backref('volumes'), + foreign_keys=instance_id, + primaryjoin='and_(Volume.instance_id==Instance.id,' + 'Volume.deleted==False)') mountpoint = Column(String(255)) attach_time = Column(String(255)) # TODO(vish): datetime status = Column(String(255)) # TODO(vish): enum? @@ -333,8 +337,11 @@ class ExportDevice(BASE, NovaBase): shelf_id = Column(Integer) blade_id = Column(Integer) volume_id = Column(Integer, ForeignKey('volumes.id'), nullable=True) - volume = relationship(Volume, backref=backref('export_device', - uselist=False)) + volume = relationship(Volume, + backref=backref('export_device', uselist=False), + foreign_keys=volume_id, + primaryjoin='and_(ExportDevice.volume_id==Volume.id,' + 'ExportDevice.deleted==False)') class KeyPair(BASE, NovaBase): @@ -407,8 +414,12 @@ class NetworkIndex(BASE, NovaBase): id = Column(Integer, primary_key=True) index = Column(Integer, unique=True) network_id = Column(Integer, ForeignKey('networks.id'), nullable=True) - network = relationship(Network, backref=backref('network_index', - uselist=False)) + network = relationship(Network, + backref=backref('network_index', uselist=False), + foreign_keys=network_id, + primaryjoin='and_(NetworkIndex.network_id==Network.id,' + 'NetworkIndex.deleted==False)') + class AuthToken(BASE, NovaBase): """Represents an authorization token for all API transactions. Fields @@ -432,8 +443,11 @@ class FixedIp(BASE, NovaBase): network_id = Column(Integer, ForeignKey('networks.id'), nullable=True) network = relationship(Network, backref=backref('fixed_ips')) instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True) - instance = relationship(Instance, backref=backref('fixed_ip', - uselist=False)) + instance = relationship(Instance, + backref=backref('fixed_ip', uselist=False), + foreign_keys=instance_id, + primaryjoin='and_(FixedIp.instance_id==Instance.id,' + 'FixedIp.deleted==False)') allocated = Column(Boolean, default=False) leased = Column(Boolean, default=False) reserved = Column(Boolean, default=False) @@ -462,8 +476,11 @@ class FloatingIp(BASE, NovaBase): id = Column(Integer, primary_key=True) address = Column(String(255)) fixed_ip_id = Column(Integer, ForeignKey('fixed_ips.id'), nullable=True) - fixed_ip = relationship(FixedIp, backref=backref('floating_ips')) - + fixed_ip = relationship(FixedIp, + backref=backref('floating_ips'), + foreign_keys=fixed_ip_id, + primaryjoin='and_(FloatingIp.fixed_ip_id==FixedIp.id,' + 'FloatingIp.deleted==False)') project_id = Column(String(255)) host = Column(String(255)) # , ForeignKey('hosts.id')) -- cgit From f4cf49ec3761bdd38dd1a6cb064875b90e65ad4e Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 14:27:31 -0700 Subject: Wired up context auth for services --- nova/db/sqlalchemy/api.py | 111 ++++++++++++++++++++++++++++++++++--------- nova/db/sqlalchemy/models.py | 15 ------ 2 files changed, 89 insertions(+), 37 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 28b937233..01a5af38b 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -71,16 +71,37 @@ def is_user_context(context): def service_destroy(context, service_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - service_ref = models.Service.find(service_id, session=session) + service_ref = service_get(context, service_id, session=session) service_ref.delete(session=session) -def service_get(_context, service_id): - return models.Service.find(service_id) + +def service_get(context, service_id, session=None): + if not is_admin_context(context): + raise exception.NotAuthorized() + + if not session: + session = get_session() + + result = session.query(models.Service + ).filter_by(id=service_id + ).filter_by(deleted=_deleted(context) + ).first() + + if not result: + raise exception.NotFound('No service for id %s' % service_id) + + return result def service_get_all_by_topic(context, topic): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.Service ).filter_by(deleted=False @@ -89,7 +110,10 @@ def service_get_all_by_topic(context, topic): ).all() -def _service_get_all_topic_subquery(_context, session, topic, subq, label): +def _service_get_all_topic_subquery(context, session, topic, subq, label): + if not is_admin_context(context): + raise exception.NotAuthorized() + sort_value = getattr(subq.c, label) return session.query(models.Service, func.coalesce(sort_value, 0) ).filter_by(topic=topic @@ -101,6 +125,9 @@ def _service_get_all_topic_subquery(_context, session, topic, subq, label): def service_get_all_compute_sorted(context): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): # NOTE(vish): The intended query is below @@ -125,6 +152,9 @@ def service_get_all_compute_sorted(context): def service_get_all_network_sorted(context): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): topic = 'network' @@ -142,6 +172,9 @@ def service_get_all_network_sorted(context): def service_get_all_volume_sorted(context): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): topic = 'volume' @@ -158,11 +191,27 @@ def service_get_all_volume_sorted(context): label) -def service_get_by_args(_context, host, binary): - return models.Service.find_by_args(host, binary) +def service_get_by_args(context, host, binary): + if not is_admin_context(context): + raise exception.NotAuthorized() + + session = get_session() + result = session.query(models.Service + ).filter_by(host=host + ).filter_by(binary=binary + ).filter_by(deleted=_deleted(context) + ).first() + + if not result: + raise exception.NotFound('No service for %s, %s' % (host, binary)) + + return result + +def service_create(context, values): + if not is_admin_context(context): + return exception.NotAuthorized() -def service_create(_context, values): service_ref = models.Service() for (key, value) in values.iteritems(): service_ref[key] = value @@ -170,10 +219,13 @@ def service_create(_context, values): return service_ref -def service_update(_context, service_id, values): +def service_update(context, service_id, values): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - service_ref = models.Service.find(service_id, session=session) + service_ref = session_get(context, service_id, session=session) for (key, value) in values.iteritems(): service_ref[key] = value service_ref.save(session=session) @@ -428,8 +480,8 @@ def instance_destroy(_context, instance_id): instance_ref.delete(session=session) -def instance_get(context, instance_id): - return models.Instance.find(instance_id, deleted=_deleted(context)) +def instance_get(context, instance_id, session=None): + return models.Instance.find(instance_id, session=session, deleted=_deleted(context)) def instance_get_all(context): @@ -810,8 +862,10 @@ def quota_destroy(_context, project_id): ################### -def volume_allocate_shelf_and_blade(_context, volume_id): - # TODO(devcamcar): Make admin only +def volume_allocate_shelf_and_blade(context, volume_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): export_device = session.query(models.ExportDevice @@ -828,15 +882,17 @@ def volume_allocate_shelf_and_blade(_context, volume_id): return (export_device.shelf_id, export_device.blade_id) -def volume_attached(_context, volume_id, instance_id, mountpoint): +def volume_attached(context, volume_id, instance_id, mountpoint): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - volume_ref = models.Volume.find(volume_id, session=session) + volume_ref = volume_get(context, volume_id, session=session) volume_ref['status'] = 'in-use' volume_ref['mountpoint'] = mountpoint volume_ref['attach_status'] = 'attached' - volume_ref.instance = models.Instance.find(instance_id, - session=session) + volume_ref.instance = instance_get(context, instance_id, session=session) volume_ref.save(session=session) @@ -855,7 +911,10 @@ def volume_create(context, values): return volume_ref -def volume_data_get_for_project(_context, project_id): +def volume_data_get_for_project(context, project_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() result = session.query(func.count(models.Volume.id), func.sum(models.Volume.size) @@ -866,7 +925,10 @@ def volume_data_get_for_project(_context, project_id): return (result[0] or 0, result[1] or 0) -def volume_destroy(_context, volume_id): +def volume_destroy(context, volume_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -878,6 +940,9 @@ def volume_destroy(_context, volume_id): def volume_detached(context, volume_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): volume_ref = volume_get(context, volume_id, session=session) @@ -914,10 +979,12 @@ def volume_get(context, volume_id, session=None): def volume_get_all(context): - if is_admin_context(context): - return models.Volume.all(deleted=_deleted(context)) + if not is_admin_context(context): + raise exception.NotAuthorized() - raise exception.NotAuthorized() + return session.query(models.Volume + ).filter_by(deleted=_deleted(context) + ).all() def volume_get_all_by_project(context, project_id): diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 1b9edf475..b9bb8e4f2 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -176,21 +176,6 @@ class Service(BASE, NovaBase): report_count = Column(Integer, nullable=False, default=0) disabled = Column(Boolean, default=False) - @classmethod - def find_by_args(cls, host, binary, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(host=host - ).filter_by(binary=binary - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for %s, %s" % (host, - binary)) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - class Instance(BASE, NovaBase): """Represents a guest vm""" -- cgit From 734df1fbad8195e7cd7072d0d0aeb5b94841f121 Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 19:09:00 -0700 Subject: Made network tests pass again --- nova/db/api.py | 1 - nova/db/sqlalchemy/api.py | 233 +++++++++++++++++++++++++++++------------ nova/db/sqlalchemy/models.py | 26 ----- nova/network/manager.py | 3 +- nova/tests/network_unittest.py | 1 + 5 files changed, 170 insertions(+), 94 deletions(-) diff --git a/nova/db/api.py b/nova/db/api.py index b68a0fe8f..4cfdd788c 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -175,7 +175,6 @@ def floating_ip_get_by_address(context, address): return IMPL.floating_ip_get_by_address(context, address) -def floating_ip_get_instance(context, address): """Get an instance for a floating ip by address.""" return IMPL.floating_ip_get_instance(context, address) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 01a5af38b..d129df2be 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -234,7 +234,13 @@ def service_update(context, service_id, values): ################### -def floating_ip_allocate_address(_context, host, project_id): +def floating_ip_allocate_address(context, host, project_id): + if is_user_context(context): + if context.project.id != project_id: + raise exception.NotAuthorized() + elif not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): floating_ip_ref = session.query(models.FloatingIp @@ -253,7 +259,10 @@ def floating_ip_allocate_address(_context, host, project_id): return floating_ip_ref['address'] -def floating_ip_create(_context, values): +def floating_ip_create(context, values): + if not is_user_context(context) and not is_admin_context(context): + raise exception.NotAuthorized() + floating_ip_ref = models.FloatingIp() for (key, value) in values.iteritems(): floating_ip_ref[key] = value @@ -261,7 +270,13 @@ def floating_ip_create(_context, values): return floating_ip_ref['address'] -def floating_ip_count_by_project(_context, project_id): +def floating_ip_count_by_project(context, project_id): + if is_user_context(context): + if context.project.id != project_id: + raise exception.NotAuthorized() + elif not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.FloatingIp ).filter_by(project_id=project_id @@ -269,39 +284,63 @@ def floating_ip_count_by_project(_context, project_id): ).count() -def floating_ip_fixed_ip_associate(_context, floating_address, fixed_address): +#@require_context +def floating_ip_fixed_ip_associate(context, floating_address, fixed_address): + if not is_user_context(context) and not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(floating_address, - session=session) - fixed_ip_ref = models.FixedIp.find_by_str(fixed_address, - session=session) + # TODO(devcamcar): How to ensure floating_id belongs to user? + floating_ip_ref = floating_ip_get_by_address(context, + floating_address, + session=session) + fixed_ip_ref = fixed_ip_get_by_address(context, + fixed_address, + session=session) floating_ip_ref.fixed_ip = fixed_ip_ref floating_ip_ref.save(session=session) -def floating_ip_deallocate(_context, address): +#@require_context +def floating_ip_deallocate(context, address): + if not is_user_context(context) and not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) + # TODO(devcamcar): How to ensure floating id belongs to user? + floating_ip_ref = floating_ip_get_by_address(context, + address, + session=session) floating_ip_ref['project_id'] = None floating_ip_ref.save(session=session) +#@require_context +def floating_ip_destroy(context, address): + if not is_user_context(context) and not is_admin_context(context): + raise exception.NotAuthorized() -def floating_ip_destroy(_context, address): session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) + # TODO(devcamcar): Ensure address belongs to user. + floating_ip_ref = get_floating_ip_by_address(context, + address, + session=session) floating_ip_ref.delete(session=session) -def floating_ip_disassociate(_context, address): +def floating_ip_disassociate(context, address): + if not is_user_context(context) and is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) + # TODO(devcamcar): Ensure address belongs to user. + # Does get_floating_ip_by_address handle this? + floating_ip_ref = floating_ip_get_by_address(context, + address, + session=session) fixed_ip_ref = floating_ip_ref.fixed_ip if fixed_ip_ref: fixed_ip_address = fixed_ip_ref['address'] @@ -311,16 +350,22 @@ def floating_ip_disassociate(_context, address): floating_ip_ref.save(session=session) return fixed_ip_address +#@require_admin_context +def floating_ip_get_all(context): + if not is_admin_context(context): + raise exception.NotAuthorized() -def floating_ip_get_all(_context): session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') ).filter_by(deleted=False ).all() +#@require_admin_context +def floating_ip_get_all_by_host(context, host): + if not is_admin_context(context): + raise exception.NotAuthorized() -def floating_ip_get_all_by_host(_context, host): session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -328,7 +373,15 @@ def floating_ip_get_all_by_host(_context, host): ).filter_by(deleted=False ).all() -def floating_ip_get_all_by_project(_context, project_id): +#@require_context +def floating_ip_get_all_by_project(context, project_id): + # TODO(devcamcar): Change to decorate and check project_id separately. + if is_user_context(context): + if context.project.id != project_id: + raise exception.NotAuthorized() + elif not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -336,22 +389,38 @@ def floating_ip_get_all_by_project(_context, project_id): ).filter_by(deleted=False ).all() -def floating_ip_get_by_address(_context, address): - return models.FloatingIp.find_by_str(address) +#@require_context +def floating_ip_get_by_address(context, address, session=None): + # TODO(devcamcar): Ensure the address belongs to user. + if not is_user_context(context) and not is_admin_context(context): + raise exception.NotAuthorized() + + if not session: + session = get_session() + + result = session.query(models.FloatingIp + ).filter_by(address=address + ).filter_by(deleted=_deleted(context) + ).first() + if not result: + raise exception.NotFound('No fixed ip for address %s' % address) + return result -def floating_ip_get_instance(_context, address): - session = get_session() - with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) - return floating_ip_ref.fixed_ip.instance + + # floating_ip_ref = get_floating_ip_by_address(context, + # address, + # session=session) + # return floating_ip_ref.fixed_ip.instance ################### +#@require_context +def fixed_ip_associate(context, address, instance_id): + if not is_user_context(context) and not is_admin_context(context): + raise exception.NotAuthorized() -def fixed_ip_associate(_context, address, instance_id): session = get_session() with session.begin(): fixed_ip_ref = session.query(models.FixedIp @@ -364,12 +433,17 @@ def fixed_ip_associate(_context, address, instance_id): # then this has concurrency issues if not fixed_ip_ref: raise db.NoMoreAddresses() - fixed_ip_ref.instance = models.Instance.find(instance_id, - session=session) + fixed_ip_ref.instance = instance_get(context, + instance_id, + session=session) session.add(fixed_ip_ref) -def fixed_ip_associate_pool(_context, network_id, instance_id): +#@require_admin_context +def fixed_ip_associate_pool(context, network_id, instance_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): network_or_none = or_(models.FixedIp.network_id == network_id, @@ -386,14 +460,16 @@ def fixed_ip_associate_pool(_context, network_id, instance_id): if not fixed_ip_ref: raise db.NoMoreAddresses() if not fixed_ip_ref.network: - fixed_ip_ref.network = models.Network.find(network_id, - session=session) - fixed_ip_ref.instance = models.Instance.find(instance_id, - session=session) + fixed_ip_ref.network = network_get(context, + network_id, + session=session) + fixed_ip_ref.instance = instance_get(context, + instance_id, + session=session) session.add(fixed_ip_ref) return fixed_ip_ref['address'] - +#@require_context def fixed_ip_create(_context, values): fixed_ip_ref = models.FixedIp() for (key, value) in values.iteritems(): @@ -401,45 +477,56 @@ def fixed_ip_create(_context, values): fixed_ip_ref.save() return fixed_ip_ref['address'] - -def fixed_ip_disassociate(_context, address): +#@require_context +def fixed_ip_disassociate(context, address): session = get_session() with session.begin(): - fixed_ip_ref = models.FixedIp.find_by_str(address, session=session) + fixed_ip_ref = fixed_ip_get_by_address(context, + address, + session=session) fixed_ip_ref.instance = None fixed_ip_ref.save(session=session) -def fixed_ip_get_by_address(_context, address): - session = get_session() - with session.begin(): - try: - return session.query(models.FixedIp - ).options(joinedload_all('instance') - ).filter_by(address=address - ).filter_by(deleted=False - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for address %s" % address) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - - -def fixed_ip_get_instance(_context, address): - session = get_session() - with session.begin(): - return models.FixedIp.find_by_str(address, session=session).instance +#@require_context +def fixed_ip_get_by_address(context, address, session=None): + # TODO(devcamcar): Ensure floating ip belongs to user. + # Only possible if it is associated with an instance. + # May have to use system context for this always. + if not session: + session = get_session() + result = session.query(models.FixedIp + ).filter_by(address=address + ).filter_by(deleted=_deleted(context) + ).options(joinedload('network') + ).options(joinedload('instance') + ).first() + if not result: + raise exception.NotFound('No floating ip for address %s' % address) + + return result + + +#@require_context +def fixed_ip_get_instance(context, address): + fixed_ip_ref = fixed_ip_get_by_address(context, address) + return fixed_ip_ref.instance -def fixed_ip_get_network(_context, address): - session = get_session() - with session.begin(): - return models.FixedIp.find_by_str(address, session=session).network +#@require_admin_context +def fixed_ip_get_network(context, address): + fixed_ip_ref = fixed_ip_get_by_address(context, address) + return fixed_ip_ref.network -def fixed_ip_update(_context, address, values): + +#@require_context +def fixed_ip_update(context, address, values): session = get_session() with session.begin(): - fixed_ip_ref = models.FixedIp.find_by_str(address, session=session) + fixed_ip_ref = fixed_ip_get_by_address(context, + address, + session=session) for (key, value) in values.iteritems(): fixed_ip_ref[key] = value fixed_ip_ref.save(session=session) @@ -462,7 +549,9 @@ def instance_create(_context, values): instance_ref.save(session=session) return instance_ref + def instance_data_get_for_project(_context, project_id): + # TODO(devmcar): Admin only session = get_session() result = session.query(func.count(models.Instance.id), func.sum(models.Instance.vcpus) @@ -474,6 +563,7 @@ def instance_data_get_for_project(_context, project_id): def instance_destroy(_context, instance_id): + # TODO(devcamcar): Support user context session = get_session() with session.begin(): instance_ref = models.Instance.find(instance_id, session=session) @@ -481,17 +571,21 @@ def instance_destroy(_context, instance_id): def instance_get(context, instance_id, session=None): + # TODO(devcamcar): Support user context return models.Instance.find(instance_id, session=session, deleted=_deleted(context)) def instance_get_all(context): + # TODO(devcamcar): Admin only session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).filter_by(deleted=_deleted(context) ).all() + def instance_get_all_by_user(context, user_id): + # TODO(devcamcar): Admin only session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -499,7 +593,9 @@ def instance_get_all_by_user(context, user_id): ).filter_by(user_id=user_id ).all() + def instance_get_all_by_project(context, project_id): + # TODO(devcamcar): Support user context session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -509,6 +605,7 @@ def instance_get_all_by_project(context, project_id): def instance_get_all_by_reservation(_context, reservation_id): + # TODO(devcamcar): Support user context session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -518,6 +615,7 @@ def instance_get_all_by_reservation(_context, reservation_id): def instance_get_by_ec2_id(context, ec2_id): + # TODO(devcamcar): Support user context session = get_session() instance_ref = session.query(models.Instance ).filter_by(ec2_id=ec2_id @@ -536,6 +634,7 @@ def instance_ec2_id_exists(context, ec2_id, session=None): def instance_get_fixed_address(_context, instance_id): + # TODO(devcamcar): Support user context session = get_session() with session.begin(): instance_ref = models.Instance.find(instance_id, session=session) @@ -545,6 +644,7 @@ def instance_get_fixed_address(_context, instance_id): def instance_get_floating_address(_context, instance_id): + # TODO(devcamcar): Support user context session = get_session() with session.begin(): instance_ref = models.Instance.find(instance_id, session=session) @@ -557,6 +657,7 @@ def instance_get_floating_address(_context, instance_id): def instance_is_vpn(context, instance_id): + # TODO(devcamcar): Admin only # TODO(vish): Move this into image code somewhere instance_ref = instance_get(context, instance_id) return instance_ref['image_id'] == FLAGS.vpn_image_id @@ -683,8 +784,8 @@ def network_destroy(_context, network_id): {'id': network_id}) -def network_get(_context, network_id): - return models.Network.find(network_id) +def network_get(_context, network_id, session=None): + return models.Network.find(network_id, session=session) # NOTE(vish): pylint complains because of the long method name, but diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index b9bb8e4f2..7a085c4df 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -441,19 +441,6 @@ class FixedIp(BASE, NovaBase): def str_id(self): return self.address - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(address=str_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for address %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - class FloatingIp(BASE, NovaBase): """Represents a floating ip that dynamically forwards to a fixed ip""" @@ -469,19 +456,6 @@ class FloatingIp(BASE, NovaBase): project_id = Column(String(255)) host = Column(String(255)) # , ForeignKey('hosts.id')) - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(address=str_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for address %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - def register_models(): """Register Models and create metadata""" diff --git a/nova/network/manager.py b/nova/network/manager.py index a7126ea4f..d125d28d8 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -232,7 +232,8 @@ class VlanManager(NetworkManager): address = network_ref['vpn_private_address'] self.db.fixed_ip_associate(context, address, instance_id) else: - address = self.db.fixed_ip_associate_pool(context, + # TODO(devcamcar) Pass system context here. + address = self.db.fixed_ip_associate_pool(None, network_ref['id'], instance_id) self.db.fixed_ip_update(context, address, {'allocated': True}) diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index da65b50a2..110e8430c 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -84,6 +84,7 @@ class NetworkTestCase(test.TrialTestCase): def test_public_network_association(self): """Makes sure that we can allocaate a public ip""" # TODO(vish): better way of adding floating ips + self.context.project = self.projects[0] pubnet = IPy.IP(flags.FLAGS.public_range) address = str(pubnet[0]) try: -- cgit From d32d95e08d67084ea04ccd1565ce6faffb1766ce Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 20:29:55 -0700 Subject: Finished instance context auth --- nova/db/sqlalchemy/api.py | 185 ++++++++++++++++++++++++++++++----------- nova/tests/compute_unittest.py | 2 + nova/tests/network_unittest.py | 4 +- 3 files changed, 141 insertions(+), 50 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index d129df2be..9ab53b89b 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -41,9 +41,10 @@ FLAGS = flags.FLAGS # pylint: disable-msg=C0111 def _deleted(context): """Calculates whether to include deleted objects based on context. - - Currently just looks for a flag called deleted in the context dict. + Currently just looks for a flag called deleted in the context dict. """ + if is_user_context(context): + return False if not hasattr(context, 'get'): return False return context.get('deleted', False) @@ -69,7 +70,7 @@ def is_user_context(context): ################### - +#@require_admin_context def service_destroy(context, service_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -80,6 +81,7 @@ def service_destroy(context, service_id): service_ref.delete(session=session) +#@require_admin_context def service_get(context, service_id, session=None): if not is_admin_context(context): raise exception.NotAuthorized() @@ -98,6 +100,7 @@ def service_get(context, service_id, session=None): return result +#@require_admin_context def service_get_all_by_topic(context, topic): if not is_admin_context(context): raise exception.NotAuthorized() @@ -110,6 +113,7 @@ def service_get_all_by_topic(context, topic): ).all() +#@require_admin_context def _service_get_all_topic_subquery(context, session, topic, subq, label): if not is_admin_context(context): raise exception.NotAuthorized() @@ -124,6 +128,7 @@ def _service_get_all_topic_subquery(context, session, topic, subq, label): ).all() +#@require_admin_context def service_get_all_compute_sorted(context): if not is_admin_context(context): raise exception.NotAuthorized() @@ -151,6 +156,7 @@ def service_get_all_compute_sorted(context): label) +#@require_admin_context def service_get_all_network_sorted(context): if not is_admin_context(context): raise exception.NotAuthorized() @@ -171,6 +177,7 @@ def service_get_all_network_sorted(context): label) +#@require_admin_context def service_get_all_volume_sorted(context): if not is_admin_context(context): raise exception.NotAuthorized() @@ -191,6 +198,7 @@ def service_get_all_volume_sorted(context): label) +#@require_admin_context def service_get_by_args(context, host, binary): if not is_admin_context(context): raise exception.NotAuthorized() @@ -208,6 +216,7 @@ def service_get_by_args(context, host, binary): return result +#@require_admin_context def service_create(context, values): if not is_admin_context(context): return exception.NotAuthorized() @@ -219,6 +228,7 @@ def service_create(context, values): return service_ref +#@require_admin_context def service_update(context, service_id, values): if not is_admin_context(context): raise exception.NotAuthorized() @@ -234,12 +244,11 @@ def service_update(context, service_id, values): ################### +#@require_context def floating_ip_allocate_address(context, host, project_id): if is_user_context(context): if context.project.id != project_id: raise exception.NotAuthorized() - elif not is_admin_context(context): - raise exception.NotAuthorized() session = get_session() with session.begin(): @@ -259,6 +268,7 @@ def floating_ip_allocate_address(context, host, project_id): return floating_ip_ref['address'] +#@require_context def floating_ip_create(context, values): if not is_user_context(context) and not is_admin_context(context): raise exception.NotAuthorized() @@ -270,12 +280,11 @@ def floating_ip_create(context, values): return floating_ip_ref['address'] +#@require_context def floating_ip_count_by_project(context, project_id): if is_user_context(context): if context.project.id != project_id: raise exception.NotAuthorized() - elif not is_admin_context(context): - raise exception.NotAuthorized() session = get_session() return session.query(models.FloatingIp @@ -316,6 +325,7 @@ def floating_ip_deallocate(context, address): floating_ip_ref['project_id'] = None floating_ip_ref.save(session=session) + #@require_context def floating_ip_destroy(context, address): if not is_user_context(context) and not is_admin_context(context): @@ -330,6 +340,7 @@ def floating_ip_destroy(context, address): floating_ip_ref.delete(session=session) +#@require_context def floating_ip_disassociate(context, address): if not is_user_context(context) and is_admin_context(context): raise exception.NotAuthorized() @@ -350,6 +361,7 @@ def floating_ip_disassociate(context, address): floating_ip_ref.save(session=session) return fixed_ip_address + #@require_admin_context def floating_ip_get_all(context): if not is_admin_context(context): @@ -361,6 +373,7 @@ def floating_ip_get_all(context): ).filter_by(deleted=False ).all() + #@require_admin_context def floating_ip_get_all_by_host(context, host): if not is_admin_context(context): @@ -373,6 +386,7 @@ def floating_ip_get_all_by_host(context, host): ).filter_by(deleted=False ).all() + #@require_context def floating_ip_get_all_by_project(context, project_id): # TODO(devcamcar): Change to decorate and check project_id separately. @@ -389,6 +403,7 @@ def floating_ip_get_all_by_project(context, project_id): ).filter_by(deleted=False ).all() + #@require_context def floating_ip_get_by_address(context, address, session=None): # TODO(devcamcar): Ensure the address belongs to user. @@ -408,14 +423,9 @@ def floating_ip_get_by_address(context, address, session=None): return result - # floating_ip_ref = get_floating_ip_by_address(context, - # address, - # session=session) - # return floating_ip_ref.fixed_ip.instance - - ################### + #@require_context def fixed_ip_associate(context, address, instance_id): if not is_user_context(context) and not is_admin_context(context): @@ -469,6 +479,7 @@ def fixed_ip_associate_pool(context, network_id, instance_id): session.add(fixed_ip_ref) return fixed_ip_ref['address'] + #@require_context def fixed_ip_create(_context, values): fixed_ip_ref = models.FixedIp() @@ -477,6 +488,7 @@ def fixed_ip_create(_context, values): fixed_ip_ref.save() return fixed_ip_ref['address'] + #@require_context def fixed_ip_disassociate(context, address): session = get_session() @@ -535,7 +547,8 @@ def fixed_ip_update(context, address, values): ################### -def instance_create(_context, values): +#@require_context +def instance_create(context, values): instance_ref = models.Instance() for (key, value) in values.iteritems(): instance_ref[key] = value @@ -544,14 +557,14 @@ def instance_create(_context, values): with session.begin(): while instance_ref.ec2_id == None: ec2_id = utils.generate_uid(instance_ref.__prefix__) - if not instance_ec2_id_exists(_context, ec2_id, session=session): + if not instance_ec2_id_exists(context, ec2_id, session=session): instance_ref.ec2_id = ec2_id instance_ref.save(session=session) return instance_ref -def instance_data_get_for_project(_context, project_id): - # TODO(devmcar): Admin only +#@require_admin_context +def instance_data_get_for_project(context, project_id): session = get_session() result = session.query(func.count(models.Instance.id), func.sum(models.Instance.vcpus) @@ -562,21 +575,42 @@ def instance_data_get_for_project(_context, project_id): return (result[0] or 0, result[1] or 0) -def instance_destroy(_context, instance_id): - # TODO(devcamcar): Support user context +#@require_context +def instance_destroy(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) instance_ref.delete(session=session) +#@require_context def instance_get(context, instance_id, session=None): - # TODO(devcamcar): Support user context - return models.Instance.find(instance_id, session=session, deleted=_deleted(context)) + if not session: + session = get_session() + result = None + + if is_admin_context(context): + result = session.query(models.Instance + ).filter_by(id=instance_id + ).filter_by(deleted=_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Instance + ).filter_by(project_id=context.project.id + ).filter_by(id=instance_id + ).filter_by(deleted=False + ).first() + if not result: + raise exception.NotFound('No instance for id %s' % instance_id) + + return result +#@require_admin_context def instance_get_all(context): - # TODO(devcamcar): Admin only + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -584,8 +618,11 @@ def instance_get_all(context): ).all() +#@require_admin_context def instance_get_all_by_user(context, user_id): - # TODO(devcamcar): Admin only + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -594,8 +631,12 @@ def instance_get_all_by_user(context, user_id): ).all() +#@require_context def instance_get_all_by_project(context, project_id): - # TODO(devcamcar): Support user context + if is_user_context(context): + if context.project.id != project_id: + raise exception.NotAuthorized() + session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -604,50 +645,68 @@ def instance_get_all_by_project(context, project_id): ).all() -def instance_get_all_by_reservation(_context, reservation_id): - # TODO(devcamcar): Support user context +#@require_context +def instance_get_all_by_reservation(context, reservation_id): session = get_session() - return session.query(models.Instance - ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(reservation_id=reservation_id - ).filter_by(deleted=False - ).all() + + if is_admin_context(context): + return session.query(models.Instance + ).options(joinedload_all('fixed_ip.floating_ips') + ).filter_by(reservation_id=reservation_id + ).filter_by(deleted=_deleted(context) + ).all() + elif is_user_context(context): + return session.query(models.Instance + ).options(joinedload_all('fixed_ip.floating_ips') + ).filter_by(project_id=context.project.id + ).filter_by(reservation_id=reservation_id + ).filter_by(deleted=False + ).all() +#@require_context def instance_get_by_ec2_id(context, ec2_id): - # TODO(devcamcar): Support user context session = get_session() - instance_ref = session.query(models.Instance + + if is_admin_context(context): + result = session.query(models.Instance ).filter_by(ec2_id=ec2_id ).filter_by(deleted=_deleted(context) ).first() - if not instance_ref: + elif is_user_context(context): + result = session.query(models.Instance + ).filter_by(project_id=context.project.id + ).filter_by(ec2_id=ec2_id + ).filter_by(deleted=False + ).first() + if not result: raise exception.NotFound('Instance %s not found' % (ec2_id)) - return instance_ref + return result +#@require_context def instance_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() return session.query(exists().where(models.Instance.id==ec2_id)).one()[0] -def instance_get_fixed_address(_context, instance_id): - # TODO(devcamcar): Support user context +#@require_context +def instance_get_fixed_address(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) if not instance_ref.fixed_ip: return None return instance_ref.fixed_ip['address'] -def instance_get_floating_address(_context, instance_id): - # TODO(devcamcar): Support user context +#@require_context +def instance_get_floating_address(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) if not instance_ref.fixed_ip: return None if not instance_ref.fixed_ip.floating_ips: @@ -656,14 +715,20 @@ def instance_get_floating_address(_context, instance_id): return instance_ref.fixed_ip.floating_ips[0]['address'] +#@require_admin_context def instance_is_vpn(context, instance_id): - # TODO(devcamcar): Admin only + if not is_admin_context(context): + raise exception.NotAuthorized() # TODO(vish): Move this into image code somewhere instance_ref = instance_get(context, instance_id) return instance_ref['image_id'] == FLAGS.vpn_image_id +#@require_admin_context def instance_set_state(context, instance_id, state, description=None): + if not is_admin_context(context): + raise exception.NotAuthorized() + # TODO(devcamcar): Move this out of models and into driver from nova.compute import power_state if not description: @@ -674,10 +739,11 @@ def instance_set_state(context, instance_id, state, description=None): 'state_description': description}) -def instance_update(_context, instance_id, values): +#@require_context +def instance_update(context, instance_id, values): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) for (key, value) in values.iteritems(): instance_ref[key] = value instance_ref.save(session=session) @@ -686,6 +752,7 @@ def instance_update(_context, instance_id, values): ################### +#@require_context def key_pair_create(_context, values): key_pair_ref = models.KeyPair() for (key, value) in values.iteritems(): @@ -694,7 +761,8 @@ def key_pair_create(_context, values): return key_pair_ref -def key_pair_destroy(_context, user_id, name): +#@require_context +def key_pair_destroy(context, user_id, name): session = get_session() with session.begin(): key_pair_ref = models.KeyPair.find_by_args(user_id, @@ -784,8 +852,27 @@ def network_destroy(_context, network_id): {'id': network_id}) -def network_get(_context, network_id, session=None): - return models.Network.find(network_id, session=session) +#@require_context +def network_get(context, network_id, session=None): + if not session: + session = get_session() + result = None + + if is_admin_context(context): + result = session.query(models.Network + ).filter_by(id=network_id + ).filter_by(deleted=_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Network + ).filter_by(project_id=context.project.id + ).filter_by(id=network_id + ).filter_by(deleted=False + ).first() + if not result: + raise exception.NotFound('No network for id %s' % network_id) + + return result # NOTE(vish): pylint complains because of the long method name, but @@ -1066,7 +1153,7 @@ def volume_get(context, volume_id, session=None): ).first() elif is_user_context(context): result = session.query(models.Volume - ).filter_by(project_id=context.project.project_id + ).filter_by(project_id=context.project.id ).filter_by(id=volume_id ).filter_by(deleted=False ).first() diff --git a/nova/tests/compute_unittest.py b/nova/tests/compute_unittest.py index f5c0f1c09..e705c2552 100644 --- a/nova/tests/compute_unittest.py +++ b/nova/tests/compute_unittest.py @@ -96,6 +96,8 @@ class ComputeTestCase(test.TrialTestCase): self.assertEqual(instance_ref['deleted_at'], None) terminate = datetime.datetime.utcnow() yield self.compute.terminate_instance(self.context, instance_id) + # TODO(devcamcar): Pass deleted in using system context. + # context.read_deleted ? instance_ref = db.instance_get({'deleted': True}, instance_id) self.assert_(instance_ref['launched_at'] < terminate) self.assert_(instance_ref['deleted_at'] > terminate) diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index 110e8430c..ca6a4bbc2 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -56,7 +56,9 @@ class NetworkTestCase(test.TrialTestCase): 'netuser', name)) # create the necessary network data for the project - self.network.set_network_host(self.context, self.projects[i].id) + user_context = context.APIRequestContext(project=self.projects[i], + user=self.user) + self.network.set_network_host(user_context, self.projects[i].id) instance_ref = db.instance_create(None, {'mac_address': utils.generate_mac()}) self.instance_id = instance_ref['id'] -- cgit From ea5dcda819f2656589df177331f693f945d98f4a Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 20:35:24 -0700 Subject: Finished instance context auth --- nova/db/sqlalchemy/api.py | 32 +++++++++++++++++++++++++++++--- nova/tests/network_unittest.py | 1 + 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 9ab53b89b..2d553d98d 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -794,11 +794,21 @@ def key_pair_get_all_by_user(_context, user_id): ################### -def network_count(_context): - return models.Network.count() +#@require_admin_context +def network_count(context): + if not is_admin_context(context): + raise exception.NotAuthorized() + return session.query(models.Network + ).filter_by(deleted=deleted + ).count() + +#@require_admin_context def network_count_allocated_ips(_context, network_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -807,7 +817,11 @@ def network_count_allocated_ips(_context, network_id): ).count() +#@require_admin_context def network_count_available_ips(_context, network_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -817,7 +831,11 @@ def network_count_available_ips(_context, network_id): ).count() +#@require_admin_context def network_count_reserved_ips(_context, network_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -826,7 +844,11 @@ def network_count_reserved_ips(_context, network_id): ).count() +#@require_admin_context def network_create(_context, values): + if not is_admin_context(context): + raise exception.NotAuthorized() + network_ref = models.Network() for (key, value) in values.iteritems(): network_ref[key] = value @@ -834,7 +856,11 @@ def network_create(_context, values): return network_ref -def network_destroy(_context, network_id): +#@require_admin_context +def network_destroy(context, network_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index ca6a4bbc2..e01d7cff9 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -49,6 +49,7 @@ class NetworkTestCase(test.TrialTestCase): self.user = self.manager.create_user('netuser', 'netuser', 'netuser') self.projects = [] self.network = utils.import_object(FLAGS.network_manager) + # TODO(devcamcar): Passing project=None is Bad(tm). self.context = context.APIRequestContext(project=None, user=self.user) for i in range(5): name = 'project%s' % i -- cgit From e716990fd58521f8c0166330ec9bc62c7cd91b7e Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 20:54:15 -0700 Subject: Finished context auth for network --- nova/db/sqlalchemy/api.py | 103 ++++++++++++++++++++++++++++++++-------------- nova/network/manager.py | 3 +- 2 files changed, 73 insertions(+), 33 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 2d553d98d..23589b7d8 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -799,13 +799,14 @@ def network_count(context): if not is_admin_context(context): raise exception.NotAuthorized() + session = get_session() return session.query(models.Network - ).filter_by(deleted=deleted + ).filter_by(deleted=_deleted(context) ).count() #@require_admin_context -def network_count_allocated_ips(_context, network_id): +def network_count_allocated_ips(context, network_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -818,7 +819,7 @@ def network_count_allocated_ips(_context, network_id): #@require_admin_context -def network_count_available_ips(_context, network_id): +def network_count_available_ips(context, network_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -832,7 +833,7 @@ def network_count_available_ips(_context, network_id): #@require_admin_context -def network_count_reserved_ips(_context, network_id): +def network_count_reserved_ips(context, network_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -845,7 +846,7 @@ def network_count_reserved_ips(_context, network_id): #@require_admin_context -def network_create(_context, values): +def network_create(context, values): if not is_admin_context(context): raise exception.NotAuthorized() @@ -904,7 +905,11 @@ def network_get(context, network_id, session=None): # NOTE(vish): pylint complains because of the long method name, but # it fits with the names of the rest of the methods # pylint: disable-msg=C0103 -def network_get_associated_fixed_ips(_context, network_id): +#@require_admin_context +def network_get_associated_fixed_ips(context, network_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.FixedIp ).options(joinedload_all('instance') @@ -914,18 +919,28 @@ def network_get_associated_fixed_ips(_context, network_id): ).all() -def network_get_by_bridge(_context, bridge): +#@require_admin_context +def network_get_by_bridge(context, bridge): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() - rv = session.query(models.Network + result = session.query(models.Network ).filter_by(bridge=bridge ).filter_by(deleted=False ).first() - if not rv: + + if not result: raise exception.NotFound('No network for bridge %s' % bridge) - return rv + + return result -def network_get_index(_context, network_id): +#@require_admin_context +def network_get_index(context, network_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): network_index = session.query(models.NetworkIndex @@ -933,19 +948,34 @@ def network_get_index(_context, network_id): ).filter_by(deleted=False ).with_lockmode('update' ).first() + if not network_index: raise db.NoMoreNetworks() - network_index['network'] = models.Network.find(network_id, - session=session) + + network_index['network'] = network_get(context, + network_id, + session=session) session.add(network_index) + return network_index['index'] -def network_index_count(_context): - return models.NetworkIndex.count() +#@require_admin_context +def network_index_count(context): + if not is_admin_context(context): + raise exception.NotAuthorized() + + session = get_session() + return session.query(models.NetworkIndex + ).filter_by(deleted=_deleted(context) + ).count() + +#@require_admin_context +def network_index_create_safe(context, values): + if not is_admin_context(context): + raise exception.NotAuthorized() -def network_index_create_safe(_context, values): network_index_ref = models.NetworkIndex() for (key, value) in values.iteritems(): network_index_ref[key] = value @@ -955,29 +985,35 @@ def network_index_create_safe(_context, values): pass -def network_set_host(_context, network_id, host_id): +#@require_admin_context +def network_set_host(context, network_id, host_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - network = session.query(models.Network - ).filter_by(id=network_id - ).filter_by(deleted=False - ).with_lockmode('update' - ).first() - if not network: - raise exception.NotFound("Couldn't find network with %s" % - network_id) + network_ref = session.query(models.Network + ).filter_by(id=network_id + ).filter_by(deleted=False + ).with_lockmode('update' + ).first() + if not network_ref: + raise exception.NotFound('No network for id %s' % network_id) + # NOTE(vish): if with_lockmode isn't supported, as in sqlite, # then this has concurrency issues - if not network['host']: - network['host'] = host_id - session.add(network) - return network['host'] + if not network_ref['host']: + network_ref['host'] = host_id + session.add(network_ref) + + return network_ref['host'] -def network_update(_context, network_id, values): +#@require_context +def network_update(context, network_id, values): session = get_session() with session.begin(): - network_ref = models.Network.find(network_id, session=session) + network_ref = network_get(context, network_id, session=session) for (key, value) in values.iteritems(): network_ref[key] = value network_ref.save(session=session) @@ -985,7 +1021,10 @@ def network_update(_context, network_id, values): ################### - +# YOU ARE HERE. +# random idea for system user: +# ctx = context.system_user(on_behalf_of=user, read_deleted=False) +# TODO(devcamcar): Rename to network_get_all_by_project def project_get_network(_context, project_id): session = get_session() rv = session.query(models.Network diff --git a/nova/network/manager.py b/nova/network/manager.py index d125d28d8..ecf2fa2c2 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -88,7 +88,8 @@ class NetworkManager(manager.Manager): # TODO(vish): can we minimize db access by just getting the # id here instead of the ref? network_id = network_ref['id'] - host = self.db.network_set_host(context, + # TODO(devcamcar): Replace with system context + host = self.db.network_set_host(None, network_id, self.host) self._on_set_network_host(context, network_id) -- cgit From 98cac90592658773791eb15b19ed60adf0a57d96 Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Thu, 30 Sep 2010 00:36:10 -0700 Subject: Completed quota context auth --- nova/db/sqlalchemy/api.py | 103 +++++++++++++++++++++++++++++++------------ nova/db/sqlalchemy/models.py | 12 ----- 2 files changed, 75 insertions(+), 40 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 23589b7d8..b225a6a88 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -1021,19 +1021,22 @@ def network_update(context, network_id, values): ################### -# YOU ARE HERE. -# random idea for system user: -# ctx = context.system_user(on_behalf_of=user, read_deleted=False) -# TODO(devcamcar): Rename to network_get_all_by_project -def project_get_network(_context, project_id): + +#@require_context +def project_get_network(context, project_id): + if not is_admin_context(context) and not is_user_context(context): + raise error.NotAuthorized() + session = get_session() - rv = session.query(models.Network + result= session.query(models.Network ).filter_by(project_id=project_id ).filter_by(deleted=False ).first() - if not rv: + + if not result: raise exception.NotFound('No network for project: %s' % project_id) - return rv + + return result ################### @@ -1043,14 +1046,26 @@ def queue_get_for(_context, topic, physical_node_id): # FIXME(ja): this should be servername? return "%s.%s" % (topic, physical_node_id) + ################### -def export_device_count(_context): - return models.ExportDevice.count() +#@require_admin_context +def export_device_count(context): + if not is_admin_context(context): + raise exception.notauthorized() + + session = get_session() + return session.query(models.ExportDevice + ).filter_by(deleted=_deleted(context) + ).count() + +#@require_admin_context +def export_device_create(context, values): + if not is_admin_context(context): + raise exception.notauthorized() -def export_device_create(_context, values): export_device_ref = models.ExportDevice() for (key, value) in values.iteritems(): export_device_ref[key] = value @@ -1084,7 +1099,29 @@ def auth_create_token(_context, token): ################### +#@require_admin_context +def quota_get(context, project_id, session=None): + if not is_admin_context(context): + raise exception.NotAuthorized() + + if not session: + session = get_session() + + result = session.query(models.Quota + ).filter_by(project_id=project_id + ).filter_by(deleted=_deleted(context) + ).first() + if not result: + raise exception.NotFound('No quota for project_id %s' % project_id) + + return result + + +#@require_admin_context def quota_create(_context, values): + if not is_admin_context(context): + raise exception.NotAuthorized() + quota_ref = models.Quota() for (key, value) in values.iteritems(): quota_ref[key] = value @@ -1092,29 +1129,34 @@ def quota_create(_context, values): return quota_ref -def quota_get(_context, project_id): - return models.Quota.find_by_str(project_id) - +#@require_admin_context +def quota_update(context, project_id, values): + if not is_admin_context(context): + raise exception.NotAuthorized() -def quota_update(_context, project_id, values): session = get_session() with session.begin(): - quota_ref = models.Quota.find_by_str(project_id, session=session) + quota_ref = quota_get(context, project_id, session=session) for (key, value) in values.iteritems(): quota_ref[key] = value quota_ref.save(session=session) -def quota_destroy(_context, project_id): +#@require_admin_context +def quota_destroy(context, project_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - quota_ref = models.Quota.find_by_str(project_id, session=session) + quota_ref = quota_get(context, project_id, session=session) quota_ref.delete(session=session) ################### +#@require_admin_context def volume_allocate_shelf_and_blade(context, volume_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -1135,6 +1177,7 @@ def volume_allocate_shelf_and_blade(context, volume_id): return (export_device.shelf_id, export_device.blade_id) +#@require_admin_context def volume_attached(context, volume_id, instance_id, mountpoint): if not is_admin_context(context): raise exception.NotAuthorized() @@ -1149,6 +1192,7 @@ def volume_attached(context, volume_id, instance_id, mountpoint): volume_ref.save(session=session) +#@require_context def volume_create(context, values): volume_ref = models.Volume() for (key, value) in values.iteritems(): @@ -1164,6 +1208,7 @@ def volume_create(context, values): return volume_ref +#@require_admin_context def volume_data_get_for_project(context, project_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -1178,6 +1223,7 @@ def volume_data_get_for_project(context, project_id): return (result[0] or 0, result[1] or 0) +#@require_admin_context def volume_destroy(context, volume_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -1192,6 +1238,7 @@ def volume_destroy(context, volume_id): {'id': volume_id}) +#@require_admin_context def volume_detached(context, volume_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -1206,6 +1253,7 @@ def volume_detached(context, volume_id): volume_ref.save(session=session) +#@require_context def volume_get(context, volume_id, session=None): if not session: session = get_session() @@ -1222,15 +1270,13 @@ def volume_get(context, volume_id, session=None): ).filter_by(id=volume_id ).filter_by(deleted=False ).first() - else: - raise exception.NotAuthorized() - if not result: raise exception.NotFound('No volume for id %s' % volume_id) return result +#@require_admin_context def volume_get_all(context): if not is_admin_context(context): raise exception.NotAuthorized() @@ -1239,7 +1285,7 @@ def volume_get_all(context): ).filter_by(deleted=_deleted(context) ).all() - +#@require_context def volume_get_all_by_project(context, project_id): if is_user_context(context): if context.project.id != project_id: @@ -1254,6 +1300,7 @@ def volume_get_all_by_project(context, project_id): ).all() +#@require_context def volume_get_by_ec2_id(context, ec2_id): session = get_session() result = None @@ -1278,6 +1325,7 @@ def volume_get_by_ec2_id(context, ec2_id): return result +#@require_context def volume_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() @@ -1286,10 +1334,9 @@ def volume_ec2_id_exists(context, ec2_id, session=None): return session.query(exists( ).where(models.Volume.id==ec2_id) ).one()[0] - else: - raise exception.NotAuthorized() +#@require_context def volume_get_instance(context, volume_id): session = get_session() result = None @@ -1315,6 +1362,7 @@ def volume_get_instance(context, volume_id): return result.instance +#@require_context def volume_get_shelf_and_blade(context, volume_id): session = get_session() result = None @@ -1329,15 +1377,14 @@ def volume_get_shelf_and_blade(context, volume_id): ).filter(models.Volume.project_id==context.project.id ).filter_by(volume_id=volume_id ).first() - else: - raise exception.NotAuthorized() - if not result: - raise exception.NotFound() + raise exception.NotFound('No export device found for volume %s' % + volume_id) return (result.shelf_id, result.blade_id) +#@require_context def volume_update(context, volume_id, values): session = get_session() with session.begin(): diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 7a085c4df..76444127f 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -302,18 +302,6 @@ class Quota(BASE, NovaBase): def str_id(self): return self.project_id - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(project_id=str_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for project_id %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] class ExportDevice(BASE, NovaBase): """Represates a shelf and blade that a volume can be exported on""" -- cgit From 30541d48b17ab4626791d969388871b3a1b7758f Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Thu, 30 Sep 2010 01:07:05 -0700 Subject: Wired up context auth for keypairs --- nova/db/sqlalchemy/api.py | 46 +++++++++++++++++++++++++++++++++++--------- nova/db/sqlalchemy/models.py | 20 ------------------- 2 files changed, 37 insertions(+), 29 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index b225a6a88..302322979 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -753,7 +753,7 @@ def instance_update(context, instance_id, values): #@require_context -def key_pair_create(_context, values): +def key_pair_create(context, values): key_pair_ref = models.KeyPair() for (key, value) in values.iteritems(): key_pair_ref[key] = value @@ -763,15 +763,22 @@ def key_pair_create(_context, values): #@require_context def key_pair_destroy(context, user_id, name): + if is_user_context(context): + if context.user.id != user_id: + raise exception.NotAuthorized() + session = get_session() with session.begin(): - key_pair_ref = models.KeyPair.find_by_args(user_id, - name, - session=session) + key_pair_ref = key_pair_get(context, user_id, name, session=session) key_pair_ref.delete(session=session) -def key_pair_destroy_all_by_user(_context, user_id): +#@require_context +def key_pair_destroy_all_by_user(context, user_id): + if is_user_context(context): + if context.user.id != user_id: + raise exception.NotAuthorized() + session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -779,11 +786,32 @@ def key_pair_destroy_all_by_user(_context, user_id): {'id': user_id}) -def key_pair_get(_context, user_id, name): - return models.KeyPair.find_by_args(user_id, name) +#@require_context +def key_pair_get(context, user_id, name, session=None): + if is_user_context(context): + if context.user.id != user_id: + raise exception.NotAuthorized() + + if not session: + session = get_session() + + result = session.query(models.KeyPair + ).filter_by(user_id=user_id + ).filter_by(name=name + ).filter_by(deleted=_deleted(context) + ).first() + if not result: + raise exception.NotFound('no keypair for user %s, name %s' % + (user_id, name)) + return result + +#@require_context +def key_pair_get_all_by_user(context, user_id): + if is_user_context(context): + if context.user.id != user_id: + raise exception.NotAuthorized() -def key_pair_get_all_by_user(_context, user_id): session = get_session() return session.query(models.KeyPair ).filter_by(user_id=user_id @@ -1118,7 +1146,7 @@ def quota_get(context, project_id, session=None): #@require_admin_context -def quota_create(_context, values): +def quota_create(context, values): if not is_admin_context(context): raise exception.NotAuthorized() diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 76444127f..1f5bdf9f5 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -332,26 +332,6 @@ class KeyPair(BASE, NovaBase): def str_id(self): return '%s.%s' % (self.user_id, self.name) - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - user_id, _sep, name = str_id.partition('.') - return cls.find_by_str(user_id, name, session, deleted) - - @classmethod - def find_by_args(cls, user_id, name, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(user_id=user_id - ).filter_by(name=name - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for user %s, name %s" % - (user_id, name)) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - class Network(BASE, NovaBase): """Represents a network""" -- cgit From 336523b36ceb8f5302acd443b7f1171b67575f73 Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Thu, 30 Sep 2010 01:11:16 -0700 Subject: Removed deprecated bits from NovaBase --- nova/db/sqlalchemy/models.py | 38 -------------------------------------- 1 file changed, 38 deletions(-) diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 1f5bdf9f5..a29090c60 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -50,44 +50,6 @@ class NovaBase(object): deleted_at = Column(DateTime) deleted = Column(Boolean, default=False) - @classmethod - def all(cls, session=None, deleted=False): - """Get all objects of this type""" - if not session: - session = get_session() - return session.query(cls - ).filter_by(deleted=deleted - ).all() - - @classmethod - def count(cls, session=None, deleted=False): - """Count objects of this type""" - if not session: - session = get_session() - return session.query(cls - ).filter_by(deleted=deleted - ).count() - - @classmethod - def find(cls, obj_id, session=None, deleted=False): - """Find object by id""" - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(id=obj_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for id %s" % obj_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - """Find object by str_id""" - int_id = int(str_id.rpartition('-')[2]) - return cls.find(int_id, session=session, deleted=deleted) - @property def str_id(self): """Get string id of object (generally prefix + '-' + id)""" -- cgit From 8bd81f3ec811e19f6e7faf7a4fe271f85fbc7fc7 Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Thu, 30 Sep 2010 02:02:14 -0700 Subject: Simplified authorization with decorators" " --- nova/db/sqlalchemy/api.py | 408 ++++++++++++++++------------------------------ 1 file changed, 142 insertions(+), 266 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 302322979..0e7d2e664 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -51,6 +51,7 @@ def _deleted(context): def is_admin_context(context): + """Indicates if the request context is an administrator.""" if not context: logging.warning('Use of empty request context is deprecated') return True @@ -60,6 +61,7 @@ def is_admin_context(context): def is_user_context(context): + """Indicates if the request context is a normal user.""" if not context: logging.warning('Use of empty request context is deprecated') return False @@ -68,24 +70,62 @@ def is_user_context(context): return True +def authorize_project_context(context, project_id): + """Ensures that the request context has permission to access the + given project. + """ + if is_user_context(context): + if not context.project: + raise exception.NotAuthorized() + elif context.project.id != project_id: + raise exception.NotAuthorized() + + +def authorize_user_context(context, user_id): + """Ensures that the request context has permission to access the + given user. + """ + if is_user_context(context): + if not context.user: + raise exception.NotAuthorized() + elif context.user.id != user_id: + raise exception.NotAuthorized() + + +def require_admin_context(f): + """Decorator used to indicate that the method requires an + administrator context. + """ + def wrapper(*args, **kwargs): + if not is_admin_context(args[0]): + raise exception.NotAuthorized() + return f(*args, **kwargs) + return wrapper + + +def require_context(f): + """Decorator used to indicate that the method requires either + an administrator or normal user context. + """ + def wrapper(*args, **kwargs): + if not is_admin_context(args[0]) and not is_user_context(args[0]): + raise exception.NotAuthorized() + return f(*args, **kwargs) + return wrapper + + ################### -#@require_admin_context +@require_admin_context def service_destroy(context, service_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): service_ref = service_get(context, service_id, session=session) service_ref.delete(session=session) -#@require_admin_context +@require_admin_context def service_get(context, service_id, session=None): - if not is_admin_context(context): - raise exception.NotAuthorized() - if not session: session = get_session() @@ -100,11 +140,8 @@ def service_get(context, service_id, session=None): return result -#@require_admin_context +@require_admin_context def service_get_all_by_topic(context, topic): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.Service ).filter_by(deleted=False @@ -113,11 +150,8 @@ def service_get_all_by_topic(context, topic): ).all() -#@require_admin_context +@require_admin_context def _service_get_all_topic_subquery(context, session, topic, subq, label): - if not is_admin_context(context): - raise exception.NotAuthorized() - sort_value = getattr(subq.c, label) return session.query(models.Service, func.coalesce(sort_value, 0) ).filter_by(topic=topic @@ -128,11 +162,8 @@ def _service_get_all_topic_subquery(context, session, topic, subq, label): ).all() -#@require_admin_context +@require_admin_context def service_get_all_compute_sorted(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # NOTE(vish): The intended query is below @@ -156,11 +187,8 @@ def service_get_all_compute_sorted(context): label) -#@require_admin_context +@require_admin_context def service_get_all_network_sorted(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): topic = 'network' @@ -177,11 +205,8 @@ def service_get_all_network_sorted(context): label) -#@require_admin_context +@require_admin_context def service_get_all_volume_sorted(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): topic = 'volume' @@ -198,11 +223,8 @@ def service_get_all_volume_sorted(context): label) -#@require_admin_context +@require_admin_context def service_get_by_args(context, host, binary): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() result = session.query(models.Service ).filter_by(host=host @@ -216,11 +238,8 @@ def service_get_by_args(context, host, binary): return result -#@require_admin_context +@require_admin_context def service_create(context, values): - if not is_admin_context(context): - return exception.NotAuthorized() - service_ref = models.Service() for (key, value) in values.iteritems(): service_ref[key] = value @@ -228,11 +247,8 @@ def service_create(context, values): return service_ref -#@require_admin_context +@require_admin_context def service_update(context, service_id, values): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): service_ref = session_get(context, service_id, session=session) @@ -244,11 +260,9 @@ def service_update(context, service_id, values): ################### -#@require_context +@require_context def floating_ip_allocate_address(context, host, project_id): - if is_user_context(context): - if context.project.id != project_id: - raise exception.NotAuthorized() + authorize_project_context(context, project_id) session = get_session() with session.begin(): @@ -268,11 +282,8 @@ def floating_ip_allocate_address(context, host, project_id): return floating_ip_ref['address'] -#@require_context +@require_context def floating_ip_create(context, values): - if not is_user_context(context) and not is_admin_context(context): - raise exception.NotAuthorized() - floating_ip_ref = models.FloatingIp() for (key, value) in values.iteritems(): floating_ip_ref[key] = value @@ -280,11 +291,9 @@ def floating_ip_create(context, values): return floating_ip_ref['address'] -#@require_context +@require_context def floating_ip_count_by_project(context, project_id): - if is_user_context(context): - if context.project.id != project_id: - raise exception.NotAuthorized() + authorize_project_context(context, project_id) session = get_session() return session.query(models.FloatingIp @@ -293,11 +302,8 @@ def floating_ip_count_by_project(context, project_id): ).count() -#@require_context +@require_context def floating_ip_fixed_ip_associate(context, floating_address, fixed_address): - if not is_user_context(context) and not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # TODO(devcamcar): How to ensure floating_id belongs to user? @@ -311,11 +317,8 @@ def floating_ip_fixed_ip_associate(context, floating_address, fixed_address): floating_ip_ref.save(session=session) -#@require_context +@require_context def floating_ip_deallocate(context, address): - if not is_user_context(context) and not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # TODO(devcamcar): How to ensure floating id belongs to user? @@ -326,11 +329,8 @@ def floating_ip_deallocate(context, address): floating_ip_ref.save(session=session) -#@require_context +@require_context def floating_ip_destroy(context, address): - if not is_user_context(context) and not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # TODO(devcamcar): Ensure address belongs to user. @@ -340,11 +340,8 @@ def floating_ip_destroy(context, address): floating_ip_ref.delete(session=session) -#@require_context +@require_context def floating_ip_disassociate(context, address): - if not is_user_context(context) and is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # TODO(devcamcar): Ensure address belongs to user. @@ -362,11 +359,8 @@ def floating_ip_disassociate(context, address): return fixed_ip_address -#@require_admin_context +@require_admin_context def floating_ip_get_all(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -374,11 +368,8 @@ def floating_ip_get_all(context): ).all() -#@require_admin_context +@require_admin_context def floating_ip_get_all_by_host(context, host): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -387,14 +378,9 @@ def floating_ip_get_all_by_host(context, host): ).all() -#@require_context +@require_context def floating_ip_get_all_by_project(context, project_id): - # TODO(devcamcar): Change to decorate and check project_id separately. - if is_user_context(context): - if context.project.id != project_id: - raise exception.NotAuthorized() - elif not is_admin_context(context): - raise exception.NotAuthorized() + authorize_project_context(context, project_id) session = get_session() return session.query(models.FloatingIp @@ -404,12 +390,9 @@ def floating_ip_get_all_by_project(context, project_id): ).all() -#@require_context +@require_context def floating_ip_get_by_address(context, address, session=None): # TODO(devcamcar): Ensure the address belongs to user. - if not is_user_context(context) and not is_admin_context(context): - raise exception.NotAuthorized() - if not session: session = get_session() @@ -426,11 +409,8 @@ def floating_ip_get_by_address(context, address, session=None): ################### -#@require_context +@require_context def fixed_ip_associate(context, address, instance_id): - if not is_user_context(context) and not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): fixed_ip_ref = session.query(models.FixedIp @@ -449,11 +429,8 @@ def fixed_ip_associate(context, address, instance_id): session.add(fixed_ip_ref) -#@require_admin_context +@require_admin_context def fixed_ip_associate_pool(context, network_id, instance_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): network_or_none = or_(models.FixedIp.network_id == network_id, @@ -480,7 +457,7 @@ def fixed_ip_associate_pool(context, network_id, instance_id): return fixed_ip_ref['address'] -#@require_context +@require_context def fixed_ip_create(_context, values): fixed_ip_ref = models.FixedIp() for (key, value) in values.iteritems(): @@ -489,7 +466,7 @@ def fixed_ip_create(_context, values): return fixed_ip_ref['address'] -#@require_context +@require_context def fixed_ip_disassociate(context, address): session = get_session() with session.begin(): @@ -500,7 +477,7 @@ def fixed_ip_disassociate(context, address): fixed_ip_ref.save(session=session) -#@require_context +@require_context def fixed_ip_get_by_address(context, address, session=None): # TODO(devcamcar): Ensure floating ip belongs to user. # Only possible if it is associated with an instance. @@ -520,19 +497,19 @@ def fixed_ip_get_by_address(context, address, session=None): return result -#@require_context +@require_context def fixed_ip_get_instance(context, address): fixed_ip_ref = fixed_ip_get_by_address(context, address) return fixed_ip_ref.instance -#@require_admin_context +@require_admin_context def fixed_ip_get_network(context, address): fixed_ip_ref = fixed_ip_get_by_address(context, address) return fixed_ip_ref.network -#@require_context +@require_context def fixed_ip_update(context, address, values): session = get_session() with session.begin(): @@ -547,7 +524,7 @@ def fixed_ip_update(context, address, values): ################### -#@require_context +@require_context def instance_create(context, values): instance_ref = models.Instance() for (key, value) in values.iteritems(): @@ -563,7 +540,7 @@ def instance_create(context, values): return instance_ref -#@require_admin_context +@require_admin_context def instance_data_get_for_project(context, project_id): session = get_session() result = session.query(func.count(models.Instance.id), @@ -575,7 +552,7 @@ def instance_data_get_for_project(context, project_id): return (result[0] or 0, result[1] or 0) -#@require_context +@require_context def instance_destroy(context, instance_id): session = get_session() with session.begin(): @@ -583,7 +560,7 @@ def instance_destroy(context, instance_id): instance_ref.delete(session=session) -#@require_context +@require_context def instance_get(context, instance_id, session=None): if not session: session = get_session() @@ -606,11 +583,8 @@ def instance_get(context, instance_id, session=None): return result -#@require_admin_context +@require_admin_context def instance_get_all(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -618,11 +592,8 @@ def instance_get_all(context): ).all() -#@require_admin_context +@require_admin_context def instance_get_all_by_user(context, user_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -631,11 +602,9 @@ def instance_get_all_by_user(context, user_id): ).all() -#@require_context +@require_context def instance_get_all_by_project(context, project_id): - if is_user_context(context): - if context.project.id != project_id: - raise exception.NotAuthorized() + authorize_project_context(context, project_id) session = get_session() return session.query(models.Instance @@ -645,7 +614,7 @@ def instance_get_all_by_project(context, project_id): ).all() -#@require_context +@require_context def instance_get_all_by_reservation(context, reservation_id): session = get_session() @@ -664,7 +633,7 @@ def instance_get_all_by_reservation(context, reservation_id): ).all() -#@require_context +@require_context def instance_get_by_ec2_id(context, ec2_id): session = get_session() @@ -685,14 +654,14 @@ def instance_get_by_ec2_id(context, ec2_id): return result -#@require_context +@require_context def instance_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() return session.query(exists().where(models.Instance.id==ec2_id)).one()[0] -#@require_context +@require_context def instance_get_fixed_address(context, instance_id): session = get_session() with session.begin(): @@ -702,7 +671,7 @@ def instance_get_fixed_address(context, instance_id): return instance_ref.fixed_ip['address'] -#@require_context +@require_context def instance_get_floating_address(context, instance_id): session = get_session() with session.begin(): @@ -715,20 +684,15 @@ def instance_get_floating_address(context, instance_id): return instance_ref.fixed_ip.floating_ips[0]['address'] -#@require_admin_context +@require_admin_context def instance_is_vpn(context, instance_id): - if not is_admin_context(context): - raise exception.NotAuthorized() # TODO(vish): Move this into image code somewhere instance_ref = instance_get(context, instance_id) return instance_ref['image_id'] == FLAGS.vpn_image_id -#@require_admin_context +@require_admin_context def instance_set_state(context, instance_id, state, description=None): - if not is_admin_context(context): - raise exception.NotAuthorized() - # TODO(devcamcar): Move this out of models and into driver from nova.compute import power_state if not description: @@ -739,7 +703,7 @@ def instance_set_state(context, instance_id, state, description=None): 'state_description': description}) -#@require_context +@require_context def instance_update(context, instance_id, values): session = get_session() with session.begin(): @@ -752,7 +716,7 @@ def instance_update(context, instance_id, values): ################### -#@require_context +@require_context def key_pair_create(context, values): key_pair_ref = models.KeyPair() for (key, value) in values.iteritems(): @@ -761,11 +725,9 @@ def key_pair_create(context, values): return key_pair_ref -#@require_context +@require_context def key_pair_destroy(context, user_id, name): - if is_user_context(context): - if context.user.id != user_id: - raise exception.NotAuthorized() + authorize_user_context(context, user_id) session = get_session() with session.begin(): @@ -773,11 +735,9 @@ def key_pair_destroy(context, user_id, name): key_pair_ref.delete(session=session) -#@require_context +@require_context def key_pair_destroy_all_by_user(context, user_id): - if is_user_context(context): - if context.user.id != user_id: - raise exception.NotAuthorized() + authorize_user_context(context, user_id) session = get_session() with session.begin(): @@ -786,11 +746,9 @@ def key_pair_destroy_all_by_user(context, user_id): {'id': user_id}) -#@require_context +@require_context def key_pair_get(context, user_id, name, session=None): - if is_user_context(context): - if context.user.id != user_id: - raise exception.NotAuthorized() + authorize_user_context(context, user_id) if not session: session = get_session() @@ -806,11 +764,9 @@ def key_pair_get(context, user_id, name, session=None): return result -#@require_context +@require_context def key_pair_get_all_by_user(context, user_id): - if is_user_context(context): - if context.user.id != user_id: - raise exception.NotAuthorized() + authorize_user_context(context, user_id) session = get_session() return session.query(models.KeyPair @@ -822,22 +778,16 @@ def key_pair_get_all_by_user(context, user_id): ################### -#@require_admin_context +@require_admin_context def network_count(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.Network ).filter_by(deleted=_deleted(context) ).count() -#@require_admin_context +@require_admin_context def network_count_allocated_ips(context, network_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -846,11 +796,8 @@ def network_count_allocated_ips(context, network_id): ).count() -#@require_admin_context +@require_admin_context def network_count_available_ips(context, network_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -860,11 +807,8 @@ def network_count_available_ips(context, network_id): ).count() -#@require_admin_context +@require_admin_context def network_count_reserved_ips(context, network_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -873,11 +817,8 @@ def network_count_reserved_ips(context, network_id): ).count() -#@require_admin_context +@require_admin_context def network_create(context, values): - if not is_admin_context(context): - raise exception.NotAuthorized() - network_ref = models.Network() for (key, value) in values.iteritems(): network_ref[key] = value @@ -885,11 +826,8 @@ def network_create(context, values): return network_ref -#@require_admin_context +@require_admin_context def network_destroy(context, network_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -907,7 +845,7 @@ def network_destroy(context, network_id): {'id': network_id}) -#@require_context +@require_context def network_get(context, network_id, session=None): if not session: session = get_session() @@ -933,11 +871,8 @@ def network_get(context, network_id, session=None): # NOTE(vish): pylint complains because of the long method name, but # it fits with the names of the rest of the methods # pylint: disable-msg=C0103 -#@require_admin_context +@require_admin_context def network_get_associated_fixed_ips(context, network_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.FixedIp ).options(joinedload_all('instance') @@ -947,11 +882,8 @@ def network_get_associated_fixed_ips(context, network_id): ).all() -#@require_admin_context +@require_admin_context def network_get_by_bridge(context, bridge): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() result = session.query(models.Network ).filter_by(bridge=bridge @@ -964,11 +896,8 @@ def network_get_by_bridge(context, bridge): return result -#@require_admin_context +@require_admin_context def network_get_index(context, network_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): network_index = session.query(models.NetworkIndex @@ -988,22 +917,16 @@ def network_get_index(context, network_id): return network_index['index'] -#@require_admin_context +@require_admin_context def network_index_count(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.NetworkIndex ).filter_by(deleted=_deleted(context) ).count() -#@require_admin_context +@require_admin_context def network_index_create_safe(context, values): - if not is_admin_context(context): - raise exception.NotAuthorized() - network_index_ref = models.NetworkIndex() for (key, value) in values.iteritems(): network_index_ref[key] = value @@ -1013,11 +936,8 @@ def network_index_create_safe(context, values): pass -#@require_admin_context +@require_admin_context def network_set_host(context, network_id, host_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): network_ref = session.query(models.Network @@ -1037,7 +957,7 @@ def network_set_host(context, network_id, host_id): return network_ref['host'] -#@require_context +@require_context def network_update(context, network_id, values): session = get_session() with session.begin(): @@ -1050,11 +970,8 @@ def network_update(context, network_id, values): ################### -#@require_context +@require_context def project_get_network(context, project_id): - if not is_admin_context(context) and not is_user_context(context): - raise error.NotAuthorized() - session = get_session() result= session.query(models.Network ).filter_by(project_id=project_id @@ -1078,22 +995,16 @@ def queue_get_for(_context, topic, physical_node_id): ################### -#@require_admin_context +@require_admin_context def export_device_count(context): - if not is_admin_context(context): - raise exception.notauthorized() - session = get_session() return session.query(models.ExportDevice ).filter_by(deleted=_deleted(context) ).count() -#@require_admin_context +@require_admin_context def export_device_create(context, values): - if not is_admin_context(context): - raise exception.notauthorized() - export_device_ref = models.ExportDevice() for (key, value) in values.iteritems(): export_device_ref[key] = value @@ -1127,11 +1038,8 @@ def auth_create_token(_context, token): ################### -#@require_admin_context +@require_admin_context def quota_get(context, project_id, session=None): - if not is_admin_context(context): - raise exception.NotAuthorized() - if not session: session = get_session() @@ -1145,11 +1053,8 @@ def quota_get(context, project_id, session=None): return result -#@require_admin_context +@require_admin_context def quota_create(context, values): - if not is_admin_context(context): - raise exception.NotAuthorized() - quota_ref = models.Quota() for (key, value) in values.iteritems(): quota_ref[key] = value @@ -1157,11 +1062,8 @@ def quota_create(context, values): return quota_ref -#@require_admin_context +@require_admin_context def quota_update(context, project_id, values): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): quota_ref = quota_get(context, project_id, session=session) @@ -1170,11 +1072,8 @@ def quota_update(context, project_id, values): quota_ref.save(session=session) -#@require_admin_context +@require_admin_context def quota_destroy(context, project_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): quota_ref = quota_get(context, project_id, session=session) @@ -1184,11 +1083,8 @@ def quota_destroy(context, project_id): ################### -#@require_admin_context +@require_admin_context def volume_allocate_shelf_and_blade(context, volume_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): export_device = session.query(models.ExportDevice @@ -1205,11 +1101,8 @@ def volume_allocate_shelf_and_blade(context, volume_id): return (export_device.shelf_id, export_device.blade_id) -#@require_admin_context +@require_admin_context def volume_attached(context, volume_id, instance_id, mountpoint): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): volume_ref = volume_get(context, volume_id, session=session) @@ -1220,7 +1113,7 @@ def volume_attached(context, volume_id, instance_id, mountpoint): volume_ref.save(session=session) -#@require_context +@require_context def volume_create(context, values): volume_ref = models.Volume() for (key, value) in values.iteritems(): @@ -1236,11 +1129,8 @@ def volume_create(context, values): return volume_ref -#@require_admin_context +@require_admin_context def volume_data_get_for_project(context, project_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() result = session.query(func.count(models.Volume.id), func.sum(models.Volume.size) @@ -1251,11 +1141,8 @@ def volume_data_get_for_project(context, project_id): return (result[0] or 0, result[1] or 0) -#@require_admin_context +@require_admin_context def volume_destroy(context, volume_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -1266,11 +1153,8 @@ def volume_destroy(context, volume_id): {'id': volume_id}) -#@require_admin_context +@require_admin_context def volume_detached(context, volume_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): volume_ref = volume_get(context, volume_id, session=session) @@ -1281,7 +1165,7 @@ def volume_detached(context, volume_id): volume_ref.save(session=session) -#@require_context +@require_context def volume_get(context, volume_id, session=None): if not session: session = get_session() @@ -1304,22 +1188,15 @@ def volume_get(context, volume_id, session=None): return result -#@require_admin_context +@require_admin_context def volume_get_all(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - return session.query(models.Volume ).filter_by(deleted=_deleted(context) ).all() -#@require_context +@require_context def volume_get_all_by_project(context, project_id): - if is_user_context(context): - if context.project.id != project_id: - raise exception.NotAuthorized() - elif not is_admin_context(context): - raise exception.NotAuthorized() + authorize_project_context(context, project_id) session = get_session() return session.query(models.Volume @@ -1328,7 +1205,7 @@ def volume_get_all_by_project(context, project_id): ).all() -#@require_context +@require_context def volume_get_by_ec2_id(context, ec2_id): session = get_session() result = None @@ -1353,18 +1230,17 @@ def volume_get_by_ec2_id(context, ec2_id): return result -#@require_context +@require_context def volume_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() - if is_admin_context(context) or is_user_context(context): - return session.query(exists( - ).where(models.Volume.id==ec2_id) - ).one()[0] + return session.query(exists( + ).where(models.Volume.id==ec2_id) + ).one()[0] -#@require_context +@require_context def volume_get_instance(context, volume_id): session = get_session() result = None @@ -1390,7 +1266,7 @@ def volume_get_instance(context, volume_id): return result.instance -#@require_context +@require_context def volume_get_shelf_and_blade(context, volume_id): session = get_session() result = None @@ -1412,7 +1288,7 @@ def volume_get_shelf_and_blade(context, volume_id): return (result.shelf_id, result.blade_id) -#@require_context +@require_context def volume_update(context, volume_id, values): session = get_session() with session.begin(): -- cgit From cf456bdb2a767644d95599aa1c8f580279959a4e Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Thu, 30 Sep 2010 02:47:05 -0700 Subject: Refactored APIRequestContext --- nova/api/context.py | 46 +++++++++++++++++++++++++++ nova/api/ec2/__init__.py | 8 ++--- nova/api/ec2/context.py | 33 -------------------- nova/db/sqlalchemy/api.py | 71 +++++++++++++++++++----------------------- nova/network/manager.py | 2 -- nova/tests/compute_unittest.py | 8 ++--- 6 files changed, 86 insertions(+), 82 deletions(-) create mode 100644 nova/api/context.py delete mode 100644 nova/api/ec2/context.py diff --git a/nova/api/context.py b/nova/api/context.py new file mode 100644 index 000000000..b66cfe468 --- /dev/null +++ b/nova/api/context.py @@ -0,0 +1,46 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +APIRequestContext +""" + +import random + + +class APIRequestContext(object): + def __init__(self, user, project): + self.user = user + self.project = project + self.request_id = ''.join( + [random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-') + for x in xrange(20)] + ) + if user: + self.is_admin = user.is_admin() + else: + self.is_admin = False + self.read_deleted = False + + +def get_admin_context(user=None, read_deleted=False): + context_ref = APIRequestContext(user=user, project=None) + context_ref.is_admin = True + context_ref.read_deleted = read_deleted + return context_ref + diff --git a/nova/api/ec2/__init__.py b/nova/api/ec2/__init__.py index 7a958f841..6b538a7f1 100644 --- a/nova/api/ec2/__init__.py +++ b/nova/api/ec2/__init__.py @@ -27,8 +27,8 @@ import webob.exc from nova import exception from nova import flags from nova import wsgi +from nova.api import context from nova.api.ec2 import apirequest -from nova.api.ec2 import context from nova.api.ec2 import admin from nova.api.ec2 import cloud from nova.auth import manager @@ -193,15 +193,15 @@ class Authorizer(wsgi.Middleware): return True if 'none' in roles: return False - return any(context.project.has_role(context.user.id, role) + return any(context.project.has_role(context.user.id, role) for role in roles) - + class Executor(wsgi.Application): """Execute an EC2 API request. - Executes 'ec2.action' upon 'ec2.controller', passing 'ec2.context' and + Executes 'ec2.action' upon 'ec2.controller', passing 'ec2.context' and 'ec2.action_args' (all variables in WSGI environ.) Returns an XML response, or a 400 upon failure. """ diff --git a/nova/api/ec2/context.py b/nova/api/ec2/context.py deleted file mode 100644 index c53ba98d9..000000000 --- a/nova/api/ec2/context.py +++ /dev/null @@ -1,33 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 United States Government as represented by the -# Administrator of the National Aeronautics and Space Administration. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -""" -APIRequestContext -""" - -import random - - -class APIRequestContext(object): - def __init__(self, user, project): - self.user = user - self.project = project - self.request_id = ''.join( - [random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-') - for x in xrange(20)] - ) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 0e7d2e664..fc5ee2235 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -21,6 +21,7 @@ Implementation of SQLAlchemy backend import logging import sys +import warnings from nova import db from nova import exception @@ -36,28 +37,13 @@ from sqlalchemy.sql import exists, func FLAGS = flags.FLAGS -# NOTE(vish): disabling docstring pylint because the docstrings are -# in the interface definition -# pylint: disable-msg=C0111 -def _deleted(context): - """Calculates whether to include deleted objects based on context. - Currently just looks for a flag called deleted in the context dict. - """ - if is_user_context(context): - return False - if not hasattr(context, 'get'): - return False - return context.get('deleted', False) - - def is_admin_context(context): """Indicates if the request context is an administrator.""" if not context: - logging.warning('Use of empty request context is deprecated') - return True - if not context.user: + warnings.warn('Use of empty request context is deprecated', + DeprecationWarning) return True - return context.user.is_admin() + return context.is_admin def is_user_context(context): @@ -92,6 +78,13 @@ def authorize_user_context(context, user_id): raise exception.NotAuthorized() +def use_deleted(context): + """Indicates if the context has access to deleted objects.""" + if not context: + return False + return context.read_deleted + + def require_admin_context(f): """Decorator used to indicate that the method requires an administrator context. @@ -131,7 +124,7 @@ def service_get(context, service_id, session=None): result = session.query(models.Service ).filter_by(id=service_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() if not result: @@ -229,7 +222,7 @@ def service_get_by_args(context, host, binary): result = session.query(models.Service ).filter_by(host=host ).filter_by(binary=binary - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() if not result: @@ -398,7 +391,7 @@ def floating_ip_get_by_address(context, address, session=None): result = session.query(models.FloatingIp ).filter_by(address=address - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() if not result: raise exception.NotFound('No fixed ip for address %s' % address) @@ -487,7 +480,7 @@ def fixed_ip_get_by_address(context, address, session=None): result = session.query(models.FixedIp ).filter_by(address=address - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).options(joinedload('network') ).options(joinedload('instance') ).first() @@ -569,7 +562,7 @@ def instance_get(context, instance_id, session=None): if is_admin_context(context): result = session.query(models.Instance ).filter_by(id=instance_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Instance @@ -588,7 +581,7 @@ def instance_get_all(context): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).all() @@ -597,7 +590,7 @@ def instance_get_all_by_user(context, user_id): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).filter_by(user_id=user_id ).all() @@ -610,7 +603,7 @@ def instance_get_all_by_project(context, project_id): return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).filter_by(project_id=project_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).all() @@ -622,7 +615,7 @@ def instance_get_all_by_reservation(context, reservation_id): return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).filter_by(reservation_id=reservation_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).all() elif is_user_context(context): return session.query(models.Instance @@ -640,7 +633,7 @@ def instance_get_by_ec2_id(context, ec2_id): if is_admin_context(context): result = session.query(models.Instance ).filter_by(ec2_id=ec2_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Instance @@ -756,7 +749,7 @@ def key_pair_get(context, user_id, name, session=None): result = session.query(models.KeyPair ).filter_by(user_id=user_id ).filter_by(name=name - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() if not result: raise exception.NotFound('no keypair for user %s, name %s' % @@ -782,7 +775,7 @@ def key_pair_get_all_by_user(context, user_id): def network_count(context): session = get_session() return session.query(models.Network - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).count() @@ -854,7 +847,7 @@ def network_get(context, network_id, session=None): if is_admin_context(context): result = session.query(models.Network ).filter_by(id=network_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Network @@ -921,7 +914,7 @@ def network_get_index(context, network_id): def network_index_count(context): session = get_session() return session.query(models.NetworkIndex - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).count() @@ -999,7 +992,7 @@ def queue_get_for(_context, topic, physical_node_id): def export_device_count(context): session = get_session() return session.query(models.ExportDevice - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).count() @@ -1045,7 +1038,7 @@ def quota_get(context, project_id, session=None): result = session.query(models.Quota ).filter_by(project_id=project_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() if not result: raise exception.NotFound('No quota for project_id %s' % project_id) @@ -1174,7 +1167,7 @@ def volume_get(context, volume_id, session=None): if is_admin_context(context): result = session.query(models.Volume ).filter_by(id=volume_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Volume @@ -1191,7 +1184,7 @@ def volume_get(context, volume_id, session=None): @require_admin_context def volume_get_all(context): return session.query(models.Volume - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).all() @require_context @@ -1201,7 +1194,7 @@ def volume_get_all_by_project(context, project_id): session = get_session() return session.query(models.Volume ).filter_by(project_id=project_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).all() @@ -1213,7 +1206,7 @@ def volume_get_by_ec2_id(context, ec2_id): if is_admin_context(context): result = session.query(models.Volume ).filter_by(ec2_id=ec2_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Volume @@ -1248,7 +1241,7 @@ def volume_get_instance(context, volume_id): if is_admin_context(context): result = session.query(models.Volume ).filter_by(id=volume_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).options(joinedload('instance') ).first() elif is_user_context(context): diff --git a/nova/network/manager.py b/nova/network/manager.py index ecf2fa2c2..265c0d742 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -88,7 +88,6 @@ class NetworkManager(manager.Manager): # TODO(vish): can we minimize db access by just getting the # id here instead of the ref? network_id = network_ref['id'] - # TODO(devcamcar): Replace with system context host = self.db.network_set_host(None, network_id, self.host) @@ -233,7 +232,6 @@ class VlanManager(NetworkManager): address = network_ref['vpn_private_address'] self.db.fixed_ip_associate(context, address, instance_id) else: - # TODO(devcamcar) Pass system context here. address = self.db.fixed_ip_associate_pool(None, network_ref['id'], instance_id) diff --git a/nova/tests/compute_unittest.py b/nova/tests/compute_unittest.py index e705c2552..1e2bb113b 100644 --- a/nova/tests/compute_unittest.py +++ b/nova/tests/compute_unittest.py @@ -30,7 +30,7 @@ from nova import flags from nova import test from nova import utils from nova.auth import manager - +from nova.api import context FLAGS = flags.FLAGS @@ -96,9 +96,9 @@ class ComputeTestCase(test.TrialTestCase): self.assertEqual(instance_ref['deleted_at'], None) terminate = datetime.datetime.utcnow() yield self.compute.terminate_instance(self.context, instance_id) - # TODO(devcamcar): Pass deleted in using system context. - # context.read_deleted ? - instance_ref = db.instance_get({'deleted': True}, instance_id) + self.context = context.get_admin_context(user=self.user, + read_deleted=True) + instance_ref = db.instance_get(self.context, instance_id) self.assert_(instance_ref['launched_at'] < terminate) self.assert_(instance_ref['deleted_at'] > terminate) -- cgit From ab948224a5c6ea976def30927ac7668dd765dbca Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Thu, 30 Sep 2010 03:13:47 -0700 Subject: Cleaned up db/api.py --- nova/db/api.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/nova/db/api.py b/nova/db/api.py index 4cfdd788c..290b460a6 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -175,10 +175,6 @@ def floating_ip_get_by_address(context, address): return IMPL.floating_ip_get_by_address(context, address) - """Get an instance for a floating ip by address.""" - return IMPL.floating_ip_get_instance(context, address) - - #################### -- cgit From c9e14d6257f0b488bd892c09d284091c0f612dd7 Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Fri, 1 Oct 2010 01:44:17 -0700 Subject: Locked down fixed ips and improved network tests --- nova/db/sqlalchemy/api.py | 98 ++++++++++++++++-------------------------- nova/tests/network_unittest.py | 44 ++++++++++--------- 2 files changed, 60 insertions(+), 82 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index fc5ee2235..860723516 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -78,7 +78,7 @@ def authorize_user_context(context, user_id): raise exception.NotAuthorized() -def use_deleted(context): +def can_read_deleted(context): """Indicates if the context has access to deleted objects.""" if not context: return False @@ -124,7 +124,7 @@ def service_get(context, service_id, session=None): result = session.query(models.Service ).filter_by(id=service_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() if not result: @@ -222,9 +222,8 @@ def service_get_by_args(context, host, binary): result = session.query(models.Service ).filter_by(host=host ).filter_by(binary=binary - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() - if not result: raise exception.NotFound('No service for %s, %s' % (host, binary)) @@ -256,7 +255,6 @@ def service_update(context, service_id, values): @require_context def floating_ip_allocate_address(context, host, project_id): authorize_project_context(context, project_id) - session = get_session() with session.begin(): floating_ip_ref = session.query(models.FloatingIp @@ -287,7 +285,6 @@ def floating_ip_create(context, values): @require_context def floating_ip_count_by_project(context, project_id): authorize_project_context(context, project_id) - session = get_session() return session.query(models.FloatingIp ).filter_by(project_id=project_id @@ -374,7 +371,6 @@ def floating_ip_get_all_by_host(context, host): @require_context def floating_ip_get_all_by_project(context, project_id): authorize_project_context(context, project_id) - session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -391,7 +387,7 @@ def floating_ip_get_by_address(context, address, session=None): result = session.query(models.FloatingIp ).filter_by(address=address - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() if not result: raise exception.NotFound('No fixed ip for address %s' % address) @@ -406,6 +402,7 @@ def floating_ip_get_by_address(context, address, session=None): def fixed_ip_associate(context, address, instance_id): session = get_session() with session.begin(): + instance = instance_get(context, instance_id, session=session) fixed_ip_ref = session.query(models.FixedIp ).filter_by(address=address ).filter_by(deleted=False @@ -416,9 +413,7 @@ def fixed_ip_associate(context, address, instance_id): # then this has concurrency issues if not fixed_ip_ref: raise db.NoMoreAddresses() - fixed_ip_ref.instance = instance_get(context, - instance_id, - session=session) + fixed_ip_ref.instance = instance session.add(fixed_ip_ref) @@ -472,21 +467,21 @@ def fixed_ip_disassociate(context, address): @require_context def fixed_ip_get_by_address(context, address, session=None): - # TODO(devcamcar): Ensure floating ip belongs to user. - # Only possible if it is associated with an instance. - # May have to use system context for this always. if not session: session = get_session() result = session.query(models.FixedIp ).filter_by(address=address - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).options(joinedload('network') ).options(joinedload('instance') ).first() if not result: raise exception.NotFound('No floating ip for address %s' % address) + if is_user_context(context): + authorize_project_context(context, result.instance.project_id) + return result @@ -562,7 +557,7 @@ def instance_get(context, instance_id, session=None): if is_admin_context(context): result = session.query(models.Instance ).filter_by(id=instance_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Instance @@ -581,7 +576,7 @@ def instance_get_all(context): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() @@ -590,7 +585,7 @@ def instance_get_all_by_user(context, user_id): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).filter_by(user_id=user_id ).all() @@ -603,7 +598,7 @@ def instance_get_all_by_project(context, project_id): return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).filter_by(project_id=project_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() @@ -615,7 +610,7 @@ def instance_get_all_by_reservation(context, reservation_id): return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).filter_by(reservation_id=reservation_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() elif is_user_context(context): return session.query(models.Instance @@ -633,7 +628,7 @@ def instance_get_by_ec2_id(context, ec2_id): if is_admin_context(context): result = session.query(models.Instance ).filter_by(ec2_id=ec2_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Instance @@ -749,7 +744,7 @@ def key_pair_get(context, user_id, name, session=None): result = session.query(models.KeyPair ).filter_by(user_id=user_id ).filter_by(name=name - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() if not result: raise exception.NotFound('no keypair for user %s, name %s' % @@ -775,7 +770,7 @@ def key_pair_get_all_by_user(context, user_id): def network_count(context): session = get_session() return session.query(models.Network - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).count() @@ -847,7 +842,7 @@ def network_get(context, network_id, session=None): if is_admin_context(context): result = session.query(models.Network ).filter_by(id=network_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Network @@ -914,7 +909,7 @@ def network_get_index(context, network_id): def network_index_count(context): session = get_session() return session.query(models.NetworkIndex - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).count() @@ -992,7 +987,7 @@ def queue_get_for(_context, topic, physical_node_id): def export_device_count(context): session = get_session() return session.query(models.ExportDevice - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).count() @@ -1038,7 +1033,7 @@ def quota_get(context, project_id, session=None): result = session.query(models.Quota ).filter_by(project_id=project_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() if not result: raise exception.NotFound('No quota for project_id %s' % project_id) @@ -1167,7 +1162,7 @@ def volume_get(context, volume_id, session=None): if is_admin_context(context): result = session.query(models.Volume ).filter_by(id=volume_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Volume @@ -1184,7 +1179,7 @@ def volume_get(context, volume_id, session=None): @require_admin_context def volume_get_all(context): return session.query(models.Volume - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() @require_context @@ -1194,7 +1189,7 @@ def volume_get_all_by_project(context, project_id): session = get_session() return session.query(models.Volume ).filter_by(project_id=project_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() @@ -1206,7 +1201,7 @@ def volume_get_by_ec2_id(context, ec2_id): if is_admin_context(context): result = session.query(models.Volume ).filter_by(ec2_id=ec2_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Volume @@ -1233,47 +1228,26 @@ def volume_ec2_id_exists(context, ec2_id, session=None): ).one()[0] -@require_context +@require_admin_context def volume_get_instance(context, volume_id): session = get_session() - result = None - - if is_admin_context(context): - result = session.query(models.Volume - ).filter_by(id=volume_id - ).filter_by(deleted=use_deleted(context) - ).options(joinedload('instance') - ).first() - elif is_user_context(context): - result = session.query(models.Volume - ).filter_by(project_id=context.project.id - ).filter_by(deleted=False - ).options(joinedload('instance') - ).first() - else: - raise exception.NotAuthorized() - + result = session.query(models.Volume + ).filter_by(id=volume_id + ).filter_by(deleted=can_read_deleted(context) + ).options(joinedload('instance') + ).first() if not result: raise exception.NotFound('Volume %s not found' % ec2_id) return result.instance -@require_context +@require_admin_context def volume_get_shelf_and_blade(context, volume_id): session = get_session() - result = None - - if is_admin_context(context): - result = session.query(models.ExportDevice - ).filter_by(volume_id=volume_id - ).first() - elif is_user_context(context): - result = session.query(models.ExportDevice - ).join(models.Volume - ).filter(models.Volume.project_id==context.project.id - ).filter_by(volume_id=volume_id - ).first() + result = session.query(models.ExportDevice + ).filter_by(volume_id=volume_id + ).first() if not result: raise exception.NotFound('No export device found for volume %s' % volume_id) diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index e01d7cff9..e601c480c 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -49,7 +49,6 @@ class NetworkTestCase(test.TrialTestCase): self.user = self.manager.create_user('netuser', 'netuser', 'netuser') self.projects = [] self.network = utils.import_object(FLAGS.network_manager) - # TODO(devcamcar): Passing project=None is Bad(tm). self.context = context.APIRequestContext(project=None, user=self.user) for i in range(5): name = 'project%s' % i @@ -60,11 +59,9 @@ class NetworkTestCase(test.TrialTestCase): user_context = context.APIRequestContext(project=self.projects[i], user=self.user) self.network.set_network_host(user_context, self.projects[i].id) - instance_ref = db.instance_create(None, - {'mac_address': utils.generate_mac()}) + instance_ref = self._create_instance(0) self.instance_id = instance_ref['id'] - instance_ref = db.instance_create(None, - {'mac_address': utils.generate_mac()}) + instance_ref = self._create_instance(1) self.instance2_id = instance_ref['id'] def tearDown(self): # pylint: disable-msg=C0103 @@ -77,6 +74,15 @@ class NetworkTestCase(test.TrialTestCase): self.manager.delete_project(project) self.manager.delete_user(self.user) + def _create_instance(self, project_num, mac=None): + if not mac: + mac = utils.generate_mac() + project = self.projects[project_num] + self.context.project = project + return db.instance_create(self.context, + {'project_id': project.id, + 'mac_address': mac}) + def _create_address(self, project_num, instance_id=None): """Create an address in given project num""" if instance_id is None: @@ -84,6 +90,11 @@ class NetworkTestCase(test.TrialTestCase): self.context.project = self.projects[project_num] return self.network.allocate_fixed_ip(self.context, instance_id) + def _deallocate_address(self, project_num, address): + self.context.project = self.projects[project_num] + self.network.deallocate_fixed_ip(self.context, address) + + def test_public_network_association(self): """Makes sure that we can allocaate a public ip""" # TODO(vish): better way of adding floating ips @@ -134,14 +145,14 @@ class NetworkTestCase(test.TrialTestCase): lease_ip(address) lease_ip(address2) - self.network.deallocate_fixed_ip(self.context, address) + self._deallocate_address(0, address) release_ip(address) self.assertFalse(is_allocated_in_project(address, self.projects[0].id)) # First address release shouldn't affect the second self.assertTrue(is_allocated_in_project(address2, self.projects[1].id)) - self.network.deallocate_fixed_ip(self.context, address2) + self._deallocate_address(1, address2) release_ip(address2) self.assertFalse(is_allocated_in_project(address2, self.projects[1].id)) @@ -152,24 +163,19 @@ class NetworkTestCase(test.TrialTestCase): lease_ip(first) instance_ids = [] for i in range(1, 5): - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address = self._create_address(i, instance_ref['id']) - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address2 = self._create_address(i, instance_ref['id']) - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address3 = self._create_address(i, instance_ref['id']) lease_ip(address) lease_ip(address2) lease_ip(address3) + self.context.project = self.projects[i] self.assertFalse(is_allocated_in_project(address, self.projects[0].id)) self.assertFalse(is_allocated_in_project(address2, @@ -185,7 +191,7 @@ class NetworkTestCase(test.TrialTestCase): for instance_id in instance_ids: db.instance_destroy(None, instance_id) release_ip(first) - self.network.deallocate_fixed_ip(self.context, first) + self._deallocate_address(0, first) def test_vpn_ip_and_port_looks_valid(self): """Ensure the vpn ip and port are reasonable""" @@ -246,9 +252,7 @@ class NetworkTestCase(test.TrialTestCase): addresses = [] instance_ids = [] for i in range(num_available_ips): - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(0) instance_ids.append(instance_ref['id']) address = self._create_address(0, instance_ref['id']) addresses.append(address) -- cgit From 033c464882c3d74ecd863abde767f37e7ad6a956 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Sat, 2 Oct 2010 12:39:47 +0200 Subject: Make _dhcp_file ensure the existence of the directory containing the files it returns. --- nova/network/linux_net.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nova/network/linux_net.py b/nova/network/linux_net.py index 95f7fe2d0..37f9c8253 100644 --- a/nova/network/linux_net.py +++ b/nova/network/linux_net.py @@ -274,6 +274,9 @@ def _stop_dnsmasq(network): def _dhcp_file(vlan, kind): """Return path to a pid, leases or conf file for a vlan""" + if not os.path.exists(FLAGS.networks_path): + os.makedirs(FLAGS.networks_path) + return os.path.abspath("%s/nova-%s.%s" % (FLAGS.networks_path, vlan, kind)) -- cgit From 5945291281f239bd928cea1833ee5a5b6c3df523 Mon Sep 17 00:00:00 2001 From: Ewan Mellor Date: Sat, 2 Oct 2010 12:42:09 +0100 Subject: Bug #653534: NameError on session_get in sqlalchemy.api.service_update Fix function call: session_get was meant to be service_get. --- nova/db/sqlalchemy/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 7f72f66b9..9a7c71a70 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -240,7 +240,7 @@ def service_create(context, values): def service_update(context, service_id, values): session = get_session() with session.begin(): - service_ref = session_get(context, service_id, session=session) + service_ref = service_get(context, service_id, session=session) for (key, value) in values.iteritems(): service_ref[key] = value service_ref.save(session=session) -- cgit From c66d550d208544799fdaf4646a846e9f9c0b6bc5 Mon Sep 17 00:00:00 2001 From: Ewan Mellor Date: Sat, 2 Oct 2010 13:11:33 +0100 Subject: Bug #653560: AttributeError in VlanManager.periodic_tasks Pass the correct context to db.fixed_ip_disassociate_all_by_timeout. --- nova/network/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nova/network/manager.py b/nova/network/manager.py index ef1d01138..9580479e5 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -230,7 +230,7 @@ class VlanManager(NetworkManager): now = datetime.datetime.utcnow() timeout = FLAGS.fixed_ip_disassociate_timeout time = now - datetime.timedelta(seconds=timeout) - num = self.db.fixed_ip_disassociate_all_by_timeout(self, + num = self.db.fixed_ip_disassociate_all_by_timeout(context, self.host, time) if num: -- cgit From 4e45f9472a95207153d32c88df8396c633c67a5d Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Sun, 3 Oct 2010 20:22:35 +0200 Subject: s/APIRequestContext/get_admin_context/ <-- sudo for request contexts. --- nova/tests/network_unittest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index 5370966d2..59b0a36e4 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -56,8 +56,8 @@ class NetworkTestCase(test.TrialTestCase): 'netuser', name)) # create the necessary network data for the project - user_context = context.APIRequestContext(project=self.projects[i], - user=self.user) + user_context = context.get_admin_context(user=self.user) + self.network.set_network_host(user_context, self.projects[i].id) instance_ref = self._create_instance(0) self.instance_id = instance_ref['id'] -- cgit