diff options
| author | Devin Carlen <devin.carlen@gmail.com> | 2010-09-29 20:29:55 -0700 |
|---|---|---|
| committer | Devin Carlen <devin.carlen@gmail.com> | 2010-09-29 20:29:55 -0700 |
| commit | d32d95e08d67084ea04ccd1565ce6faffb1766ce (patch) | |
| tree | c4cafa539833beed656b971bdfef9164925073d3 | |
| parent | 734df1fbad8195e7cd7072d0d0aeb5b94841f121 (diff) | |
| download | nova-d32d95e08d67084ea04ccd1565ce6faffb1766ce.tar.gz nova-d32d95e08d67084ea04ccd1565ce6faffb1766ce.tar.xz nova-d32d95e08d67084ea04ccd1565ce6faffb1766ce.zip | |
Finished instance context auth
| -rw-r--r-- | nova/db/sqlalchemy/api.py | 185 | ||||
| -rw-r--r-- | nova/tests/compute_unittest.py | 2 | ||||
| -rw-r--r-- | nova/tests/network_unittest.py | 4 |
3 files changed, 141 insertions, 50 deletions
diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index d129df2be..9ab53b89b 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -41,9 +41,10 @@ FLAGS = flags.FLAGS # pylint: disable-msg=C0111 def _deleted(context): """Calculates whether to include deleted objects based on context. - - Currently just looks for a flag called deleted in the context dict. + Currently just looks for a flag called deleted in the context dict. """ + if is_user_context(context): + return False if not hasattr(context, 'get'): return False return context.get('deleted', False) @@ -69,7 +70,7 @@ def is_user_context(context): ################### - +#@require_admin_context def service_destroy(context, service_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -80,6 +81,7 @@ def service_destroy(context, service_id): service_ref.delete(session=session) +#@require_admin_context def service_get(context, service_id, session=None): if not is_admin_context(context): raise exception.NotAuthorized() @@ -98,6 +100,7 @@ def service_get(context, service_id, session=None): return result +#@require_admin_context def service_get_all_by_topic(context, topic): if not is_admin_context(context): raise exception.NotAuthorized() @@ -110,6 +113,7 @@ def service_get_all_by_topic(context, topic): ).all() +#@require_admin_context def _service_get_all_topic_subquery(context, session, topic, subq, label): if not is_admin_context(context): raise exception.NotAuthorized() @@ -124,6 +128,7 @@ def _service_get_all_topic_subquery(context, session, topic, subq, label): ).all() +#@require_admin_context def service_get_all_compute_sorted(context): if not is_admin_context(context): raise exception.NotAuthorized() @@ -151,6 +156,7 @@ def service_get_all_compute_sorted(context): label) +#@require_admin_context def service_get_all_network_sorted(context): if not is_admin_context(context): raise exception.NotAuthorized() @@ -171,6 +177,7 @@ def service_get_all_network_sorted(context): label) +#@require_admin_context def service_get_all_volume_sorted(context): if not is_admin_context(context): raise exception.NotAuthorized() @@ -191,6 +198,7 @@ def service_get_all_volume_sorted(context): label) +#@require_admin_context def service_get_by_args(context, host, binary): if not is_admin_context(context): raise exception.NotAuthorized() @@ -208,6 +216,7 @@ def service_get_by_args(context, host, binary): return result +#@require_admin_context def service_create(context, values): if not is_admin_context(context): return exception.NotAuthorized() @@ -219,6 +228,7 @@ def service_create(context, values): return service_ref +#@require_admin_context def service_update(context, service_id, values): if not is_admin_context(context): raise exception.NotAuthorized() @@ -234,12 +244,11 @@ def service_update(context, service_id, values): ################### +#@require_context def floating_ip_allocate_address(context, host, project_id): if is_user_context(context): if context.project.id != project_id: raise exception.NotAuthorized() - elif not is_admin_context(context): - raise exception.NotAuthorized() session = get_session() with session.begin(): @@ -259,6 +268,7 @@ def floating_ip_allocate_address(context, host, project_id): return floating_ip_ref['address'] +#@require_context def floating_ip_create(context, values): if not is_user_context(context) and not is_admin_context(context): raise exception.NotAuthorized() @@ -270,12 +280,11 @@ def floating_ip_create(context, values): return floating_ip_ref['address'] +#@require_context def floating_ip_count_by_project(context, project_id): if is_user_context(context): if context.project.id != project_id: raise exception.NotAuthorized() - elif not is_admin_context(context): - raise exception.NotAuthorized() session = get_session() return session.query(models.FloatingIp @@ -316,6 +325,7 @@ def floating_ip_deallocate(context, address): floating_ip_ref['project_id'] = None floating_ip_ref.save(session=session) + #@require_context def floating_ip_destroy(context, address): if not is_user_context(context) and not is_admin_context(context): @@ -330,6 +340,7 @@ def floating_ip_destroy(context, address): floating_ip_ref.delete(session=session) +#@require_context def floating_ip_disassociate(context, address): if not is_user_context(context) and is_admin_context(context): raise exception.NotAuthorized() @@ -350,6 +361,7 @@ def floating_ip_disassociate(context, address): floating_ip_ref.save(session=session) return fixed_ip_address + #@require_admin_context def floating_ip_get_all(context): if not is_admin_context(context): @@ -361,6 +373,7 @@ def floating_ip_get_all(context): ).filter_by(deleted=False ).all() + #@require_admin_context def floating_ip_get_all_by_host(context, host): if not is_admin_context(context): @@ -373,6 +386,7 @@ def floating_ip_get_all_by_host(context, host): ).filter_by(deleted=False ).all() + #@require_context def floating_ip_get_all_by_project(context, project_id): # TODO(devcamcar): Change to decorate and check project_id separately. @@ -389,6 +403,7 @@ def floating_ip_get_all_by_project(context, project_id): ).filter_by(deleted=False ).all() + #@require_context def floating_ip_get_by_address(context, address, session=None): # TODO(devcamcar): Ensure the address belongs to user. @@ -408,14 +423,9 @@ def floating_ip_get_by_address(context, address, session=None): return result - # floating_ip_ref = get_floating_ip_by_address(context, - # address, - # session=session) - # return floating_ip_ref.fixed_ip.instance - - ################### + #@require_context def fixed_ip_associate(context, address, instance_id): if not is_user_context(context) and not is_admin_context(context): @@ -469,6 +479,7 @@ def fixed_ip_associate_pool(context, network_id, instance_id): session.add(fixed_ip_ref) return fixed_ip_ref['address'] + #@require_context def fixed_ip_create(_context, values): fixed_ip_ref = models.FixedIp() @@ -477,6 +488,7 @@ def fixed_ip_create(_context, values): fixed_ip_ref.save() return fixed_ip_ref['address'] + #@require_context def fixed_ip_disassociate(context, address): session = get_session() @@ -535,7 +547,8 @@ def fixed_ip_update(context, address, values): ################### -def instance_create(_context, values): +#@require_context +def instance_create(context, values): instance_ref = models.Instance() for (key, value) in values.iteritems(): instance_ref[key] = value @@ -544,14 +557,14 @@ def instance_create(_context, values): with session.begin(): while instance_ref.ec2_id == None: ec2_id = utils.generate_uid(instance_ref.__prefix__) - if not instance_ec2_id_exists(_context, ec2_id, session=session): + if not instance_ec2_id_exists(context, ec2_id, session=session): instance_ref.ec2_id = ec2_id instance_ref.save(session=session) return instance_ref -def instance_data_get_for_project(_context, project_id): - # TODO(devmcar): Admin only +#@require_admin_context +def instance_data_get_for_project(context, project_id): session = get_session() result = session.query(func.count(models.Instance.id), func.sum(models.Instance.vcpus) @@ -562,21 +575,42 @@ def instance_data_get_for_project(_context, project_id): return (result[0] or 0, result[1] or 0) -def instance_destroy(_context, instance_id): - # TODO(devcamcar): Support user context +#@require_context +def instance_destroy(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) instance_ref.delete(session=session) +#@require_context def instance_get(context, instance_id, session=None): - # TODO(devcamcar): Support user context - return models.Instance.find(instance_id, session=session, deleted=_deleted(context)) + if not session: + session = get_session() + result = None + + if is_admin_context(context): + result = session.query(models.Instance + ).filter_by(id=instance_id + ).filter_by(deleted=_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Instance + ).filter_by(project_id=context.project.id + ).filter_by(id=instance_id + ).filter_by(deleted=False + ).first() + if not result: + raise exception.NotFound('No instance for id %s' % instance_id) + + return result +#@require_admin_context def instance_get_all(context): - # TODO(devcamcar): Admin only + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -584,8 +618,11 @@ def instance_get_all(context): ).all() +#@require_admin_context def instance_get_all_by_user(context, user_id): - # TODO(devcamcar): Admin only + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -594,8 +631,12 @@ def instance_get_all_by_user(context, user_id): ).all() +#@require_context def instance_get_all_by_project(context, project_id): - # TODO(devcamcar): Support user context + if is_user_context(context): + if context.project.id != project_id: + raise exception.NotAuthorized() + session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -604,50 +645,68 @@ def instance_get_all_by_project(context, project_id): ).all() -def instance_get_all_by_reservation(_context, reservation_id): - # TODO(devcamcar): Support user context +#@require_context +def instance_get_all_by_reservation(context, reservation_id): session = get_session() - return session.query(models.Instance - ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(reservation_id=reservation_id - ).filter_by(deleted=False - ).all() + + if is_admin_context(context): + return session.query(models.Instance + ).options(joinedload_all('fixed_ip.floating_ips') + ).filter_by(reservation_id=reservation_id + ).filter_by(deleted=_deleted(context) + ).all() + elif is_user_context(context): + return session.query(models.Instance + ).options(joinedload_all('fixed_ip.floating_ips') + ).filter_by(project_id=context.project.id + ).filter_by(reservation_id=reservation_id + ).filter_by(deleted=False + ).all() +#@require_context def instance_get_by_ec2_id(context, ec2_id): - # TODO(devcamcar): Support user context session = get_session() - instance_ref = session.query(models.Instance + + if is_admin_context(context): + result = session.query(models.Instance ).filter_by(ec2_id=ec2_id ).filter_by(deleted=_deleted(context) ).first() - if not instance_ref: + elif is_user_context(context): + result = session.query(models.Instance + ).filter_by(project_id=context.project.id + ).filter_by(ec2_id=ec2_id + ).filter_by(deleted=False + ).first() + if not result: raise exception.NotFound('Instance %s not found' % (ec2_id)) - return instance_ref + return result +#@require_context def instance_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() return session.query(exists().where(models.Instance.id==ec2_id)).one()[0] -def instance_get_fixed_address(_context, instance_id): - # TODO(devcamcar): Support user context +#@require_context +def instance_get_fixed_address(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) if not instance_ref.fixed_ip: return None return instance_ref.fixed_ip['address'] -def instance_get_floating_address(_context, instance_id): - # TODO(devcamcar): Support user context +#@require_context +def instance_get_floating_address(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) if not instance_ref.fixed_ip: return None if not instance_ref.fixed_ip.floating_ips: @@ -656,14 +715,20 @@ def instance_get_floating_address(_context, instance_id): return instance_ref.fixed_ip.floating_ips[0]['address'] +#@require_admin_context def instance_is_vpn(context, instance_id): - # TODO(devcamcar): Admin only + if not is_admin_context(context): + raise exception.NotAuthorized() # TODO(vish): Move this into image code somewhere instance_ref = instance_get(context, instance_id) return instance_ref['image_id'] == FLAGS.vpn_image_id +#@require_admin_context def instance_set_state(context, instance_id, state, description=None): + if not is_admin_context(context): + raise exception.NotAuthorized() + # TODO(devcamcar): Move this out of models and into driver from nova.compute import power_state if not description: @@ -674,10 +739,11 @@ def instance_set_state(context, instance_id, state, description=None): 'state_description': description}) -def instance_update(_context, instance_id, values): +#@require_context +def instance_update(context, instance_id, values): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) for (key, value) in values.iteritems(): instance_ref[key] = value instance_ref.save(session=session) @@ -686,6 +752,7 @@ def instance_update(_context, instance_id, values): ################### +#@require_context def key_pair_create(_context, values): key_pair_ref = models.KeyPair() for (key, value) in values.iteritems(): @@ -694,7 +761,8 @@ def key_pair_create(_context, values): return key_pair_ref -def key_pair_destroy(_context, user_id, name): +#@require_context +def key_pair_destroy(context, user_id, name): session = get_session() with session.begin(): key_pair_ref = models.KeyPair.find_by_args(user_id, @@ -784,8 +852,27 @@ def network_destroy(_context, network_id): {'id': network_id}) -def network_get(_context, network_id, session=None): - return models.Network.find(network_id, session=session) +#@require_context +def network_get(context, network_id, session=None): + if not session: + session = get_session() + result = None + + if is_admin_context(context): + result = session.query(models.Network + ).filter_by(id=network_id + ).filter_by(deleted=_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Network + ).filter_by(project_id=context.project.id + ).filter_by(id=network_id + ).filter_by(deleted=False + ).first() + if not result: + raise exception.NotFound('No network for id %s' % network_id) + + return result # NOTE(vish): pylint complains because of the long method name, but @@ -1066,7 +1153,7 @@ def volume_get(context, volume_id, session=None): ).first() elif is_user_context(context): result = session.query(models.Volume - ).filter_by(project_id=context.project.project_id + ).filter_by(project_id=context.project.id ).filter_by(id=volume_id ).filter_by(deleted=False ).first() diff --git a/nova/tests/compute_unittest.py b/nova/tests/compute_unittest.py index f5c0f1c09..e705c2552 100644 --- a/nova/tests/compute_unittest.py +++ b/nova/tests/compute_unittest.py @@ -96,6 +96,8 @@ class ComputeTestCase(test.TrialTestCase): self.assertEqual(instance_ref['deleted_at'], None) terminate = datetime.datetime.utcnow() yield self.compute.terminate_instance(self.context, instance_id) + # TODO(devcamcar): Pass deleted in using system context. + # context.read_deleted ? instance_ref = db.instance_get({'deleted': True}, instance_id) self.assert_(instance_ref['launched_at'] < terminate) self.assert_(instance_ref['deleted_at'] > terminate) diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index 110e8430c..ca6a4bbc2 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -56,7 +56,9 @@ class NetworkTestCase(test.TrialTestCase): 'netuser', name)) # create the necessary network data for the project - self.network.set_network_host(self.context, self.projects[i].id) + user_context = context.APIRequestContext(project=self.projects[i], + user=self.user) + self.network.set_network_host(user_context, self.projects[i].id) instance_ref = db.instance_create(None, {'mac_address': utils.generate_mac()}) self.instance_id = instance_ref['id'] |
