summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--nova/api/context.py (renamed from nova/api/ec2/context.py)13
-rw-r--r--nova/api/ec2/__init__.py8
-rw-r--r--nova/db/api.py5
-rw-r--r--nova/db/sqlalchemy/api.py721
-rw-r--r--nova/db/sqlalchemy/models.py146
-rw-r--r--nova/network/manager.py4
-rw-r--r--nova/tests/compute_unittest.py6
-rw-r--r--nova/tests/network_unittest.py50
8 files changed, 606 insertions, 347 deletions
diff --git a/nova/api/ec2/context.py b/nova/api/context.py
index c53ba98d9..b66cfe468 100644
--- a/nova/api/ec2/context.py
+++ b/nova/api/context.py
@@ -31,3 +31,16 @@ class APIRequestContext(object):
[random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-')
for x in xrange(20)]
)
+ if user:
+ self.is_admin = user.is_admin()
+ else:
+ self.is_admin = False
+ self.read_deleted = False
+
+
+def get_admin_context(user=None, read_deleted=False):
+ context_ref = APIRequestContext(user=user, project=None)
+ context_ref.is_admin = True
+ context_ref.read_deleted = read_deleted
+ return context_ref
+
diff --git a/nova/api/ec2/__init__.py b/nova/api/ec2/__init__.py
index 7a958f841..6b538a7f1 100644
--- a/nova/api/ec2/__init__.py
+++ b/nova/api/ec2/__init__.py
@@ -27,8 +27,8 @@ import webob.exc
from nova import exception
from nova import flags
from nova import wsgi
+from nova.api import context
from nova.api.ec2 import apirequest
-from nova.api.ec2 import context
from nova.api.ec2 import admin
from nova.api.ec2 import cloud
from nova.auth import manager
@@ -193,15 +193,15 @@ class Authorizer(wsgi.Middleware):
return True
if 'none' in roles:
return False
- return any(context.project.has_role(context.user.id, role)
+ return any(context.project.has_role(context.user.id, role)
for role in roles)
-
+
class Executor(wsgi.Application):
"""Execute an EC2 API request.
- Executes 'ec2.action' upon 'ec2.controller', passing 'ec2.context' and
+ Executes 'ec2.action' upon 'ec2.controller', passing 'ec2.context' and
'ec2.action_args' (all variables in WSGI environ.) Returns an XML
response, or a 400 upon failure.
"""
diff --git a/nova/db/api.py b/nova/db/api.py
index 4aea0e6a4..5c935b561 100644
--- a/nova/db/api.py
+++ b/nova/db/api.py
@@ -175,11 +175,6 @@ def floating_ip_get_by_address(context, address):
return IMPL.floating_ip_get_by_address(context, address)
-def floating_ip_get_instance(context, address):
- """Get an instance for a floating ip by address."""
- return IMPL.floating_ip_get_instance(context, address)
-
-
####################
diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py
index 4aa3c693d..7f72f66b9 100644
--- a/nova/db/sqlalchemy/api.py
+++ b/nova/db/sqlalchemy/api.py
@@ -19,6 +19,8 @@
Implementation of SQLAlchemy backend
"""
+import warnings
+
from nova import db
from nova import exception
from nova import flags
@@ -27,39 +29,108 @@ from nova.db.sqlalchemy import models
from nova.db.sqlalchemy.session import get_session
from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
-from sqlalchemy.orm import joinedload_all
-from sqlalchemy.orm.exc import NoResultFound
+from sqlalchemy.orm import joinedload, joinedload_all
from sqlalchemy.sql import exists, func
FLAGS = flags.FLAGS
-# NOTE(vish): disabling docstring pylint because the docstrings are
-# in the interface definition
-# pylint: disable-msg=C0111
-def _deleted(context):
- """Calculates whether to include deleted objects based on context.
+def is_admin_context(context):
+ """Indicates if the request context is an administrator."""
+ if not context:
+ warnings.warn('Use of empty request context is deprecated',
+ DeprecationWarning)
+ return True
+ return context.is_admin
+
+
+def is_user_context(context):
+ """Indicates if the request context is a normal user."""
+ if not context:
+ return False
+ if not context.user or not context.project:
+ return False
+ return True
+
+
+def authorize_project_context(context, project_id):
+ """Ensures that the request context has permission to access the
+ given project.
+ """
+ if is_user_context(context):
+ if not context.project:
+ raise exception.NotAuthorized()
+ elif context.project.id != project_id:
+ raise exception.NotAuthorized()
- Currently just looks for a flag called deleted in the context dict.
+
+def authorize_user_context(context, user_id):
+ """Ensures that the request context has permission to access the
+ given user.
"""
- if not hasattr(context, 'get'):
+ if is_user_context(context):
+ if not context.user:
+ raise exception.NotAuthorized()
+ elif context.user.id != user_id:
+ raise exception.NotAuthorized()
+
+
+def can_read_deleted(context):
+ """Indicates if the context has access to deleted objects."""
+ if not context:
return False
- return context.get('deleted', False)
+ return context.read_deleted
-###################
+def require_admin_context(f):
+ """Decorator used to indicate that the method requires an
+ administrator context.
+ """
+ def wrapper(*args, **kwargs):
+ if not is_admin_context(args[0]):
+ raise exception.NotAuthorized()
+ return f(*args, **kwargs)
+ return wrapper
+
+
+def require_context(f):
+ """Decorator used to indicate that the method requires either
+ an administrator or normal user context.
+ """
+ def wrapper(*args, **kwargs):
+ if not is_admin_context(args[0]) and not is_user_context(args[0]):
+ raise exception.NotAuthorized()
+ return f(*args, **kwargs)
+ return wrapper
+
+###################
+@require_admin_context
def service_destroy(context, service_id):
session = get_session()
with session.begin():
- service_ref = models.Service.find(service_id, session=session)
+ service_ref = service_get(context, service_id, session=session)
service_ref.delete(session=session)
-def service_get(_context, service_id):
- return models.Service.find(service_id)
+
+@require_admin_context
+def service_get(context, service_id, session=None):
+ if not session:
+ session = get_session()
+
+ result = session.query(models.Service
+ ).filter_by(id=service_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+
+ if not result:
+ raise exception.NotFound('No service for id %s' % service_id)
+
+ return result
+@require_admin_context
def service_get_all_by_topic(context, topic):
session = get_session()
return session.query(models.Service
@@ -69,7 +140,8 @@ def service_get_all_by_topic(context, topic):
).all()
-def _service_get_all_topic_subquery(_context, session, topic, subq, label):
+@require_admin_context
+def _service_get_all_topic_subquery(context, session, topic, subq, label):
sort_value = getattr(subq.c, label)
return session.query(models.Service, func.coalesce(sort_value, 0)
).filter_by(topic=topic
@@ -80,6 +152,7 @@ def _service_get_all_topic_subquery(_context, session, topic, subq, label):
).all()
+@require_admin_context
def service_get_all_compute_sorted(context):
session = get_session()
with session.begin():
@@ -104,6 +177,7 @@ def service_get_all_compute_sorted(context):
label)
+@require_admin_context
def service_get_all_network_sorted(context):
session = get_session()
with session.begin():
@@ -121,6 +195,7 @@ def service_get_all_network_sorted(context):
label)
+@require_admin_context
def service_get_all_volume_sorted(context):
session = get_session()
with session.begin():
@@ -138,11 +213,22 @@ def service_get_all_volume_sorted(context):
label)
-def service_get_by_args(_context, host, binary):
- return models.Service.find_by_args(host, binary)
+@require_admin_context
+def service_get_by_args(context, host, binary):
+ session = get_session()
+ result = session.query(models.Service
+ ).filter_by(host=host
+ ).filter_by(binary=binary
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ if not result:
+ raise exception.NotFound('No service for %s, %s' % (host, binary))
+
+ return result
-def service_create(_context, values):
+@require_admin_context
+def service_create(context, values):
service_ref = models.Service()
for (key, value) in values.iteritems():
service_ref[key] = value
@@ -150,10 +236,11 @@ def service_create(_context, values):
return service_ref
-def service_update(_context, service_id, values):
+@require_admin_context
+def service_update(context, service_id, values):
session = get_session()
with session.begin():
- service_ref = models.Service.find(service_id, session=session)
+ service_ref = session_get(context, service_id, session=session)
for (key, value) in values.iteritems():
service_ref[key] = value
service_ref.save(session=session)
@@ -162,7 +249,9 @@ def service_update(_context, service_id, values):
###################
-def floating_ip_allocate_address(_context, host, project_id):
+@require_context
+def floating_ip_allocate_address(context, host, project_id):
+ authorize_project_context(context, project_id)
session = get_session()
with session.begin():
floating_ip_ref = session.query(models.FloatingIp
@@ -181,7 +270,8 @@ def floating_ip_allocate_address(_context, host, project_id):
return floating_ip_ref['address']
-def floating_ip_create(_context, values):
+@require_context
+def floating_ip_create(context, values):
floating_ip_ref = models.FloatingIp()
for (key, value) in values.iteritems():
floating_ip_ref[key] = value
@@ -189,7 +279,9 @@ def floating_ip_create(_context, values):
return floating_ip_ref['address']
-def floating_ip_count_by_project(_context, project_id):
+@require_context
+def floating_ip_count_by_project(context, project_id):
+ authorize_project_context(context, project_id)
session = get_session()
return session.query(models.FloatingIp
).filter_by(project_id=project_id
@@ -197,39 +289,53 @@ def floating_ip_count_by_project(_context, project_id):
).count()
-def floating_ip_fixed_ip_associate(_context, floating_address, fixed_address):
+@require_context
+def floating_ip_fixed_ip_associate(context, floating_address, fixed_address):
session = get_session()
with session.begin():
- floating_ip_ref = models.FloatingIp.find_by_str(floating_address,
- session=session)
- fixed_ip_ref = models.FixedIp.find_by_str(fixed_address,
- session=session)
+ # TODO(devcamcar): How to ensure floating_id belongs to user?
+ floating_ip_ref = floating_ip_get_by_address(context,
+ floating_address,
+ session=session)
+ fixed_ip_ref = fixed_ip_get_by_address(context,
+ fixed_address,
+ session=session)
floating_ip_ref.fixed_ip = fixed_ip_ref
floating_ip_ref.save(session=session)
-def floating_ip_deallocate(_context, address):
+@require_context
+def floating_ip_deallocate(context, address):
session = get_session()
with session.begin():
- floating_ip_ref = models.FloatingIp.find_by_str(address,
- session=session)
+ # TODO(devcamcar): How to ensure floating id belongs to user?
+ floating_ip_ref = floating_ip_get_by_address(context,
+ address,
+ session=session)
floating_ip_ref['project_id'] = None
floating_ip_ref.save(session=session)
-def floating_ip_destroy(_context, address):
+@require_context
+def floating_ip_destroy(context, address):
session = get_session()
with session.begin():
- floating_ip_ref = models.FloatingIp.find_by_str(address,
- session=session)
+ # TODO(devcamcar): Ensure address belongs to user.
+ floating_ip_ref = get_floating_ip_by_address(context,
+ address,
+ session=session)
floating_ip_ref.delete(session=session)
-def floating_ip_disassociate(_context, address):
+@require_context
+def floating_ip_disassociate(context, address):
session = get_session()
with session.begin():
- floating_ip_ref = models.FloatingIp.find_by_str(address,
- session=session)
+ # TODO(devcamcar): Ensure address belongs to user.
+ # Does get_floating_ip_by_address handle this?
+ floating_ip_ref = floating_ip_get_by_address(context,
+ address,
+ session=session)
fixed_ip_ref = floating_ip_ref.fixed_ip
if fixed_ip_ref:
fixed_ip_address = fixed_ip_ref['address']
@@ -240,7 +346,8 @@ def floating_ip_disassociate(_context, address):
return fixed_ip_address
-def floating_ip_get_all(_context):
+@require_admin_context
+def floating_ip_get_all(context):
session = get_session()
return session.query(models.FloatingIp
).options(joinedload_all('fixed_ip.instance')
@@ -248,7 +355,8 @@ def floating_ip_get_all(_context):
).all()
-def floating_ip_get_all_by_host(_context, host):
+@require_admin_context
+def floating_ip_get_all_by_host(context, host):
session = get_session()
return session.query(models.FloatingIp
).options(joinedload_all('fixed_ip.instance')
@@ -256,7 +364,10 @@ def floating_ip_get_all_by_host(_context, host):
).filter_by(deleted=False
).all()
-def floating_ip_get_all_by_project(_context, project_id):
+
+@require_context
+def floating_ip_get_all_by_project(context, project_id):
+ authorize_project_context(context, project_id)
session = get_session()
return session.query(models.FloatingIp
).options(joinedload_all('fixed_ip.instance')
@@ -264,24 +375,31 @@ def floating_ip_get_all_by_project(_context, project_id):
).filter_by(deleted=False
).all()
-def floating_ip_get_by_address(_context, address):
- return models.FloatingIp.find_by_str(address)
+@require_context
+def floating_ip_get_by_address(context, address, session=None):
+ # TODO(devcamcar): Ensure the address belongs to user.
+ if not session:
+ session = get_session()
-def floating_ip_get_instance(_context, address):
- session = get_session()
- with session.begin():
- floating_ip_ref = models.FloatingIp.find_by_str(address,
- session=session)
- return floating_ip_ref.fixed_ip.instance
+ result = session.query(models.FloatingIp
+ ).filter_by(address=address
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ if not result:
+ raise exception.NotFound('No fixed ip for address %s' % address)
+
+ return result
###################
-def fixed_ip_associate(_context, address, instance_id):
+@require_context
+def fixed_ip_associate(context, address, instance_id):
session = get_session()
with session.begin():
+ instance = instance_get(context, instance_id, session=session)
fixed_ip_ref = session.query(models.FixedIp
).filter_by(address=address
).filter_by(deleted=False
@@ -292,12 +410,12 @@ def fixed_ip_associate(_context, address, instance_id):
# then this has concurrency issues
if not fixed_ip_ref:
raise db.NoMoreAddresses()
- fixed_ip_ref.instance = models.Instance.find(instance_id,
- session=session)
+ fixed_ip_ref.instance = instance
session.add(fixed_ip_ref)
-def fixed_ip_associate_pool(_context, network_id, instance_id):
+@require_admin_context
+def fixed_ip_associate_pool(context, network_id, instance_id):
session = get_session()
with session.begin():
network_or_none = or_(models.FixedIp.network_id == network_id,
@@ -314,14 +432,17 @@ def fixed_ip_associate_pool(_context, network_id, instance_id):
if not fixed_ip_ref:
raise db.NoMoreAddresses()
if not fixed_ip_ref.network:
- fixed_ip_ref.network = models.Network.find(network_id,
- session=session)
- fixed_ip_ref.instance = models.Instance.find(instance_id,
- session=session)
+ fixed_ip_ref.network = network_get(context,
+ network_id,
+ session=session)
+ fixed_ip_ref.instance = instance_get(context,
+ instance_id,
+ session=session)
session.add(fixed_ip_ref)
return fixed_ip_ref['address']
+@require_context
def fixed_ip_create(_context, values):
fixed_ip_ref = models.FixedIp()
for (key, value) in values.iteritems():
@@ -330,14 +451,18 @@ def fixed_ip_create(_context, values):
return fixed_ip_ref['address']
-def fixed_ip_disassociate(_context, address):
+@require_context
+def fixed_ip_disassociate(context, address):
session = get_session()
with session.begin():
- fixed_ip_ref = models.FixedIp.find_by_str(address, session=session)
+ fixed_ip_ref = fixed_ip_get_by_address(context,
+ address,
+ session=session)
fixed_ip_ref.instance = None
fixed_ip_ref.save(session=session)
+@require_admin_context
def fixed_ip_disassociate_all_by_timeout(_context, host, time):
session = get_session()
# NOTE(vish): The nested select is because sqlite doesn't support
@@ -354,34 +479,44 @@ def fixed_ip_disassociate_all_by_timeout(_context, host, time):
return result.rowcount
-def fixed_ip_get_by_address(_context, address):
- session = get_session()
+@require_context
+def fixed_ip_get_by_address(context, address, session=None):
+ if not session:
+ session = get_session()
result = session.query(models.FixedIp
- ).options(joinedload_all('instance')
).filter_by(address=address
- ).filter_by(deleted=False
+ ).filter_by(deleted=can_read_deleted(context)
+ ).options(joinedload('network')
+ ).options(joinedload('instance')
).first()
if not result:
- raise exception.NotFound("No model for address %s" % address)
+ raise exception.NotFound('No floating ip for address %s' % address)
+
+ if is_user_context(context):
+ authorize_project_context(context, result.instance.project_id)
+
return result
-def fixed_ip_get_instance(_context, address):
- session = get_session()
- with session.begin():
- return models.FixedIp.find_by_str(address, session=session).instance
+@require_context
+def fixed_ip_get_instance(context, address):
+ fixed_ip_ref = fixed_ip_get_by_address(context, address)
+ return fixed_ip_ref.instance
-def fixed_ip_get_network(_context, address):
- session = get_session()
- with session.begin():
- return models.FixedIp.find_by_str(address, session=session).network
+@require_admin_context
+def fixed_ip_get_network(context, address):
+ fixed_ip_ref = fixed_ip_get_by_address(context, address)
+ return fixed_ip_ref.network
-def fixed_ip_update(_context, address, values):
+@require_context
+def fixed_ip_update(context, address, values):
session = get_session()
with session.begin():
- fixed_ip_ref = models.FixedIp.find_by_str(address, session=session)
+ fixed_ip_ref = fixed_ip_get_by_address(context,
+ address,
+ session=session)
for (key, value) in values.iteritems():
fixed_ip_ref[key] = value
fixed_ip_ref.save(session=session)
@@ -390,7 +525,8 @@ def fixed_ip_update(_context, address, values):
###################
-def instance_create(_context, values):
+@require_context
+def instance_create(context, values):
instance_ref = models.Instance()
for (key, value) in values.iteritems():
instance_ref[key] = value
@@ -399,12 +535,14 @@ def instance_create(_context, values):
with session.begin():
while instance_ref.ec2_id == None:
ec2_id = utils.generate_uid(instance_ref.__prefix__)
- if not instance_ec2_id_exists(_context, ec2_id, session=session):
+ if not instance_ec2_id_exists(context, ec2_id, session=session):
instance_ref.ec2_id = ec2_id
instance_ref.save(session=session)
return instance_ref
-def instance_data_get_for_project(_context, project_id):
+
+@require_admin_context
+def instance_data_get_for_project(context, project_id):
session = get_session()
result = session.query(func.count(models.Instance.id),
func.sum(models.Instance.vcpus)
@@ -415,81 +553,130 @@ def instance_data_get_for_project(_context, project_id):
return (result[0] or 0, result[1] or 0)
-def instance_destroy(_context, instance_id):
+@require_context
+def instance_destroy(context, instance_id):
session = get_session()
with session.begin():
- instance_ref = models.Instance.find(instance_id, session=session)
+ instance_ref = instance_get(context, instance_id, session=session)
instance_ref.delete(session=session)
-def instance_get(context, instance_id):
- return models.Instance.find(instance_id, deleted=_deleted(context))
+@require_context
+def instance_get(context, instance_id, session=None):
+ if not session:
+ session = get_session()
+ result = None
+
+ if is_admin_context(context):
+ result = session.query(models.Instance
+ ).filter_by(id=instance_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Instance
+ ).filter_by(project_id=context.project.id
+ ).filter_by(id=instance_id
+ ).filter_by(deleted=False
+ ).first()
+ if not result:
+ raise exception.NotFound('No instance for id %s' % instance_id)
+
+ return result
+@require_admin_context
def instance_get_all(context):
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
- ).filter_by(deleted=_deleted(context)
+ ).filter_by(deleted=can_read_deleted(context)
).all()
+
+@require_admin_context
def instance_get_all_by_user(context, user_id):
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
- ).filter_by(deleted=_deleted(context)
+ ).filter_by(deleted=can_read_deleted(context)
).filter_by(user_id=user_id
).all()
+
+@require_context
def instance_get_all_by_project(context, project_id):
+ authorize_project_context(context, project_id)
+
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
).filter_by(project_id=project_id
- ).filter_by(deleted=_deleted(context)
+ ).filter_by(deleted=can_read_deleted(context)
).all()
-def instance_get_all_by_reservation(_context, reservation_id):
+@require_context
+def instance_get_all_by_reservation(context, reservation_id):
session = get_session()
- return session.query(models.Instance
- ).options(joinedload_all('fixed_ip.floating_ips')
- ).filter_by(reservation_id=reservation_id
- ).filter_by(deleted=False
- ).all()
+ if is_admin_context(context):
+ return session.query(models.Instance
+ ).options(joinedload_all('fixed_ip.floating_ips')
+ ).filter_by(reservation_id=reservation_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).all()
+ elif is_user_context(context):
+ return session.query(models.Instance
+ ).options(joinedload_all('fixed_ip.floating_ips')
+ ).filter_by(project_id=context.project.id
+ ).filter_by(reservation_id=reservation_id
+ ).filter_by(deleted=False
+ ).all()
+
+@require_context
def instance_get_by_ec2_id(context, ec2_id):
session = get_session()
- instance_ref = session.query(models.Instance
+
+ if is_admin_context(context):
+ result = session.query(models.Instance
+ ).filter_by(ec2_id=ec2_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Instance
+ ).filter_by(project_id=context.project.id
).filter_by(ec2_id=ec2_id
- ).filter_by(deleted=_deleted(context)
+ ).filter_by(deleted=False
).first()
- if not instance_ref:
+ if not result:
raise exception.NotFound('Instance %s not found' % (ec2_id))
- return instance_ref
+ return result
+@require_context
def instance_ec2_id_exists(context, ec2_id, session=None):
if not session:
session = get_session()
return session.query(exists().where(models.Instance.id==ec2_id)).one()[0]
-def instance_get_fixed_address(_context, instance_id):
+@require_context
+def instance_get_fixed_address(context, instance_id):
session = get_session()
with session.begin():
- instance_ref = models.Instance.find(instance_id, session=session)
+ instance_ref = instance_get(context, instance_id, session=session)
if not instance_ref.fixed_ip:
return None
return instance_ref.fixed_ip['address']
-def instance_get_floating_address(_context, instance_id):
+@require_context
+def instance_get_floating_address(context, instance_id):
session = get_session()
with session.begin():
- instance_ref = models.Instance.find(instance_id, session=session)
+ instance_ref = instance_get(context, instance_id, session=session)
if not instance_ref.fixed_ip:
return None
if not instance_ref.fixed_ip.floating_ips:
@@ -498,12 +685,14 @@ def instance_get_floating_address(_context, instance_id):
return instance_ref.fixed_ip.floating_ips[0]['address']
+@require_admin_context
def instance_is_vpn(context, instance_id):
# TODO(vish): Move this into image code somewhere
instance_ref = instance_get(context, instance_id)
return instance_ref['image_id'] == FLAGS.vpn_image_id
+@require_admin_context
def instance_set_state(context, instance_id, state, description=None):
# TODO(devcamcar): Move this out of models and into driver
from nova.compute import power_state
@@ -515,10 +704,11 @@ def instance_set_state(context, instance_id, state, description=None):
'state_description': description})
-def instance_update(_context, instance_id, values):
+@require_context
+def instance_update(context, instance_id, values):
session = get_session()
with session.begin():
- instance_ref = models.Instance.find(instance_id, session=session)
+ instance_ref = instance_get(context, instance_id, session=session)
for (key, value) in values.iteritems():
instance_ref[key] = value
instance_ref.save(session=session)
@@ -527,7 +717,8 @@ def instance_update(_context, instance_id, values):
###################
-def key_pair_create(_context, values):
+@require_context
+def key_pair_create(context, values):
key_pair_ref = models.KeyPair()
for (key, value) in values.iteritems():
key_pair_ref[key] = value
@@ -535,16 +726,18 @@ def key_pair_create(_context, values):
return key_pair_ref
-def key_pair_destroy(_context, user_id, name):
+@require_context
+def key_pair_destroy(context, user_id, name):
+ authorize_user_context(context, user_id)
session = get_session()
with session.begin():
- key_pair_ref = models.KeyPair.find_by_args(user_id,
- name,
- session=session)
+ key_pair_ref = key_pair_get(context, user_id, name, session=session)
key_pair_ref.delete(session=session)
-def key_pair_destroy_all_by_user(_context, user_id):
+@require_context
+def key_pair_destroy_all_by_user(context, user_id):
+ authorize_user_context(context, user_id)
session = get_session()
with session.begin():
# TODO(vish): do we have to use sql here?
@@ -552,11 +745,27 @@ def key_pair_destroy_all_by_user(_context, user_id):
{'id': user_id})
-def key_pair_get(_context, user_id, name):
- return models.KeyPair.find_by_args(user_id, name)
+@require_context
+def key_pair_get(context, user_id, name, session=None):
+ authorize_user_context(context, user_id)
+
+ if not session:
+ session = get_session()
+
+ result = session.query(models.KeyPair
+ ).filter_by(user_id=user_id
+ ).filter_by(name=name
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ if not result:
+ raise exception.NotFound('no keypair for user %s, name %s' %
+ (user_id, name))
+ return result
-def key_pair_get_all_by_user(_context, user_id):
+@require_context
+def key_pair_get_all_by_user(context, user_id):
+ authorize_user_context(context, user_id)
session = get_session()
return session.query(models.KeyPair
).filter_by(user_id=user_id
@@ -567,11 +776,16 @@ def key_pair_get_all_by_user(_context, user_id):
###################
-def network_count(_context):
- return models.Network.count()
+@require_admin_context
+def network_count(context):
+ session = get_session()
+ return session.query(models.Network
+ ).filter_by(deleted=can_read_deleted(context)
+ ).count()
-def network_count_allocated_ips(_context, network_id):
+@require_admin_context
+def network_count_allocated_ips(context, network_id):
session = get_session()
return session.query(models.FixedIp
).filter_by(network_id=network_id
@@ -580,7 +794,8 @@ def network_count_allocated_ips(_context, network_id):
).count()
-def network_count_available_ips(_context, network_id):
+@require_admin_context
+def network_count_available_ips(context, network_id):
session = get_session()
return session.query(models.FixedIp
).filter_by(network_id=network_id
@@ -590,7 +805,8 @@ def network_count_available_ips(_context, network_id):
).count()
-def network_count_reserved_ips(_context, network_id):
+@require_admin_context
+def network_count_reserved_ips(context, network_id):
session = get_session()
return session.query(models.FixedIp
).filter_by(network_id=network_id
@@ -599,7 +815,8 @@ def network_count_reserved_ips(_context, network_id):
).count()
-def network_create(_context, values):
+@require_admin_context
+def network_create(context, values):
network_ref = models.Network()
for (key, value) in values.iteritems():
network_ref[key] = value
@@ -607,7 +824,8 @@ def network_create(_context, values):
return network_ref
-def network_destroy(_context, network_id):
+@require_admin_context
+def network_destroy(context, network_id):
session = get_session()
with session.begin():
# TODO(vish): do we have to use sql here?
@@ -625,14 +843,34 @@ def network_destroy(_context, network_id):
{'id': network_id})
-def network_get(_context, network_id):
- return models.Network.find(network_id)
+@require_context
+def network_get(context, network_id, session=None):
+ if not session:
+ session = get_session()
+ result = None
+
+ if is_admin_context(context):
+ result = session.query(models.Network
+ ).filter_by(id=network_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Network
+ ).filter_by(project_id=context.project.id
+ ).filter_by(id=network_id
+ ).filter_by(deleted=False
+ ).first()
+ if not result:
+ raise exception.NotFound('No network for id %s' % network_id)
+
+ return result
# NOTE(vish): pylint complains because of the long method name, but
# it fits with the names of the rest of the methods
# pylint: disable-msg=C0103
-def network_get_associated_fixed_ips(_context, network_id):
+@require_admin_context
+def network_get_associated_fixed_ips(context, network_id):
session = get_session()
return session.query(models.FixedIp
).options(joinedload_all('instance')
@@ -642,18 +880,22 @@ def network_get_associated_fixed_ips(_context, network_id):
).all()
-def network_get_by_bridge(_context, bridge):
+@require_admin_context
+def network_get_by_bridge(context, bridge):
session = get_session()
- rv = session.query(models.Network
+ result = session.query(models.Network
).filter_by(bridge=bridge
).filter_by(deleted=False
).first()
- if not rv:
+
+ if not result:
raise exception.NotFound('No network for bridge %s' % bridge)
- return rv
+ return result
-def network_get_index(_context, network_id):
+
+@require_admin_context
+def network_get_index(context, network_id):
session = get_session()
with session.begin():
network_index = session.query(models.NetworkIndex
@@ -661,19 +903,28 @@ def network_get_index(_context, network_id):
).filter_by(deleted=False
).with_lockmode('update'
).first()
+
if not network_index:
raise db.NoMoreNetworks()
- network_index['network'] = models.Network.find(network_id,
- session=session)
+
+ network_index['network'] = network_get(context,
+ network_id,
+ session=session)
session.add(network_index)
+
return network_index['index']
-def network_index_count(_context):
- return models.NetworkIndex.count()
+@require_admin_context
+def network_index_count(context):
+ session = get_session()
+ return session.query(models.NetworkIndex
+ ).filter_by(deleted=can_read_deleted(context)
+ ).count()
-def network_index_create_safe(_context, values):
+@require_admin_context
+def network_index_create_safe(context, values):
network_index_ref = models.NetworkIndex()
for (key, value) in values.iteritems():
network_index_ref[key] = value
@@ -683,29 +934,32 @@ def network_index_create_safe(_context, values):
pass
-def network_set_host(_context, network_id, host_id):
+@require_admin_context
+def network_set_host(context, network_id, host_id):
session = get_session()
with session.begin():
- network = session.query(models.Network
- ).filter_by(id=network_id
- ).filter_by(deleted=False
- ).with_lockmode('update'
- ).first()
- if not network:
- raise exception.NotFound("Couldn't find network with %s" %
- network_id)
+ network_ref = session.query(models.Network
+ ).filter_by(id=network_id
+ ).filter_by(deleted=False
+ ).with_lockmode('update'
+ ).first()
+ if not network_ref:
+ raise exception.NotFound('No network for id %s' % network_id)
+
# NOTE(vish): if with_lockmode isn't supported, as in sqlite,
# then this has concurrency issues
- if not network['host']:
- network['host'] = host_id
- session.add(network)
- return network['host']
+ if not network_ref['host']:
+ network_ref['host'] = host_id
+ session.add(network_ref)
+ return network_ref['host']
-def network_update(_context, network_id, values):
+
+@require_context
+def network_update(context, network_id, values):
session = get_session()
with session.begin():
- network_ref = models.Network.find(network_id, session=session)
+ network_ref = network_get(context, network_id, session=session)
for (key, value) in values.iteritems():
network_ref[key] = value
network_ref.save(session=session)
@@ -714,15 +968,18 @@ def network_update(_context, network_id, values):
###################
-def project_get_network(_context, project_id):
+@require_context
+def project_get_network(context, project_id):
session = get_session()
- rv = session.query(models.Network
+ result= session.query(models.Network
).filter_by(project_id=project_id
).filter_by(deleted=False
).first()
- if not rv:
+
+ if not result:
raise exception.NotFound('No network for project: %s' % project_id)
- return rv
+
+ return result
###################
@@ -732,14 +989,20 @@ def queue_get_for(_context, topic, physical_node_id):
# FIXME(ja): this should be servername?
return "%s.%s" % (topic, physical_node_id)
+
###################
-def export_device_count(_context):
- return models.ExportDevice.count()
+@require_admin_context
+def export_device_count(context):
+ session = get_session()
+ return session.query(models.ExportDevice
+ ).filter_by(deleted=can_read_deleted(context)
+ ).count()
-def export_device_create(_context, values):
+@require_admin_context
+def export_device_create(context, values):
export_device_ref = models.ExportDevice()
for (key, value) in values.iteritems():
export_device_ref[key] = value
@@ -773,7 +1036,23 @@ def auth_create_token(_context, token):
###################
-def quota_create(_context, values):
+@require_admin_context
+def quota_get(context, project_id, session=None):
+ if not session:
+ session = get_session()
+
+ result = session.query(models.Quota
+ ).filter_by(project_id=project_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ if not result:
+ raise exception.NotFound('No quota for project_id %s' % project_id)
+
+ return result
+
+
+@require_admin_context
+def quota_create(context, values):
quota_ref = models.Quota()
for (key, value) in values.iteritems():
quota_ref[key] = value
@@ -781,30 +1060,29 @@ def quota_create(_context, values):
return quota_ref
-def quota_get(_context, project_id):
- return models.Quota.find_by_str(project_id)
-
-
-def quota_update(_context, project_id, values):
+@require_admin_context
+def quota_update(context, project_id, values):
session = get_session()
with session.begin():
- quota_ref = models.Quota.find_by_str(project_id, session=session)
+ quota_ref = quota_get(context, project_id, session=session)
for (key, value) in values.iteritems():
quota_ref[key] = value
quota_ref.save(session=session)
-def quota_destroy(_context, project_id):
+@require_admin_context
+def quota_destroy(context, project_id):
session = get_session()
with session.begin():
- quota_ref = models.Quota.find_by_str(project_id, session=session)
+ quota_ref = quota_get(context, project_id, session=session)
quota_ref.delete(session=session)
###################
-def volume_allocate_shelf_and_blade(_context, volume_id):
+@require_admin_context
+def volume_allocate_shelf_and_blade(context, volume_id):
session = get_session()
with session.begin():
export_device = session.query(models.ExportDevice
@@ -821,19 +1099,20 @@ def volume_allocate_shelf_and_blade(_context, volume_id):
return (export_device.shelf_id, export_device.blade_id)
-def volume_attached(_context, volume_id, instance_id, mountpoint):
+@require_admin_context
+def volume_attached(context, volume_id, instance_id, mountpoint):
session = get_session()
with session.begin():
- volume_ref = models.Volume.find(volume_id, session=session)
+ volume_ref = volume_get(context, volume_id, session=session)
volume_ref['status'] = 'in-use'
volume_ref['mountpoint'] = mountpoint
volume_ref['attach_status'] = 'attached'
- volume_ref.instance = models.Instance.find(instance_id,
- session=session)
+ volume_ref.instance = instance_get(context, instance_id, session=session)
volume_ref.save(session=session)
-def volume_create(_context, values):
+@require_context
+def volume_create(context, values):
volume_ref = models.Volume()
for (key, value) in values.iteritems():
volume_ref[key] = value
@@ -842,13 +1121,14 @@ def volume_create(_context, values):
with session.begin():
while volume_ref.ec2_id == None:
ec2_id = utils.generate_uid(volume_ref.__prefix__)
- if not volume_ec2_id_exists(_context, ec2_id, session=session):
+ if not volume_ec2_id_exists(context, ec2_id, session=session):
volume_ref.ec2_id = ec2_id
volume_ref.save(session=session)
return volume_ref
-def volume_data_get_for_project(_context, project_id):
+@require_admin_context
+def volume_data_get_for_project(context, project_id):
session = get_session()
result = session.query(func.count(models.Volume.id),
func.sum(models.Volume.size)
@@ -859,7 +1139,8 @@ def volume_data_get_for_project(_context, project_id):
return (result[0] or 0, result[1] or 0)
-def volume_destroy(_context, volume_id):
+@require_admin_context
+def volume_destroy(context, volume_id):
session = get_session()
with session.begin():
# TODO(vish): do we have to use sql here?
@@ -870,10 +1151,11 @@ def volume_destroy(_context, volume_id):
{'id': volume_id})
-def volume_detached(_context, volume_id):
+@require_admin_context
+def volume_detached(context, volume_id):
session = get_session()
with session.begin():
- volume_ref = models.Volume.find(volume_id, session=session)
+ volume_ref = volume_get(context, volume_id, session=session)
volume_ref['status'] = 'available'
volume_ref['mountpoint'] = None
volume_ref['attach_status'] = 'detached'
@@ -881,60 +1163,113 @@ def volume_detached(_context, volume_id):
volume_ref.save(session=session)
-def volume_get(context, volume_id):
- return models.Volume.find(volume_id, deleted=_deleted(context))
+@require_context
+def volume_get(context, volume_id, session=None):
+ if not session:
+ session = get_session()
+ result = None
+ if is_admin_context(context):
+ result = session.query(models.Volume
+ ).filter_by(id=volume_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Volume
+ ).filter_by(project_id=context.project.id
+ ).filter_by(id=volume_id
+ ).filter_by(deleted=False
+ ).first()
+ if not result:
+ raise exception.NotFound('No volume for id %s' % volume_id)
-def volume_get_all(context):
- return models.Volume.all(deleted=_deleted(context))
+ return result
+@require_admin_context
+def volume_get_all(context):
+ return session.query(models.Volume
+ ).filter_by(deleted=can_read_deleted(context)
+ ).all()
+
+@require_context
def volume_get_all_by_project(context, project_id):
+ authorize_project_context(context, project_id)
+
session = get_session()
return session.query(models.Volume
).filter_by(project_id=project_id
- ).filter_by(deleted=_deleted(context)
+ ).filter_by(deleted=can_read_deleted(context)
).all()
+@require_context
def volume_get_by_ec2_id(context, ec2_id):
session = get_session()
- volume_ref = session.query(models.Volume
+ result = None
+
+ if is_admin_context(context):
+ result = session.query(models.Volume
+ ).filter_by(ec2_id=ec2_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Volume
+ ).filter_by(project_id=context.project.id
).filter_by(ec2_id=ec2_id
- ).filter_by(deleted=_deleted(context)
+ ).filter_by(deleted=False
).first()
- if not volume_ref:
- raise exception.NotFound('Volume %s not found' % (ec2_id))
+ else:
+ raise exception.NotAuthorized()
- return volume_ref
+ if not result:
+ raise exception.NotFound('Volume %s not found' % ec2_id)
+
+ return result
+@require_context
def volume_ec2_id_exists(context, ec2_id, session=None):
if not session:
session = get_session()
- return session.query(exists().where(models.Volume.id==ec2_id)).one()[0]
+ return session.query(exists(
+ ).where(models.Volume.id==ec2_id)
+ ).one()[0]
-def volume_get_instance(_context, volume_id):
+
+@require_admin_context
+def volume_get_instance(context, volume_id):
session = get_session()
- with session.begin():
- return models.Volume.find(volume_id, session=session).instance
+ result = session.query(models.Volume
+ ).filter_by(id=volume_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).options(joinedload('instance')
+ ).first()
+ if not result:
+ raise exception.NotFound('Volume %s not found' % ec2_id)
+
+ return result.instance
-def volume_get_shelf_and_blade(_context, volume_id):
+@require_admin_context
+def volume_get_shelf_and_blade(context, volume_id):
session = get_session()
- export_device = session.query(models.ExportDevice
- ).filter_by(volume_id=volume_id
- ).first()
- if not export_device:
- raise exception.NotFound()
- return (export_device.shelf_id, export_device.blade_id)
+ result = session.query(models.ExportDevice
+ ).filter_by(volume_id=volume_id
+ ).first()
+ if not result:
+ raise exception.NotFound('No export device found for volume %s' %
+ volume_id)
+
+ return (result.shelf_id, result.blade_id)
-def volume_update(_context, volume_id, values):
+@require_context
+def volume_update(context, volume_id, values):
session = get_session()
with session.begin():
- volume_ref = models.Volume.find(volume_id, session=session)
+ volume_ref = volume_get(context, volume_id, session=session)
for (key, value) in values.iteritems():
volume_ref[key] = value
volume_ref.save(session=session)
diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py
index 6cb377476..1837a7584 100644
--- a/nova/db/sqlalchemy/models.py
+++ b/nova/db/sqlalchemy/models.py
@@ -50,44 +50,6 @@ class NovaBase(object):
deleted_at = Column(DateTime)
deleted = Column(Boolean, default=False)
- @classmethod
- def all(cls, session=None, deleted=False):
- """Get all objects of this type"""
- if not session:
- session = get_session()
- return session.query(cls
- ).filter_by(deleted=deleted
- ).all()
-
- @classmethod
- def count(cls, session=None, deleted=False):
- """Count objects of this type"""
- if not session:
- session = get_session()
- return session.query(cls
- ).filter_by(deleted=deleted
- ).count()
-
- @classmethod
- def find(cls, obj_id, session=None, deleted=False):
- """Find object by id"""
- if not session:
- session = get_session()
- try:
- return session.query(cls
- ).filter_by(id=obj_id
- ).filter_by(deleted=deleted
- ).one()
- except exc.NoResultFound:
- new_exc = exception.NotFound("No model for id %s" % obj_id)
- raise new_exc.__class__, new_exc, sys.exc_info()[2]
-
- @classmethod
- def find_by_str(cls, str_id, session=None, deleted=False):
- """Find object by str_id"""
- int_id = int(str_id.rpartition('-')[2])
- return cls.find(int_id, session=session, deleted=deleted)
-
@property
def str_id(self):
"""Get string id of object (generally prefix + '-' + id)"""
@@ -176,21 +138,6 @@ class Service(BASE, NovaBase):
report_count = Column(Integer, nullable=False, default=0)
disabled = Column(Boolean, default=False)
- @classmethod
- def find_by_args(cls, host, binary, session=None, deleted=False):
- if not session:
- session = get_session()
- try:
- return session.query(cls
- ).filter_by(host=host
- ).filter_by(binary=binary
- ).filter_by(deleted=deleted
- ).one()
- except exc.NoResultFound:
- new_exc = exception.NotFound("No model for %s, %s" % (host,
- binary))
- raise new_exc.__class__, new_exc, sys.exc_info()[2]
-
class Instance(BASE, NovaBase):
"""Represents a guest vm"""
@@ -284,7 +231,11 @@ class Volume(BASE, NovaBase):
size = Column(Integer)
availability_zone = Column(String(255)) # TODO(vish): foreign key?
instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True)
- instance = relationship(Instance, backref=backref('volumes'))
+ instance = relationship(Instance,
+ backref=backref('volumes'),
+ foreign_keys=instance_id,
+ primaryjoin='and_(Volume.instance_id==Instance.id,'
+ 'Volume.deleted==False)')
mountpoint = Column(String(255))
attach_time = Column(String(255)) # TODO(vish): datetime
status = Column(String(255)) # TODO(vish): enum?
@@ -315,18 +266,6 @@ class Quota(BASE, NovaBase):
def str_id(self):
return self.project_id
- @classmethod
- def find_by_str(cls, str_id, session=None, deleted=False):
- if not session:
- session = get_session()
- try:
- return session.query(cls
- ).filter_by(project_id=str_id
- ).filter_by(deleted=deleted
- ).one()
- except exc.NoResultFound:
- new_exc = exception.NotFound("No model for project_id %s" % str_id)
- raise new_exc.__class__, new_exc, sys.exc_info()[2]
class ExportDevice(BASE, NovaBase):
"""Represates a shelf and blade that a volume can be exported on"""
@@ -335,8 +274,11 @@ class ExportDevice(BASE, NovaBase):
shelf_id = Column(Integer)
blade_id = Column(Integer)
volume_id = Column(Integer, ForeignKey('volumes.id'), nullable=True)
- volume = relationship(Volume, backref=backref('export_device',
- uselist=False))
+ volume = relationship(Volume,
+ backref=backref('export_device', uselist=False),
+ foreign_keys=volume_id,
+ primaryjoin='and_(ExportDevice.volume_id==Volume.id,'
+ 'ExportDevice.deleted==False)')
class KeyPair(BASE, NovaBase):
@@ -354,26 +296,6 @@ class KeyPair(BASE, NovaBase):
def str_id(self):
return '%s.%s' % (self.user_id, self.name)
- @classmethod
- def find_by_str(cls, str_id, session=None, deleted=False):
- user_id, _sep, name = str_id.partition('.')
- return cls.find_by_str(user_id, name, session, deleted)
-
- @classmethod
- def find_by_args(cls, user_id, name, session=None, deleted=False):
- if not session:
- session = get_session()
- try:
- return session.query(cls
- ).filter_by(user_id=user_id
- ).filter_by(name=name
- ).filter_by(deleted=deleted
- ).one()
- except exc.NoResultFound:
- new_exc = exception.NotFound("No model for user %s, name %s" %
- (user_id, name))
- raise new_exc.__class__, new_exc, sys.exc_info()[2]
-
class Network(BASE, NovaBase):
"""Represents a network"""
@@ -409,8 +331,12 @@ class NetworkIndex(BASE, NovaBase):
id = Column(Integer, primary_key=True)
index = Column(Integer, unique=True)
network_id = Column(Integer, ForeignKey('networks.id'), nullable=True)
- network = relationship(Network, backref=backref('network_index',
- uselist=False))
+ network = relationship(Network,
+ backref=backref('network_index', uselist=False),
+ foreign_keys=network_id,
+ primaryjoin='and_(NetworkIndex.network_id==Network.id,'
+ 'NetworkIndex.deleted==False)')
+
class AuthToken(BASE, NovaBase):
"""Represents an authorization token for all API transactions. Fields
@@ -434,8 +360,11 @@ class FixedIp(BASE, NovaBase):
network_id = Column(Integer, ForeignKey('networks.id'), nullable=True)
network = relationship(Network, backref=backref('fixed_ips'))
instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True)
- instance = relationship(Instance, backref=backref('fixed_ip',
- uselist=False))
+ instance = relationship(Instance,
+ backref=backref('fixed_ip', uselist=False),
+ foreign_keys=instance_id,
+ primaryjoin='and_(FixedIp.instance_id==Instance.id,'
+ 'FixedIp.deleted==False)')
allocated = Column(Boolean, default=False)
leased = Column(Boolean, default=False)
reserved = Column(Boolean, default=False)
@@ -444,19 +373,6 @@ class FixedIp(BASE, NovaBase):
def str_id(self):
return self.address
- @classmethod
- def find_by_str(cls, str_id, session=None, deleted=False):
- if not session:
- session = get_session()
- try:
- return session.query(cls
- ).filter_by(address=str_id
- ).filter_by(deleted=deleted
- ).one()
- except exc.NoResultFound:
- new_exc = exception.NotFound("No model for address %s" % str_id)
- raise new_exc.__class__, new_exc, sys.exc_info()[2]
-
class FloatingIp(BASE, NovaBase):
"""Represents a floating ip that dynamically forwards to a fixed ip"""
@@ -464,24 +380,14 @@ class FloatingIp(BASE, NovaBase):
id = Column(Integer, primary_key=True)
address = Column(String(255))
fixed_ip_id = Column(Integer, ForeignKey('fixed_ips.id'), nullable=True)
- fixed_ip = relationship(FixedIp, backref=backref('floating_ips'))
-
+ fixed_ip = relationship(FixedIp,
+ backref=backref('floating_ips'),
+ foreign_keys=fixed_ip_id,
+ primaryjoin='and_(FloatingIp.fixed_ip_id==FixedIp.id,'
+ 'FloatingIp.deleted==False)')
project_id = Column(String(255))
host = Column(String(255)) # , ForeignKey('hosts.id'))
- @classmethod
- def find_by_str(cls, str_id, session=None, deleted=False):
- if not session:
- session = get_session()
- try:
- return session.query(cls
- ).filter_by(address=str_id
- ).filter_by(deleted=deleted
- ).one()
- except exc.NoResultFound:
- new_exc = exception.NotFound("No model for address %s" % str_id)
- raise new_exc.__class__, new_exc, sys.exc_info()[2]
-
def register_models():
"""Register Models and create metadata"""
diff --git a/nova/network/manager.py b/nova/network/manager.py
index 1325c300b..ef1d01138 100644
--- a/nova/network/manager.py
+++ b/nova/network/manager.py
@@ -92,7 +92,7 @@ class NetworkManager(manager.Manager):
# TODO(vish): can we minimize db access by just getting the
# id here instead of the ref?
network_id = network_ref['id']
- host = self.db.network_set_host(context,
+ host = self.db.network_set_host(None,
network_id,
self.host)
self._on_set_network_host(context, network_id)
@@ -249,7 +249,7 @@ class VlanManager(NetworkManager):
address = network_ref['vpn_private_address']
self.db.fixed_ip_associate(context, address, instance_id)
else:
- address = self.db.fixed_ip_associate_pool(context,
+ address = self.db.fixed_ip_associate_pool(None,
network_ref['id'],
instance_id)
self.db.fixed_ip_update(context, address, {'allocated': True})
diff --git a/nova/tests/compute_unittest.py b/nova/tests/compute_unittest.py
index f5c0f1c09..1e2bb113b 100644
--- a/nova/tests/compute_unittest.py
+++ b/nova/tests/compute_unittest.py
@@ -30,7 +30,7 @@ from nova import flags
from nova import test
from nova import utils
from nova.auth import manager
-
+from nova.api import context
FLAGS = flags.FLAGS
@@ -96,7 +96,9 @@ class ComputeTestCase(test.TrialTestCase):
self.assertEqual(instance_ref['deleted_at'], None)
terminate = datetime.datetime.utcnow()
yield self.compute.terminate_instance(self.context, instance_id)
- instance_ref = db.instance_get({'deleted': True}, instance_id)
+ self.context = context.get_admin_context(user=self.user,
+ read_deleted=True)
+ instance_ref = db.instance_get(self.context, instance_id)
self.assert_(instance_ref['launched_at'] < terminate)
self.assert_(instance_ref['deleted_at'] > terminate)
diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py
index da65b50a2..5370966d2 100644
--- a/nova/tests/network_unittest.py
+++ b/nova/tests/network_unittest.py
@@ -56,12 +56,12 @@ class NetworkTestCase(test.TrialTestCase):
'netuser',
name))
# create the necessary network data for the project
- self.network.set_network_host(self.context, self.projects[i].id)
- instance_ref = db.instance_create(None,
- {'mac_address': utils.generate_mac()})
+ user_context = context.APIRequestContext(project=self.projects[i],
+ user=self.user)
+ self.network.set_network_host(user_context, self.projects[i].id)
+ instance_ref = self._create_instance(0)
self.instance_id = instance_ref['id']
- instance_ref = db.instance_create(None,
- {'mac_address': utils.generate_mac()})
+ instance_ref = self._create_instance(1)
self.instance2_id = instance_ref['id']
def tearDown(self): # pylint: disable-msg=C0103
@@ -74,6 +74,15 @@ class NetworkTestCase(test.TrialTestCase):
self.manager.delete_project(project)
self.manager.delete_user(self.user)
+ def _create_instance(self, project_num, mac=None):
+ if not mac:
+ mac = utils.generate_mac()
+ project = self.projects[project_num]
+ self.context.project = project
+ return db.instance_create(self.context,
+ {'project_id': project.id,
+ 'mac_address': mac})
+
def _create_address(self, project_num, instance_id=None):
"""Create an address in given project num"""
if instance_id is None:
@@ -81,9 +90,15 @@ class NetworkTestCase(test.TrialTestCase):
self.context.project = self.projects[project_num]
return self.network.allocate_fixed_ip(self.context, instance_id)
+ def _deallocate_address(self, project_num, address):
+ self.context.project = self.projects[project_num]
+ self.network.deallocate_fixed_ip(self.context, address)
+
+
def test_public_network_association(self):
"""Makes sure that we can allocaate a public ip"""
# TODO(vish): better way of adding floating ips
+ self.context.project = self.projects[0]
pubnet = IPy.IP(flags.FLAGS.public_range)
address = str(pubnet[0])
try:
@@ -109,7 +124,7 @@ class NetworkTestCase(test.TrialTestCase):
address = self._create_address(0)
self.assertTrue(is_allocated_in_project(address, self.projects[0].id))
lease_ip(address)
- self.network.deallocate_fixed_ip(self.context, address)
+ self._deallocate_address(0, address)
# Doesn't go away until it's dhcp released
self.assertTrue(is_allocated_in_project(address, self.projects[0].id))
@@ -130,14 +145,14 @@ class NetworkTestCase(test.TrialTestCase):
lease_ip(address)
lease_ip(address2)
- self.network.deallocate_fixed_ip(self.context, address)
+ self._deallocate_address(0, address)
release_ip(address)
self.assertFalse(is_allocated_in_project(address, self.projects[0].id))
# First address release shouldn't affect the second
self.assertTrue(is_allocated_in_project(address2, self.projects[1].id))
- self.network.deallocate_fixed_ip(self.context, address2)
+ self._deallocate_address(1, address2)
release_ip(address2)
self.assertFalse(is_allocated_in_project(address2,
self.projects[1].id))
@@ -148,24 +163,19 @@ class NetworkTestCase(test.TrialTestCase):
lease_ip(first)
instance_ids = []
for i in range(1, 5):
- mac = utils.generate_mac()
- instance_ref = db.instance_create(None,
- {'mac_address': mac})
+ instance_ref = self._create_instance(i, mac=utils.generate_mac())
instance_ids.append(instance_ref['id'])
address = self._create_address(i, instance_ref['id'])
- mac = utils.generate_mac()
- instance_ref = db.instance_create(None,
- {'mac_address': mac})
+ instance_ref = self._create_instance(i, mac=utils.generate_mac())
instance_ids.append(instance_ref['id'])
address2 = self._create_address(i, instance_ref['id'])
- mac = utils.generate_mac()
- instance_ref = db.instance_create(None,
- {'mac_address': mac})
+ instance_ref = self._create_instance(i, mac=utils.generate_mac())
instance_ids.append(instance_ref['id'])
address3 = self._create_address(i, instance_ref['id'])
lease_ip(address)
lease_ip(address2)
lease_ip(address3)
+ self.context.project = self.projects[i]
self.assertFalse(is_allocated_in_project(address,
self.projects[0].id))
self.assertFalse(is_allocated_in_project(address2,
@@ -181,7 +191,7 @@ class NetworkTestCase(test.TrialTestCase):
for instance_id in instance_ids:
db.instance_destroy(None, instance_id)
release_ip(first)
- self.network.deallocate_fixed_ip(self.context, first)
+ self._deallocate_address(0, first)
def test_vpn_ip_and_port_looks_valid(self):
"""Ensure the vpn ip and port are reasonable"""
@@ -242,9 +252,7 @@ class NetworkTestCase(test.TrialTestCase):
addresses = []
instance_ids = []
for i in range(num_available_ips):
- mac = utils.generate_mac()
- instance_ref = db.instance_create(None,
- {'mac_address': mac})
+ instance_ref = self._create_instance(0)
instance_ids.append(instance_ref['id'])
address = self._create_address(0, instance_ref['id'])
addresses.append(address)