summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDevin Carlen <devin.carlen@gmail.com>2010-09-30 02:02:14 -0700
committerDevin Carlen <devin.carlen@gmail.com>2010-09-30 02:02:14 -0700
commit8bd81f3ec811e19f6e7faf7a4fe271f85fbc7fc7 (patch)
tree5e548bb2dddebe5e5a8cc9b1c481fd37ef095519
parent336523b36ceb8f5302acd443b7f1171b67575f73 (diff)
downloadnova-8bd81f3ec811e19f6e7faf7a4fe271f85fbc7fc7.tar.gz
nova-8bd81f3ec811e19f6e7faf7a4fe271f85fbc7fc7.tar.xz
nova-8bd81f3ec811e19f6e7faf7a4fe271f85fbc7fc7.zip
Simplified authorization with decorators"
"
-rw-r--r--nova/db/sqlalchemy/api.py408
1 files changed, 142 insertions, 266 deletions
diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py
index 302322979..0e7d2e664 100644
--- a/nova/db/sqlalchemy/api.py
+++ b/nova/db/sqlalchemy/api.py
@@ -51,6 +51,7 @@ def _deleted(context):
def is_admin_context(context):
+ """Indicates if the request context is an administrator."""
if not context:
logging.warning('Use of empty request context is deprecated')
return True
@@ -60,6 +61,7 @@ def is_admin_context(context):
def is_user_context(context):
+ """Indicates if the request context is a normal user."""
if not context:
logging.warning('Use of empty request context is deprecated')
return False
@@ -68,24 +70,62 @@ def is_user_context(context):
return True
+def authorize_project_context(context, project_id):
+ """Ensures that the request context has permission to access the
+ given project.
+ """
+ if is_user_context(context):
+ if not context.project:
+ raise exception.NotAuthorized()
+ elif context.project.id != project_id:
+ raise exception.NotAuthorized()
+
+
+def authorize_user_context(context, user_id):
+ """Ensures that the request context has permission to access the
+ given user.
+ """
+ if is_user_context(context):
+ if not context.user:
+ raise exception.NotAuthorized()
+ elif context.user.id != user_id:
+ raise exception.NotAuthorized()
+
+
+def require_admin_context(f):
+ """Decorator used to indicate that the method requires an
+ administrator context.
+ """
+ def wrapper(*args, **kwargs):
+ if not is_admin_context(args[0]):
+ raise exception.NotAuthorized()
+ return f(*args, **kwargs)
+ return wrapper
+
+
+def require_context(f):
+ """Decorator used to indicate that the method requires either
+ an administrator or normal user context.
+ """
+ def wrapper(*args, **kwargs):
+ if not is_admin_context(args[0]) and not is_user_context(args[0]):
+ raise exception.NotAuthorized()
+ return f(*args, **kwargs)
+ return wrapper
+
+
###################
-#@require_admin_context
+@require_admin_context
def service_destroy(context, service_id):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
service_ref = service_get(context, service_id, session=session)
service_ref.delete(session=session)
-#@require_admin_context
+@require_admin_context
def service_get(context, service_id, session=None):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
if not session:
session = get_session()
@@ -100,11 +140,8 @@ def service_get(context, service_id, session=None):
return result
-#@require_admin_context
+@require_admin_context
def service_get_all_by_topic(context, topic):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
return session.query(models.Service
).filter_by(deleted=False
@@ -113,11 +150,8 @@ def service_get_all_by_topic(context, topic):
).all()
-#@require_admin_context
+@require_admin_context
def _service_get_all_topic_subquery(context, session, topic, subq, label):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
sort_value = getattr(subq.c, label)
return session.query(models.Service, func.coalesce(sort_value, 0)
).filter_by(topic=topic
@@ -128,11 +162,8 @@ def _service_get_all_topic_subquery(context, session, topic, subq, label):
).all()
-#@require_admin_context
+@require_admin_context
def service_get_all_compute_sorted(context):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
# NOTE(vish): The intended query is below
@@ -156,11 +187,8 @@ def service_get_all_compute_sorted(context):
label)
-#@require_admin_context
+@require_admin_context
def service_get_all_network_sorted(context):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
topic = 'network'
@@ -177,11 +205,8 @@ def service_get_all_network_sorted(context):
label)
-#@require_admin_context
+@require_admin_context
def service_get_all_volume_sorted(context):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
topic = 'volume'
@@ -198,11 +223,8 @@ def service_get_all_volume_sorted(context):
label)
-#@require_admin_context
+@require_admin_context
def service_get_by_args(context, host, binary):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
result = session.query(models.Service
).filter_by(host=host
@@ -216,11 +238,8 @@ def service_get_by_args(context, host, binary):
return result
-#@require_admin_context
+@require_admin_context
def service_create(context, values):
- if not is_admin_context(context):
- return exception.NotAuthorized()
-
service_ref = models.Service()
for (key, value) in values.iteritems():
service_ref[key] = value
@@ -228,11 +247,8 @@ def service_create(context, values):
return service_ref
-#@require_admin_context
+@require_admin_context
def service_update(context, service_id, values):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
service_ref = session_get(context, service_id, session=session)
@@ -244,11 +260,9 @@ def service_update(context, service_id, values):
###################
-#@require_context
+@require_context
def floating_ip_allocate_address(context, host, project_id):
- if is_user_context(context):
- if context.project.id != project_id:
- raise exception.NotAuthorized()
+ authorize_project_context(context, project_id)
session = get_session()
with session.begin():
@@ -268,11 +282,8 @@ def floating_ip_allocate_address(context, host, project_id):
return floating_ip_ref['address']
-#@require_context
+@require_context
def floating_ip_create(context, values):
- if not is_user_context(context) and not is_admin_context(context):
- raise exception.NotAuthorized()
-
floating_ip_ref = models.FloatingIp()
for (key, value) in values.iteritems():
floating_ip_ref[key] = value
@@ -280,11 +291,9 @@ def floating_ip_create(context, values):
return floating_ip_ref['address']
-#@require_context
+@require_context
def floating_ip_count_by_project(context, project_id):
- if is_user_context(context):
- if context.project.id != project_id:
- raise exception.NotAuthorized()
+ authorize_project_context(context, project_id)
session = get_session()
return session.query(models.FloatingIp
@@ -293,11 +302,8 @@ def floating_ip_count_by_project(context, project_id):
).count()
-#@require_context
+@require_context
def floating_ip_fixed_ip_associate(context, floating_address, fixed_address):
- if not is_user_context(context) and not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
# TODO(devcamcar): How to ensure floating_id belongs to user?
@@ -311,11 +317,8 @@ def floating_ip_fixed_ip_associate(context, floating_address, fixed_address):
floating_ip_ref.save(session=session)
-#@require_context
+@require_context
def floating_ip_deallocate(context, address):
- if not is_user_context(context) and not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
# TODO(devcamcar): How to ensure floating id belongs to user?
@@ -326,11 +329,8 @@ def floating_ip_deallocate(context, address):
floating_ip_ref.save(session=session)
-#@require_context
+@require_context
def floating_ip_destroy(context, address):
- if not is_user_context(context) and not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
# TODO(devcamcar): Ensure address belongs to user.
@@ -340,11 +340,8 @@ def floating_ip_destroy(context, address):
floating_ip_ref.delete(session=session)
-#@require_context
+@require_context
def floating_ip_disassociate(context, address):
- if not is_user_context(context) and is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
# TODO(devcamcar): Ensure address belongs to user.
@@ -362,11 +359,8 @@ def floating_ip_disassociate(context, address):
return fixed_ip_address
-#@require_admin_context
+@require_admin_context
def floating_ip_get_all(context):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
return session.query(models.FloatingIp
).options(joinedload_all('fixed_ip.instance')
@@ -374,11 +368,8 @@ def floating_ip_get_all(context):
).all()
-#@require_admin_context
+@require_admin_context
def floating_ip_get_all_by_host(context, host):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
return session.query(models.FloatingIp
).options(joinedload_all('fixed_ip.instance')
@@ -387,14 +378,9 @@ def floating_ip_get_all_by_host(context, host):
).all()
-#@require_context
+@require_context
def floating_ip_get_all_by_project(context, project_id):
- # TODO(devcamcar): Change to decorate and check project_id separately.
- if is_user_context(context):
- if context.project.id != project_id:
- raise exception.NotAuthorized()
- elif not is_admin_context(context):
- raise exception.NotAuthorized()
+ authorize_project_context(context, project_id)
session = get_session()
return session.query(models.FloatingIp
@@ -404,12 +390,9 @@ def floating_ip_get_all_by_project(context, project_id):
).all()
-#@require_context
+@require_context
def floating_ip_get_by_address(context, address, session=None):
# TODO(devcamcar): Ensure the address belongs to user.
- if not is_user_context(context) and not is_admin_context(context):
- raise exception.NotAuthorized()
-
if not session:
session = get_session()
@@ -426,11 +409,8 @@ def floating_ip_get_by_address(context, address, session=None):
###################
-#@require_context
+@require_context
def fixed_ip_associate(context, address, instance_id):
- if not is_user_context(context) and not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
fixed_ip_ref = session.query(models.FixedIp
@@ -449,11 +429,8 @@ def fixed_ip_associate(context, address, instance_id):
session.add(fixed_ip_ref)
-#@require_admin_context
+@require_admin_context
def fixed_ip_associate_pool(context, network_id, instance_id):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
network_or_none = or_(models.FixedIp.network_id == network_id,
@@ -480,7 +457,7 @@ def fixed_ip_associate_pool(context, network_id, instance_id):
return fixed_ip_ref['address']
-#@require_context
+@require_context
def fixed_ip_create(_context, values):
fixed_ip_ref = models.FixedIp()
for (key, value) in values.iteritems():
@@ -489,7 +466,7 @@ def fixed_ip_create(_context, values):
return fixed_ip_ref['address']
-#@require_context
+@require_context
def fixed_ip_disassociate(context, address):
session = get_session()
with session.begin():
@@ -500,7 +477,7 @@ def fixed_ip_disassociate(context, address):
fixed_ip_ref.save(session=session)
-#@require_context
+@require_context
def fixed_ip_get_by_address(context, address, session=None):
# TODO(devcamcar): Ensure floating ip belongs to user.
# Only possible if it is associated with an instance.
@@ -520,19 +497,19 @@ def fixed_ip_get_by_address(context, address, session=None):
return result
-#@require_context
+@require_context
def fixed_ip_get_instance(context, address):
fixed_ip_ref = fixed_ip_get_by_address(context, address)
return fixed_ip_ref.instance
-#@require_admin_context
+@require_admin_context
def fixed_ip_get_network(context, address):
fixed_ip_ref = fixed_ip_get_by_address(context, address)
return fixed_ip_ref.network
-#@require_context
+@require_context
def fixed_ip_update(context, address, values):
session = get_session()
with session.begin():
@@ -547,7 +524,7 @@ def fixed_ip_update(context, address, values):
###################
-#@require_context
+@require_context
def instance_create(context, values):
instance_ref = models.Instance()
for (key, value) in values.iteritems():
@@ -563,7 +540,7 @@ def instance_create(context, values):
return instance_ref
-#@require_admin_context
+@require_admin_context
def instance_data_get_for_project(context, project_id):
session = get_session()
result = session.query(func.count(models.Instance.id),
@@ -575,7 +552,7 @@ def instance_data_get_for_project(context, project_id):
return (result[0] or 0, result[1] or 0)
-#@require_context
+@require_context
def instance_destroy(context, instance_id):
session = get_session()
with session.begin():
@@ -583,7 +560,7 @@ def instance_destroy(context, instance_id):
instance_ref.delete(session=session)
-#@require_context
+@require_context
def instance_get(context, instance_id, session=None):
if not session:
session = get_session()
@@ -606,11 +583,8 @@ def instance_get(context, instance_id, session=None):
return result
-#@require_admin_context
+@require_admin_context
def instance_get_all(context):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
@@ -618,11 +592,8 @@ def instance_get_all(context):
).all()
-#@require_admin_context
+@require_admin_context
def instance_get_all_by_user(context, user_id):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
@@ -631,11 +602,9 @@ def instance_get_all_by_user(context, user_id):
).all()
-#@require_context
+@require_context
def instance_get_all_by_project(context, project_id):
- if is_user_context(context):
- if context.project.id != project_id:
- raise exception.NotAuthorized()
+ authorize_project_context(context, project_id)
session = get_session()
return session.query(models.Instance
@@ -645,7 +614,7 @@ def instance_get_all_by_project(context, project_id):
).all()
-#@require_context
+@require_context
def instance_get_all_by_reservation(context, reservation_id):
session = get_session()
@@ -664,7 +633,7 @@ def instance_get_all_by_reservation(context, reservation_id):
).all()
-#@require_context
+@require_context
def instance_get_by_ec2_id(context, ec2_id):
session = get_session()
@@ -685,14 +654,14 @@ def instance_get_by_ec2_id(context, ec2_id):
return result
-#@require_context
+@require_context
def instance_ec2_id_exists(context, ec2_id, session=None):
if not session:
session = get_session()
return session.query(exists().where(models.Instance.id==ec2_id)).one()[0]
-#@require_context
+@require_context
def instance_get_fixed_address(context, instance_id):
session = get_session()
with session.begin():
@@ -702,7 +671,7 @@ def instance_get_fixed_address(context, instance_id):
return instance_ref.fixed_ip['address']
-#@require_context
+@require_context
def instance_get_floating_address(context, instance_id):
session = get_session()
with session.begin():
@@ -715,20 +684,15 @@ def instance_get_floating_address(context, instance_id):
return instance_ref.fixed_ip.floating_ips[0]['address']
-#@require_admin_context
+@require_admin_context
def instance_is_vpn(context, instance_id):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
# TODO(vish): Move this into image code somewhere
instance_ref = instance_get(context, instance_id)
return instance_ref['image_id'] == FLAGS.vpn_image_id
-#@require_admin_context
+@require_admin_context
def instance_set_state(context, instance_id, state, description=None):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
# TODO(devcamcar): Move this out of models and into driver
from nova.compute import power_state
if not description:
@@ -739,7 +703,7 @@ def instance_set_state(context, instance_id, state, description=None):
'state_description': description})
-#@require_context
+@require_context
def instance_update(context, instance_id, values):
session = get_session()
with session.begin():
@@ -752,7 +716,7 @@ def instance_update(context, instance_id, values):
###################
-#@require_context
+@require_context
def key_pair_create(context, values):
key_pair_ref = models.KeyPair()
for (key, value) in values.iteritems():
@@ -761,11 +725,9 @@ def key_pair_create(context, values):
return key_pair_ref
-#@require_context
+@require_context
def key_pair_destroy(context, user_id, name):
- if is_user_context(context):
- if context.user.id != user_id:
- raise exception.NotAuthorized()
+ authorize_user_context(context, user_id)
session = get_session()
with session.begin():
@@ -773,11 +735,9 @@ def key_pair_destroy(context, user_id, name):
key_pair_ref.delete(session=session)
-#@require_context
+@require_context
def key_pair_destroy_all_by_user(context, user_id):
- if is_user_context(context):
- if context.user.id != user_id:
- raise exception.NotAuthorized()
+ authorize_user_context(context, user_id)
session = get_session()
with session.begin():
@@ -786,11 +746,9 @@ def key_pair_destroy_all_by_user(context, user_id):
{'id': user_id})
-#@require_context
+@require_context
def key_pair_get(context, user_id, name, session=None):
- if is_user_context(context):
- if context.user.id != user_id:
- raise exception.NotAuthorized()
+ authorize_user_context(context, user_id)
if not session:
session = get_session()
@@ -806,11 +764,9 @@ def key_pair_get(context, user_id, name, session=None):
return result
-#@require_context
+@require_context
def key_pair_get_all_by_user(context, user_id):
- if is_user_context(context):
- if context.user.id != user_id:
- raise exception.NotAuthorized()
+ authorize_user_context(context, user_id)
session = get_session()
return session.query(models.KeyPair
@@ -822,22 +778,16 @@ def key_pair_get_all_by_user(context, user_id):
###################
-#@require_admin_context
+@require_admin_context
def network_count(context):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
return session.query(models.Network
).filter_by(deleted=_deleted(context)
).count()
-#@require_admin_context
+@require_admin_context
def network_count_allocated_ips(context, network_id):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
return session.query(models.FixedIp
).filter_by(network_id=network_id
@@ -846,11 +796,8 @@ def network_count_allocated_ips(context, network_id):
).count()
-#@require_admin_context
+@require_admin_context
def network_count_available_ips(context, network_id):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
return session.query(models.FixedIp
).filter_by(network_id=network_id
@@ -860,11 +807,8 @@ def network_count_available_ips(context, network_id):
).count()
-#@require_admin_context
+@require_admin_context
def network_count_reserved_ips(context, network_id):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
return session.query(models.FixedIp
).filter_by(network_id=network_id
@@ -873,11 +817,8 @@ def network_count_reserved_ips(context, network_id):
).count()
-#@require_admin_context
+@require_admin_context
def network_create(context, values):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
network_ref = models.Network()
for (key, value) in values.iteritems():
network_ref[key] = value
@@ -885,11 +826,8 @@ def network_create(context, values):
return network_ref
-#@require_admin_context
+@require_admin_context
def network_destroy(context, network_id):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
# TODO(vish): do we have to use sql here?
@@ -907,7 +845,7 @@ def network_destroy(context, network_id):
{'id': network_id})
-#@require_context
+@require_context
def network_get(context, network_id, session=None):
if not session:
session = get_session()
@@ -933,11 +871,8 @@ def network_get(context, network_id, session=None):
# NOTE(vish): pylint complains because of the long method name, but
# it fits with the names of the rest of the methods
# pylint: disable-msg=C0103
-#@require_admin_context
+@require_admin_context
def network_get_associated_fixed_ips(context, network_id):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
return session.query(models.FixedIp
).options(joinedload_all('instance')
@@ -947,11 +882,8 @@ def network_get_associated_fixed_ips(context, network_id):
).all()
-#@require_admin_context
+@require_admin_context
def network_get_by_bridge(context, bridge):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
result = session.query(models.Network
).filter_by(bridge=bridge
@@ -964,11 +896,8 @@ def network_get_by_bridge(context, bridge):
return result
-#@require_admin_context
+@require_admin_context
def network_get_index(context, network_id):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
network_index = session.query(models.NetworkIndex
@@ -988,22 +917,16 @@ def network_get_index(context, network_id):
return network_index['index']
-#@require_admin_context
+@require_admin_context
def network_index_count(context):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
return session.query(models.NetworkIndex
).filter_by(deleted=_deleted(context)
).count()
-#@require_admin_context
+@require_admin_context
def network_index_create_safe(context, values):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
network_index_ref = models.NetworkIndex()
for (key, value) in values.iteritems():
network_index_ref[key] = value
@@ -1013,11 +936,8 @@ def network_index_create_safe(context, values):
pass
-#@require_admin_context
+@require_admin_context
def network_set_host(context, network_id, host_id):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
network_ref = session.query(models.Network
@@ -1037,7 +957,7 @@ def network_set_host(context, network_id, host_id):
return network_ref['host']
-#@require_context
+@require_context
def network_update(context, network_id, values):
session = get_session()
with session.begin():
@@ -1050,11 +970,8 @@ def network_update(context, network_id, values):
###################
-#@require_context
+@require_context
def project_get_network(context, project_id):
- if not is_admin_context(context) and not is_user_context(context):
- raise error.NotAuthorized()
-
session = get_session()
result= session.query(models.Network
).filter_by(project_id=project_id
@@ -1078,22 +995,16 @@ def queue_get_for(_context, topic, physical_node_id):
###################
-#@require_admin_context
+@require_admin_context
def export_device_count(context):
- if not is_admin_context(context):
- raise exception.notauthorized()
-
session = get_session()
return session.query(models.ExportDevice
).filter_by(deleted=_deleted(context)
).count()
-#@require_admin_context
+@require_admin_context
def export_device_create(context, values):
- if not is_admin_context(context):
- raise exception.notauthorized()
-
export_device_ref = models.ExportDevice()
for (key, value) in values.iteritems():
export_device_ref[key] = value
@@ -1127,11 +1038,8 @@ def auth_create_token(_context, token):
###################
-#@require_admin_context
+@require_admin_context
def quota_get(context, project_id, session=None):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
if not session:
session = get_session()
@@ -1145,11 +1053,8 @@ def quota_get(context, project_id, session=None):
return result
-#@require_admin_context
+@require_admin_context
def quota_create(context, values):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
quota_ref = models.Quota()
for (key, value) in values.iteritems():
quota_ref[key] = value
@@ -1157,11 +1062,8 @@ def quota_create(context, values):
return quota_ref
-#@require_admin_context
+@require_admin_context
def quota_update(context, project_id, values):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
quota_ref = quota_get(context, project_id, session=session)
@@ -1170,11 +1072,8 @@ def quota_update(context, project_id, values):
quota_ref.save(session=session)
-#@require_admin_context
+@require_admin_context
def quota_destroy(context, project_id):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
quota_ref = quota_get(context, project_id, session=session)
@@ -1184,11 +1083,8 @@ def quota_destroy(context, project_id):
###################
-#@require_admin_context
+@require_admin_context
def volume_allocate_shelf_and_blade(context, volume_id):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
export_device = session.query(models.ExportDevice
@@ -1205,11 +1101,8 @@ def volume_allocate_shelf_and_blade(context, volume_id):
return (export_device.shelf_id, export_device.blade_id)
-#@require_admin_context
+@require_admin_context
def volume_attached(context, volume_id, instance_id, mountpoint):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
volume_ref = volume_get(context, volume_id, session=session)
@@ -1220,7 +1113,7 @@ def volume_attached(context, volume_id, instance_id, mountpoint):
volume_ref.save(session=session)
-#@require_context
+@require_context
def volume_create(context, values):
volume_ref = models.Volume()
for (key, value) in values.iteritems():
@@ -1236,11 +1129,8 @@ def volume_create(context, values):
return volume_ref
-#@require_admin_context
+@require_admin_context
def volume_data_get_for_project(context, project_id):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
result = session.query(func.count(models.Volume.id),
func.sum(models.Volume.size)
@@ -1251,11 +1141,8 @@ def volume_data_get_for_project(context, project_id):
return (result[0] or 0, result[1] or 0)
-#@require_admin_context
+@require_admin_context
def volume_destroy(context, volume_id):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
# TODO(vish): do we have to use sql here?
@@ -1266,11 +1153,8 @@ def volume_destroy(context, volume_id):
{'id': volume_id})
-#@require_admin_context
+@require_admin_context
def volume_detached(context, volume_id):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
session = get_session()
with session.begin():
volume_ref = volume_get(context, volume_id, session=session)
@@ -1281,7 +1165,7 @@ def volume_detached(context, volume_id):
volume_ref.save(session=session)
-#@require_context
+@require_context
def volume_get(context, volume_id, session=None):
if not session:
session = get_session()
@@ -1304,22 +1188,15 @@ def volume_get(context, volume_id, session=None):
return result
-#@require_admin_context
+@require_admin_context
def volume_get_all(context):
- if not is_admin_context(context):
- raise exception.NotAuthorized()
-
return session.query(models.Volume
).filter_by(deleted=_deleted(context)
).all()
-#@require_context
+@require_context
def volume_get_all_by_project(context, project_id):
- if is_user_context(context):
- if context.project.id != project_id:
- raise exception.NotAuthorized()
- elif not is_admin_context(context):
- raise exception.NotAuthorized()
+ authorize_project_context(context, project_id)
session = get_session()
return session.query(models.Volume
@@ -1328,7 +1205,7 @@ def volume_get_all_by_project(context, project_id):
).all()
-#@require_context
+@require_context
def volume_get_by_ec2_id(context, ec2_id):
session = get_session()
result = None
@@ -1353,18 +1230,17 @@ def volume_get_by_ec2_id(context, ec2_id):
return result
-#@require_context
+@require_context
def volume_ec2_id_exists(context, ec2_id, session=None):
if not session:
session = get_session()
- if is_admin_context(context) or is_user_context(context):
- return session.query(exists(
- ).where(models.Volume.id==ec2_id)
- ).one()[0]
+ return session.query(exists(
+ ).where(models.Volume.id==ec2_id)
+ ).one()[0]
-#@require_context
+@require_context
def volume_get_instance(context, volume_id):
session = get_session()
result = None
@@ -1390,7 +1266,7 @@ def volume_get_instance(context, volume_id):
return result.instance
-#@require_context
+@require_context
def volume_get_shelf_and_blade(context, volume_id):
session = get_session()
result = None
@@ -1412,7 +1288,7 @@ def volume_get_shelf_and_blade(context, volume_id):
return (result.shelf_id, result.blade_id)
-#@require_context
+@require_context
def volume_update(context, volume_id, values):
session = get_session()
with session.begin():