summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDevin Carlen <devin.carlen@gmail.com>2010-09-29 20:29:55 -0700
committerDevin Carlen <devin.carlen@gmail.com>2010-09-29 20:29:55 -0700
commitd32d95e08d67084ea04ccd1565ce6faffb1766ce (patch)
treec4cafa539833beed656b971bdfef9164925073d3
parent734df1fbad8195e7cd7072d0d0aeb5b94841f121 (diff)
downloadnova-d32d95e08d67084ea04ccd1565ce6faffb1766ce.tar.gz
nova-d32d95e08d67084ea04ccd1565ce6faffb1766ce.tar.xz
nova-d32d95e08d67084ea04ccd1565ce6faffb1766ce.zip
Finished instance context auth
-rw-r--r--nova/db/sqlalchemy/api.py185
-rw-r--r--nova/tests/compute_unittest.py2
-rw-r--r--nova/tests/network_unittest.py4
3 files changed, 141 insertions, 50 deletions
diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py
index d129df2be..9ab53b89b 100644
--- a/nova/db/sqlalchemy/api.py
+++ b/nova/db/sqlalchemy/api.py
@@ -41,9 +41,10 @@ FLAGS = flags.FLAGS
# pylint: disable-msg=C0111
def _deleted(context):
"""Calculates whether to include deleted objects based on context.
-
- Currently just looks for a flag called deleted in the context dict.
+ Currently just looks for a flag called deleted in the context dict.
"""
+ if is_user_context(context):
+ return False
if not hasattr(context, 'get'):
return False
return context.get('deleted', False)
@@ -69,7 +70,7 @@ def is_user_context(context):
###################
-
+#@require_admin_context
def service_destroy(context, service_id):
if not is_admin_context(context):
raise exception.NotAuthorized()
@@ -80,6 +81,7 @@ def service_destroy(context, service_id):
service_ref.delete(session=session)
+#@require_admin_context
def service_get(context, service_id, session=None):
if not is_admin_context(context):
raise exception.NotAuthorized()
@@ -98,6 +100,7 @@ def service_get(context, service_id, session=None):
return result
+#@require_admin_context
def service_get_all_by_topic(context, topic):
if not is_admin_context(context):
raise exception.NotAuthorized()
@@ -110,6 +113,7 @@ def service_get_all_by_topic(context, topic):
).all()
+#@require_admin_context
def _service_get_all_topic_subquery(context, session, topic, subq, label):
if not is_admin_context(context):
raise exception.NotAuthorized()
@@ -124,6 +128,7 @@ def _service_get_all_topic_subquery(context, session, topic, subq, label):
).all()
+#@require_admin_context
def service_get_all_compute_sorted(context):
if not is_admin_context(context):
raise exception.NotAuthorized()
@@ -151,6 +156,7 @@ def service_get_all_compute_sorted(context):
label)
+#@require_admin_context
def service_get_all_network_sorted(context):
if not is_admin_context(context):
raise exception.NotAuthorized()
@@ -171,6 +177,7 @@ def service_get_all_network_sorted(context):
label)
+#@require_admin_context
def service_get_all_volume_sorted(context):
if not is_admin_context(context):
raise exception.NotAuthorized()
@@ -191,6 +198,7 @@ def service_get_all_volume_sorted(context):
label)
+#@require_admin_context
def service_get_by_args(context, host, binary):
if not is_admin_context(context):
raise exception.NotAuthorized()
@@ -208,6 +216,7 @@ def service_get_by_args(context, host, binary):
return result
+#@require_admin_context
def service_create(context, values):
if not is_admin_context(context):
return exception.NotAuthorized()
@@ -219,6 +228,7 @@ def service_create(context, values):
return service_ref
+#@require_admin_context
def service_update(context, service_id, values):
if not is_admin_context(context):
raise exception.NotAuthorized()
@@ -234,12 +244,11 @@ def service_update(context, service_id, values):
###################
+#@require_context
def floating_ip_allocate_address(context, host, project_id):
if is_user_context(context):
if context.project.id != project_id:
raise exception.NotAuthorized()
- elif not is_admin_context(context):
- raise exception.NotAuthorized()
session = get_session()
with session.begin():
@@ -259,6 +268,7 @@ def floating_ip_allocate_address(context, host, project_id):
return floating_ip_ref['address']
+#@require_context
def floating_ip_create(context, values):
if not is_user_context(context) and not is_admin_context(context):
raise exception.NotAuthorized()
@@ -270,12 +280,11 @@ def floating_ip_create(context, values):
return floating_ip_ref['address']
+#@require_context
def floating_ip_count_by_project(context, project_id):
if is_user_context(context):
if context.project.id != project_id:
raise exception.NotAuthorized()
- elif not is_admin_context(context):
- raise exception.NotAuthorized()
session = get_session()
return session.query(models.FloatingIp
@@ -316,6 +325,7 @@ def floating_ip_deallocate(context, address):
floating_ip_ref['project_id'] = None
floating_ip_ref.save(session=session)
+
#@require_context
def floating_ip_destroy(context, address):
if not is_user_context(context) and not is_admin_context(context):
@@ -330,6 +340,7 @@ def floating_ip_destroy(context, address):
floating_ip_ref.delete(session=session)
+#@require_context
def floating_ip_disassociate(context, address):
if not is_user_context(context) and is_admin_context(context):
raise exception.NotAuthorized()
@@ -350,6 +361,7 @@ def floating_ip_disassociate(context, address):
floating_ip_ref.save(session=session)
return fixed_ip_address
+
#@require_admin_context
def floating_ip_get_all(context):
if not is_admin_context(context):
@@ -361,6 +373,7 @@ def floating_ip_get_all(context):
).filter_by(deleted=False
).all()
+
#@require_admin_context
def floating_ip_get_all_by_host(context, host):
if not is_admin_context(context):
@@ -373,6 +386,7 @@ def floating_ip_get_all_by_host(context, host):
).filter_by(deleted=False
).all()
+
#@require_context
def floating_ip_get_all_by_project(context, project_id):
# TODO(devcamcar): Change to decorate and check project_id separately.
@@ -389,6 +403,7 @@ def floating_ip_get_all_by_project(context, project_id):
).filter_by(deleted=False
).all()
+
#@require_context
def floating_ip_get_by_address(context, address, session=None):
# TODO(devcamcar): Ensure the address belongs to user.
@@ -408,14 +423,9 @@ def floating_ip_get_by_address(context, address, session=None):
return result
- # floating_ip_ref = get_floating_ip_by_address(context,
- # address,
- # session=session)
- # return floating_ip_ref.fixed_ip.instance
-
-
###################
+
#@require_context
def fixed_ip_associate(context, address, instance_id):
if not is_user_context(context) and not is_admin_context(context):
@@ -469,6 +479,7 @@ def fixed_ip_associate_pool(context, network_id, instance_id):
session.add(fixed_ip_ref)
return fixed_ip_ref['address']
+
#@require_context
def fixed_ip_create(_context, values):
fixed_ip_ref = models.FixedIp()
@@ -477,6 +488,7 @@ def fixed_ip_create(_context, values):
fixed_ip_ref.save()
return fixed_ip_ref['address']
+
#@require_context
def fixed_ip_disassociate(context, address):
session = get_session()
@@ -535,7 +547,8 @@ def fixed_ip_update(context, address, values):
###################
-def instance_create(_context, values):
+#@require_context
+def instance_create(context, values):
instance_ref = models.Instance()
for (key, value) in values.iteritems():
instance_ref[key] = value
@@ -544,14 +557,14 @@ def instance_create(_context, values):
with session.begin():
while instance_ref.ec2_id == None:
ec2_id = utils.generate_uid(instance_ref.__prefix__)
- if not instance_ec2_id_exists(_context, ec2_id, session=session):
+ if not instance_ec2_id_exists(context, ec2_id, session=session):
instance_ref.ec2_id = ec2_id
instance_ref.save(session=session)
return instance_ref
-def instance_data_get_for_project(_context, project_id):
- # TODO(devmcar): Admin only
+#@require_admin_context
+def instance_data_get_for_project(context, project_id):
session = get_session()
result = session.query(func.count(models.Instance.id),
func.sum(models.Instance.vcpus)
@@ -562,21 +575,42 @@ def instance_data_get_for_project(_context, project_id):
return (result[0] or 0, result[1] or 0)
-def instance_destroy(_context, instance_id):
- # TODO(devcamcar): Support user context
+#@require_context
+def instance_destroy(context, instance_id):
session = get_session()
with session.begin():
- instance_ref = models.Instance.find(instance_id, session=session)
+ instance_ref = instance_get(context, instance_id, session=session)
instance_ref.delete(session=session)
+#@require_context
def instance_get(context, instance_id, session=None):
- # TODO(devcamcar): Support user context
- return models.Instance.find(instance_id, session=session, deleted=_deleted(context))
+ if not session:
+ session = get_session()
+ result = None
+
+ if is_admin_context(context):
+ result = session.query(models.Instance
+ ).filter_by(id=instance_id
+ ).filter_by(deleted=_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Instance
+ ).filter_by(project_id=context.project.id
+ ).filter_by(id=instance_id
+ ).filter_by(deleted=False
+ ).first()
+ if not result:
+ raise exception.NotFound('No instance for id %s' % instance_id)
+
+ return result
+#@require_admin_context
def instance_get_all(context):
- # TODO(devcamcar): Admin only
+ if not is_admin_context(context):
+ raise exception.NotAuthorized()
+
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
@@ -584,8 +618,11 @@ def instance_get_all(context):
).all()
+#@require_admin_context
def instance_get_all_by_user(context, user_id):
- # TODO(devcamcar): Admin only
+ if not is_admin_context(context):
+ raise exception.NotAuthorized()
+
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
@@ -594,8 +631,12 @@ def instance_get_all_by_user(context, user_id):
).all()
+#@require_context
def instance_get_all_by_project(context, project_id):
- # TODO(devcamcar): Support user context
+ if is_user_context(context):
+ if context.project.id != project_id:
+ raise exception.NotAuthorized()
+
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
@@ -604,50 +645,68 @@ def instance_get_all_by_project(context, project_id):
).all()
-def instance_get_all_by_reservation(_context, reservation_id):
- # TODO(devcamcar): Support user context
+#@require_context
+def instance_get_all_by_reservation(context, reservation_id):
session = get_session()
- return session.query(models.Instance
- ).options(joinedload_all('fixed_ip.floating_ips')
- ).filter_by(reservation_id=reservation_id
- ).filter_by(deleted=False
- ).all()
+
+ if is_admin_context(context):
+ return session.query(models.Instance
+ ).options(joinedload_all('fixed_ip.floating_ips')
+ ).filter_by(reservation_id=reservation_id
+ ).filter_by(deleted=_deleted(context)
+ ).all()
+ elif is_user_context(context):
+ return session.query(models.Instance
+ ).options(joinedload_all('fixed_ip.floating_ips')
+ ).filter_by(project_id=context.project.id
+ ).filter_by(reservation_id=reservation_id
+ ).filter_by(deleted=False
+ ).all()
+#@require_context
def instance_get_by_ec2_id(context, ec2_id):
- # TODO(devcamcar): Support user context
session = get_session()
- instance_ref = session.query(models.Instance
+
+ if is_admin_context(context):
+ result = session.query(models.Instance
).filter_by(ec2_id=ec2_id
).filter_by(deleted=_deleted(context)
).first()
- if not instance_ref:
+ elif is_user_context(context):
+ result = session.query(models.Instance
+ ).filter_by(project_id=context.project.id
+ ).filter_by(ec2_id=ec2_id
+ ).filter_by(deleted=False
+ ).first()
+ if not result:
raise exception.NotFound('Instance %s not found' % (ec2_id))
- return instance_ref
+ return result
+#@require_context
def instance_ec2_id_exists(context, ec2_id, session=None):
if not session:
session = get_session()
return session.query(exists().where(models.Instance.id==ec2_id)).one()[0]
-def instance_get_fixed_address(_context, instance_id):
- # TODO(devcamcar): Support user context
+#@require_context
+def instance_get_fixed_address(context, instance_id):
session = get_session()
with session.begin():
- instance_ref = models.Instance.find(instance_id, session=session)
+ instance_ref = instance_get(context, instance_id, session=session)
if not instance_ref.fixed_ip:
return None
return instance_ref.fixed_ip['address']
-def instance_get_floating_address(_context, instance_id):
- # TODO(devcamcar): Support user context
+#@require_context
+def instance_get_floating_address(context, instance_id):
session = get_session()
with session.begin():
- instance_ref = models.Instance.find(instance_id, session=session)
+ instance_ref = instance_get(context, instance_id, session=session)
if not instance_ref.fixed_ip:
return None
if not instance_ref.fixed_ip.floating_ips:
@@ -656,14 +715,20 @@ def instance_get_floating_address(_context, instance_id):
return instance_ref.fixed_ip.floating_ips[0]['address']
+#@require_admin_context
def instance_is_vpn(context, instance_id):
- # TODO(devcamcar): Admin only
+ if not is_admin_context(context):
+ raise exception.NotAuthorized()
# TODO(vish): Move this into image code somewhere
instance_ref = instance_get(context, instance_id)
return instance_ref['image_id'] == FLAGS.vpn_image_id
+#@require_admin_context
def instance_set_state(context, instance_id, state, description=None):
+ if not is_admin_context(context):
+ raise exception.NotAuthorized()
+
# TODO(devcamcar): Move this out of models and into driver
from nova.compute import power_state
if not description:
@@ -674,10 +739,11 @@ def instance_set_state(context, instance_id, state, description=None):
'state_description': description})
-def instance_update(_context, instance_id, values):
+#@require_context
+def instance_update(context, instance_id, values):
session = get_session()
with session.begin():
- instance_ref = models.Instance.find(instance_id, session=session)
+ instance_ref = instance_get(context, instance_id, session=session)
for (key, value) in values.iteritems():
instance_ref[key] = value
instance_ref.save(session=session)
@@ -686,6 +752,7 @@ def instance_update(_context, instance_id, values):
###################
+#@require_context
def key_pair_create(_context, values):
key_pair_ref = models.KeyPair()
for (key, value) in values.iteritems():
@@ -694,7 +761,8 @@ def key_pair_create(_context, values):
return key_pair_ref
-def key_pair_destroy(_context, user_id, name):
+#@require_context
+def key_pair_destroy(context, user_id, name):
session = get_session()
with session.begin():
key_pair_ref = models.KeyPair.find_by_args(user_id,
@@ -784,8 +852,27 @@ def network_destroy(_context, network_id):
{'id': network_id})
-def network_get(_context, network_id, session=None):
- return models.Network.find(network_id, session=session)
+#@require_context
+def network_get(context, network_id, session=None):
+ if not session:
+ session = get_session()
+ result = None
+
+ if is_admin_context(context):
+ result = session.query(models.Network
+ ).filter_by(id=network_id
+ ).filter_by(deleted=_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Network
+ ).filter_by(project_id=context.project.id
+ ).filter_by(id=network_id
+ ).filter_by(deleted=False
+ ).first()
+ if not result:
+ raise exception.NotFound('No network for id %s' % network_id)
+
+ return result
# NOTE(vish): pylint complains because of the long method name, but
@@ -1066,7 +1153,7 @@ def volume_get(context, volume_id, session=None):
).first()
elif is_user_context(context):
result = session.query(models.Volume
- ).filter_by(project_id=context.project.project_id
+ ).filter_by(project_id=context.project.id
).filter_by(id=volume_id
).filter_by(deleted=False
).first()
diff --git a/nova/tests/compute_unittest.py b/nova/tests/compute_unittest.py
index f5c0f1c09..e705c2552 100644
--- a/nova/tests/compute_unittest.py
+++ b/nova/tests/compute_unittest.py
@@ -96,6 +96,8 @@ class ComputeTestCase(test.TrialTestCase):
self.assertEqual(instance_ref['deleted_at'], None)
terminate = datetime.datetime.utcnow()
yield self.compute.terminate_instance(self.context, instance_id)
+ # TODO(devcamcar): Pass deleted in using system context.
+ # context.read_deleted ?
instance_ref = db.instance_get({'deleted': True}, instance_id)
self.assert_(instance_ref['launched_at'] < terminate)
self.assert_(instance_ref['deleted_at'] > terminate)
diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py
index 110e8430c..ca6a4bbc2 100644
--- a/nova/tests/network_unittest.py
+++ b/nova/tests/network_unittest.py
@@ -56,7 +56,9 @@ class NetworkTestCase(test.TrialTestCase):
'netuser',
name))
# create the necessary network data for the project
- self.network.set_network_host(self.context, self.projects[i].id)
+ user_context = context.APIRequestContext(project=self.projects[i],
+ user=self.user)
+ self.network.set_network_host(user_context, self.projects[i].id)
instance_ref = db.instance_create(None,
{'mac_address': utils.generate_mac()})
self.instance_id = instance_ref['id']