diff options
-rw-r--r-- | nova/api/context.py (renamed from nova/api/ec2/context.py) | 13 | ||||
-rw-r--r-- | nova/api/ec2/__init__.py | 8 | ||||
-rw-r--r-- | nova/db/api.py | 5 | ||||
-rw-r--r-- | nova/db/sqlalchemy/api.py | 721 | ||||
-rw-r--r-- | nova/db/sqlalchemy/models.py | 146 | ||||
-rw-r--r-- | nova/network/manager.py | 4 | ||||
-rw-r--r-- | nova/tests/compute_unittest.py | 6 | ||||
-rw-r--r-- | nova/tests/network_unittest.py | 50 |
8 files changed, 606 insertions, 347 deletions
diff --git a/nova/api/ec2/context.py b/nova/api/context.py index c53ba98d9..b66cfe468 100644 --- a/nova/api/ec2/context.py +++ b/nova/api/context.py @@ -31,3 +31,16 @@ class APIRequestContext(object): [random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-') for x in xrange(20)] ) + if user: + self.is_admin = user.is_admin() + else: + self.is_admin = False + self.read_deleted = False + + +def get_admin_context(user=None, read_deleted=False): + context_ref = APIRequestContext(user=user, project=None) + context_ref.is_admin = True + context_ref.read_deleted = read_deleted + return context_ref + diff --git a/nova/api/ec2/__init__.py b/nova/api/ec2/__init__.py index 7a958f841..6b538a7f1 100644 --- a/nova/api/ec2/__init__.py +++ b/nova/api/ec2/__init__.py @@ -27,8 +27,8 @@ import webob.exc from nova import exception from nova import flags from nova import wsgi +from nova.api import context from nova.api.ec2 import apirequest -from nova.api.ec2 import context from nova.api.ec2 import admin from nova.api.ec2 import cloud from nova.auth import manager @@ -193,15 +193,15 @@ class Authorizer(wsgi.Middleware): return True if 'none' in roles: return False - return any(context.project.has_role(context.user.id, role) + return any(context.project.has_role(context.user.id, role) for role in roles) - + class Executor(wsgi.Application): """Execute an EC2 API request. - Executes 'ec2.action' upon 'ec2.controller', passing 'ec2.context' and + Executes 'ec2.action' upon 'ec2.controller', passing 'ec2.context' and 'ec2.action_args' (all variables in WSGI environ.) Returns an XML response, or a 400 upon failure. """ diff --git a/nova/db/api.py b/nova/db/api.py index 4aea0e6a4..5c935b561 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -175,11 +175,6 @@ def floating_ip_get_by_address(context, address): return IMPL.floating_ip_get_by_address(context, address) -def floating_ip_get_instance(context, address): - """Get an instance for a floating ip by address.""" - return IMPL.floating_ip_get_instance(context, address) - - #################### diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 4aa3c693d..7f72f66b9 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -19,6 +19,8 @@ Implementation of SQLAlchemy backend """ +import warnings + from nova import db from nova import exception from nova import flags @@ -27,39 +29,108 @@ from nova.db.sqlalchemy import models from nova.db.sqlalchemy.session import get_session from sqlalchemy import or_ from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import joinedload_all -from sqlalchemy.orm.exc import NoResultFound +from sqlalchemy.orm import joinedload, joinedload_all from sqlalchemy.sql import exists, func FLAGS = flags.FLAGS -# NOTE(vish): disabling docstring pylint because the docstrings are -# in the interface definition -# pylint: disable-msg=C0111 -def _deleted(context): - """Calculates whether to include deleted objects based on context. +def is_admin_context(context): + """Indicates if the request context is an administrator.""" + if not context: + warnings.warn('Use of empty request context is deprecated', + DeprecationWarning) + return True + return context.is_admin + + +def is_user_context(context): + """Indicates if the request context is a normal user.""" + if not context: + return False + if not context.user or not context.project: + return False + return True + + +def authorize_project_context(context, project_id): + """Ensures that the request context has permission to access the + given project. + """ + if is_user_context(context): + if not context.project: + raise exception.NotAuthorized() + elif context.project.id != project_id: + raise exception.NotAuthorized() - Currently just looks for a flag called deleted in the context dict. + +def authorize_user_context(context, user_id): + """Ensures that the request context has permission to access the + given user. """ - if not hasattr(context, 'get'): + if is_user_context(context): + if not context.user: + raise exception.NotAuthorized() + elif context.user.id != user_id: + raise exception.NotAuthorized() + + +def can_read_deleted(context): + """Indicates if the context has access to deleted objects.""" + if not context: return False - return context.get('deleted', False) + return context.read_deleted -################### +def require_admin_context(f): + """Decorator used to indicate that the method requires an + administrator context. + """ + def wrapper(*args, **kwargs): + if not is_admin_context(args[0]): + raise exception.NotAuthorized() + return f(*args, **kwargs) + return wrapper + + +def require_context(f): + """Decorator used to indicate that the method requires either + an administrator or normal user context. + """ + def wrapper(*args, **kwargs): + if not is_admin_context(args[0]) and not is_user_context(args[0]): + raise exception.NotAuthorized() + return f(*args, **kwargs) + return wrapper + +################### +@require_admin_context def service_destroy(context, service_id): session = get_session() with session.begin(): - service_ref = models.Service.find(service_id, session=session) + service_ref = service_get(context, service_id, session=session) service_ref.delete(session=session) -def service_get(_context, service_id): - return models.Service.find(service_id) + +@require_admin_context +def service_get(context, service_id, session=None): + if not session: + session = get_session() + + result = session.query(models.Service + ).filter_by(id=service_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + + if not result: + raise exception.NotFound('No service for id %s' % service_id) + + return result +@require_admin_context def service_get_all_by_topic(context, topic): session = get_session() return session.query(models.Service @@ -69,7 +140,8 @@ def service_get_all_by_topic(context, topic): ).all() -def _service_get_all_topic_subquery(_context, session, topic, subq, label): +@require_admin_context +def _service_get_all_topic_subquery(context, session, topic, subq, label): sort_value = getattr(subq.c, label) return session.query(models.Service, func.coalesce(sort_value, 0) ).filter_by(topic=topic @@ -80,6 +152,7 @@ def _service_get_all_topic_subquery(_context, session, topic, subq, label): ).all() +@require_admin_context def service_get_all_compute_sorted(context): session = get_session() with session.begin(): @@ -104,6 +177,7 @@ def service_get_all_compute_sorted(context): label) +@require_admin_context def service_get_all_network_sorted(context): session = get_session() with session.begin(): @@ -121,6 +195,7 @@ def service_get_all_network_sorted(context): label) +@require_admin_context def service_get_all_volume_sorted(context): session = get_session() with session.begin(): @@ -138,11 +213,22 @@ def service_get_all_volume_sorted(context): label) -def service_get_by_args(_context, host, binary): - return models.Service.find_by_args(host, binary) +@require_admin_context +def service_get_by_args(context, host, binary): + session = get_session() + result = session.query(models.Service + ).filter_by(host=host + ).filter_by(binary=binary + ).filter_by(deleted=can_read_deleted(context) + ).first() + if not result: + raise exception.NotFound('No service for %s, %s' % (host, binary)) + + return result -def service_create(_context, values): +@require_admin_context +def service_create(context, values): service_ref = models.Service() for (key, value) in values.iteritems(): service_ref[key] = value @@ -150,10 +236,11 @@ def service_create(_context, values): return service_ref -def service_update(_context, service_id, values): +@require_admin_context +def service_update(context, service_id, values): session = get_session() with session.begin(): - service_ref = models.Service.find(service_id, session=session) + service_ref = session_get(context, service_id, session=session) for (key, value) in values.iteritems(): service_ref[key] = value service_ref.save(session=session) @@ -162,7 +249,9 @@ def service_update(_context, service_id, values): ################### -def floating_ip_allocate_address(_context, host, project_id): +@require_context +def floating_ip_allocate_address(context, host, project_id): + authorize_project_context(context, project_id) session = get_session() with session.begin(): floating_ip_ref = session.query(models.FloatingIp @@ -181,7 +270,8 @@ def floating_ip_allocate_address(_context, host, project_id): return floating_ip_ref['address'] -def floating_ip_create(_context, values): +@require_context +def floating_ip_create(context, values): floating_ip_ref = models.FloatingIp() for (key, value) in values.iteritems(): floating_ip_ref[key] = value @@ -189,7 +279,9 @@ def floating_ip_create(_context, values): return floating_ip_ref['address'] -def floating_ip_count_by_project(_context, project_id): +@require_context +def floating_ip_count_by_project(context, project_id): + authorize_project_context(context, project_id) session = get_session() return session.query(models.FloatingIp ).filter_by(project_id=project_id @@ -197,39 +289,53 @@ def floating_ip_count_by_project(_context, project_id): ).count() -def floating_ip_fixed_ip_associate(_context, floating_address, fixed_address): +@require_context +def floating_ip_fixed_ip_associate(context, floating_address, fixed_address): session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(floating_address, - session=session) - fixed_ip_ref = models.FixedIp.find_by_str(fixed_address, - session=session) + # TODO(devcamcar): How to ensure floating_id belongs to user? + floating_ip_ref = floating_ip_get_by_address(context, + floating_address, + session=session) + fixed_ip_ref = fixed_ip_get_by_address(context, + fixed_address, + session=session) floating_ip_ref.fixed_ip = fixed_ip_ref floating_ip_ref.save(session=session) -def floating_ip_deallocate(_context, address): +@require_context +def floating_ip_deallocate(context, address): session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) + # TODO(devcamcar): How to ensure floating id belongs to user? + floating_ip_ref = floating_ip_get_by_address(context, + address, + session=session) floating_ip_ref['project_id'] = None floating_ip_ref.save(session=session) -def floating_ip_destroy(_context, address): +@require_context +def floating_ip_destroy(context, address): session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) + # TODO(devcamcar): Ensure address belongs to user. + floating_ip_ref = get_floating_ip_by_address(context, + address, + session=session) floating_ip_ref.delete(session=session) -def floating_ip_disassociate(_context, address): +@require_context +def floating_ip_disassociate(context, address): session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) + # TODO(devcamcar): Ensure address belongs to user. + # Does get_floating_ip_by_address handle this? + floating_ip_ref = floating_ip_get_by_address(context, + address, + session=session) fixed_ip_ref = floating_ip_ref.fixed_ip if fixed_ip_ref: fixed_ip_address = fixed_ip_ref['address'] @@ -240,7 +346,8 @@ def floating_ip_disassociate(_context, address): return fixed_ip_address -def floating_ip_get_all(_context): +@require_admin_context +def floating_ip_get_all(context): session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -248,7 +355,8 @@ def floating_ip_get_all(_context): ).all() -def floating_ip_get_all_by_host(_context, host): +@require_admin_context +def floating_ip_get_all_by_host(context, host): session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -256,7 +364,10 @@ def floating_ip_get_all_by_host(_context, host): ).filter_by(deleted=False ).all() -def floating_ip_get_all_by_project(_context, project_id): + +@require_context +def floating_ip_get_all_by_project(context, project_id): + authorize_project_context(context, project_id) session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -264,24 +375,31 @@ def floating_ip_get_all_by_project(_context, project_id): ).filter_by(deleted=False ).all() -def floating_ip_get_by_address(_context, address): - return models.FloatingIp.find_by_str(address) +@require_context +def floating_ip_get_by_address(context, address, session=None): + # TODO(devcamcar): Ensure the address belongs to user. + if not session: + session = get_session() -def floating_ip_get_instance(_context, address): - session = get_session() - with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) - return floating_ip_ref.fixed_ip.instance + result = session.query(models.FloatingIp + ).filter_by(address=address + ).filter_by(deleted=can_read_deleted(context) + ).first() + if not result: + raise exception.NotFound('No fixed ip for address %s' % address) + + return result ################### -def fixed_ip_associate(_context, address, instance_id): +@require_context +def fixed_ip_associate(context, address, instance_id): session = get_session() with session.begin(): + instance = instance_get(context, instance_id, session=session) fixed_ip_ref = session.query(models.FixedIp ).filter_by(address=address ).filter_by(deleted=False @@ -292,12 +410,12 @@ def fixed_ip_associate(_context, address, instance_id): # then this has concurrency issues if not fixed_ip_ref: raise db.NoMoreAddresses() - fixed_ip_ref.instance = models.Instance.find(instance_id, - session=session) + fixed_ip_ref.instance = instance session.add(fixed_ip_ref) -def fixed_ip_associate_pool(_context, network_id, instance_id): +@require_admin_context +def fixed_ip_associate_pool(context, network_id, instance_id): session = get_session() with session.begin(): network_or_none = or_(models.FixedIp.network_id == network_id, @@ -314,14 +432,17 @@ def fixed_ip_associate_pool(_context, network_id, instance_id): if not fixed_ip_ref: raise db.NoMoreAddresses() if not fixed_ip_ref.network: - fixed_ip_ref.network = models.Network.find(network_id, - session=session) - fixed_ip_ref.instance = models.Instance.find(instance_id, - session=session) + fixed_ip_ref.network = network_get(context, + network_id, + session=session) + fixed_ip_ref.instance = instance_get(context, + instance_id, + session=session) session.add(fixed_ip_ref) return fixed_ip_ref['address'] +@require_context def fixed_ip_create(_context, values): fixed_ip_ref = models.FixedIp() for (key, value) in values.iteritems(): @@ -330,14 +451,18 @@ def fixed_ip_create(_context, values): return fixed_ip_ref['address'] -def fixed_ip_disassociate(_context, address): +@require_context +def fixed_ip_disassociate(context, address): session = get_session() with session.begin(): - fixed_ip_ref = models.FixedIp.find_by_str(address, session=session) + fixed_ip_ref = fixed_ip_get_by_address(context, + address, + session=session) fixed_ip_ref.instance = None fixed_ip_ref.save(session=session) +@require_admin_context def fixed_ip_disassociate_all_by_timeout(_context, host, time): session = get_session() # NOTE(vish): The nested select is because sqlite doesn't support @@ -354,34 +479,44 @@ def fixed_ip_disassociate_all_by_timeout(_context, host, time): return result.rowcount -def fixed_ip_get_by_address(_context, address): - session = get_session() +@require_context +def fixed_ip_get_by_address(context, address, session=None): + if not session: + session = get_session() result = session.query(models.FixedIp - ).options(joinedload_all('instance') ).filter_by(address=address - ).filter_by(deleted=False + ).filter_by(deleted=can_read_deleted(context) + ).options(joinedload('network') + ).options(joinedload('instance') ).first() if not result: - raise exception.NotFound("No model for address %s" % address) + raise exception.NotFound('No floating ip for address %s' % address) + + if is_user_context(context): + authorize_project_context(context, result.instance.project_id) + return result -def fixed_ip_get_instance(_context, address): - session = get_session() - with session.begin(): - return models.FixedIp.find_by_str(address, session=session).instance +@require_context +def fixed_ip_get_instance(context, address): + fixed_ip_ref = fixed_ip_get_by_address(context, address) + return fixed_ip_ref.instance -def fixed_ip_get_network(_context, address): - session = get_session() - with session.begin(): - return models.FixedIp.find_by_str(address, session=session).network +@require_admin_context +def fixed_ip_get_network(context, address): + fixed_ip_ref = fixed_ip_get_by_address(context, address) + return fixed_ip_ref.network -def fixed_ip_update(_context, address, values): +@require_context +def fixed_ip_update(context, address, values): session = get_session() with session.begin(): - fixed_ip_ref = models.FixedIp.find_by_str(address, session=session) + fixed_ip_ref = fixed_ip_get_by_address(context, + address, + session=session) for (key, value) in values.iteritems(): fixed_ip_ref[key] = value fixed_ip_ref.save(session=session) @@ -390,7 +525,8 @@ def fixed_ip_update(_context, address, values): ################### -def instance_create(_context, values): +@require_context +def instance_create(context, values): instance_ref = models.Instance() for (key, value) in values.iteritems(): instance_ref[key] = value @@ -399,12 +535,14 @@ def instance_create(_context, values): with session.begin(): while instance_ref.ec2_id == None: ec2_id = utils.generate_uid(instance_ref.__prefix__) - if not instance_ec2_id_exists(_context, ec2_id, session=session): + if not instance_ec2_id_exists(context, ec2_id, session=session): instance_ref.ec2_id = ec2_id instance_ref.save(session=session) return instance_ref -def instance_data_get_for_project(_context, project_id): + +@require_admin_context +def instance_data_get_for_project(context, project_id): session = get_session() result = session.query(func.count(models.Instance.id), func.sum(models.Instance.vcpus) @@ -415,81 +553,130 @@ def instance_data_get_for_project(_context, project_id): return (result[0] or 0, result[1] or 0) -def instance_destroy(_context, instance_id): +@require_context +def instance_destroy(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) instance_ref.delete(session=session) -def instance_get(context, instance_id): - return models.Instance.find(instance_id, deleted=_deleted(context)) +@require_context +def instance_get(context, instance_id, session=None): + if not session: + session = get_session() + result = None + + if is_admin_context(context): + result = session.query(models.Instance + ).filter_by(id=instance_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Instance + ).filter_by(project_id=context.project.id + ).filter_by(id=instance_id + ).filter_by(deleted=False + ).first() + if not result: + raise exception.NotFound('No instance for id %s' % instance_id) + + return result +@require_admin_context def instance_get_all(context): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() + +@require_admin_context def instance_get_all_by_user(context, user_id): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).filter_by(user_id=user_id ).all() + +@require_context def instance_get_all_by_project(context, project_id): + authorize_project_context(context, project_id) + session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).filter_by(project_id=project_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() -def instance_get_all_by_reservation(_context, reservation_id): +@require_context +def instance_get_all_by_reservation(context, reservation_id): session = get_session() - return session.query(models.Instance - ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(reservation_id=reservation_id - ).filter_by(deleted=False - ).all() + if is_admin_context(context): + return session.query(models.Instance + ).options(joinedload_all('fixed_ip.floating_ips') + ).filter_by(reservation_id=reservation_id + ).filter_by(deleted=can_read_deleted(context) + ).all() + elif is_user_context(context): + return session.query(models.Instance + ).options(joinedload_all('fixed_ip.floating_ips') + ).filter_by(project_id=context.project.id + ).filter_by(reservation_id=reservation_id + ).filter_by(deleted=False + ).all() + +@require_context def instance_get_by_ec2_id(context, ec2_id): session = get_session() - instance_ref = session.query(models.Instance + + if is_admin_context(context): + result = session.query(models.Instance + ).filter_by(ec2_id=ec2_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Instance + ).filter_by(project_id=context.project.id ).filter_by(ec2_id=ec2_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=False ).first() - if not instance_ref: + if not result: raise exception.NotFound('Instance %s not found' % (ec2_id)) - return instance_ref + return result +@require_context def instance_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() return session.query(exists().where(models.Instance.id==ec2_id)).one()[0] -def instance_get_fixed_address(_context, instance_id): +@require_context +def instance_get_fixed_address(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) if not instance_ref.fixed_ip: return None return instance_ref.fixed_ip['address'] -def instance_get_floating_address(_context, instance_id): +@require_context +def instance_get_floating_address(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) if not instance_ref.fixed_ip: return None if not instance_ref.fixed_ip.floating_ips: @@ -498,12 +685,14 @@ def instance_get_floating_address(_context, instance_id): return instance_ref.fixed_ip.floating_ips[0]['address'] +@require_admin_context def instance_is_vpn(context, instance_id): # TODO(vish): Move this into image code somewhere instance_ref = instance_get(context, instance_id) return instance_ref['image_id'] == FLAGS.vpn_image_id +@require_admin_context def instance_set_state(context, instance_id, state, description=None): # TODO(devcamcar): Move this out of models and into driver from nova.compute import power_state @@ -515,10 +704,11 @@ def instance_set_state(context, instance_id, state, description=None): 'state_description': description}) -def instance_update(_context, instance_id, values): +@require_context +def instance_update(context, instance_id, values): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) for (key, value) in values.iteritems(): instance_ref[key] = value instance_ref.save(session=session) @@ -527,7 +717,8 @@ def instance_update(_context, instance_id, values): ################### -def key_pair_create(_context, values): +@require_context +def key_pair_create(context, values): key_pair_ref = models.KeyPair() for (key, value) in values.iteritems(): key_pair_ref[key] = value @@ -535,16 +726,18 @@ def key_pair_create(_context, values): return key_pair_ref -def key_pair_destroy(_context, user_id, name): +@require_context +def key_pair_destroy(context, user_id, name): + authorize_user_context(context, user_id) session = get_session() with session.begin(): - key_pair_ref = models.KeyPair.find_by_args(user_id, - name, - session=session) + key_pair_ref = key_pair_get(context, user_id, name, session=session) key_pair_ref.delete(session=session) -def key_pair_destroy_all_by_user(_context, user_id): +@require_context +def key_pair_destroy_all_by_user(context, user_id): + authorize_user_context(context, user_id) session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -552,11 +745,27 @@ def key_pair_destroy_all_by_user(_context, user_id): {'id': user_id}) -def key_pair_get(_context, user_id, name): - return models.KeyPair.find_by_args(user_id, name) +@require_context +def key_pair_get(context, user_id, name, session=None): + authorize_user_context(context, user_id) + + if not session: + session = get_session() + + result = session.query(models.KeyPair + ).filter_by(user_id=user_id + ).filter_by(name=name + ).filter_by(deleted=can_read_deleted(context) + ).first() + if not result: + raise exception.NotFound('no keypair for user %s, name %s' % + (user_id, name)) + return result -def key_pair_get_all_by_user(_context, user_id): +@require_context +def key_pair_get_all_by_user(context, user_id): + authorize_user_context(context, user_id) session = get_session() return session.query(models.KeyPair ).filter_by(user_id=user_id @@ -567,11 +776,16 @@ def key_pair_get_all_by_user(_context, user_id): ################### -def network_count(_context): - return models.Network.count() +@require_admin_context +def network_count(context): + session = get_session() + return session.query(models.Network + ).filter_by(deleted=can_read_deleted(context) + ).count() -def network_count_allocated_ips(_context, network_id): +@require_admin_context +def network_count_allocated_ips(context, network_id): session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -580,7 +794,8 @@ def network_count_allocated_ips(_context, network_id): ).count() -def network_count_available_ips(_context, network_id): +@require_admin_context +def network_count_available_ips(context, network_id): session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -590,7 +805,8 @@ def network_count_available_ips(_context, network_id): ).count() -def network_count_reserved_ips(_context, network_id): +@require_admin_context +def network_count_reserved_ips(context, network_id): session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -599,7 +815,8 @@ def network_count_reserved_ips(_context, network_id): ).count() -def network_create(_context, values): +@require_admin_context +def network_create(context, values): network_ref = models.Network() for (key, value) in values.iteritems(): network_ref[key] = value @@ -607,7 +824,8 @@ def network_create(_context, values): return network_ref -def network_destroy(_context, network_id): +@require_admin_context +def network_destroy(context, network_id): session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -625,14 +843,34 @@ def network_destroy(_context, network_id): {'id': network_id}) -def network_get(_context, network_id): - return models.Network.find(network_id) +@require_context +def network_get(context, network_id, session=None): + if not session: + session = get_session() + result = None + + if is_admin_context(context): + result = session.query(models.Network + ).filter_by(id=network_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Network + ).filter_by(project_id=context.project.id + ).filter_by(id=network_id + ).filter_by(deleted=False + ).first() + if not result: + raise exception.NotFound('No network for id %s' % network_id) + + return result # NOTE(vish): pylint complains because of the long method name, but # it fits with the names of the rest of the methods # pylint: disable-msg=C0103 -def network_get_associated_fixed_ips(_context, network_id): +@require_admin_context +def network_get_associated_fixed_ips(context, network_id): session = get_session() return session.query(models.FixedIp ).options(joinedload_all('instance') @@ -642,18 +880,22 @@ def network_get_associated_fixed_ips(_context, network_id): ).all() -def network_get_by_bridge(_context, bridge): +@require_admin_context +def network_get_by_bridge(context, bridge): session = get_session() - rv = session.query(models.Network + result = session.query(models.Network ).filter_by(bridge=bridge ).filter_by(deleted=False ).first() - if not rv: + + if not result: raise exception.NotFound('No network for bridge %s' % bridge) - return rv + return result -def network_get_index(_context, network_id): + +@require_admin_context +def network_get_index(context, network_id): session = get_session() with session.begin(): network_index = session.query(models.NetworkIndex @@ -661,19 +903,28 @@ def network_get_index(_context, network_id): ).filter_by(deleted=False ).with_lockmode('update' ).first() + if not network_index: raise db.NoMoreNetworks() - network_index['network'] = models.Network.find(network_id, - session=session) + + network_index['network'] = network_get(context, + network_id, + session=session) session.add(network_index) + return network_index['index'] -def network_index_count(_context): - return models.NetworkIndex.count() +@require_admin_context +def network_index_count(context): + session = get_session() + return session.query(models.NetworkIndex + ).filter_by(deleted=can_read_deleted(context) + ).count() -def network_index_create_safe(_context, values): +@require_admin_context +def network_index_create_safe(context, values): network_index_ref = models.NetworkIndex() for (key, value) in values.iteritems(): network_index_ref[key] = value @@ -683,29 +934,32 @@ def network_index_create_safe(_context, values): pass -def network_set_host(_context, network_id, host_id): +@require_admin_context +def network_set_host(context, network_id, host_id): session = get_session() with session.begin(): - network = session.query(models.Network - ).filter_by(id=network_id - ).filter_by(deleted=False - ).with_lockmode('update' - ).first() - if not network: - raise exception.NotFound("Couldn't find network with %s" % - network_id) + network_ref = session.query(models.Network + ).filter_by(id=network_id + ).filter_by(deleted=False + ).with_lockmode('update' + ).first() + if not network_ref: + raise exception.NotFound('No network for id %s' % network_id) + # NOTE(vish): if with_lockmode isn't supported, as in sqlite, # then this has concurrency issues - if not network['host']: - network['host'] = host_id - session.add(network) - return network['host'] + if not network_ref['host']: + network_ref['host'] = host_id + session.add(network_ref) + return network_ref['host'] -def network_update(_context, network_id, values): + +@require_context +def network_update(context, network_id, values): session = get_session() with session.begin(): - network_ref = models.Network.find(network_id, session=session) + network_ref = network_get(context, network_id, session=session) for (key, value) in values.iteritems(): network_ref[key] = value network_ref.save(session=session) @@ -714,15 +968,18 @@ def network_update(_context, network_id, values): ################### -def project_get_network(_context, project_id): +@require_context +def project_get_network(context, project_id): session = get_session() - rv = session.query(models.Network + result= session.query(models.Network ).filter_by(project_id=project_id ).filter_by(deleted=False ).first() - if not rv: + + if not result: raise exception.NotFound('No network for project: %s' % project_id) - return rv + + return result ################### @@ -732,14 +989,20 @@ def queue_get_for(_context, topic, physical_node_id): # FIXME(ja): this should be servername? return "%s.%s" % (topic, physical_node_id) + ################### -def export_device_count(_context): - return models.ExportDevice.count() +@require_admin_context +def export_device_count(context): + session = get_session() + return session.query(models.ExportDevice + ).filter_by(deleted=can_read_deleted(context) + ).count() -def export_device_create(_context, values): +@require_admin_context +def export_device_create(context, values): export_device_ref = models.ExportDevice() for (key, value) in values.iteritems(): export_device_ref[key] = value @@ -773,7 +1036,23 @@ def auth_create_token(_context, token): ################### -def quota_create(_context, values): +@require_admin_context +def quota_get(context, project_id, session=None): + if not session: + session = get_session() + + result = session.query(models.Quota + ).filter_by(project_id=project_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + if not result: + raise exception.NotFound('No quota for project_id %s' % project_id) + + return result + + +@require_admin_context +def quota_create(context, values): quota_ref = models.Quota() for (key, value) in values.iteritems(): quota_ref[key] = value @@ -781,30 +1060,29 @@ def quota_create(_context, values): return quota_ref -def quota_get(_context, project_id): - return models.Quota.find_by_str(project_id) - - -def quota_update(_context, project_id, values): +@require_admin_context +def quota_update(context, project_id, values): session = get_session() with session.begin(): - quota_ref = models.Quota.find_by_str(project_id, session=session) + quota_ref = quota_get(context, project_id, session=session) for (key, value) in values.iteritems(): quota_ref[key] = value quota_ref.save(session=session) -def quota_destroy(_context, project_id): +@require_admin_context +def quota_destroy(context, project_id): session = get_session() with session.begin(): - quota_ref = models.Quota.find_by_str(project_id, session=session) + quota_ref = quota_get(context, project_id, session=session) quota_ref.delete(session=session) ################### -def volume_allocate_shelf_and_blade(_context, volume_id): +@require_admin_context +def volume_allocate_shelf_and_blade(context, volume_id): session = get_session() with session.begin(): export_device = session.query(models.ExportDevice @@ -821,19 +1099,20 @@ def volume_allocate_shelf_and_blade(_context, volume_id): return (export_device.shelf_id, export_device.blade_id) -def volume_attached(_context, volume_id, instance_id, mountpoint): +@require_admin_context +def volume_attached(context, volume_id, instance_id, mountpoint): session = get_session() with session.begin(): - volume_ref = models.Volume.find(volume_id, session=session) + volume_ref = volume_get(context, volume_id, session=session) volume_ref['status'] = 'in-use' volume_ref['mountpoint'] = mountpoint volume_ref['attach_status'] = 'attached' - volume_ref.instance = models.Instance.find(instance_id, - session=session) + volume_ref.instance = instance_get(context, instance_id, session=session) volume_ref.save(session=session) -def volume_create(_context, values): +@require_context +def volume_create(context, values): volume_ref = models.Volume() for (key, value) in values.iteritems(): volume_ref[key] = value @@ -842,13 +1121,14 @@ def volume_create(_context, values): with session.begin(): while volume_ref.ec2_id == None: ec2_id = utils.generate_uid(volume_ref.__prefix__) - if not volume_ec2_id_exists(_context, ec2_id, session=session): + if not volume_ec2_id_exists(context, ec2_id, session=session): volume_ref.ec2_id = ec2_id volume_ref.save(session=session) return volume_ref -def volume_data_get_for_project(_context, project_id): +@require_admin_context +def volume_data_get_for_project(context, project_id): session = get_session() result = session.query(func.count(models.Volume.id), func.sum(models.Volume.size) @@ -859,7 +1139,8 @@ def volume_data_get_for_project(_context, project_id): return (result[0] or 0, result[1] or 0) -def volume_destroy(_context, volume_id): +@require_admin_context +def volume_destroy(context, volume_id): session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -870,10 +1151,11 @@ def volume_destroy(_context, volume_id): {'id': volume_id}) -def volume_detached(_context, volume_id): +@require_admin_context +def volume_detached(context, volume_id): session = get_session() with session.begin(): - volume_ref = models.Volume.find(volume_id, session=session) + volume_ref = volume_get(context, volume_id, session=session) volume_ref['status'] = 'available' volume_ref['mountpoint'] = None volume_ref['attach_status'] = 'detached' @@ -881,60 +1163,113 @@ def volume_detached(_context, volume_id): volume_ref.save(session=session) -def volume_get(context, volume_id): - return models.Volume.find(volume_id, deleted=_deleted(context)) +@require_context +def volume_get(context, volume_id, session=None): + if not session: + session = get_session() + result = None + if is_admin_context(context): + result = session.query(models.Volume + ).filter_by(id=volume_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Volume + ).filter_by(project_id=context.project.id + ).filter_by(id=volume_id + ).filter_by(deleted=False + ).first() + if not result: + raise exception.NotFound('No volume for id %s' % volume_id) -def volume_get_all(context): - return models.Volume.all(deleted=_deleted(context)) + return result +@require_admin_context +def volume_get_all(context): + return session.query(models.Volume + ).filter_by(deleted=can_read_deleted(context) + ).all() + +@require_context def volume_get_all_by_project(context, project_id): + authorize_project_context(context, project_id) + session = get_session() return session.query(models.Volume ).filter_by(project_id=project_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() +@require_context def volume_get_by_ec2_id(context, ec2_id): session = get_session() - volume_ref = session.query(models.Volume + result = None + + if is_admin_context(context): + result = session.query(models.Volume + ).filter_by(ec2_id=ec2_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Volume + ).filter_by(project_id=context.project.id ).filter_by(ec2_id=ec2_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=False ).first() - if not volume_ref: - raise exception.NotFound('Volume %s not found' % (ec2_id)) + else: + raise exception.NotAuthorized() - return volume_ref + if not result: + raise exception.NotFound('Volume %s not found' % ec2_id) + + return result +@require_context def volume_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() - return session.query(exists().where(models.Volume.id==ec2_id)).one()[0] + return session.query(exists( + ).where(models.Volume.id==ec2_id) + ).one()[0] -def volume_get_instance(_context, volume_id): + +@require_admin_context +def volume_get_instance(context, volume_id): session = get_session() - with session.begin(): - return models.Volume.find(volume_id, session=session).instance + result = session.query(models.Volume + ).filter_by(id=volume_id + ).filter_by(deleted=can_read_deleted(context) + ).options(joinedload('instance') + ).first() + if not result: + raise exception.NotFound('Volume %s not found' % ec2_id) + + return result.instance -def volume_get_shelf_and_blade(_context, volume_id): +@require_admin_context +def volume_get_shelf_and_blade(context, volume_id): session = get_session() - export_device = session.query(models.ExportDevice - ).filter_by(volume_id=volume_id - ).first() - if not export_device: - raise exception.NotFound() - return (export_device.shelf_id, export_device.blade_id) + result = session.query(models.ExportDevice + ).filter_by(volume_id=volume_id + ).first() + if not result: + raise exception.NotFound('No export device found for volume %s' % + volume_id) + + return (result.shelf_id, result.blade_id) -def volume_update(_context, volume_id, values): +@require_context +def volume_update(context, volume_id, values): session = get_session() with session.begin(): - volume_ref = models.Volume.find(volume_id, session=session) + volume_ref = volume_get(context, volume_id, session=session) for (key, value) in values.iteritems(): volume_ref[key] = value volume_ref.save(session=session) diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 6cb377476..1837a7584 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -50,44 +50,6 @@ class NovaBase(object): deleted_at = Column(DateTime) deleted = Column(Boolean, default=False) - @classmethod - def all(cls, session=None, deleted=False): - """Get all objects of this type""" - if not session: - session = get_session() - return session.query(cls - ).filter_by(deleted=deleted - ).all() - - @classmethod - def count(cls, session=None, deleted=False): - """Count objects of this type""" - if not session: - session = get_session() - return session.query(cls - ).filter_by(deleted=deleted - ).count() - - @classmethod - def find(cls, obj_id, session=None, deleted=False): - """Find object by id""" - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(id=obj_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for id %s" % obj_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - """Find object by str_id""" - int_id = int(str_id.rpartition('-')[2]) - return cls.find(int_id, session=session, deleted=deleted) - @property def str_id(self): """Get string id of object (generally prefix + '-' + id)""" @@ -176,21 +138,6 @@ class Service(BASE, NovaBase): report_count = Column(Integer, nullable=False, default=0) disabled = Column(Boolean, default=False) - @classmethod - def find_by_args(cls, host, binary, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(host=host - ).filter_by(binary=binary - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for %s, %s" % (host, - binary)) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - class Instance(BASE, NovaBase): """Represents a guest vm""" @@ -284,7 +231,11 @@ class Volume(BASE, NovaBase): size = Column(Integer) availability_zone = Column(String(255)) # TODO(vish): foreign key? instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True) - instance = relationship(Instance, backref=backref('volumes')) + instance = relationship(Instance, + backref=backref('volumes'), + foreign_keys=instance_id, + primaryjoin='and_(Volume.instance_id==Instance.id,' + 'Volume.deleted==False)') mountpoint = Column(String(255)) attach_time = Column(String(255)) # TODO(vish): datetime status = Column(String(255)) # TODO(vish): enum? @@ -315,18 +266,6 @@ class Quota(BASE, NovaBase): def str_id(self): return self.project_id - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(project_id=str_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for project_id %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] class ExportDevice(BASE, NovaBase): """Represates a shelf and blade that a volume can be exported on""" @@ -335,8 +274,11 @@ class ExportDevice(BASE, NovaBase): shelf_id = Column(Integer) blade_id = Column(Integer) volume_id = Column(Integer, ForeignKey('volumes.id'), nullable=True) - volume = relationship(Volume, backref=backref('export_device', - uselist=False)) + volume = relationship(Volume, + backref=backref('export_device', uselist=False), + foreign_keys=volume_id, + primaryjoin='and_(ExportDevice.volume_id==Volume.id,' + 'ExportDevice.deleted==False)') class KeyPair(BASE, NovaBase): @@ -354,26 +296,6 @@ class KeyPair(BASE, NovaBase): def str_id(self): return '%s.%s' % (self.user_id, self.name) - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - user_id, _sep, name = str_id.partition('.') - return cls.find_by_str(user_id, name, session, deleted) - - @classmethod - def find_by_args(cls, user_id, name, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(user_id=user_id - ).filter_by(name=name - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for user %s, name %s" % - (user_id, name)) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - class Network(BASE, NovaBase): """Represents a network""" @@ -409,8 +331,12 @@ class NetworkIndex(BASE, NovaBase): id = Column(Integer, primary_key=True) index = Column(Integer, unique=True) network_id = Column(Integer, ForeignKey('networks.id'), nullable=True) - network = relationship(Network, backref=backref('network_index', - uselist=False)) + network = relationship(Network, + backref=backref('network_index', uselist=False), + foreign_keys=network_id, + primaryjoin='and_(NetworkIndex.network_id==Network.id,' + 'NetworkIndex.deleted==False)') + class AuthToken(BASE, NovaBase): """Represents an authorization token for all API transactions. Fields @@ -434,8 +360,11 @@ class FixedIp(BASE, NovaBase): network_id = Column(Integer, ForeignKey('networks.id'), nullable=True) network = relationship(Network, backref=backref('fixed_ips')) instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True) - instance = relationship(Instance, backref=backref('fixed_ip', - uselist=False)) + instance = relationship(Instance, + backref=backref('fixed_ip', uselist=False), + foreign_keys=instance_id, + primaryjoin='and_(FixedIp.instance_id==Instance.id,' + 'FixedIp.deleted==False)') allocated = Column(Boolean, default=False) leased = Column(Boolean, default=False) reserved = Column(Boolean, default=False) @@ -444,19 +373,6 @@ class FixedIp(BASE, NovaBase): def str_id(self): return self.address - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(address=str_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for address %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - class FloatingIp(BASE, NovaBase): """Represents a floating ip that dynamically forwards to a fixed ip""" @@ -464,24 +380,14 @@ class FloatingIp(BASE, NovaBase): id = Column(Integer, primary_key=True) address = Column(String(255)) fixed_ip_id = Column(Integer, ForeignKey('fixed_ips.id'), nullable=True) - fixed_ip = relationship(FixedIp, backref=backref('floating_ips')) - + fixed_ip = relationship(FixedIp, + backref=backref('floating_ips'), + foreign_keys=fixed_ip_id, + primaryjoin='and_(FloatingIp.fixed_ip_id==FixedIp.id,' + 'FloatingIp.deleted==False)') project_id = Column(String(255)) host = Column(String(255)) # , ForeignKey('hosts.id')) - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(address=str_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for address %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - def register_models(): """Register Models and create metadata""" diff --git a/nova/network/manager.py b/nova/network/manager.py index 1325c300b..ef1d01138 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -92,7 +92,7 @@ class NetworkManager(manager.Manager): # TODO(vish): can we minimize db access by just getting the # id here instead of the ref? network_id = network_ref['id'] - host = self.db.network_set_host(context, + host = self.db.network_set_host(None, network_id, self.host) self._on_set_network_host(context, network_id) @@ -249,7 +249,7 @@ class VlanManager(NetworkManager): address = network_ref['vpn_private_address'] self.db.fixed_ip_associate(context, address, instance_id) else: - address = self.db.fixed_ip_associate_pool(context, + address = self.db.fixed_ip_associate_pool(None, network_ref['id'], instance_id) self.db.fixed_ip_update(context, address, {'allocated': True}) diff --git a/nova/tests/compute_unittest.py b/nova/tests/compute_unittest.py index f5c0f1c09..1e2bb113b 100644 --- a/nova/tests/compute_unittest.py +++ b/nova/tests/compute_unittest.py @@ -30,7 +30,7 @@ from nova import flags from nova import test from nova import utils from nova.auth import manager - +from nova.api import context FLAGS = flags.FLAGS @@ -96,7 +96,9 @@ class ComputeTestCase(test.TrialTestCase): self.assertEqual(instance_ref['deleted_at'], None) terminate = datetime.datetime.utcnow() yield self.compute.terminate_instance(self.context, instance_id) - instance_ref = db.instance_get({'deleted': True}, instance_id) + self.context = context.get_admin_context(user=self.user, + read_deleted=True) + instance_ref = db.instance_get(self.context, instance_id) self.assert_(instance_ref['launched_at'] < terminate) self.assert_(instance_ref['deleted_at'] > terminate) diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index da65b50a2..5370966d2 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -56,12 +56,12 @@ class NetworkTestCase(test.TrialTestCase): 'netuser', name)) # create the necessary network data for the project - self.network.set_network_host(self.context, self.projects[i].id) - instance_ref = db.instance_create(None, - {'mac_address': utils.generate_mac()}) + user_context = context.APIRequestContext(project=self.projects[i], + user=self.user) + self.network.set_network_host(user_context, self.projects[i].id) + instance_ref = self._create_instance(0) self.instance_id = instance_ref['id'] - instance_ref = db.instance_create(None, - {'mac_address': utils.generate_mac()}) + instance_ref = self._create_instance(1) self.instance2_id = instance_ref['id'] def tearDown(self): # pylint: disable-msg=C0103 @@ -74,6 +74,15 @@ class NetworkTestCase(test.TrialTestCase): self.manager.delete_project(project) self.manager.delete_user(self.user) + def _create_instance(self, project_num, mac=None): + if not mac: + mac = utils.generate_mac() + project = self.projects[project_num] + self.context.project = project + return db.instance_create(self.context, + {'project_id': project.id, + 'mac_address': mac}) + def _create_address(self, project_num, instance_id=None): """Create an address in given project num""" if instance_id is None: @@ -81,9 +90,15 @@ class NetworkTestCase(test.TrialTestCase): self.context.project = self.projects[project_num] return self.network.allocate_fixed_ip(self.context, instance_id) + def _deallocate_address(self, project_num, address): + self.context.project = self.projects[project_num] + self.network.deallocate_fixed_ip(self.context, address) + + def test_public_network_association(self): """Makes sure that we can allocaate a public ip""" # TODO(vish): better way of adding floating ips + self.context.project = self.projects[0] pubnet = IPy.IP(flags.FLAGS.public_range) address = str(pubnet[0]) try: @@ -109,7 +124,7 @@ class NetworkTestCase(test.TrialTestCase): address = self._create_address(0) self.assertTrue(is_allocated_in_project(address, self.projects[0].id)) lease_ip(address) - self.network.deallocate_fixed_ip(self.context, address) + self._deallocate_address(0, address) # Doesn't go away until it's dhcp released self.assertTrue(is_allocated_in_project(address, self.projects[0].id)) @@ -130,14 +145,14 @@ class NetworkTestCase(test.TrialTestCase): lease_ip(address) lease_ip(address2) - self.network.deallocate_fixed_ip(self.context, address) + self._deallocate_address(0, address) release_ip(address) self.assertFalse(is_allocated_in_project(address, self.projects[0].id)) # First address release shouldn't affect the second self.assertTrue(is_allocated_in_project(address2, self.projects[1].id)) - self.network.deallocate_fixed_ip(self.context, address2) + self._deallocate_address(1, address2) release_ip(address2) self.assertFalse(is_allocated_in_project(address2, self.projects[1].id)) @@ -148,24 +163,19 @@ class NetworkTestCase(test.TrialTestCase): lease_ip(first) instance_ids = [] for i in range(1, 5): - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address = self._create_address(i, instance_ref['id']) - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address2 = self._create_address(i, instance_ref['id']) - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address3 = self._create_address(i, instance_ref['id']) lease_ip(address) lease_ip(address2) lease_ip(address3) + self.context.project = self.projects[i] self.assertFalse(is_allocated_in_project(address, self.projects[0].id)) self.assertFalse(is_allocated_in_project(address2, @@ -181,7 +191,7 @@ class NetworkTestCase(test.TrialTestCase): for instance_id in instance_ids: db.instance_destroy(None, instance_id) release_ip(first) - self.network.deallocate_fixed_ip(self.context, first) + self._deallocate_address(0, first) def test_vpn_ip_and_port_looks_valid(self): """Ensure the vpn ip and port are reasonable""" @@ -242,9 +252,7 @@ class NetworkTestCase(test.TrialTestCase): addresses = [] instance_ids = [] for i in range(num_available_ips): - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(0) instance_ids.append(instance_ref['id']) address = self._create_address(0, instance_ref['id']) addresses.append(address) |