summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDevin Carlen <devin.carlen@gmail.com>2010-10-01 15:08:19 +0000
committerTarmac <>2010-10-01 15:08:19 +0000
commit4d13a8554459638387d772a23fffe6aaaab3348d (patch)
treef8299aef65d76d3e003fe0c42d94c9df66eb5ee9
parentc9cb22f87561fad4ba57865d8a614ca024393f13 (diff)
parent4b3d4eb51a5d927a8eecdca550e04fc699443d21 (diff)
downloadnova-4d13a8554459638387d772a23fffe6aaaab3348d.tar.gz
nova-4d13a8554459638387d772a23fffe6aaaab3348d.tar.xz
nova-4d13a8554459638387d772a23fffe6aaaab3348d.zip
Refactor sqlalchemy api to perform contextual authorization.
All database calls now examine the context object for information about what kind of user is accessing the data. If an administrator is accessing, full privileges are granted. If a normal user is accessing, then checks are made to ensure that the user does indeed have the rights to the data. Also refactored NovaBase and removed several methods since they would have to be changed when we move away from sqlalchemy models and begin using sqlalchemy table definitions.
-rw-r--r--nova/api/context.py (renamed from nova/api/ec2/context.py)13
-rw-r--r--nova/api/ec2/__init__.py8
-rw-r--r--nova/db/api.py5
-rw-r--r--nova/db/sqlalchemy/api.py721
-rw-r--r--nova/db/sqlalchemy/models.py146
-rw-r--r--nova/network/manager.py4
-rw-r--r--nova/tests/compute_unittest.py6
-rw-r--r--nova/tests/network_unittest.py50
8 files changed, 606 insertions, 347 deletions
diff --git a/nova/api/ec2/context.py b/nova/api/context.py
index c53ba98d9..b66cfe468 100644
--- a/nova/api/ec2/context.py
+++ b/nova/api/context.py
@@ -31,3 +31,16 @@ class APIRequestContext(object):
[random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-')
for x in xrange(20)]
)
+ if user:
+ self.is_admin = user.is_admin()
+ else:
+ self.is_admin = False
+ self.read_deleted = False
+
+
+def get_admin_context(user=None, read_deleted=False):
+ context_ref = APIRequestContext(user=user, project=None)
+ context_ref.is_admin = True
+ context_ref.read_deleted = read_deleted
+ return context_ref
+
diff --git a/nova/api/ec2/__init__.py b/nova/api/ec2/__init__.py
index 7a958f841..6b538a7f1 100644
--- a/nova/api/ec2/__init__.py
+++ b/nova/api/ec2/__init__.py
@@ -27,8 +27,8 @@ import webob.exc
from nova import exception
from nova import flags
from nova import wsgi
+from nova.api import context
from nova.api.ec2 import apirequest
-from nova.api.ec2 import context
from nova.api.ec2 import admin
from nova.api.ec2 import cloud
from nova.auth import manager
@@ -193,15 +193,15 @@ class Authorizer(wsgi.Middleware):
return True
if 'none' in roles:
return False
- return any(context.project.has_role(context.user.id, role)
+ return any(context.project.has_role(context.user.id, role)
for role in roles)
-
+
class Executor(wsgi.Application):
"""Execute an EC2 API request.
- Executes 'ec2.action' upon 'ec2.controller', passing 'ec2.context' and
+ Executes 'ec2.action' upon 'ec2.controller', passing 'ec2.context' and
'ec2.action_args' (all variables in WSGI environ.) Returns an XML
response, or a 400 upon failure.
"""
diff --git a/nova/db/api.py b/nova/db/api.py
index 4aea0e6a4..5c935b561 100644
--- a/nova/db/api.py
+++ b/nova/db/api.py
@@ -175,11 +175,6 @@ def floating_ip_get_by_address(context, address):
return IMPL.floating_ip_get_by_address(context, address)
-def floating_ip_get_instance(context, address):
- """Get an instance for a floating ip by address."""
- return IMPL.floating_ip_get_instance(context, address)
-
-
####################
diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py
index 4aa3c693d..7f72f66b9 100644
--- a/nova/db/sqlalchemy/api.py
+++ b/nova/db/sqlalchemy/api.py
@@ -19,6 +19,8 @@
Implementation of SQLAlchemy backend
"""
+import warnings
+
from nova import db
from nova import exception
from nova import flags
@@ -27,39 +29,108 @@ from nova.db.sqlalchemy import models
from nova.db.sqlalchemy.session import get_session
from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
-from sqlalchemy.orm import joinedload_all
-from sqlalchemy.orm.exc import NoResultFound
+from sqlalchemy.orm import joinedload, joinedload_all
from sqlalchemy.sql import exists, func
FLAGS = flags.FLAGS
-# NOTE(vish): disabling docstring pylint because the docstrings are
-# in the interface definition
-# pylint: disable-msg=C0111
-def _deleted(context):
- """Calculates whether to include deleted objects based on context.
+def is_admin_context(context):
+ """Indicates if the request context is an administrator."""
+ if not context:
+ warnings.warn('Use of empty request context is deprecated',
+ DeprecationWarning)
+ return True
+ return context.is_admin
+
+
+def is_user_context(context):
+ """Indicates if the request context is a normal user."""
+ if not context:
+ return False
+ if not context.user or not context.project:
+ return False
+ return True
+
+
+def authorize_project_context(context, project_id):
+ """Ensures that the request context has permission to access the
+ given project.
+ """
+ if is_user_context(context):
+ if not context.project:
+ raise exception.NotAuthorized()
+ elif context.project.id != project_id:
+ raise exception.NotAuthorized()
- Currently just looks for a flag called deleted in the context dict.
+
+def authorize_user_context(context, user_id):
+ """Ensures that the request context has permission to access the
+ given user.
"""
- if not hasattr(context, 'get'):
+ if is_user_context(context):
+ if not context.user:
+ raise exception.NotAuthorized()
+ elif context.user.id != user_id:
+ raise exception.NotAuthorized()
+
+
+def can_read_deleted(context):
+ """Indicates if the context has access to deleted objects."""
+ if not context:
return False
- return context.get('deleted', False)
+ return context.read_deleted
-###################
+def require_admin_context(f):
+ """Decorator used to indicate that the method requires an
+ administrator context.
+ """
+ def wrapper(*args, **kwargs):
+ if not is_admin_context(args[0]):
+ raise exception.NotAuthorized()
+ return f(*args, **kwargs)
+ return wrapper
+
+
+def require_context(f):
+ """Decorator used to indicate that the method requires either
+ an administrator or normal user context.
+ """
+ def wrapper(*args, **kwargs):
+ if not is_admin_context(args[0]) and not is_user_context(args[0]):
+ raise exception.NotAuthorized()
+ return f(*args, **kwargs)
+ return wrapper
+
+###################
+@require_admin_context
def service_destroy(context, service_id):
session = get_session()
with session.begin():
- service_ref = models.Service.find(service_id, session=session)
+ service_ref = service_get(context, service_id, session=session)
service_ref.delete(session=session)
-def service_get(_context, service_id):
- return models.Service.find(service_id)
+
+@require_admin_context
+def service_get(context, service_id, session=None):
+ if not session:
+ session = get_session()
+
+ result = session.query(models.Service
+ ).filter_by(id=service_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+
+ if not result:
+ raise exception.NotFound('No service for id %s' % service_id)
+
+ return result
+@require_admin_context
def service_get_all_by_topic(context, topic):
session = get_session()
return session.query(models.Service
@@ -69,7 +140,8 @@ def service_get_all_by_topic(context, topic):
).all()
-def _service_get_all_topic_subquery(_context, session, topic, subq, label):
+@require_admin_context
+def _service_get_all_topic_subquery(context, session, topic, subq, label):
sort_value = getattr(subq.c, label)
return session.query(models.Service, func.coalesce(sort_value, 0)
).filter_by(topic=topic
@@ -80,6 +152,7 @@ def _service_get_all_topic_subquery(_context, session, topic, subq, label):
).all()
+@require_admin_context
def service_get_all_compute_sorted(context):
session = get_session()
with session.begin():
@@ -104,6 +177,7 @@ def service_get_all_compute_sorted(context):
label)
+@require_admin_context
def service_get_all_network_sorted(context):
session = get_session()
with session.begin():
@@ -121,6 +195,7 @@ def service_get_all_network_sorted(context):
label)
+@require_admin_context
def service_get_all_volume_sorted(context):
session = get_session()
with session.begin():
@@ -138,11 +213,22 @@ def service_get_all_volume_sorted(context):
label)
-def service_get_by_args(_context, host, binary):
- return models.Service.find_by_args(host, binary)
+@require_admin_context
+def service_get_by_args(context, host, binary):
+ session = get_session()
+ result = session.query(models.Service
+ ).filter_by(host=host
+ ).filter_by(binary=binary
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ if not result:
+ raise exception.NotFound('No service for %s, %s' % (host, binary))
+
+ return result
-def service_create(_context, values):
+@require_admin_context
+def service_create(context, values):
service_ref = models.Service()
for (key, value) in values.iteritems():
service_ref[key] = value
@@ -150,10 +236,11 @@ def service_create(_context, values):
return service_ref
-def service_update(_context, service_id, values):
+@require_admin_context
+def service_update(context, service_id, values):
session = get_session()
with session.begin():
- service_ref = models.Service.find(service_id, session=session)
+ service_ref = session_get(context, service_id, session=session)
for (key, value) in values.iteritems():
service_ref[key] = value
service_ref.save(session=session)
@@ -162,7 +249,9 @@ def service_update(_context, service_id, values):
###################
-def floating_ip_allocate_address(_context, host, project_id):
+@require_context
+def floating_ip_allocate_address(context, host, project_id):
+ authorize_project_context(context, project_id)
session = get_session()
with session.begin():
floating_ip_ref = session.query(models.FloatingIp
@@ -181,7 +270,8 @@ def floating_ip_allocate_address(_context, host, project_id):
return floating_ip_ref['address']
-def floating_ip_create(_context, values):
+@require_context
+def floating_ip_create(context, values):
floating_ip_ref = models.FloatingIp()
for (key, value) in values.iteritems():
floating_ip_ref[key] = value
@@ -189,7 +279,9 @@ def floating_ip_create(_context, values):
return floating_ip_ref['address']
-def floating_ip_count_by_project(_context, project_id):
+@require_context
+def floating_ip_count_by_project(context, project_id):
+ authorize_project_context(context, project_id)
session = get_session()
return session.query(models.FloatingIp
).filter_by(project_id=project_id
@@ -197,39 +289,53 @@ def floating_ip_count_by_project(_context, project_id):
).count()
-def floating_ip_fixed_ip_associate(_context, floating_address, fixed_address):
+@require_context
+def floating_ip_fixed_ip_associate(context, floating_address, fixed_address):
session = get_session()
with session.begin():
- floating_ip_ref = models.FloatingIp.find_by_str(floating_address,
- session=session)
- fixed_ip_ref = models.FixedIp.find_by_str(fixed_address,
- session=session)
+ # TODO(devcamcar): How to ensure floating_id belongs to user?
+ floating_ip_ref = floating_ip_get_by_address(context,
+ floating_address,
+ session=session)
+ fixed_ip_ref = fixed_ip_get_by_address(context,
+ fixed_address,
+ session=session)
floating_ip_ref.fixed_ip = fixed_ip_ref
floating_ip_ref.save(session=session)
-def floating_ip_deallocate(_context, address):
+@require_context
+def floating_ip_deallocate(context, address):
session = get_session()
with session.begin():
- floating_ip_ref = models.FloatingIp.find_by_str(address,
- session=session)
+ # TODO(devcamcar): How to ensure floating id belongs to user?
+ floating_ip_ref = floating_ip_get_by_address(context,
+ address,
+ session=session)
floating_ip_ref['project_id'] = None
floating_ip_ref.save(session=session)
-def floating_ip_destroy(_context, address):
+@require_context
+def floating_ip_destroy(context, address):
session = get_session()
with session.begin():
- floating_ip_ref = models.FloatingIp.find_by_str(address,
- session=session)
+ # TODO(devcamcar): Ensure address belongs to user.
+ floating_ip_ref = get_floating_ip_by_address(context,
+ address,
+ session=session)
floating_ip_ref.delete(session=session)
-def floating_ip_disassociate(_context, address):
+@require_context
+def floating_ip_disassociate(context, address):
session = get_session()
with session.begin():
- floating_ip_ref = models.FloatingIp.find_by_str(address,
- session=session)
+ # TODO(devcamcar): Ensure address belongs to user.
+ # Does get_floating_ip_by_address handle this?
+ floating_ip_ref = floating_ip_get_by_address(context,
+ address,
+ session=session)
fixed_ip_ref = floating_ip_ref.fixed_ip
if fixed_ip_ref:
fixed_ip_address = fixed_ip_ref['address']
@@ -240,7 +346,8 @@ def floating_ip_disassociate(_context, address):
return fixed_ip_address
-def floating_ip_get_all(_context):
+@require_admin_context
+def floating_ip_get_all(context):
session = get_session()
return session.query(models.FloatingIp
).options(joinedload_all('fixed_ip.instance')
@@ -248,7 +355,8 @@ def floating_ip_get_all(_context):
).all()
-def floating_ip_get_all_by_host(_context, host):
+@require_admin_context
+def floating_ip_get_all_by_host(context, host):
session = get_session()
return session.query(models.FloatingIp
).options(joinedload_all('fixed_ip.instance')
@@ -256,7 +364,10 @@ def floating_ip_get_all_by_host(_context, host):
).filter_by(deleted=False
).all()
-def floating_ip_get_all_by_project(_context, project_id):
+
+@require_context
+def floating_ip_get_all_by_project(context, project_id):
+ authorize_project_context(context, project_id)
session = get_session()
return session.query(models.FloatingIp
).options(joinedload_all('fixed_ip.instance')
@@ -264,24 +375,31 @@ def floating_ip_get_all_by_project(_context, project_id):
).filter_by(deleted=False
).all()
-def floating_ip_get_by_address(_context, address):
- return models.FloatingIp.find_by_str(address)
+@require_context
+def floating_ip_get_by_address(context, address, session=None):
+ # TODO(devcamcar): Ensure the address belongs to user.
+ if not session:
+ session = get_session()
-def floating_ip_get_instance(_context, address):
- session = get_session()
- with session.begin():
- floating_ip_ref = models.FloatingIp.find_by_str(address,
- session=session)
- return floating_ip_ref.fixed_ip.instance
+ result = session.query(models.FloatingIp
+ ).filter_by(address=address
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ if not result:
+ raise exception.NotFound('No fixed ip for address %s' % address)
+
+ return result
###################
-def fixed_ip_associate(_context, address, instance_id):
+@require_context
+def fixed_ip_associate(context, address, instance_id):
session = get_session()
with session.begin():
+ instance = instance_get(context, instance_id, session=session)
fixed_ip_ref = session.query(models.FixedIp
).filter_by(address=address
).filter_by(deleted=False
@@ -292,12 +410,12 @@ def fixed_ip_associate(_context, address, instance_id):
# then this has concurrency issues
if not fixed_ip_ref:
raise db.NoMoreAddresses()
- fixed_ip_ref.instance = models.Instance.find(instance_id,
- session=session)
+ fixed_ip_ref.instance = instance
session.add(fixed_ip_ref)
-def fixed_ip_associate_pool(_context, network_id, instance_id):
+@require_admin_context
+def fixed_ip_associate_pool(context, network_id, instance_id):
session = get_session()
with session.begin():
network_or_none = or_(models.FixedIp.network_id == network_id,
@@ -314,14 +432,17 @@ def fixed_ip_associate_pool(_context, network_id, instance_id):
if not fixed_ip_ref:
raise db.NoMoreAddresses()
if not fixed_ip_ref.network:
- fixed_ip_ref.network = models.Network.find(network_id,
- session=session)
- fixed_ip_ref.instance = models.Instance.find(instance_id,
- session=session)
+ fixed_ip_ref.network = network_get(context,
+ network_id,
+ session=session)
+ fixed_ip_ref.instance = instance_get(context,
+ instance_id,
+ session=session)
session.add(fixed_ip_ref)
return fixed_ip_ref['address']
+@require_context
def fixed_ip_create(_context, values):
fixed_ip_ref = models.FixedIp()
for (key, value) in values.iteritems():
@@ -330,14 +451,18 @@ def fixed_ip_create(_context, values):
return fixed_ip_ref['address']
-def fixed_ip_disassociate(_context, address):
+@require_context
+def fixed_ip_disassociate(context, address):
session = get_session()
with session.begin():
- fixed_ip_ref = models.FixedIp.find_by_str(address, session=session)
+ fixed_ip_ref = fixed_ip_get_by_address(context,
+ address,
+ session=session)
fixed_ip_ref.instance = None
fixed_ip_ref.save(session=session)
+@require_admin_context
def fixed_ip_disassociate_all_by_timeout(_context, host, time):
session = get_session()
# NOTE(vish): The nested select is because sqlite doesn't support
@@ -354,34 +479,44 @@ def fixed_ip_disassociate_all_by_timeout(_context, host, time):
return result.rowcount
-def fixed_ip_get_by_address(_context, address):
- session = get_session()
+@require_context
+def fixed_ip_get_by_address(context, address, session=None):
+ if not session:
+ session = get_session()
result = session.query(models.FixedIp
- ).options(joinedload_all('instance')
).filter_by(address=address
- ).filter_by(deleted=False
+ ).filter_by(deleted=can_read_deleted(context)
+ ).options(joinedload('network')
+ ).options(joinedload('instance')
).first()
if not result:
- raise exception.NotFound("No model for address %s" % address)
+ raise exception.NotFound('No floating ip for address %s' % address)
+
+ if is_user_context(context):
+ authorize_project_context(context, result.instance.project_id)
+
return result
-def fixed_ip_get_instance(_context, address):
- session = get_session()
- with session.begin():
- return models.FixedIp.find_by_str(address, session=session).instance
+@require_context
+def fixed_ip_get_instance(context, address):
+ fixed_ip_ref = fixed_ip_get_by_address(context, address)
+ return fixed_ip_ref.instance
-def fixed_ip_get_network(_context, address):
- session = get_session()
- with session.begin():
- return models.FixedIp.find_by_str(address, session=session).network
+@require_admin_context
+def fixed_ip_get_network(context, address):
+ fixed_ip_ref = fixed_ip_get_by_address(context, address)
+ return fixed_ip_ref.network
-def fixed_ip_update(_context, address, values):
+@require_context
+def fixed_ip_update(context, address, values):
session = get_session()
with session.begin():
- fixed_ip_ref = models.FixedIp.find_by_str(address, session=session)
+ fixed_ip_ref = fixed_ip_get_by_address(context,
+ address,
+ session=session)
for (key, value) in values.iteritems():
fixed_ip_ref[key] = value
fixed_ip_ref.save(session=session)
@@ -390,7 +525,8 @@ def fixed_ip_update(_context, address, values):
###################
-def instance_create(_context, values):
+@require_context
+def instance_create(context, values):
instance_ref = models.Instance()
for (key, value) in values.iteritems():
instance_ref[key] = value
@@ -399,12 +535,14 @@ def instance_create(_context, values):
with session.begin():
while instance_ref.ec2_id == None:
ec2_id = utils.generate_uid(instance_ref.__prefix__)
- if not instance_ec2_id_exists(_context, ec2_id, session=session):
+ if not instance_ec2_id_exists(context, ec2_id, session=session):
instance_ref.ec2_id = ec2_id
instance_ref.save(session=session)
return instance_ref
-def instance_data_get_for_project(_context, project_id):
+
+@require_admin_context
+def instance_data_get_for_project(context, project_id):
session = get_session()
result = session.query(func.count(models.Instance.id),
func.sum(models.Instance.vcpus)
@@ -415,81 +553,130 @@ def instance_data_get_for_project(_context, project_id):
return (result[0] or 0, result[1] or 0)
-def instance_destroy(_context, instance_id):
+@require_context
+def instance_destroy(context, instance_id):
session = get_session()
with session.begin():
- instance_ref = models.Instance.find(instance_id, session=session)
+ instance_ref = instance_get(context, instance_id, session=session)
instance_ref.delete(session=session)
-def instance_get(context, instance_id):
- return models.Instance.find(instance_id, deleted=_deleted(context))
+@require_context
+def instance_get(context, instance_id, session=None):
+ if not session:
+ session = get_session()
+ result = None
+
+ if is_admin_context(context):
+ result = session.query(models.Instance
+ ).filter_by(id=instance_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Instance
+ ).filter_by(project_id=context.project.id
+ ).filter_by(id=instance_id
+ ).filter_by(deleted=False
+ ).first()
+ if not result:
+ raise exception.NotFound('No instance for id %s' % instance_id)
+
+ return result
+@require_admin_context
def instance_get_all(context):
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
- ).filter_by(deleted=_deleted(context)
+ ).filter_by(deleted=can_read_deleted(context)
).all()
+
+@require_admin_context
def instance_get_all_by_user(context, user_id):
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
- ).filter_by(deleted=_deleted(context)
+ ).filter_by(deleted=can_read_deleted(context)
).filter_by(user_id=user_id
).all()
+
+@require_context
def instance_get_all_by_project(context, project_id):
+ authorize_project_context(context, project_id)
+
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
).filter_by(project_id=project_id
- ).filter_by(deleted=_deleted(context)
+ ).filter_by(deleted=can_read_deleted(context)
).all()
-def instance_get_all_by_reservation(_context, reservation_id):
+@require_context
+def instance_get_all_by_reservation(context, reservation_id):
session = get_session()
- return session.query(models.Instance
- ).options(joinedload_all('fixed_ip.floating_ips')
- ).filter_by(reservation_id=reservation_id
- ).filter_by(deleted=False
- ).all()
+ if is_admin_context(context):
+ return session.query(models.Instance
+ ).options(joinedload_all('fixed_ip.floating_ips')
+ ).filter_by(reservation_id=reservation_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).all()
+ elif is_user_context(context):
+ return session.query(models.Instance
+ ).options(joinedload_all('fixed_ip.floating_ips')
+ ).filter_by(project_id=context.project.id
+ ).filter_by(reservation_id=reservation_id
+ ).filter_by(deleted=False
+ ).all()
+
+@require_context
def instance_get_by_ec2_id(context, ec2_id):
session = get_session()
- instance_ref = session.query(models.Instance
+
+ if is_admin_context(context):
+ result = session.query(models.Instance
+ ).filter_by(ec2_id=ec2_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Instance
+ ).filter_by(project_id=context.project.id
).filter_by(ec2_id=ec2_id
- ).filter_by(deleted=_deleted(context)
+ ).filter_by(deleted=False
).first()
- if not instance_ref:
+ if not result:
raise exception.NotFound('Instance %s not found' % (ec2_id))
- return instance_ref
+ return result
+@require_context
def instance_ec2_id_exists(context, ec2_id, session=None):
if not session:
session = get_session()
return session.query(exists().where(models.Instance.id==ec2_id)).one()[0]
-def instance_get_fixed_address(_context, instance_id):
+@require_context
+def instance_get_fixed_address(context, instance_id):
session = get_session()
with session.begin():
- instance_ref = models.Instance.find(instance_id, session=session)
+ instance_ref = instance_get(context, instance_id, session=session)
if not instance_ref.fixed_ip:
return None
return instance_ref.fixed_ip['address']
-def instance_get_floating_address(_context, instance_id):
+@require_context
+def instance_get_floating_address(context, instance_id):
session = get_session()
with session.begin():
- instance_ref = models.Instance.find(instance_id, session=session)
+ instance_ref = instance_get(context, instance_id, session=session)
if not instance_ref.fixed_ip:
return None
if not instance_ref.fixed_ip.floating_ips:
@@ -498,12 +685,14 @@ def instance_get_floating_address(_context, instance_id):
return instance_ref.fixed_ip.floating_ips[0]['address']
+@require_admin_context
def instance_is_vpn(context, instance_id):
# TODO(vish): Move this into image code somewhere
instance_ref = instance_get(context, instance_id)
return instance_ref['image_id'] == FLAGS.vpn_image_id
+@require_admin_context
def instance_set_state(context, instance_id, state, description=None):
# TODO(devcamcar): Move this out of models and into driver
from nova.compute import power_state
@@ -515,10 +704,11 @@ def instance_set_state(context, instance_id, state, description=None):
'state_description': description})
-def instance_update(_context, instance_id, values):
+@require_context
+def instance_update(context, instance_id, values):
session = get_session()
with session.begin():
- instance_ref = models.Instance.find(instance_id, session=session)
+ instance_ref = instance_get(context, instance_id, session=session)
for (key, value) in values.iteritems():
instance_ref[key] = value
instance_ref.save(session=session)
@@ -527,7 +717,8 @@ def instance_update(_context, instance_id, values):
###################
-def key_pair_create(_context, values):
+@require_context
+def key_pair_create(context, values):
key_pair_ref = models.KeyPair()
for (key, value) in values.iteritems():
key_pair_ref[key] = value
@@ -535,16 +726,18 @@ def key_pair_create(_context, values):
return key_pair_ref
-def key_pair_destroy(_context, user_id, name):
+@require_context
+def key_pair_destroy(context, user_id, name):
+ authorize_user_context(context, user_id)
session = get_session()
with session.begin():
- key_pair_ref = models.KeyPair.find_by_args(user_id,
- name,
- session=session)
+ key_pair_ref = key_pair_get(context, user_id, name, session=session)
key_pair_ref.delete(session=session)
-def key_pair_destroy_all_by_user(_context, user_id):
+@require_context
+def key_pair_destroy_all_by_user(context, user_id):
+ authorize_user_context(context, user_id)
session = get_session()
with session.begin():
# TODO(vish): do we have to use sql here?
@@ -552,11 +745,27 @@ def key_pair_destroy_all_by_user(_context, user_id):
{'id': user_id})
-def key_pair_get(_context, user_id, name):
- return models.KeyPair.find_by_args(user_id, name)
+@require_context
+def key_pair_get(context, user_id, name, session=None):
+ authorize_user_context(context, user_id)
+
+ if not session:
+ session = get_session()
+
+ result = session.query(models.KeyPair
+ ).filter_by(user_id=user_id
+ ).filter_by(name=name
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ if not result:
+ raise exception.NotFound('no keypair for user %s, name %s' %
+ (user_id, name))
+ return result
-def key_pair_get_all_by_user(_context, user_id):
+@require_context
+def key_pair_get_all_by_user(context, user_id):
+ authorize_user_context(context, user_id)
session = get_session()
return session.query(models.KeyPair
).filter_by(user_id=user_id
@@ -567,11 +776,16 @@ def key_pair_get_all_by_user(_context, user_id):
###################
-def network_count(_context):
- return models.Network.count()
+@require_admin_context
+def network_count(context):
+ session = get_session()
+ return session.query(models.Network
+ ).filter_by(deleted=can_read_deleted(context)
+ ).count()
-def network_count_allocated_ips(_context, network_id):
+@require_admin_context
+def network_count_allocated_ips(context, network_id):
session = get_session()
return session.query(models.FixedIp
).filter_by(network_id=network_id
@@ -580,7 +794,8 @@ def network_count_allocated_ips(_context, network_id):
).count()
-def network_count_available_ips(_context, network_id):
+@require_admin_context
+def network_count_available_ips(context, network_id):
session = get_session()
return session.query(models.FixedIp
).filter_by(network_id=network_id
@@ -590,7 +805,8 @@ def network_count_available_ips(_context, network_id):
).count()
-def network_count_reserved_ips(_context, network_id):
+@require_admin_context
+def network_count_reserved_ips(context, network_id):
session = get_session()
return session.query(models.FixedIp
).filter_by(network_id=network_id
@@ -599,7 +815,8 @@ def network_count_reserved_ips(_context, network_id):
).count()
-def network_create(_context, values):
+@require_admin_context
+def network_create(context, values):
network_ref = models.Network()
for (key, value) in values.iteritems():
network_ref[key] = value
@@ -607,7 +824,8 @@ def network_create(_context, values):
return network_ref
-def network_destroy(_context, network_id):
+@require_admin_context
+def network_destroy(context, network_id):
session = get_session()
with session.begin():
# TODO(vish): do we have to use sql here?
@@ -625,14 +843,34 @@ def network_destroy(_context, network_id):
{'id': network_id})
-def network_get(_context, network_id):
- return models.Network.find(network_id)
+@require_context
+def network_get(context, network_id, session=None):
+ if not session:
+ session = get_session()
+ result = None
+
+ if is_admin_context(context):
+ result = session.query(models.Network
+ ).filter_by(id=network_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Network
+ ).filter_by(project_id=context.project.id
+ ).filter_by(id=network_id
+ ).filter_by(deleted=False
+ ).first()
+ if not result:
+ raise exception.NotFound('No network for id %s' % network_id)
+
+ return result
# NOTE(vish): pylint complains because of the long method name, but
# it fits with the names of the rest of the methods
# pylint: disable-msg=C0103
-def network_get_associated_fixed_ips(_context, network_id):
+@require_admin_context
+def network_get_associated_fixed_ips(context, network_id):
session = get_session()
return session.query(models.FixedIp
).options(joinedload_all('instance')
@@ -642,18 +880,22 @@ def network_get_associated_fixed_ips(_context, network_id):
).all()
-def network_get_by_bridge(_context, bridge):
+@require_admin_context
+def network_get_by_bridge(context, bridge):
session = get_session()
- rv = session.query(models.Network
+ result = session.query(models.Network
).filter_by(bridge=bridge
).filter_by(deleted=False
).first()
- if not rv:
+
+ if not result:
raise exception.NotFound('No network for bridge %s' % bridge)
- return rv
+ return result
-def network_get_index(_context, network_id):
+
+@require_admin_context
+def network_get_index(context, network_id):
session = get_session()
with session.begin():
network_index = session.query(models.NetworkIndex
@@ -661,19 +903,28 @@ def network_get_index(_context, network_id):
).filter_by(deleted=False
).with_lockmode('update'
).first()
+
if not network_index:
raise db.NoMoreNetworks()
- network_index['network'] = models.Network.find(network_id,
- session=session)
+
+ network_index['network'] = network_get(context,
+ network_id,
+ session=session)
session.add(network_index)
+
return network_index['index']
-def network_index_count(_context):
- return models.NetworkIndex.count()
+@require_admin_context
+def network_index_count(context):
+ session = get_session()
+ return session.query(models.NetworkIndex
+ ).filter_by(deleted=can_read_deleted(context)
+ ).count()
-def network_index_create_safe(_context, values):
+@require_admin_context
+def network_index_create_safe(context, values):
network_index_ref = models.NetworkIndex()
for (key, value) in values.iteritems():
network_index_ref[key] = value
@@ -683,29 +934,32 @@ def network_index_create_safe(_context, values):
pass
-def network_set_host(_context, network_id, host_id):
+@require_admin_context
+def network_set_host(context, network_id, host_id):
session = get_session()
with session.begin():
- network = session.query(models.Network
- ).filter_by(id=network_id
- ).filter_by(deleted=False
- ).with_lockmode('update'
- ).first()
- if not network:
- raise exception.NotFound("Couldn't find network with %s" %
- network_id)
+ network_ref = session.query(models.Network
+ ).filter_by(id=network_id
+ ).filter_by(deleted=False
+ ).with_lockmode('update'
+ ).first()
+ if not network_ref:
+ raise exception.NotFound('No network for id %s' % network_id)
+
# NOTE(vish): if with_lockmode isn't supported, as in sqlite,
# then this has concurrency issues
- if not network['host']:
- network['host'] = host_id
- session.add(network)
- return network['host']
+ if not network_ref['host']:
+ network_ref['host'] = host_id
+ session.add(network_ref)
+ return network_ref['host']
-def network_update(_context, network_id, values):
+
+@require_context
+def network_update(context, network_id, values):
session = get_session()
with session.begin():
- network_ref = models.Network.find(network_id, session=session)
+ network_ref = network_get(context, network_id, session=session)
for (key, value) in values.iteritems():
network_ref[key] = value
network_ref.save(session=session)
@@ -714,15 +968,18 @@ def network_update(_context, network_id, values):
###################
-def project_get_network(_context, project_id):
+@require_context
+def project_get_network(context, project_id):
session = get_session()
- rv = session.query(models.Network
+ result= session.query(models.Network
).filter_by(project_id=project_id
).filter_by(deleted=False
).first()
- if not rv:
+
+ if not result:
raise exception.NotFound('No network for project: %s' % project_id)
- return rv
+
+ return result
###################
@@ -732,14 +989,20 @@ def queue_get_for(_context, topic, physical_node_id):
# FIXME(ja): this should be servername?
return "%s.%s" % (topic, physical_node_id)
+
###################
-def export_device_count(_context):
- return models.ExportDevice.count()
+@require_admin_context
+def export_device_count(context):
+ session = get_session()
+ return session.query(models.ExportDevice
+ ).filter_by(deleted=can_read_deleted(context)
+ ).count()
-def export_device_create(_context, values):
+@require_admin_context
+def export_device_create(context, values):
export_device_ref = models.ExportDevice()
for (key, value) in values.iteritems():
export_device_ref[key] = value
@@ -773,7 +1036,23 @@ def auth_create_token(_context, token):
###################
-def quota_create(_context, values):
+@require_admin_context
+def quota_get(context, project_id, session=None):
+ if not session:
+ session = get_session()
+
+ result = session.query(models.Quota
+ ).filter_by(project_id=project_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ if not result:
+ raise exception.NotFound('No quota for project_id %s' % project_id)
+
+ return result
+
+
+@require_admin_context
+def quota_create(context, values):
quota_ref = models.Quota()
for (key, value) in values.iteritems():
quota_ref[key] = value
@@ -781,30 +1060,29 @@ def quota_create(_context, values):
return quota_ref
-def quota_get(_context, project_id):
- return models.Quota.find_by_str(project_id)
-
-
-def quota_update(_context, project_id, values):
+@require_admin_context
+def quota_update(context, project_id, values):
session = get_session()
with session.begin():
- quota_ref = models.Quota.find_by_str(project_id, session=session)
+ quota_ref = quota_get(context, project_id, session=session)
for (key, value) in values.iteritems():
quota_ref[key] = value
quota_ref.save(session=session)
-def quota_destroy(_context, project_id):
+@require_admin_context
+def quota_destroy(context, project_id):
session = get_session()
with session.begin():
- quota_ref = models.Quota.find_by_str(project_id, session=session)
+ quota_ref = quota_get(context, project_id, session=session)
quota_ref.delete(session=session)
###################
-def volume_allocate_shelf_and_blade(_context, volume_id):
+@require_admin_context
+def volume_allocate_shelf_and_blade(context, volume_id):
session = get_session()
with session.begin():
export_device = session.query(models.ExportDevice
@@ -821,19 +1099,20 @@ def volume_allocate_shelf_and_blade(_context, volume_id):
return (export_device.shelf_id, export_device.blade_id)
-def volume_attached(_context, volume_id, instance_id, mountpoint):
+@require_admin_context
+def volume_attached(context, volume_id, instance_id, mountpoint):
session = get_session()
with session.begin():
- volume_ref = models.Volume.find(volume_id, session=session)
+ volume_ref = volume_get(context, volume_id, session=session)
volume_ref['status'] = 'in-use'
volume_ref['mountpoint'] = mountpoint
volume_ref['attach_status'] = 'attached'
- volume_ref.instance = models.Instance.find(instance_id,
- session=session)
+ volume_ref.instance = instance_get(context, instance_id, session=session)
volume_ref.save(session=session)
-def volume_create(_context, values):
+@require_context
+def volume_create(context, values):
volume_ref = models.Volume()
for (key, value) in values.iteritems():
volume_ref[key] = value
@@ -842,13 +1121,14 @@ def volume_create(_context, values):
with session.begin():
while volume_ref.ec2_id == None:
ec2_id = utils.generate_uid(volume_ref.__prefix__)
- if not volume_ec2_id_exists(_context, ec2_id, session=session):
+ if not volume_ec2_id_exists(context, ec2_id, session=session):
volume_ref.ec2_id = ec2_id
volume_ref.save(session=session)
return volume_ref
-def volume_data_get_for_project(_context, project_id):
+@require_admin_context
+def volume_data_get_for_project(context, project_id):
session = get_session()
result = session.query(func.count(models.Volume.id),
func.sum(models.Volume.size)
@@ -859,7 +1139,8 @@ def volume_data_get_for_project(_context, project_id):
return (result[0] or 0, result[1] or 0)
-def volume_destroy(_context, volume_id):
+@require_admin_context
+def volume_destroy(context, volume_id):
session = get_session()
with session.begin():
# TODO(vish): do we have to use sql here?
@@ -870,10 +1151,11 @@ def volume_destroy(_context, volume_id):
{'id': volume_id})
-def volume_detached(_context, volume_id):
+@require_admin_context
+def volume_detached(context, volume_id):
session = get_session()
with session.begin():
- volume_ref = models.Volume.find(volume_id, session=session)
+ volume_ref = volume_get(context, volume_id, session=session)
volume_ref['status'] = 'available'
volume_ref['mountpoint'] = None
volume_ref['attach_status'] = 'detached'
@@ -881,60 +1163,113 @@ def volume_detached(_context, volume_id):
volume_ref.save(session=session)
-def volume_get(context, volume_id):
- return models.Volume.find(volume_id, deleted=_deleted(context))
+@require_context
+def volume_get(context, volume_id, session=None):
+ if not session:
+ session = get_session()
+ result = None
+ if is_admin_context(context):
+ result = session.query(models.Volume
+ ).filter_by(id=volume_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Volume
+ ).filter_by(project_id=context.project.id
+ ).filter_by(id=volume_id
+ ).filter_by(deleted=False
+ ).first()
+ if not result:
+ raise exception.NotFound('No volume for id %s' % volume_id)
-def volume_get_all(context):
- return models.Volume.all(deleted=_deleted(context))
+ return result
+@require_admin_context
+def volume_get_all(context):
+ return session.query(models.Volume
+ ).filter_by(deleted=can_read_deleted(context)
+ ).all()
+
+@require_context
def volume_get_all_by_project(context, project_id):
+ authorize_project_context(context, project_id)
+
session = get_session()
return session.query(models.Volume
).filter_by(project_id=project_id
- ).filter_by(deleted=_deleted(context)
+ ).filter_by(deleted=can_read_deleted(context)
).all()
+@require_context
def volume_get_by_ec2_id(context, ec2_id):
session = get_session()
- volume_ref = session.query(models.Volume
+ result = None
+
+ if is_admin_context(context):
+ result = session.query(models.Volume
+ ).filter_by(ec2_id=ec2_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Volume
+ ).filter_by(project_id=context.project.id
).filter_by(ec2_id=ec2_id
- ).filter_by(deleted=_deleted(context)
+ ).filter_by(deleted=False
).first()
- if not volume_ref:
- raise exception.NotFound('Volume %s not found' % (ec2_id))
+ else:
+ raise exception.NotAuthorized()
- return volume_ref
+ if not result:
+ raise exception.NotFound('Volume %s not found' % ec2_id)
+
+ return result
+@require_context
def volume_ec2_id_exists(context, ec2_id, session=None):
if not session:
session = get_session()
- return session.query(exists().where(models.Volume.id==ec2_id)).one()[0]
+ return session.query(exists(
+ ).where(models.Volume.id==ec2_id)
+ ).one()[0]
-def volume_get_instance(_context, volume_id):
+
+@require_admin_context
+def volume_get_instance(context, volume_id):
session = get_session()
- with session.begin():
- return models.Volume.find(volume_id, session=session).instance
+ result = session.query(models.Volume
+ ).filter_by(id=volume_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).options(joinedload('instance')
+ ).first()
+ if not result:
+ raise exception.NotFound('Volume %s not found' % ec2_id)
+
+ return result.instance
-def volume_get_shelf_and_blade(_context, volume_id):
+@require_admin_context
+def volume_get_shelf_and_blade(context, volume_id):
session = get_session()
- export_device = session.query(models.ExportDevice
- ).filter_by(volume_id=volume_id
- ).first()
- if not export_device:
- raise exception.NotFound()
- return (export_device.shelf_id, export_device.blade_id)
+ result = session.query(models.ExportDevice
+ ).filter_by(volume_id=volume_id
+ ).first()
+ if not result:
+ raise exception.NotFound('No export device found for volume %s' %
+ volume_id)
+
+ return (result.shelf_id, result.blade_id)
-def volume_update(_context, volume_id, values):
+@require_context
+def volume_update(context, volume_id, values):
session = get_session()
with session.begin():
- volume_ref = models.Volume.find(volume_id, session=session)
+ volume_ref = volume_get(context, volume_id, session=session)
for (key, value) in values.iteritems():
volume_ref[key] = value
volume_ref.save(session=session)
diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py
index 6cb377476..1837a7584 100644
--- a/nova/db/sqlalchemy/models.py
+++ b/nova/db/sqlalchemy/models.py
@@ -50,44 +50,6 @@ class NovaBase(object):
deleted_at = Column(DateTime)
deleted = Column(Boolean, default=False)
- @classmethod
- def all(cls, session=None, deleted=False):
- """Get all objects of this type"""
- if not session:
- session = get_session()
- return session.query(cls
- ).filter_by(deleted=deleted
- ).all()
-
- @classmethod
- def count(cls, session=None, deleted=False):
- """Count objects of this type"""
- if not session:
- session = get_session()
- return session.query(cls
- ).filter_by(deleted=deleted
- ).count()
-
- @classmethod
- def find(cls, obj_id, session=None, deleted=False):
- """Find object by id"""
- if not session:
- session = get_session()
- try:
- return session.query(cls
- ).filter_by(id=obj_id
- ).filter_by(deleted=deleted
- ).one()
- except exc.NoResultFound:
- new_exc = exception.NotFound("No model for id %s" % obj_id)
- raise new_exc.__class__, new_exc, sys.exc_info()[2]
-
- @classmethod
- def find_by_str(cls, str_id, session=None, deleted=False):
- """Find object by str_id"""
- int_id = int(str_id.rpartition('-')[2])
- return cls.find(int_id, session=session, deleted=deleted)
-
@property
def str_id(self):
"""Get string id of object (generally prefix + '-' + id)"""
@@ -176,21 +138,6 @@ class Service(BASE, NovaBase):
report_count = Column(Integer, nullable=False, default=0)
disabled = Column(Boolean, default=False)
- @classmethod
- def find_by_args(cls, host, binary, session=None, deleted=False):
- if not session:
- session = get_session()
- try:
- return session.query(cls
- ).filter_by(host=host
- ).filter_by(binary=binary
- ).filter_by(deleted=deleted
- ).one()
- except exc.NoResultFound:
- new_exc = exception.NotFound("No model for %s, %s" % (host,
- binary))
- raise new_exc.__class__, new_exc, sys.exc_info()[2]
-
class Instance(BASE, NovaBase):
"""Represents a guest vm"""
@@ -284,7 +231,11 @@ class Volume(BASE, NovaBase):
size = Column(Integer)
availability_zone = Column(String(255)) # TODO(vish): foreign key?
instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True)
- instance = relationship(Instance, backref=backref('volumes'))
+ instance = relationship(Instance,
+ backref=backref('volumes'),
+ foreign_keys=instance_id,
+ primaryjoin='and_(Volume.instance_id==Instance.id,'
+ 'Volume.deleted==False)')
mountpoint = Column(String(255))
attach_time = Column(String(255)) # TODO(vish): datetime
status = Column(String(255)) # TODO(vish): enum?
@@ -315,18 +266,6 @@ class Quota(BASE, NovaBase):
def str_id(self):
return self.project_id
- @classmethod
- def find_by_str(cls, str_id, session=None, deleted=False):
- if not session:
- session = get_session()
- try:
- return session.query(cls
- ).filter_by(project_id=str_id
- ).filter_by(deleted=deleted
- ).one()
- except exc.NoResultFound:
- new_exc = exception.NotFound("No model for project_id %s" % str_id)
- raise new_exc.__class__, new_exc, sys.exc_info()[2]
class ExportDevice(BASE, NovaBase):
"""Represates a shelf and blade that a volume can be exported on"""
@@ -335,8 +274,11 @@ class ExportDevice(BASE, NovaBase):
shelf_id = Column(Integer)
blade_id = Column(Integer)
volume_id = Column(Integer, ForeignKey('volumes.id'), nullable=True)
- volume = relationship(Volume, backref=backref('export_device',
- uselist=False))
+ volume = relationship(Volume,
+ backref=backref('export_device', uselist=False),
+ foreign_keys=volume_id,
+ primaryjoin='and_(ExportDevice.volume_id==Volume.id,'
+ 'ExportDevice.deleted==False)')
class KeyPair(BASE, NovaBase):
@@ -354,26 +296,6 @@ class KeyPair(BASE, NovaBase):
def str_id(self):
return '%s.%s' % (self.user_id, self.name)
- @classmethod
- def find_by_str(cls, str_id, session=None, deleted=False):
- user_id, _sep, name = str_id.partition('.')
- return cls.find_by_str(user_id, name, session, deleted)
-
- @classmethod
- def find_by_args(cls, user_id, name, session=None, deleted=False):
- if not session:
- session = get_session()
- try:
- return session.query(cls
- ).filter_by(user_id=user_id
- ).filter_by(name=name
- ).filter_by(deleted=deleted
- ).one()
- except exc.NoResultFound:
- new_exc = exception.NotFound("No model for user %s, name %s" %
- (user_id, name))
- raise new_exc.__class__, new_exc, sys.exc_info()[2]
-
class Network(BASE, NovaBase):
"""Represents a network"""
@@ -409,8 +331,12 @@ class NetworkIndex(BASE, NovaBase):
id = Column(Integer, primary_key=True)
index = Column(Integer, unique=True)
network_id = Column(Integer, ForeignKey('networks.id'), nullable=True)
- network = relationship(Network, backref=backref('network_index',
- uselist=False))
+ network = relationship(Network,
+ backref=backref('network_index', uselist=False),
+ foreign_keys=network_id,
+ primaryjoin='and_(NetworkIndex.network_id==Network.id,'
+ 'NetworkIndex.deleted==False)')
+
class AuthToken(BASE, NovaBase):
"""Represents an authorization token for all API transactions. Fields
@@ -434,8 +360,11 @@ class FixedIp(BASE, NovaBase):
network_id = Column(Integer, ForeignKey('networks.id'), nullable=True)
network = relationship(Network, backref=backref('fixed_ips'))
instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True)
- instance = relationship(Instance, backref=backref('fixed_ip',
- uselist=False))
+ instance = relationship(Instance,
+ backref=backref('fixed_ip', uselist=False),
+ foreign_keys=instance_id,
+ primaryjoin='and_(FixedIp.instance_id==Instance.id,'
+ 'FixedIp.deleted==False)')
allocated = Column(Boolean, default=False)
leased = Column(Boolean, default=False)
reserved = Column(Boolean, default=False)
@@ -444,19 +373,6 @@ class FixedIp(BASE, NovaBase):
def str_id(self):
return self.address
- @classmethod
- def find_by_str(cls, str_id, session=None, deleted=False):
- if not session:
- session = get_session()
- try:
- return session.query(cls
- ).filter_by(address=str_id
- ).filter_by(deleted=deleted
- ).one()
- except exc.NoResultFound:
- new_exc = exception.NotFound("No model for address %s" % str_id)
- raise new_exc.__class__, new_exc, sys.exc_info()[2]
-
class FloatingIp(BASE, NovaBase):
"""Represents a floating ip that dynamically forwards to a fixed ip"""
@@ -464,24 +380,14 @@ class FloatingIp(BASE, NovaBase):
id = Column(Integer, primary_key=True)
address = Column(String(255))
fixed_ip_id = Column(Integer, ForeignKey('fixed_ips.id'), nullable=True)
- fixed_ip = relationship(FixedIp, backref=backref('floating_ips'))
-
+ fixed_ip = relationship(FixedIp,
+ backref=backref('floating_ips'),
+ foreign_keys=fixed_ip_id,
+ primaryjoin='and_(FloatingIp.fixed_ip_id==FixedIp.id,'
+ 'FloatingIp.deleted==False)')
project_id = Column(String(255))
host = Column(String(255)) # , ForeignKey('hosts.id'))
- @classmethod
- def find_by_str(cls, str_id, session=None, deleted=False):
- if not session:
- session = get_session()
- try:
- return session.query(cls
- ).filter_by(address=str_id
- ).filter_by(deleted=deleted
- ).one()
- except exc.NoResultFound:
- new_exc = exception.NotFound("No model for address %s" % str_id)
- raise new_exc.__class__, new_exc, sys.exc_info()[2]
-
def register_models():
"""Register Models and create metadata"""
diff --git a/nova/network/manager.py b/nova/network/manager.py
index 1325c300b..ef1d01138 100644
--- a/nova/network/manager.py
+++ b/nova/network/manager.py
@@ -92,7 +92,7 @@ class NetworkManager(manager.Manager):
# TODO(vish): can we minimize db access by just getting the
# id here instead of the ref?
network_id = network_ref['id']
- host = self.db.network_set_host(context,
+ host = self.db.network_set_host(None,
network_id,
self.host)
self._on_set_network_host(context, network_id)
@@ -249,7 +249,7 @@ class VlanManager(NetworkManager):
address = network_ref['vpn_private_address']
self.db.fixed_ip_associate(context, address, instance_id)
else:
- address = self.db.fixed_ip_associate_pool(context,
+ address = self.db.fixed_ip_associate_pool(None,
network_ref['id'],
instance_id)
self.db.fixed_ip_update(context, address, {'allocated': True})
diff --git a/nova/tests/compute_unittest.py b/nova/tests/compute_unittest.py
index f5c0f1c09..1e2bb113b 100644
--- a/nova/tests/compute_unittest.py
+++ b/nova/tests/compute_unittest.py
@@ -30,7 +30,7 @@ from nova import flags
from nova import test
from nova import utils
from nova.auth import manager
-
+from nova.api import context
FLAGS = flags.FLAGS
@@ -96,7 +96,9 @@ class ComputeTestCase(test.TrialTestCase):
self.assertEqual(instance_ref['deleted_at'], None)
terminate = datetime.datetime.utcnow()
yield self.compute.terminate_instance(self.context, instance_id)
- instance_ref = db.instance_get({'deleted': True}, instance_id)
+ self.context = context.get_admin_context(user=self.user,
+ read_deleted=True)
+ instance_ref = db.instance_get(self.context, instance_id)
self.assert_(instance_ref['launched_at'] < terminate)
self.assert_(instance_ref['deleted_at'] > terminate)
diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py
index da65b50a2..5370966d2 100644
--- a/nova/tests/network_unittest.py
+++ b/nova/tests/network_unittest.py
@@ -56,12 +56,12 @@ class NetworkTestCase(test.TrialTestCase):
'netuser',
name))
# create the necessary network data for the project
- self.network.set_network_host(self.context, self.projects[i].id)
- instance_ref = db.instance_create(None,
- {'mac_address': utils.generate_mac()})
+ user_context = context.APIRequestContext(project=self.projects[i],
+ user=self.user)
+ self.network.set_network_host(user_context, self.projects[i].id)
+ instance_ref = self._create_instance(0)
self.instance_id = instance_ref['id']
- instance_ref = db.instance_create(None,
- {'mac_address': utils.generate_mac()})
+ instance_ref = self._create_instance(1)
self.instance2_id = instance_ref['id']
def tearDown(self): # pylint: disable-msg=C0103
@@ -74,6 +74,15 @@ class NetworkTestCase(test.TrialTestCase):
self.manager.delete_project(project)
self.manager.delete_user(self.user)
+ def _create_instance(self, project_num, mac=None):
+ if not mac:
+ mac = utils.generate_mac()
+ project = self.projects[project_num]
+ self.context.project = project
+ return db.instance_create(self.context,
+ {'project_id': project.id,
+ 'mac_address': mac})
+
def _create_address(self, project_num, instance_id=None):
"""Create an address in given project num"""
if instance_id is None:
@@ -81,9 +90,15 @@ class NetworkTestCase(test.TrialTestCase):
self.context.project = self.projects[project_num]
return self.network.allocate_fixed_ip(self.context, instance_id)
+ def _deallocate_address(self, project_num, address):
+ self.context.project = self.projects[project_num]
+ self.network.deallocate_fixed_ip(self.context, address)
+
+
def test_public_network_association(self):
"""Makes sure that we can allocaate a public ip"""
# TODO(vish): better way of adding floating ips
+ self.context.project = self.projects[0]
pubnet = IPy.IP(flags.FLAGS.public_range)
address = str(pubnet[0])
try:
@@ -109,7 +124,7 @@ class NetworkTestCase(test.TrialTestCase):
address = self._create_address(0)
self.assertTrue(is_allocated_in_project(address, self.projects[0].id))
lease_ip(address)
- self.network.deallocate_fixed_ip(self.context, address)
+ self._deallocate_address(0, address)
# Doesn't go away until it's dhcp released
self.assertTrue(is_allocated_in_project(address, self.projects[0].id))
@@ -130,14 +145,14 @@ class NetworkTestCase(test.TrialTestCase):
lease_ip(address)
lease_ip(address2)
- self.network.deallocate_fixed_ip(self.context, address)
+ self._deallocate_address(0, address)
release_ip(address)
self.assertFalse(is_allocated_in_project(address, self.projects[0].id))
# First address release shouldn't affect the second
self.assertTrue(is_allocated_in_project(address2, self.projects[1].id))
- self.network.deallocate_fixed_ip(self.context, address2)
+ self._deallocate_address(1, address2)
release_ip(address2)
self.assertFalse(is_allocated_in_project(address2,
self.projects[1].id))
@@ -148,24 +163,19 @@ class NetworkTestCase(test.TrialTestCase):
lease_ip(first)
instance_ids = []
for i in range(1, 5):
- mac = utils.generate_mac()
- instance_ref = db.instance_create(None,
- {'mac_address': mac})
+ instance_ref = self._create_instance(i, mac=utils.generate_mac())
instance_ids.append(instance_ref['id'])
address = self._create_address(i, instance_ref['id'])
- mac = utils.generate_mac()
- instance_ref = db.instance_create(None,
- {'mac_address': mac})
+ instance_ref = self._create_instance(i, mac=utils.generate_mac())
instance_ids.append(instance_ref['id'])
address2 = self._create_address(i, instance_ref['id'])
- mac = utils.generate_mac()
- instance_ref = db.instance_create(None,
- {'mac_address': mac})
+ instance_ref = self._create_instance(i, mac=utils.generate_mac())
instance_ids.append(instance_ref['id'])
address3 = self._create_address(i, instance_ref['id'])
lease_ip(address)
lease_ip(address2)
lease_ip(address3)
+ self.context.project = self.projects[i]
self.assertFalse(is_allocated_in_project(address,
self.projects[0].id))
self.assertFalse(is_allocated_in_project(address2,
@@ -181,7 +191,7 @@ class NetworkTestCase(test.TrialTestCase):
for instance_id in instance_ids:
db.instance_destroy(None, instance_id)
release_ip(first)
- self.network.deallocate_fixed_ip(self.context, first)
+ self._deallocate_address(0, first)
def test_vpn_ip_and_port_looks_valid(self):
"""Ensure the vpn ip and port are reasonable"""
@@ -242,9 +252,7 @@ class NetworkTestCase(test.TrialTestCase):
addresses = []
instance_ids = []
for i in range(num_available_ips):
- mac = utils.generate_mac()
- instance_ref = db.instance_create(None,
- {'mac_address': mac})
+ instance_ref = self._create_instance(0)
instance_ids.append(instance_ref['id'])
address = self._create_address(0, instance_ref['id'])
addresses.append(address)