diff options
| author | jaypipes@gmail.com <> | 2010-10-15 16:24:19 -0400 |
|---|---|---|
| committer | jaypipes@gmail.com <> | 2010-10-15 16:24:19 -0400 |
| commit | 3fdced0a19315732fec0ead200604e396f06823c (patch) | |
| tree | cf4675fdf71e8c126888432007e15d1bbf9c2360 | |
| parent | ff60af51cc2990c7b60ca97cc899f0719560bc6f (diff) | |
| parent | b70742cd442e8477d15c82a825641d934529bedf (diff) | |
| download | nova-3fdced0a19315732fec0ead200604e396f06823c.tar.gz nova-3fdced0a19315732fec0ead200604e396f06823c.tar.xz nova-3fdced0a19315732fec0ead200604e396f06823c.zip | |
Merge trunk
48 files changed, 823 insertions, 656 deletions
diff --git a/bin/nova-dhcpbridge b/bin/nova-dhcpbridge index a127ed03c..2b7a083d2 100755 --- a/bin/nova-dhcpbridge +++ b/bin/nova-dhcpbridge @@ -33,6 +33,7 @@ possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]), if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')): sys.path.insert(0, possible_topdir) +from nova import context from nova import db from nova import flags from nova import rpc @@ -52,12 +53,14 @@ def add_lease(mac, ip_address, _hostname, _interface): if FLAGS.fake_rabbit: logging.debug("leasing ip") network_manager = utils.import_object(FLAGS.network_manager) - network_manager.lease_fixed_ip(None, mac, ip_address) + network_manager.lease_fixed_ip(context.get_admin_context(), + mac, + ip_address) else: - rpc.cast("%s.%s" % (FLAGS.network_topic, FLAGS.host), + rpc.cast(context.get_admin_context(), + "%s.%s" % (FLAGS.network_topic, FLAGS.host), {"method": "lease_fixed_ip", - "args": {"context": None, - "mac": mac, + "args": {"mac": mac, "address": ip_address}}) @@ -71,19 +74,22 @@ def del_lease(mac, ip_address, _hostname, _interface): if FLAGS.fake_rabbit: logging.debug("releasing ip") network_manager = utils.import_object(FLAGS.network_manager) - network_manager.release_fixed_ip(None, mac, ip_address) + network_manager.release_fixed_ip(context.get_admin_context(), + mac, + ip_address) else: - rpc.cast("%s.%s" % (FLAGS.network_topic, FLAGS.host), + rpc.cast(context.get_admin_context(), + "%s.%s" % (FLAGS.network_topic, FLAGS.host), {"method": "release_fixed_ip", - "args": {"context": None, - "mac": mac, + "args": {"mac": mac, "address": ip_address}}) def init_leases(interface): """Get the list of hosts for an interface.""" - network_ref = db.network_get_by_bridge(None, interface) - return linux_net.get_dhcp_hosts(None, network_ref['id']) + ctxt = context.get_admin_context() + network_ref = db.network_get_by_bridge(ctxt, interface) + return linux_net.get_dhcp_hosts(ctxt, network_ref['id']) def main(): diff --git a/bin/nova-manage b/bin/nova-manage index d36b0f53a..1c5700190 100755 --- a/bin/nova-manage +++ b/bin/nova-manage @@ -67,17 +67,22 @@ possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]), if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')): sys.path.insert(0, possible_topdir) +from nova import context from nova import db from nova import exception from nova import flags from nova import quota from nova import utils from nova.auth import manager -from nova.network import manager as network_manager from nova.cloudpipe import pipelib FLAGS = flags.FLAGS +flags.DECLARE('fixed_range', 'nova.network.manager') +flags.DECLARE('num_networks', 'nova.network.manager') +flags.DECLARE('network_size', 'nova.network.manager') +flags.DECLARE('vlan_start', 'nova.network.manager') +flags.DECLARE('vpn_start', 'nova.network.manager') class VpnCommands(object): @@ -121,7 +126,7 @@ class VpnCommands(object): def _vpn_for(self, project_id): """Get the VPN instance for a project ID.""" - for instance in db.instance_get_all(None): + for instance in db.instance_get_all(context.get_admin_context()): if (instance['image_id'] == FLAGS.vpn_image_id and not instance['state_description'] in ['shutting_down', 'shutdown'] @@ -323,13 +328,14 @@ class ProjectCommands(object): def quota(self, project_id, key=None, value=None): """Set or display quotas for project arguments: project_id [key] [value]""" + ctxt = context.get_admin_context() if key: quo = {'project_id': project_id, key: value} try: - db.quota_update(None, project_id, quo) + db.quota_update(ctxt, project_id, quo) except exception.NotFound: - db.quota_create(None, quo) - project_quota = quota.get_quota(None, project_id) + db.quota_create(ctxt, quo) + project_quota = quota.get_quota(ctxt, project_id) for key, value in project_quota.iteritems(): print '%s: %s' % (key, value) @@ -353,23 +359,26 @@ class FloatingIpCommands(object): """Creates floating ips for host by range arguments: host ip_range""" for address in IPy.IP(range): - db.floating_ip_create(None, {'address': str(address), - 'host': host}) + db.floating_ip_create(context.get_admin_context(), + {'address': str(address), + 'host': host}) def delete(self, ip_range): """Deletes floating ips by range arguments: range""" for address in IPy.IP(ip_range): - db.floating_ip_destroy(None, str(address)) + db.floating_ip_destroy(context.get_admin_context(), + str(address)) def list(self, host=None): """Lists all floating ips (optionally by host) arguments: [host]""" + ctxt = context.get_admin_context() if host == None: - floating_ips = db.floating_ip_get_all(None) + floating_ips = db.floating_ip_get_all(ctxt) else: - floating_ips = db.floating_ip_get_all_by_host(None, host) + floating_ips = db.floating_ip_get_all_by_host(ctxt, host) for floating_ip in floating_ips: instance = None if floating_ip['fixed_ip']: @@ -451,7 +460,7 @@ def main(): if FLAGS.verbose: logging.getLogger().setLevel(logging.DEBUG) - + script_name = argv.pop(0) if len(argv) < 1: print script_name + " category action [<args>]" diff --git a/nova/api/cloud.py b/nova/api/cloud.py index 57e94a17a..aa84075dc 100644 --- a/nova/api/cloud.py +++ b/nova/api/cloud.py @@ -29,14 +29,10 @@ FLAGS = flags.FLAGS def reboot(instance_id, context=None): - """Reboot the given instance. - - #TODO(gundlach) not actually sure what context is used for by ec2 here - -- I think we can just remove it and use None all the time. - """ - instance_ref = db.instance_get_by_internal_id(None, instance_id) + """Reboot the given instance.""" + instance_ref = db.instance_get_by_internal_id(context, instance_id) host = instance_ref['host'] - rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host), + rpc.cast(context, + db.queue_get_for(context, FLAGS.compute_topic, host), {"method": "reboot_instance", - "args": {"context": None, - "instance_id": instance_ref['id']}}) + "args": {"instance_id": instance_ref['id']}}) diff --git a/nova/api/context.py b/nova/api/context.py deleted file mode 100644 index b66cfe468..000000000 --- a/nova/api/context.py +++ /dev/null @@ -1,46 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 United States Government as represented by the -# Administrator of the National Aeronautics and Space Administration. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -""" -APIRequestContext -""" - -import random - - -class APIRequestContext(object): - def __init__(self, user, project): - self.user = user - self.project = project - self.request_id = ''.join( - [random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-') - for x in xrange(20)] - ) - if user: - self.is_admin = user.is_admin() - else: - self.is_admin = False - self.read_deleted = False - - -def get_admin_context(user=None, read_deleted=False): - context_ref = APIRequestContext(user=user, project=None) - context_ref.is_admin = True - context_ref.read_deleted = read_deleted - return context_ref - diff --git a/nova/api/ec2/__init__.py b/nova/api/ec2/__init__.py index 6e771f064..5735eb956 100644 --- a/nova/api/ec2/__init__.py +++ b/nova/api/ec2/__init__.py @@ -25,9 +25,9 @@ import webob.dec import webob.exc from nova import exception +from nova import context from nova import flags from nova import wsgi -from nova.api import context from nova.api.ec2 import apirequest from nova.api.ec2 import admin from nova.api.ec2 import cloud @@ -78,7 +78,10 @@ class Authenticate(wsgi.Middleware): raise webob.exc.HTTPForbidden() # Authenticated! - req.environ['ec2.context'] = context.APIRequestContext(user, project) + ctxt = context.RequestContext(user=user, + project=project, + remote_address=req.remote_addr) + req.environ['ec2.context'] = ctxt return self.application diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index a7693cadd..6d4f58499 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -28,6 +28,7 @@ import logging import os import time +from nova import context import IPy from nova import crypto @@ -117,9 +118,9 @@ class CloudController(object): utils.runthis("Generating root CA: %s", "sh genrootca.sh") os.chdir(start) - def _get_mpi_data(self, project_id): + def _get_mpi_data(self, context, project_id): result = {} - for instance in db.instance_get_all_by_project(None, project_id): + for instance in db.instance_get_all_by_project(context, project_id): if instance['fixed_ip']: line = '%s slots=%d' % (instance['fixed_ip']['address'], INSTANCE_TYPES[instance['instance_type']]['vcpus']) @@ -130,20 +131,21 @@ class CloudController(object): result[key] = [line] return result - def _trigger_refresh_security_group(self, security_group): + def _trigger_refresh_security_group(self, context, security_group): nodes = set([instance['host'] for instance in security_group.instances if instance['host'] is not None]) for node in nodes: - rpc.call('%s.%s' % (FLAGS.compute_topic, node), + rpc.cast(context, + '%s.%s' % (FLAGS.compute_topic, node), { "method": "refresh_security_group", - "args": { "context": None, - "security_group_id": security_group.id}}) + "args": {"security_group_id": security_group.id}}) def get_metadata(self, address): - instance_ref = db.fixed_ip_get_instance(None, address) + ctxt = context.get_admin_context() + instance_ref = db.fixed_ip_get_instance(ctxt, address) if instance_ref is None: return None - mpi = self._get_mpi_data(instance_ref['project_id']) + mpi = self._get_mpi_data(ctxt, instance_ref['project_id']) if instance_ref['key_name']: keys = { '0': { @@ -154,7 +156,7 @@ class CloudController(object): else: keys = '' hostname = instance_ref['hostname'] - floating_ip = db.instance_get_floating_address(None, + floating_ip = db.instance_get_floating_address(ctxt, instance_ref['id']) data = { 'user-data': base64.b64decode(instance_ref['user_data']), @@ -162,7 +164,7 @@ class CloudController(object): 'ami-id': instance_ref['image_id'], 'ami-launch-index': instance_ref['launch_index'], 'ami-manifest-path': 'FIXME', - 'block-device-mapping': { # TODO(vish): replace with real data + 'block-device-mapping': { # TODO(vish): replace with real data 'ami': 'sda1', 'ephemeral0': 'sda2', 'root': '/dev/sda1', @@ -244,7 +246,7 @@ class CloudController(object): return {'keypairsSet': result} def create_key_pair(self, context, key_name, **kwargs): - data = _gen_key(None, context.user.id, key_name) + data = _gen_key(context, context.user.id, key_name) return {'keyName': key_name, 'keyFingerprint': data['fingerprint'], 'keyMaterial': data['private_key']} @@ -264,7 +266,7 @@ class CloudController(object): groups = db.security_group_get_all(context) else: groups = db.security_group_get_by_project(context, - context.project.id) + context.project_id) groups = [self._format_security_group(context, g) for g in groups] if not group_name is None: groups = [g for g in groups if g.name in group_name] @@ -308,7 +310,7 @@ class CloudController(object): source_security_group_owner_id) source_security_group = \ - db.security_group_get_by_name(context, + db.security_group_get_by_name(context.elevated(), source_project_id, source_security_group_name) values['group_id'] = source_security_group['id'] @@ -364,7 +366,7 @@ class CloudController(object): def revoke_security_group_ingress(self, context, group_name, **kwargs): self._ensure_default_security_group(context) security_group = db.security_group_get_by_name(context, - context.project.id, + context.project_id, group_name) criteria = self._authorize_revoke_rule_args_to_dict(context, **kwargs) @@ -378,7 +380,7 @@ class CloudController(object): match = False if match: db.security_group_rule_destroy(context, rule['id']) - self._trigger_refresh_security_group(security_group) + self._trigger_refresh_security_group(context, security_group) return True raise exception.ApiError("No rule for the specified parameters.") @@ -389,7 +391,7 @@ class CloudController(object): def authorize_security_group_ingress(self, context, group_name, **kwargs): self._ensure_default_security_group(context) security_group = db.security_group_get_by_name(context, - context.project.id, + context.project_id, group_name) values = self._authorize_revoke_rule_args_to_dict(context, **kwargs) @@ -401,7 +403,7 @@ class CloudController(object): security_group_rule = db.security_group_rule_create(context, values) - self._trigger_refresh_security_group(security_group) + self._trigger_refresh_security_group(context, security_group) return True @@ -419,18 +421,18 @@ class CloudController(object): else: source_project_id = source_parts[0] else: - source_project_id = context.project.id + source_project_id = context.project_id return source_project_id def create_security_group(self, context, group_name, group_description): self._ensure_default_security_group(context) - if db.security_group_exists(context, context.project.id, group_name): + if db.security_group_exists(context, context.project_id, group_name): raise exception.ApiError('group %s already exists' % group_name) group = {'user_id' : context.user.id, - 'project_id': context.project.id, + 'project_id': context.project_id, 'name': group_name, 'description': group_description} group_ref = db.security_group_create(context, group) @@ -441,7 +443,7 @@ class CloudController(object): def delete_security_group(self, context, group_name, **kwargs): security_group = db.security_group_get_by_name(context, - context.project.id, + context.project_id, group_name) db.security_group_destroy(context, security_group.id) return True @@ -452,11 +454,11 @@ class CloudController(object): ec2_id = instance_id[0] internal_id = ec2_id_to_internal_id(ec2_id) instance_ref = db.instance_get_by_internal_id(context, internal_id) - output = rpc.call('%s.%s' % (FLAGS.compute_topic, - instance_ref['host']), - { "method" : "get_console_output", - "args" : { "context": None, - "instance_id": instance_ref['id']}}) + output = rpc.call(context, + '%s.%s' % (FLAGS.compute_topic, + instance_ref['host']), + {"method" : "get_console_output", + "args" : {"instance_id": instance_ref['id']}}) now = datetime.datetime.utcnow() return { "InstanceId" : ec2_id, @@ -467,7 +469,7 @@ class CloudController(object): if context.user.is_admin(): volumes = db.volume_get_all(context) else: - volumes = db.volume_get_all_by_project(context, context.project.id) + volumes = db.volume_get_all_by_project(context, context.project_id) volumes = [self._format_volume(context, v) for v in volumes] @@ -505,14 +507,14 @@ class CloudController(object): # check quota if quota.allowed_volumes(context, 1, size) < 1: logging.warn("Quota exceeeded for %s, tried to create %sG volume", - context.project.id, size) + context.project_id, size) raise QuotaError("Volume quota exceeded. You cannot " "create a volume of size %s" % size) vol = {} vol['size'] = size vol['user_id'] = context.user.id - vol['project_id'] = context.project.id + vol['project_id'] = context.project_id vol['availability_zone'] = FLAGS.storage_availability_zone vol['status'] = "creating" vol['attach_status'] = "detached" @@ -520,10 +522,10 @@ class CloudController(object): vol['display_description'] = kwargs.get('display_description') volume_ref = db.volume_create(context, vol) - rpc.cast(FLAGS.scheduler_topic, + rpc.cast(context, + FLAGS.scheduler_topic, {"method": "create_volume", - "args": {"context": None, - "topic": FLAGS.volume_topic, + "args": {"topic": FLAGS.volume_topic, "volume_id": volume_ref['id']}}) return {'volumeSet': [self._format_volume(context, volume_ref)]} @@ -539,12 +541,12 @@ class CloudController(object): internal_id = ec2_id_to_internal_id(instance_id) instance_ref = db.instance_get_by_internal_id(context, internal_id) host = instance_ref['host'] - rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host), - {"method": "attach_volume", - "args": {"context": None, - "volume_id": volume_ref['id'], - "instance_id": instance_ref['id'], - "mountpoint": device}}) + rpc.cast(context, + db.queue_get_for(context, FLAGS.compute_topic, host), + {"method": "attach_volume", + "args": {"volume_id": volume_ref['id'], + "instance_id": instance_ref['id'], + "mountpoint": device}}) return {'attachTime': volume_ref['attach_time'], 'device': volume_ref['mountpoint'], 'instanceId': instance_ref['id'], @@ -554,7 +556,8 @@ class CloudController(object): def detach_volume(self, context, volume_id, **kwargs): volume_ref = db.volume_get_by_ec2_id(context, volume_id) - instance_ref = db.volume_get_instance(context, volume_ref['id']) + instance_ref = db.volume_get_instance(context.elevated(), + volume_ref['id']) if not instance_ref: raise exception.ApiError("Volume isn't attached to anything!") # TODO(vish): abstract status checking? @@ -562,11 +565,11 @@ class CloudController(object): raise exception.ApiError("Volume is already detached") try: host = instance_ref['host'] - rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host), - {"method": "detach_volume", - "args": {"context": None, - "instance_id": instance_ref['id'], - "volume_id": volume_ref['id']}}) + rpc.cast(context, + db.queue_get_for(context, FLAGS.compute_topic, host), + {"method": "detach_volume", + "args": {"instance_id": instance_ref['id'], + "volume_id": volume_ref['id']}}) except exception.NotFound: # If the instance doesn't exist anymore, # then we need to call detach blind @@ -601,7 +604,7 @@ class CloudController(object): return self._format_describe_instances(context) def _format_describe_instances(self, context): - return { 'reservationSet': self._format_instances(context) } + return {'reservationSet': self._format_instances(context)} def _format_run_instances(self, context, reservation_id): i = self._format_instances(context, reservation_id) @@ -618,7 +621,7 @@ class CloudController(object): instances = db.instance_get_all(context) else: instances = db.instance_get_all_by_project(context, - context.project.id) + context.project_id) for instance in instances: if not context.user.is_admin(): if instance['image_id'] == FLAGS.vpn_image_id: @@ -673,7 +676,7 @@ class CloudController(object): iterator = db.floating_ip_get_all(context) else: iterator = db.floating_ip_get_all_by_project(context, - context.project.id) + context.project_id) for floating_ip_ref in iterator: address = floating_ip_ref['address'] instance_id = None @@ -694,24 +697,24 @@ class CloudController(object): # check quota if quota.allowed_floating_ips(context, 1) < 1: logging.warn("Quota exceeeded for %s, tried to allocate address", - context.project.id) + context.project_id) raise QuotaError("Address quota exceeded. You cannot " "allocate any more addresses") network_topic = self._get_network_topic(context) - public_ip = rpc.call(network_topic, - {"method": "allocate_floating_ip", - "args": {"context": None, - "project_id": context.project.id}}) + public_ip = rpc.call(context, + network_topic, + {"method": "allocate_floating_ip", + "args": {"project_id": context.project_id}}) return {'addressSet': [{'publicIp': public_ip}]} def release_address(self, context, public_ip, **kwargs): # NOTE(vish): Should we make sure this works? floating_ip_ref = db.floating_ip_get_by_address(context, public_ip) network_topic = self._get_network_topic(context) - rpc.cast(network_topic, + rpc.cast(context, + network_topic, {"method": "deallocate_floating_ip", - "args": {"context": None, - "floating_address": floating_ip_ref['address']}}) + "args": {"floating_address": floating_ip_ref['address']}}) return {'releaseResponse': ["Address released."]} def associate_address(self, context, ec2_id, public_ip, **kwargs): @@ -721,20 +724,20 @@ class CloudController(object): instance_ref['id']) floating_ip_ref = db.floating_ip_get_by_address(context, public_ip) network_topic = self._get_network_topic(context) - rpc.cast(network_topic, + rpc.cast(context, + network_topic, {"method": "associate_floating_ip", - "args": {"context": None, - "floating_address": floating_ip_ref['address'], + "args": {"floating_address": floating_ip_ref['address'], "fixed_address": fixed_address}}) return {'associateResponse': ["Address associated."]} def disassociate_address(self, context, public_ip, **kwargs): floating_ip_ref = db.floating_ip_get_by_address(context, public_ip) network_topic = self._get_network_topic(context) - rpc.cast(network_topic, + rpc.cast(context, + network_topic, {"method": "disassociate_floating_ip", - "args": {"context": None, - "floating_address": floating_ip_ref['address']}}) + "args": {"floating_address": floating_ip_ref['address']}}) return {'disassociateResponse': ["Address disassociated."]} def _get_network_topic(self, context): @@ -742,22 +745,22 @@ class CloudController(object): network_ref = self.network_manager.get_network(context) host = network_ref['host'] if not host: - host = rpc.call(FLAGS.network_topic, - {"method": "set_network_host", - "args": {"context": None, - "network_id": network_ref['id']}}) + host = rpc.call(context, + FLAGS.network_topic, + {"method": "set_network_host", + "args": {"network_id": network_ref['id']}}) return db.queue_get_for(context, FLAGS.network_topic, host) def _ensure_default_security_group(self, context): try: db.security_group_get_by_name(context, - context.project.id, + context.project_id, 'default') except exception.NotFound: values = { 'name' : 'default', 'description' : 'default', 'user_id' : context.user.id, - 'project_id' : context.project.id } + 'project_id' : context.project_id } group = db.security_group_create(context, values) def run_instances(self, context, **kwargs): @@ -773,7 +776,7 @@ class CloudController(object): instance_type) if num_instances < min_instances: logging.warn("Quota exceeeded for %s, tried to run %s instances", - context.project.id, min_instances) + context.project_id, min_instances) raise QuotaError("Instance quota exceeded. You can only " "run %s more instances of this type." % num_instances, "InstanceLimitExceeded") @@ -815,7 +818,7 @@ class CloudController(object): self._ensure_default_security_group(context) for security_group_name in security_group_arg: group = db.security_group_get_by_name(context, - context.project.id, + context.project_id, security_group_name) security_groups.append(group['id']) @@ -829,7 +832,7 @@ class CloudController(object): base_options['key_data'] = key_data base_options['key_name'] = kwargs.get('key_name', None) base_options['user_id'] = context.user.id - base_options['project_id'] = context.project.id + base_options['project_id'] = context.project_id base_options['user_data'] = kwargs.get('user_data', '') base_options['display_name'] = kwargs.get('display_name') @@ -840,13 +843,15 @@ class CloudController(object): base_options['memory_mb'] = type_data['memory_mb'] base_options['vcpus'] = type_data['vcpus'] base_options['local_gb'] = type_data['local_gb'] + elevated = context.elevated() for num in range(num_instances): instance_ref = db.instance_create(context, base_options) inst_id = instance_ref['id'] for security_group_id in security_groups: - db.instance_add_security_group(context, inst_id, + db.instance_add_security_group(elevated, + inst_id, security_group_id) inst = {} @@ -864,15 +869,15 @@ class CloudController(object): inst_id, vpn) network_topic = self._get_network_topic(context) - rpc.call(network_topic, + rpc.cast(elevated, + network_topic, {"method": "setup_fixed_ip", - "args": {"context": None, - "address": address}}) + "args": {"address": address}}) - rpc.cast(FLAGS.scheduler_topic, + rpc.cast(context, + FLAGS.scheduler_topic, {"method": "run_instance", - "args": {"context": None, - "topic": FLAGS.compute_topic, + "args": {"topic": FLAGS.compute_topic, "instance_id": inst_id}}) logging.debug("Casting to scheduler for %s/%s's instance %s" % (context.project.name, context.user.name, inst_id)) @@ -890,17 +895,23 @@ class CloudController(object): internal_id = ec2_id_to_internal_id(id_str) logging.debug("Going to try and terminate %s" % id_str) try: - instance_ref = db.instance_get_by_internal_id(context, + instance_ref = db.instance_get_by_internal_id(context, internal_id) except exception.NotFound: - logging.warning("Instance %s was not found during terminate" - % id_str) + logging.warning("Instance %s was not found during terminate", + id_str) continue + if (instance_ref['state_description'] == 'terminating'): + logging.warning("Instance %s is already being terminated", + id_str) + continue now = datetime.datetime.utcnow() db.instance_update(context, instance_ref['id'], - {'terminated_at': now}) + {'state_description': 'terminating', + 'state': 0, + 'terminated_at': now}) # FIXME(ja): where should network deallocate occur? address = db.instance_get_floating_address(context, instance_ref['id']) @@ -910,10 +921,10 @@ class CloudController(object): # disassociated. We may need to worry about # checking this later. Perhaps in the scheduler? network_topic = self._get_network_topic(context) - rpc.cast(network_topic, + rpc.cast(context, + network_topic, {"method": "disassociate_floating_ip", - "args": {"context": None, - "floating_address": address}}) + "args": {"floating_address": address}}) address = db.instance_get_fixed_address(context, instance_ref['id']) @@ -922,14 +933,15 @@ class CloudController(object): # NOTE(vish): Currently, nothing needs to be done on the # network node until release. If this changes, # we will need to cast here. - self.network_manager.deallocate_fixed_ip(context, address) + self.network_manager.deallocate_fixed_ip(context.elevated(), + address) host = instance_ref['host'] if host: - rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host), + rpc.cast(context, + db.queue_get_for(context, FLAGS.compute_topic, host), {"method": "terminate_instance", - "args": {"context": None, - "instance_id": instance_ref['id']}}) + "args": {"instance_id": instance_ref['id']}}) else: db.instance_destroy(context, instance_ref['id']) return True @@ -947,10 +959,9 @@ class CloudController(object): if field in kwargs: changes[field] = kwargs[field] if changes: - db_context = {} internal_id = ec2_id_to_internal_id(ec2_id) - inst = db.instance_get_by_internal_id(db_context, internal_id) - db.instance_update(db_context, inst['id'], kwargs) + inst = db.instance_get_by_internal_id(context, internal_id) + db.instance_update(context, inst['id'], kwargs) return True def delete_volume(self, context, volume_id, **kwargs): @@ -959,12 +970,13 @@ class CloudController(object): if volume_ref['status'] != "available": raise exception.ApiError("Volume status must be available") now = datetime.datetime.utcnow() - db.volume_update(context, volume_ref['id'], {'terminated_at': now}) + db.volume_update(context, volume_ref['id'], {'status': 'deleting', + 'terminated_at': now}) host = volume_ref['host'] - rpc.cast(db.queue_get_for(context, FLAGS.volume_topic, host), + rpc.cast(context, + db.queue_get_for(context, FLAGS.volume_topic, host), {"method": "delete_volume", - "args": {"context": None, - "volume_id": volume_ref['id']}}) + "args": {"volume_id": volume_ref['id']}}) return True def describe_images(self, context, image_id=None, **kwargs): diff --git a/nova/api/openstack/auth.py b/nova/api/openstack/auth.py index 4c909293e..7aba55728 100644 --- a/nova/api/openstack/auth.py +++ b/nova/api/openstack/auth.py @@ -24,9 +24,9 @@ class BasicApiAuthManager(object): def __init__(self, host=None, db_driver=None): if not host: host = FLAGS.host - self.host = host + self.host = host if not db_driver: - db_driver = FLAGS.db_driver + db_driver = FLAGS.db_driver self.db = utils.import_object(db_driver) self.auth = auth.manager.AuthManager() self.context = Context() @@ -40,20 +40,19 @@ class BasicApiAuthManager(object): return faults.Fault(webob.exc.HTTPUnauthorized()) try: - username, key = req.headers['X-Auth-User'], \ - req.headers['X-Auth-Key'] + username = req.headers['X-Auth-User'] + key = req.headers['X-Auth-Key'] except KeyError: return faults.Fault(webob.exc.HTTPUnauthorized()) - username, key = req.headers['X-Auth-User'], req.headers['X-Auth-Key'] token, user = self._authorize_user(username, key) if user and token: res = webob.Response() - res.headers['X-Auth-Token'] = token['token_hash'] + res.headers['X-Auth-Token'] = token.token_hash res.headers['X-Server-Management-Url'] = \ - token['server_management_url'] - res.headers['X-Storage-Url'] = token['storage_url'] - res.headers['X-CDN-Management-Url'] = token['cdn_management_url'] + token.server_management_url + res.headers['X-Storage-Url'] = token.storage_url + res.headers['X-CDN-Management-Url'] = token.cdn_management_url res.content_type = 'text/plain' res.status = '204' return res @@ -65,34 +64,35 @@ class BasicApiAuthManager(object): If the token has expired, returns None If the token is not found, returns None - Otherwise returns the token + Otherwise returns dict(id=(the authorized user's id)) This method will also remove the token if the timestamp is older than 2 days ago. """ token = self.db.auth_get_token(self.context, token_hash) if token: - delta = datetime.datetime.now() - token['created_at'] + delta = datetime.datetime.now() - token.created_at if delta.days >= 2: self.db.auth_destroy_token(self.context, token) else: - user = self.auth.get_user(token['user_id']) - return { 'id':user['uid'] } + #TODO(gundlach): Why not just return dict(id=token.user_id)? + user = self.auth.get_user(token.user_id) + return {'id': user.id} return None def _authorize_user(self, username, key): """ Generates a new token and assigns it to a user """ user = self.auth.get_user_from_access_key(key) - if user and user['name'] == username: + if user and user.name == username: token_hash = hashlib.sha1('%s%s%f' % (username, key, time.time())).hexdigest() - token = {} - token['token_hash'] = token_hash - token['cdn_management_url'] = '' - token['server_management_url'] = self._get_server_mgmt_url() - token['storage_url'] = '' - token['user_id'] = user['uid'] - self.db.auth_create_token(self.context, token) + token_dict = {} + token_dict['token_hash'] = token_hash + token_dict['cdn_management_url'] = '' + token_dict['server_management_url'] = self._get_server_mgmt_url() + token_dict['storage_url'] = '' + token_dict['user_id'] = user.id + token = self.db.auth_create_token(self.context, token_dict) return token, user return None, None diff --git a/nova/api/openstack/context.py b/nova/api/openstack/context.py deleted file mode 100644 index 77394615b..000000000 --- a/nova/api/openstack/context.py +++ /dev/null @@ -1,33 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -""" -APIRequestContext -""" - -import random - -class Project(object): - def __init__(self, user_id): - self.id = user_id - -class APIRequestContext(object): - """ This is an adapter class to get around all of the assumptions made in - the FlatNetworking """ - def __init__(self, user_id): - self.user_id = user_id - self.project = Project(user_id) diff --git a/nova/api/openstack/servers.py b/nova/api/openstack/servers.py index cb5132635..a73591ccc 100644 --- a/nova/api/openstack/servers.py +++ b/nova/api/openstack/servers.py @@ -24,8 +24,8 @@ from nova import flags from nova import rpc from nova import utils from nova import wsgi +from nova import context from nova.api import cloud -from nova.api.openstack import context from nova.api.openstack import faults from nova.compute import instance_types from nova.compute import power_state @@ -93,6 +93,7 @@ class Controller(wsgi.Controller): if not db_driver: db_driver = FLAGS.db_driver self.db_driver = utils.import_object(db_driver) + self.network_manager = utils.import_object(FLAGS.network_manager) super(Controller, self).__init__() def index(self, req): @@ -109,7 +110,8 @@ class Controller(wsgi.Controller): entity_maker - either _entity_detail or _entity_inst """ user_id = req.environ['nova.context']['user']['id'] - instance_list = self.db_driver.instance_get_all_by_user(None, user_id) + ctxt = context.RequestContext(user_id, user_id) + instance_list = self.db_driver.instance_get_all_by_user(ctxt, user_id) limited_list = nova.api.openstack.limited(instance_list, req) res = [entity_maker(inst)['server'] for inst in limited_list] return _entity_list(res) @@ -117,7 +119,8 @@ class Controller(wsgi.Controller): def show(self, req, id): """ Returns server details by server id """ user_id = req.environ['nova.context']['user']['id'] - inst = self.db_driver.instance_get_by_internal_id(None, int(id)) + ctxt = context.RequestContext(user_id, user_id) + inst = self.db_driver.instance_get_by_internal_id(ctxt, int(id)) if inst: if inst.user_id == user_id: return _entity_detail(inst) @@ -126,9 +129,10 @@ class Controller(wsgi.Controller): def delete(self, req, id): """ Destroys a server """ user_id = req.environ['nova.context']['user']['id'] - instance = self.db_driver.instance_get_by_internal_id(None, int(id)) + ctxt = context.RequestContext(user_id, user_id) + instance = self.db_driver.instance_get_by_internal_id(ctxt, int(id)) if instance and instance['user_id'] == user_id: - self.db_driver.instance_destroy(None, id) + self.db_driver.instance_destroy(ctxt, id) return faults.Fault(exc.HTTPAccepted()) return faults.Fault(exc.HTTPNotFound()) @@ -144,39 +148,43 @@ class Controller(wsgi.Controller): #except Exception, e: # return faults.Fault(exc.HTTPUnprocessableEntity()) - rpc.cast( - FLAGS.compute_topic, { - "method": "run_instance", - "args": {"instance_id": inst['id']}}) + user_id = req.environ['nova.context']['user']['id'] + rpc.cast(context.RequestContext(user_id, user_id), + FLAGS.compute_topic, + {"method": "run_instance", + "args": {"instance_id": inst['id']}}) return _entity_inst(inst) def update(self, req, id): """ Updates the server name or password """ user_id = req.environ['nova.context']['user']['id'] + ctxt = context.RequestContext(user_id, user_id) inst_dict = self._deserialize(req.body, req) if not inst_dict: return faults.Fault(exc.HTTPUnprocessableEntity()) - instance = self.db_driver.instance_get_by_internal_id(None, int(id)) + instance = self.db_driver.instance_get_by_internal_id(ctxt, int(id)) if not instance or instance.user_id != user_id: return faults.Fault(exc.HTTPNotFound()) - self.db_driver.instance_update(None, int(id), - _filter_params(inst_dict['server'])) + self.db_driver.instance_update(ctxt, + int(id), + _filter_params(inst_dict['server'])) return faults.Fault(exc.HTTPNoContent()) def action(self, req, id): """ multi-purpose method used to reboot, rebuild, and resize a server """ user_id = req.environ['nova.context']['user']['id'] + ctxt = context.RequestContext(user_id, user_id) input_dict = self._deserialize(req.body, req) try: reboot_type = input_dict['reboot']['type'] except Exception: raise faults.Fault(webob.exc.HTTPNotImplemented()) - inst_ref = self.db.instance_get_by_internal_id(None, int(id)) + inst_ref = self.db.instance_get_by_internal_id(ctxt, int(id)) if not inst_ref or (inst_ref and not inst_ref.user_id == user_id): return faults.Fault(exc.HTTPUnprocessableEntity()) cloud.reboot(id) @@ -187,6 +195,7 @@ class Controller(wsgi.Controller): inst = {} user_id = req.environ['nova.context']['user']['id'] + ctxt = context.RequestContext(user_id, user_id) flavor_id = env['server']['flavorId'] @@ -233,12 +242,8 @@ class Controller(wsgi.Controller): inst['vcpus'] = flavor['vcpus'] inst['local_gb'] = flavor['local_gb'] - ref = self.db_driver.instance_create(None, inst) + ref = self.db_driver.instance_create(ctxt, inst) inst['id'] = ref.internal_id - # TODO(dietz): this isn't explicitly necessary, but the networking - # calls depend on an object with a project_id property, and therefore - # should be cleaned up later - api_context = context.APIRequestContext(user_id) inst['mac_address'] = utils.generate_mac() @@ -246,19 +251,19 @@ class Controller(wsgi.Controller): inst['launch_index'] = 0 inst['hostname'] = str(ref.internal_id) - self.db_driver.instance_update(api_context, inst['id'], inst) + self.db_driver.instance_update(ctxt, inst['id'], inst) network_manager = utils.import_object(FLAGS.network_manager) - address = network_manager.allocate_fixed_ip(api_context, + address = network_manager.allocate_fixed_ip(ctxt, inst['id']) # TODO(vish): This probably should be done in the scheduler # network is setup when host is assigned - network_topic = self._get_network_topic(api_context, network_manager) - rpc.call(network_topic, + network_topic = self._get_network_topic(ctxt, network_manager) + rpc.call(ctxt, + network_topic, {"method": "setup_fixed_ip", - "args": {"context": api_context, - "address": address}}) + "args": {"address": address}}) return inst def _get_network_topic(self, context, network_manager): @@ -266,8 +271,8 @@ class Controller(wsgi.Controller): network_ref = network_manager.get_network(context) host = network_ref['host'] if not host: - host = rpc.call(FLAGS.network_topic, - {"method": "set_network_host", - "args": {"context": context, - "network_id": network_ref['id']}}) - return self.db_driver.queue_get_for(None, FLAGS.network_topic, host) + host = rpc.call(context, + FLAGS.network_topic, + {"method": "set_network_host", + "args": {"network_id": network_ref['id']}}) + return self.db_driver.queue_get_for(context, FLAGS.network_topic, host) diff --git a/nova/auth/dbdriver.py b/nova/auth/dbdriver.py index 09d15018b..648d6e828 100644 --- a/nova/auth/dbdriver.py +++ b/nova/auth/dbdriver.py @@ -23,6 +23,7 @@ Auth driver using the DB as its backend. import logging import sys +from nova import context from nova import exception from nova import db @@ -46,26 +47,26 @@ class DbDriver(object): def get_user(self, uid): """Retrieve user by id""" - return self._db_user_to_auth_user(db.user_get({}, uid)) + return self._db_user_to_auth_user(db.user_get(context.get_admin_context(), uid)) def get_user_from_access_key(self, access): """Retrieve user by access key""" - return self._db_user_to_auth_user(db.user_get_by_access_key({}, access)) + return self._db_user_to_auth_user(db.user_get_by_access_key(context.get_admin_context(), access)) def get_project(self, pid): """Retrieve project by id""" - return self._db_project_to_auth_projectuser(db.project_get({}, pid)) + return self._db_project_to_auth_projectuser(db.project_get(context.get_admin_context(), pid)) def get_users(self): """Retrieve list of users""" - return [self._db_user_to_auth_user(user) for user in db.user_get_all({})] + return [self._db_user_to_auth_user(user) for user in db.user_get_all(context.get_admin_context())] def get_projects(self, uid=None): """Retrieve list of projects""" if uid: - result = db.project_get_by_user({}, uid) + result = db.project_get_by_user(context.get_admin_context(), uid) else: - result = db.project_get_all({}) + result = db.project_get_all(context.get_admin_context()) return [self._db_project_to_auth_projectuser(proj) for proj in result] def create_user(self, name, access_key, secret_key, is_admin): @@ -76,7 +77,7 @@ class DbDriver(object): 'is_admin' : is_admin } try: - user_ref = db.user_create({}, values) + user_ref = db.user_create(context.get_admin_context(), values) return self._db_user_to_auth_user(user_ref) except exception.Duplicate, e: raise exception.Duplicate('User %s already exists' % name) @@ -98,7 +99,7 @@ class DbDriver(object): def create_project(self, name, manager_uid, description=None, member_uids=None): """Create a project""" - manager = db.user_get({}, manager_uid) + manager = db.user_get(context.get_admin_context(), manager_uid) if not manager: raise exception.NotFound("Project can't be created because " "manager %s doesn't exist" % manager_uid) @@ -113,7 +114,7 @@ class DbDriver(object): members = set([manager]) if member_uids != None: for member_uid in member_uids: - member = db.user_get({}, member_uid) + member = db.user_get(context.get_admin_context(), member_uid) if not member: raise exception.NotFound("Project can't be created " "because user %s doesn't exist" @@ -126,17 +127,20 @@ class DbDriver(object): 'description': description } try: - project = db.project_create({}, values) + project = db.project_create(context.get_admin_context(), values) except exception.Duplicate: raise exception.Duplicate("Project can't be created because " "project %s already exists" % name) for member in members: - db.project_add_member({}, project['id'], member['id']) + db.project_add_member(context.get_admin_context(), + project['id'], + member['id']) # This looks silly, but ensures that the members element has been # correctly populated - project_ref = db.project_get({}, project['id']) + project_ref = db.project_get(context.get_admin_context(), + project['id']) return self._db_project_to_auth_projectuser(project_ref) def modify_project(self, project_id, manager_uid=None, description=None): @@ -145,7 +149,7 @@ class DbDriver(object): return values = {} if manager_uid: - manager = db.user_get({}, manager_uid) + manager = db.user_get(context.get_admin_context(), manager_uid) if not manager: raise exception.NotFound("Project can't be modified because " "manager %s doesn't exist" % @@ -154,17 +158,21 @@ class DbDriver(object): if description: values['description'] = description - db.project_update({}, project_id, values) + db.project_update(context.get_admin_context(), project_id, values) def add_to_project(self, uid, project_id): """Add user to project""" user, project = self._validate_user_and_project(uid, project_id) - db.project_add_member({}, project['id'], user['id']) + db.project_add_member(context.get_admin_context(), + project['id'], + user['id']) def remove_from_project(self, uid, project_id): """Remove user from project""" user, project = self._validate_user_and_project(uid, project_id) - db.project_remove_member({}, project['id'], user['id']) + db.project_remove_member(context.get_admin_context(), + project['id'], + user['id']) def is_in_project(self, uid, project_id): """Check if user is in project""" @@ -183,34 +191,37 @@ class DbDriver(object): def add_role(self, uid, role, project_id=None): """Add role for user (or user and project)""" if not project_id: - db.user_add_role({}, uid, role) + db.user_add_role(context.get_admin_context(), uid, role) return - db.user_add_project_role({}, uid, project_id, role) + db.user_add_project_role(context.get_admin_context(), + uid, project_id, role) def remove_role(self, uid, role, project_id=None): """Remove role for user (or user and project)""" if not project_id: - db.user_remove_role({}, uid, role) + db.user_remove_role(context.get_admin_context(), uid, role) return - db.user_remove_project_role({}, uid, project_id, role) + db.user_remove_project_role(context.get_admin_context(), + uid, project_id, role) def get_user_roles(self, uid, project_id=None): """Retrieve list of roles for user (or user and project)""" if project_id is None: - roles = db.user_get_roles({}, uid) + roles = db.user_get_roles(context.get_admin_context(), uid) return roles else: - roles = db.user_get_roles_for_project({}, uid, project_id) + roles = db.user_get_roles_for_project(context.get_admin_context(), + uid, project_id) return roles def delete_user(self, id): """Delete a user""" - user = db.user_get({}, id) - db.user_delete({}, user['id']) + user = db.user_get(context.get_admin_context(), id) + db.user_delete(context.get_admin_context(), user['id']) def delete_project(self, project_id): """Delete a project""" - db.project_delete({}, project_id) + db.project_delete(context.get_admin_context(), project_id) def modify_user(self, uid, access_key=None, secret_key=None, admin=None): """Modify an existing user""" @@ -223,13 +234,13 @@ class DbDriver(object): values['secret_key'] = secret_key if admin is not None: values['is_admin'] = admin - db.user_update({}, uid, values) + db.user_update(context.get_admin_context(), uid, values) def _validate_user_and_project(self, user_id, project_id): - user = db.user_get({}, user_id) + user = db.user_get(context.get_admin_context(), user_id) if not user: raise exception.NotFound('User "%s" not found' % user_id) - project = db.project_get({}, project_id) + project = db.project_get(context.get_admin_context(), project_id) if not project: raise exception.NotFound('Project "%s" not found' % project_id) return user, project diff --git a/nova/auth/fakeldap.py b/nova/auth/fakeldap.py index 2791dfde6..3e92c38f6 100644 --- a/nova/auth/fakeldap.py +++ b/nova/auth/fakeldap.py @@ -24,8 +24,30 @@ library to work with nova. """ import json +import redis -from nova import datastore +from nova import flags + +FLAGS = flags.FLAGS +flags.DEFINE_string('redis_host', '127.0.0.1', + 'Host that redis is running on.') +flags.DEFINE_integer('redis_port', 6379, + 'Port that redis is running on.') +flags.DEFINE_integer('redis_db', 0, 'Multiple DB keeps tests away') + +class Redis(object): + def __init__(self): + if hasattr(self.__class__, '_instance'): + raise Exception('Attempted to instantiate singleton') + + @classmethod + def instance(cls): + if not hasattr(cls, '_instance'): + inst = redis.Redis(host=FLAGS.redis_host, + port=FLAGS.redis_port, + db=FLAGS.redis_db) + cls._instance = inst + return cls._instance SCOPE_BASE = 0 @@ -164,11 +186,11 @@ class FakeLDAP(object): key = "%s%s" % (self.__redis_prefix, dn) value_dict = dict([(k, _to_json(v)) for k, v in attr]) - datastore.Redis.instance().hmset(key, value_dict) + Redis.instance().hmset(key, value_dict) def delete_s(self, dn): """Remove the ldap object at specified dn.""" - datastore.Redis.instance().delete("%s%s" % (self.__redis_prefix, dn)) + Redis.instance().delete("%s%s" % (self.__redis_prefix, dn)) def modify_s(self, dn, attrs): """Modify the object at dn using the attribute list. @@ -179,7 +201,7 @@ class FakeLDAP(object): ([MOD_ADD | MOD_DELETE | MOD_REPACE], attribute, value) """ - redis = datastore.Redis.instance() + redis = Redis.instance() key = "%s%s" % (self.__redis_prefix, dn) for cmd, k, v in attrs: @@ -204,7 +226,7 @@ class FakeLDAP(object): """ if scope != SCOPE_BASE and scope != SCOPE_SUBTREE: raise NotImplementedError(str(scope)) - redis = datastore.Redis.instance() + redis = Redis.instance() if scope == SCOPE_BASE: keys = ["%s%s" % (self.__redis_prefix, dn)] else: @@ -232,3 +254,5 @@ class FakeLDAP(object): def __redis_prefix(self): # pylint: disable-msg=R0201 """Get the prefix to use for all redis keys.""" return 'ldap:' + + diff --git a/nova/auth/manager.py b/nova/auth/manager.py index 9c499c98d..bf7ca8a95 100644 --- a/nova/auth/manager.py +++ b/nova/auth/manager.py @@ -28,6 +28,7 @@ import tempfile import uuid import zipfile +from nova import context from nova import crypto from nova import db from nova import exception @@ -201,7 +202,7 @@ class AuthManager(object): def __new__(cls, *args, **kwargs): """Returns the AuthManager singleton""" - if not cls._instance: + if not cls._instance or ('new' in kwargs and kwargs['new']): cls._instance = super(AuthManager, cls).__new__(cls) return cls._instance @@ -454,7 +455,7 @@ class AuthManager(object): return [Project(**project_dict) for project_dict in project_list] def create_project(self, name, manager_user, description=None, - member_users=None, context=None): + member_users=None): """Create a project @type name: str @@ -531,7 +532,7 @@ class AuthManager(object): Project.safe_id(project)) @staticmethod - def get_project_vpn_data(project, context=None): + def get_project_vpn_data(project): """Gets vpn ip and port for project @type project: Project or project_id @@ -542,7 +543,7 @@ class AuthManager(object): not been allocated for user. """ - network_ref = db.project_get_network(context, + network_ref = db.project_get_network(context.get_admin_context(), Project.safe_id(project)) if not network_ref['vpn_public_port']: @@ -550,7 +551,7 @@ class AuthManager(object): return (network_ref['vpn_public_address'], network_ref['vpn_public_port']) - def delete_project(self, project, context=None): + def delete_project(self, project): """Deletes a project""" with self.driver() as drv: drv.delete_project(Project.safe_id(project)) @@ -613,7 +614,8 @@ class AuthManager(object): Additionally deletes all users key_pairs""" uid = User.safe_id(user) - db.key_pair_destroy_all_by_user(None, uid) + db.key_pair_destroy_all_by_user(context.get_admin_context(), + uid) with self.driver() as drv: drv.delete_user(uid) diff --git a/nova/cloudpipe/pipelib.py b/nova/cloudpipe/pipelib.py index 706a175d9..4fc2c85cb 100644 --- a/nova/cloudpipe/pipelib.py +++ b/nova/cloudpipe/pipelib.py @@ -28,13 +28,13 @@ import os import tempfile import zipfile +from nova import context from nova import exception from nova import flags from nova import utils from nova.auth import manager # TODO(eday): Eventually changes these to something not ec2-specific from nova.api.ec2 import cloud -from nova.api.ec2 import context FLAGS = flags.FLAGS @@ -62,7 +62,7 @@ class CloudPipe(object): key_name = self.setup_key_pair(project.project_manager_id, project_id) zippy = open(zippath, "r") - context = context.APIRequestContext(user=project.project_manager, project=project) + context = context.RequestContext(user=project.project_manager, project=project) reservation = self.controller.run_instances(context, # run instances expects encoded userdata, it is decoded in the get_metadata_call diff --git a/nova/compute/manager.py b/nova/compute/manager.py index c602d013d..523bb8893 100644 --- a/nova/compute/manager.py +++ b/nova/compute/manager.py @@ -71,6 +71,7 @@ class ComputeManager(manager.Manager): @exception.wrap_exception def run_instance(self, context, instance_id, **_kwargs): """Launch a new instance with specified options.""" + context = context.elevated() instance_ref = self.db.instance_get(context, instance_id) if instance_ref['name'] in self.driver.list_instances(): raise exception.Error("Instance has already been created") @@ -106,6 +107,7 @@ class ComputeManager(manager.Manager): @exception.wrap_exception def terminate_instance(self, context, instance_id): """Terminate an instance on this machine.""" + context = context.elevated() logging.debug("instance %s: terminating", instance_id) instance_ref = self.db.instance_get(context, instance_id) @@ -114,10 +116,6 @@ class ComputeManager(manager.Manager): raise exception.Error('trying to destroy already destroyed' ' instance: %s' % instance_id) - self.db.instance_set_state(context, - instance_id, - power_state.NOSTATE, - 'shutting_down') yield self.driver.destroy(instance_ref) # TODO(ja): should we keep it in a terminated state for a bit? @@ -127,6 +125,7 @@ class ComputeManager(manager.Manager): @exception.wrap_exception def reboot_instance(self, context, instance_id): """Reboot an instance on this server.""" + context = context.elevated() self._update_state(context, instance_id) instance_ref = self.db.instance_get(context, instance_id) @@ -149,6 +148,7 @@ class ComputeManager(manager.Manager): @exception.wrap_exception def get_console_output(self, context, instance_id): """Send the console output for an instance.""" + context = context.elevated() logging.debug("instance %s: getting console output", instance_id) instance_ref = self.db.instance_get(context, instance_id) @@ -158,6 +158,7 @@ class ComputeManager(manager.Manager): @exception.wrap_exception def attach_volume(self, context, instance_id, volume_id, mountpoint): """Attach a volume to an instance.""" + context = context.elevated() logging.debug("instance %s: attaching volume %s to %s", instance_id, volume_id, mountpoint) instance_ref = self.db.instance_get(context, instance_id) @@ -173,6 +174,7 @@ class ComputeManager(manager.Manager): @exception.wrap_exception def detach_volume(self, context, instance_id, volume_id): """Detach a volume from an instance.""" + context = context.elevated() logging.debug("instance %s: detaching volume %s", instance_id, volume_id) diff --git a/nova/context.py b/nova/context.py new file mode 100644 index 000000000..f5d3fed08 --- /dev/null +++ b/nova/context.py @@ -0,0 +1,114 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +RequestContext: context for requests that persist through all of nova. +""" + +import datetime +import random + +from nova import exception +from nova import utils + +class RequestContext(object): + def __init__(self, user, project, is_admin=None, read_deleted=False, + remote_address=None, timestamp=None, request_id=None): + if hasattr(user, 'id'): + self._user = user + self.user_id = user.id + else: + self._user = None + self.user_id = user + if hasattr(project, 'id'): + self._project = project + self.project_id = project.id + else: + self._project = None + self.project_id = project + if is_admin is None: + if self.user_id and self.user: + self.is_admin = self.user.is_admin() + else: + self.is_admin = False + else: + self.is_admin = is_admin + self.read_deleted = read_deleted + self.remote_address = remote_address + if not timestamp: + timestamp = datetime.datetime.utcnow() + if isinstance(timestamp, str) or isinstance(timestamp, unicode): + timestamp = utils.parse_isotime(timestamp) + self.timestamp = timestamp + if not request_id: + request_id = ''.join( + [random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-') + for x in xrange(20)] + ) + self.request_id = request_id + + @property + def user(self): + # NOTE(vish): Delay import of manager, so that we can import this + # file from manager. + from nova.auth import manager + if not self._user: + try: + self._user = manager.AuthManager().get_user(self.user_id) + except exception.NotFound: + pass + return self._user + + @property + def project(self): + # NOTE(vish): Delay import of manager, so that we can import this + # file from manager. + from nova.auth import manager + if not self._project: + try: + self._project = manager.AuthManager().get_project(self.project_id) + except exception.NotFound: + pass + return self._project + + def to_dict(self): + return {'user': self.user_id, + 'project': self.project_id, + 'is_admin': self.is_admin, + 'read_deleted': self.read_deleted, + 'remote_address': self.remote_address, + 'timestamp': utils.isotime(self.timestamp), + 'request_id': self.request_id} + + @classmethod + def from_dict(cls, values): + return cls(**values) + + def elevated(self, read_deleted=False): + """Return a version of this context with admin flag set""" + return RequestContext(self.user_id, + self.project_id, + True, + read_deleted, + self.remote_address, + self.timestamp, + self.request_id) + + +def get_admin_context(read_deleted=False): + return RequestContext(None, None, True, read_deleted) diff --git a/nova/datastore.py b/nova/datastore.py deleted file mode 100644 index 8e2519429..000000000 --- a/nova/datastore.py +++ /dev/null @@ -1,53 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 United States Government as represented by the -# Administrator of the National Aeronautics and Space Administration. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -""" -Datastore: - -MAKE Sure that ReDIS is running, and your flags are set properly, -before trying to run this. -""" - -import logging -import redis - -from nova import flags - -FLAGS = flags.FLAGS -flags.DEFINE_string('redis_host', '127.0.0.1', - 'Host that redis is running on.') -flags.DEFINE_integer('redis_port', 6379, - 'Port that redis is running on.') -flags.DEFINE_integer('redis_db', 0, 'Multiple DB keeps tests away') - - -class Redis(object): - def __init__(self): - if hasattr(self.__class__, '_instance'): - raise Exception('Attempted to instantiate singleton') - - @classmethod - def instance(cls): - if not hasattr(cls, '_instance'): - inst = redis.Redis(host=FLAGS.redis_host, - port=FLAGS.redis_port, - db=FLAGS.redis_db) - cls._instance = inst - return cls._instance - - diff --git a/nova/db/api.py b/nova/db/api.py index a655e6a8a..6dbf3b809 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -258,7 +258,7 @@ def instance_get_all(context): def instance_get_all_by_user(context, user_id): """Get all instances.""" - return IMPL.instance_get_all(context, user_id) + return IMPL.instance_get_all_by_user(context, user_id) def instance_get_all_by_project(context, project_id): """Get all instance belonging to a project.""" @@ -466,13 +466,18 @@ def export_device_count(context): return IMPL.export_device_count(context) -def export_device_create(context, values): - """Create an export_device from the values dictionary.""" - return IMPL.export_device_create(context, values) +def export_device_create_safe(context, values): + """Create an export_device from the values dictionary. + + The device is not returned. If the create violates the unique + constraints because the shelf_id and blade_id already exist, + no exception is raised.""" + return IMPL.export_device_create_safe(context, values) ################### + def auth_destroy_token(context, token): """Destroy an auth token""" return IMPL.auth_destroy_token(context, token) @@ -483,7 +488,7 @@ def auth_get_token(context, token_hash): def auth_create_token(context, token): """Creates a new token""" - return IMPL.auth_create_token(context, token_hash, token) + return IMPL.auth_create_token(context, token) ################### diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 14714d4b1..209d6e51f 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -33,7 +33,6 @@ from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload_all from sqlalchemy.sql import exists from sqlalchemy.sql import func -from sqlalchemy.orm.exc import NoResultFound FLAGS = flags.FLAGS @@ -43,6 +42,7 @@ def is_admin_context(context): if not context: warnings.warn('Use of empty request context is deprecated', DeprecationWarning) + raise Exception('die') return True return context.is_admin @@ -51,7 +51,9 @@ def is_user_context(context): """Indicates if the request context is a normal user.""" if not context: return False - if not context.user or not context.project: + if context.is_admin: + return False + if not context.user_id or not context.project_id: return False return True @@ -63,7 +65,7 @@ def authorize_project_context(context, project_id): if is_user_context(context): if not context.project: raise exception.NotAuthorized() - elif context.project.id != project_id: + elif context.project_id != project_id: raise exception.NotAuthorized() @@ -74,7 +76,7 @@ def authorize_user_context(context, user_id): if is_user_context(context): if not context.user: raise exception.NotAuthorized() - elif context.user.id != user_id: + elif context.user_id != user_id: raise exception.NotAuthorized() @@ -324,7 +326,7 @@ def floating_ip_destroy(context, address): session = get_session() with session.begin(): # TODO(devcamcar): Ensure address belongs to user. - floating_ip_ref = get_floating_ip_by_address(context, + floating_ip_ref = floating_ip_get_by_address(context, address, session=session) floating_ip_ref.delete(session=session) @@ -539,7 +541,7 @@ def instance_create(context, values): with session.begin(): while instance_ref.internal_id == None: internal_id = utils.generate_uid(instance_ref.__prefix__) - if not instance_internal_id_exists(context, internal_id, + if not instance_internal_id_exists(context, internal_id, session=session): instance_ref.internal_id = internal_id instance_ref.save(session=session) @@ -581,7 +583,7 @@ def instance_get(context, instance_id, session=None): elif is_user_context(context): result = session.query(models.Instance ).options(joinedload('security_groups') - ).filter_by(project_id=context.project.id + ).filter_by(project_id=context.project_id ).filter_by(id=instance_id ).filter_by(deleted=False ).first() @@ -640,7 +642,7 @@ def instance_get_all_by_reservation(context, reservation_id): return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).options(joinedload('security_groups') - ).filter_by(project_id=context.project.id + ).filter_by(project_id=context.project_id ).filter_by(reservation_id=reservation_id ).filter_by(deleted=False ).all() @@ -659,7 +661,7 @@ def instance_get_by_internal_id(context, internal_id): elif is_user_context(context): result = session.query(models.Instance ).options(joinedload('security_groups') - ).filter_by(project_id=context.project.id + ).filter_by(project_id=context.project_id ).filter_by(internal_id=internal_id ).filter_by(deleted=False ).first() @@ -897,7 +899,7 @@ def network_get(context, network_id, session=None): ).first() elif is_user_context(context): result = session.query(models.Network - ).filter_by(project_id=context.project.id + ).filter_by(project_id=context.project_id ).filter_by(id=network_id ).filter_by(deleted=False ).first() @@ -1023,12 +1025,15 @@ def export_device_count(context): @require_admin_context -def export_device_create(context, values): +def export_device_create_safe(context, values): export_device_ref = models.ExportDevice() for (key, value) in values.iteritems(): export_device_ref[key] = value - export_device_ref.save() - return export_device_ref + try: + export_device_ref.save() + return export_device_ref + except IntegrityError: + return None ################### @@ -1041,7 +1046,8 @@ def auth_destroy_token(_context, token): def auth_get_token(_context, token_hash): session = get_session() tk = session.query(models.AuthToken - ).filter_by(token_hash=token_hash) + ).filter_by(token_hash=token_hash + ).first() if not tk: raise exception.NotFound('Token %s does not exist' % token_hash) return tk @@ -1197,7 +1203,7 @@ def volume_get(context, volume_id, session=None): ).first() elif is_user_context(context): result = session.query(models.Volume - ).filter_by(project_id=context.project.id + ).filter_by(project_id=context.project_id ).filter_by(id=volume_id ).filter_by(deleted=False ).first() @@ -1237,7 +1243,7 @@ def volume_get_by_ec2_id(context, ec2_id): ).first() elif is_user_context(context): result = session.query(models.Volume - ).filter_by(project_id=context.project.id + ).filter_by(project_id=context.project_id ).filter_by(ec2_id=ec2_id ).filter_by(deleted=False ).first() @@ -1484,7 +1490,7 @@ def user_get_by_access_key(context, access_key, session=None): ).first() if not result: - raise exception.NotFound('No user for id %s' % id) + raise exception.NotFound('No user for access key %s' % access_key) return result @@ -1630,7 +1636,7 @@ def user_remove_project_role(context, user_id, project_id, role): with session.begin(): session.execute('delete from user_project_role_association where ' + \ 'user_id=:user_id and project_id=:project_id and ' + \ - 'role=:role', { 'user_id' : user_id, + 'role=:role', { 'user_id' : user_id, 'project_id' : project_id, 'role' : role }) diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index eed8f0578..a63bca2b0 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -277,6 +277,7 @@ class Quota(BASE, NovaBase): class ExportDevice(BASE, NovaBase): """Represates a shelf and blade that a volume can be exported on""" __tablename__ = 'export_devices' + __table_args__ = (schema.UniqueConstraint("shelf_id", "blade_id"), {'mysql_engine': 'InnoDB'}) id = Column(Integer, primary_key=True) shelf_id = Column(Integer) blade_id = Column(Integer) diff --git a/nova/network/manager.py b/nova/network/manager.py index 2ea1c1aa0..c7080ccd8 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -27,6 +27,7 @@ import math import IPy from twisted.internet import defer +from nova import context from nova import db from nova import exception from nova import flags @@ -79,13 +80,14 @@ class NetworkManager(manager.Manager): def init_host(self): # Set up networking for the projects for which we're already # the designated network host. - for network in self.db.host_get_networks(None, self.host): - self._on_set_network_host(None, network['id']) + ctxt = context.get_admin_context() + for network in self.db.host_get_networks(ctxt, self.host): + self._on_set_network_host(ctxt, network['id']) def set_network_host(self, context, network_id): """Safely sets the host of the network""" logging.debug("setting network host") - host = self.db.network_set_host(None, + host = self.db.network_set_host(context, network_id, self.host) self._on_set_network_host(context, network_id) @@ -227,9 +229,9 @@ class FlatManager(NetworkManager): # with a network, or a cluster of computes with a network # and use that network here with a method like # network_get_by_compute_host - network_ref = self.db.network_get_by_bridge(None, + network_ref = self.db.network_get_by_bridge(context, FLAGS.flat_network_bridge) - address = self.db.fixed_ip_associate_pool(None, + address = self.db.fixed_ip_associate_pool(context.elevated(), network_ref['id'], instance_id) self.db.fixed_ip_update(context, address, {'allocated': True}) @@ -238,7 +240,7 @@ class FlatManager(NetworkManager): def deallocate_fixed_ip(self, context, address, *args, **kwargs): """Returns a fixed ip to the pool""" self.db.fixed_ip_update(context, address, {'allocated': False}) - self.db.fixed_ip_disassociate(None, address) + self.db.fixed_ip_disassociate(context.elevated(), address) def setup_compute_network(self, context, instance_id): """Network is created manually""" @@ -338,12 +340,16 @@ class VlanManager(NetworkManager): # TODO(vish): This should probably be getting project_id from # the instance, but it is another trip to the db. # Perhaps this method should take an instance_ref. - network_ref = self.db.project_get_network(context, context.project.id) + ctxt = context.elevated() + network_ref = self.db.project_get_network(ctxt, + context.project_id) if kwargs.get('vpn', None): address = network_ref['vpn_private_address'] - self.db.fixed_ip_associate(None, address, instance_id) + self.db.fixed_ip_associate(ctxt, + address, + instance_id) else: - address = self.db.fixed_ip_associate_pool(None, + address = self.db.fixed_ip_associate_pool(ctxt, network_ref['id'], instance_id) self.db.fixed_ip_update(context, address, {'allocated': True}) @@ -402,7 +408,8 @@ class VlanManager(NetworkManager): def get_network(self, context): """Get the network for the current context""" - return self.db.project_get_network(None, context.project.id) + return self.db.project_get_network(context.elevated(), + context.project_id) def _on_set_network_host(self, context, network_id): """Called when this host becomes the host for a network""" diff --git a/nova/objectstore/bucket.py b/nova/objectstore/bucket.py index c2b412dd7..cfe5b14d8 100644 --- a/nova/objectstore/bucket.py +++ b/nova/objectstore/bucket.py @@ -83,7 +83,7 @@ class Bucket(object): os.makedirs(path) with open(path+'.json', 'w') as f: - json.dump({'ownerId': context.project.id}, f) + json.dump({'ownerId': context.project_id}, f) @property def metadata(self): @@ -106,7 +106,7 @@ class Bucket(object): def is_authorized(self, context): try: - return context.user.is_admin() or self.owner_id == context.project.id + return context.user.is_admin() or self.owner_id == context.project_id except Exception, e: return False diff --git a/nova/objectstore/handler.py b/nova/objectstore/handler.py index dfee64aca..b93e92fe6 100644 --- a/nova/objectstore/handler.py +++ b/nova/objectstore/handler.py @@ -52,10 +52,10 @@ from twisted.web import resource from twisted.web import server from twisted.web import static +from nova import context from nova import exception from nova import flags from nova.auth import manager -from nova.api.ec2 import context from nova.objectstore import bucket from nova.objectstore import image @@ -131,7 +131,7 @@ def get_context(request): request.uri, headers=request.getAllHeaders(), check_type='s3') - return context.APIRequestContext(user, project) + return context.RequestContext(user, project) except exception.Error as ex: logging.debug("Authentication Failure: %s", ex) raise exception.NotAuthorized @@ -144,7 +144,7 @@ class ErrorHandlingResource(resource.Resource): # plugged in to the right place in twisted... # This doesn't look like it's the right place # (consider exceptions in getChild; or after - # NOT_DONE_YET is returned + # NOT_DONE_YET is returned def render(self, request): """Renders the response as XML""" try: @@ -255,7 +255,7 @@ class ObjectResource(ErrorHandlingResource): def render_GET(self, request): """Returns the object - + Raises NotAuthorized if user in request context is not authorized to delete the object. """ @@ -273,7 +273,7 @@ class ObjectResource(ErrorHandlingResource): def render_PUT(self, request): """Modifies/inserts the object and returns a result code - + Raises NotAuthorized if user in request context is not authorized to delete the object. """ @@ -291,7 +291,7 @@ class ObjectResource(ErrorHandlingResource): def render_DELETE(self, request): """Deletes the object and returns a result code - + Raises NotAuthorized if user in request context is not authorized to delete the object. """ diff --git a/nova/objectstore/image.py b/nova/objectstore/image.py index c01b041bb..413b269b7 100644 --- a/nova/objectstore/image.py +++ b/nova/objectstore/image.py @@ -72,7 +72,7 @@ class Image(object): try: return (self.metadata['isPublic'] and readonly) or \ context.user.is_admin() or \ - self.metadata['imageOwnerId'] == context.project.id + self.metadata['imageOwnerId'] == context.project_id except: return False @@ -133,11 +133,11 @@ class Image(object): @type public: bool @param public: determine if this is a public image or private - + @rtype: str @return: a string with the image id """ - + image_type = 'machine' image_id = utils.generate_uid('ami') @@ -162,7 +162,7 @@ class Image(object): 'imageType': image_type, 'state': 'available' } - + if type(kernel) is str and len(kernel) > 0: info['kernelId'] = kernel @@ -203,7 +203,7 @@ class Image(object): info = { 'imageId': image_id, 'imageLocation': image_location, - 'imageOwnerId': context.project.id, + 'imageOwnerId': context.project_id, 'isPublic': False, # FIXME: grab public from manifest 'architecture': 'x86_64', # FIXME: grab architecture from manifest 'imageType' : image_type @@ -249,13 +249,13 @@ class Image(object): @staticmethod def decrypt_image(encrypted_filename, encrypted_key, encrypted_iv, cloud_private_key, decrypted_filename): key, err = utils.execute( - 'openssl rsautl -decrypt -inkey %s' % cloud_private_key, + 'openssl rsautl -decrypt -inkey %s' % cloud_private_key, process_input=encrypted_key, check_exit_code=False) if err: raise exception.Error("Failed to decrypt private key: %s" % err) iv, err = utils.execute( - 'openssl rsautl -decrypt -inkey %s' % cloud_private_key, + 'openssl rsautl -decrypt -inkey %s' % cloud_private_key, process_input=encrypted_iv, check_exit_code=False) if err: diff --git a/nova/quota.py b/nova/quota.py index edbb83111..045051207 100644 --- a/nova/quota.py +++ b/nova/quota.py @@ -54,7 +54,8 @@ def get_quota(context, project_id): def allowed_instances(context, num_instances, instance_type): """Check quota and return min(num_instances, allowed_instances)""" - project_id = context.project.id + project_id = context.project_id + context = context.elevated() used_instances, used_cores = db.instance_data_get_for_project(context, project_id) quota = get_quota(context, project_id) @@ -69,7 +70,8 @@ def allowed_instances(context, num_instances, instance_type): def allowed_volumes(context, num_volumes, size): """Check quota and return min(num_volumes, allowed_volumes)""" - project_id = context.project.id + project_id = context.project_id + context = context.elevated() used_volumes, used_gigabytes = db.volume_data_get_for_project(context, project_id) quota = get_quota(context, project_id) @@ -84,7 +86,8 @@ def allowed_volumes(context, num_volumes, size): def allowed_floating_ips(context, num_floating_ips): """Check quota and return min(num_floating_ips, allowed_floating_ips)""" - project_id = context.project.id + project_id = context.project_id + context = context.elevated() used_floating_ips = db.floating_ip_count_by_project(context, project_id) quota = get_quota(context, project_id) allowed_floating_ips = quota['floating_ips'] - used_floating_ips diff --git a/nova/rpc.py b/nova/rpc.py index 447ad3b93..965934205 100644 --- a/nova/rpc.py +++ b/nova/rpc.py @@ -35,7 +35,7 @@ from twisted.internet import task from nova import exception from nova import fakerabbit from nova import flags - +from nova import context FLAGS = flags.FLAGS @@ -161,6 +161,8 @@ class AdapterConsumer(TopicConsumer): LOG.debug('received %s' % (message_data)) msg_id = message_data.pop('_msg_id', None) + ctxt = _unpack_context(message_data) + method = message_data.get('method') args = message_data.get('args', {}) message.ack() @@ -177,7 +179,7 @@ class AdapterConsumer(TopicConsumer): node_args = dict((str(k), v) for k, v in args.iteritems()) # NOTE(vish): magic is fun! # pylint: disable-msg=W0142 - d = defer.maybeDeferred(node_func, **node_args) + d = defer.maybeDeferred(node_func, context=ctxt, **node_args) if msg_id: d.addCallback(lambda rval: msg_reply(msg_id, rval, None)) d.addErrback(lambda e: msg_reply(msg_id, None, e)) @@ -256,12 +258,35 @@ class RemoteError(exception.Error): traceback)) -def call(topic, msg): +def _unpack_context(msg): + """Unpack context from msg.""" + context_dict = {} + for key in list(msg.keys()): + if key.startswith('_context_'): + value = msg.pop(key) + context_dict[key[9:]] = value + LOG.debug('unpacked context: %s', context_dict) + return context.RequestContext.from_dict(context_dict) + +def _pack_context(msg, context): + """Pack context into msg. + + Values for message keys need to be less than 255 chars, so we pull + context out into a bunch of separate keys. If we want to support + more arguments in rabbit messages, we may want to do the same + for args at some point. + """ + context = dict([('_context_%s' % key, value) + for (key, value) in context.to_dict().iteritems()]) + msg.update(context) + +def call(context, topic, msg): """Sends a message on a topic and wait for a response""" LOG.debug("Making asynchronous call...") msg_id = uuid.uuid4().hex msg.update({'_msg_id': msg_id}) LOG.debug("MSG_ID is %s" % (msg_id)) + _pack_context(msg, context) class WaitMessage(object): @@ -291,12 +316,13 @@ def call(topic, msg): return wait_msg.result -def call_twisted(topic, msg): +def call_twisted(context, topic, msg): """Sends a message on a topic and wait for a response""" LOG.debug("Making asynchronous call...") msg_id = uuid.uuid4().hex msg.update({'_msg_id': msg_id}) LOG.debug("MSG_ID is %s" % (msg_id)) + _pack_context(msg, context) conn = Connection.instance() d = defer.Deferred() @@ -322,9 +348,10 @@ def call_twisted(topic, msg): return d -def cast(topic, msg): +def cast(context, topic, msg): """Sends a message on a topic without waiting for a response""" LOG.debug("Making asynchronous cast...") + _pack_context(msg, context) conn = Connection.instance() publisher = TopicPublisher(connection=conn, topic=topic) publisher.send(msg) diff --git a/nova/scheduler/manager.py b/nova/scheduler/manager.py index 0ad7ca86b..b3b2b4dce 100644 --- a/nova/scheduler/manager.py +++ b/nova/scheduler/manager.py @@ -54,13 +54,14 @@ class SchedulerManager(manager.Manager): Falls back to schedule(context, topic) if method doesn't exist. """ driver_method = 'schedule_%s' % method + elevated = context.elevated() try: - host = getattr(self.driver, driver_method)(context, *args, **kwargs) + host = getattr(self.driver, driver_method)(elevated, *args, **kwargs) except AttributeError: - host = self.driver.schedule(context, topic, *args, **kwargs) + host = self.driver.schedule(elevated, topic, *args, **kwargs) - kwargs.update({"context": None}) - rpc.cast(db.queue_get_for(context, topic, host), + rpc.cast(context, + db.queue_get_for(context, topic, host), {"method": method, "args": kwargs}) logging.debug("Casting to %s %s for %s", topic, host, method) diff --git a/nova/service.py b/nova/service.py index 115e0ff32..2d7961ab9 100644 --- a/nova/service.py +++ b/nova/service.py @@ -28,6 +28,7 @@ from twisted.internet import defer from twisted.internet import task from twisted.application import service +from nova import context from nova import db from nova import exception from nova import flags @@ -63,20 +64,22 @@ class Service(object, service.Service): **self.saved_kwargs) self.manager.init_host() self.model_disconnected = False + ctxt = context.get_admin_context() try: - service_ref = db.service_get_by_args(None, - self.host, - self.binary) + service_ref = db.service_get_by_args(ctxt, + self.host, + self.binary) self.service_id = service_ref['id'] except exception.NotFound: - self._create_service_ref() + self._create_service_ref(ctxt) - def _create_service_ref(self): - service_ref = db.service_create(None, {'host': self.host, - 'binary': self.binary, - 'topic': self.topic, - 'report_count': 0}) + def _create_service_ref(self, context): + service_ref = db.service_create(context, + {'host': self.host, + 'binary': self.binary, + 'topic': self.topic, + 'report_count': 0}) self.service_id = service_ref['id'] def __getattr__(self, key): @@ -142,31 +145,32 @@ class Service(object, service.Service): service_obj.setServiceParent(application) return application - def kill(self, context=None): + def kill(self): """Destroy the service object in the datastore""" try: - db.service_destroy(context, self.service_id) + db.service_destroy(context.get_admin_context(), self.service_id) except exception.NotFound: logging.warn("Service killed that has no database entry") @defer.inlineCallbacks - def periodic_tasks(self, context=None): + def periodic_tasks(self): """Tasks to be run at a periodic interval""" - yield self.manager.periodic_tasks(context) + yield self.manager.periodic_tasks(context.get_admin_context()) @defer.inlineCallbacks - def report_state(self, context=None): + def report_state(self): """Update the state of this service in the datastore.""" + ctxt = context.get_admin_context() try: try: - service_ref = db.service_get(context, self.service_id) + service_ref = db.service_get(ctxt, self.service_id) except exception.NotFound: logging.debug("The service database object disappeared, " "Recreating it.") - self._create_service_ref() - service_ref = db.service_get(context, self.service_id) + self._create_service_ref(ctxt) + service_ref = db.service_get(ctxt, self.service_id) - db.service_update(context, + db.service_update(ctxt, self.service_id, {'report_count': service_ref['report_count'] + 1}) diff --git a/nova/test.py b/nova/test.py index f6485377d..b9ea36e1d 100644 --- a/nova/test.py +++ b/nova/test.py @@ -32,6 +32,7 @@ from tornado import ioloop from twisted.internet import defer from twisted.trial import unittest +from nova import context from nova import db from nova import fakerabbit from nova import flags @@ -64,8 +65,9 @@ class TrialTestCase(unittest.TestCase): # now that we have some required db setup for the system # to work properly. self.start = datetime.datetime.utcnow() - if db.network_count(None) != 5: - network_manager.VlanManager().create_networks(None, + ctxt = context.get_admin_context() + if db.network_count(ctxt) != 5: + network_manager.VlanManager().create_networks(ctxt, FLAGS.fixed_range, 5, 16, FLAGS.vlan_start, @@ -87,8 +89,9 @@ class TrialTestCase(unittest.TestCase): self.stubs.SmartUnsetAll() self.mox.VerifyAll() # NOTE(vish): Clean up any ips associated during the test. - db.fixed_ip_disassociate_all_by_timeout(None, FLAGS.host, self.start) - db.network_disassociate_all(None) + ctxt = context.get_admin_context() + db.fixed_ip_disassociate_all_by_timeout(ctxt, FLAGS.host, self.start) + db.network_disassociate_all(ctxt) rpc.Consumer.attach_to_twisted = self.originalAttach for x in self.injected: try: @@ -98,7 +101,7 @@ class TrialTestCase(unittest.TestCase): if FLAGS.fake_rabbit: fakerabbit.reset_all() - db.security_group_destroy_all(None) + db.security_group_destroy_all(ctxt) super(TrialTestCase, self).tearDown() diff --git a/nova/tests/access_unittest.py b/nova/tests/access_unittest.py index 4b40ffd0a..8167259c4 100644 --- a/nova/tests/access_unittest.py +++ b/nova/tests/access_unittest.py @@ -20,6 +20,7 @@ import unittest import logging import webob +from nova import context from nova import exception from nova import flags from nova import test @@ -35,44 +36,25 @@ class AccessTestCase(test.TrialTestCase): def setUp(self): super(AccessTestCase, self).setUp() um = manager.AuthManager() + self.context = context.get_admin_context() # Make test users - try: - self.testadmin = um.create_user('testadmin') - except Exception, err: - logging.error(str(err)) - try: - self.testpmsys = um.create_user('testpmsys') - except: pass - try: - self.testnet = um.create_user('testnet') - except: pass - try: - self.testsys = um.create_user('testsys') - except: pass + self.testadmin = um.create_user('testadmin') + self.testpmsys = um.create_user('testpmsys') + self.testnet = um.create_user('testnet') + self.testsys = um.create_user('testsys') # Assign some rules - try: - um.add_role('testadmin', 'cloudadmin') - except: pass - try: - um.add_role('testpmsys', 'sysadmin') - except: pass - try: - um.add_role('testnet', 'netadmin') - except: pass - try: - um.add_role('testsys', 'sysadmin') - except: pass + um.add_role('testadmin', 'cloudadmin') + um.add_role('testpmsys', 'sysadmin') + um.add_role('testnet', 'netadmin') + um.add_role('testsys', 'sysadmin') # Make a test project - try: - self.project = um.create_project('testproj', 'testpmsys', 'a test project', ['testpmsys', 'testnet', 'testsys']) - except: pass - try: - self.project.add_role(self.testnet, 'netadmin') - except: pass - try: - self.project.add_role(self.testsys, 'sysadmin') - except: pass + self.project = um.create_project('testproj', + 'testpmsys', + 'a test project', + ['testpmsys', 'testnet', 'testsys']) + self.project.add_role(self.testnet, 'netadmin') + self.project.add_role(self.testsys, 'sysadmin') #user is set in each test def noopWSGIApp(environ, start_response): start_response('200 OK', []) @@ -97,10 +79,8 @@ class AccessTestCase(test.TrialTestCase): super(AccessTestCase, self).tearDown() def response_status(self, user, methodName): - context = Context() - context.project = self.project - context.user = user - environ = {'ec2.context' : context, + ctxt = context.RequestContext(user, self.project) + environ = {'ec2.context' : ctxt, 'ec2.controller': 'some string', 'ec2.action': methodName} req = webob.Request.blank('/', environ) diff --git a/nova/tests/api/openstack/fakes.py b/nova/tests/api/openstack/fakes.py index 58022bfde..041e2cf76 100644 --- a/nova/tests/api/openstack/fakes.py +++ b/nova/tests/api/openstack/fakes.py @@ -170,6 +170,10 @@ def stub_out_glance(stubs, initial_fixtures=[]): stubs.Set(nova.image.services.glance.GlanceImageService, 'delete_all', fake_parallax_client.fake_delete_all) +class FakeToken(object): + def __init__(self, **kwargs): + for k,v in kwargs.iteritems(): + setattr(self, k, v) class FakeAuthDatabase(object): data = {} @@ -180,12 +184,13 @@ class FakeAuthDatabase(object): @staticmethod def auth_create_token(context, token): - token['created_at'] = datetime.datetime.now() - FakeAuthDatabase.data[token['token_hash']] = token + fake_token = FakeToken(created_at=datetime.datetime.now(), **token) + FakeAuthDatabase.data[fake_token.token_hash] = fake_token + return fake_token @staticmethod def auth_destroy_token(context, token): - if FakeAuthDatabase.data.has_key(token['token_hash']): + if token.token_hash in FakeAuthDatabase.data: del FakeAuthDatabase.data['token_hash'] @@ -197,7 +202,7 @@ class FakeAuthManager(object): def get_user(self, uid): for k, v in FakeAuthManager.auth_data.iteritems(): - if v['uid'] == uid: + if v.id == uid: return v return None diff --git a/nova/tests/api/openstack/test_auth.py b/nova/tests/api/openstack/test_auth.py index d2ba80243..bbfb0fcea 100644 --- a/nova/tests/api/openstack/test_auth.py +++ b/nova/tests/api/openstack/test_auth.py @@ -7,6 +7,7 @@ import webob.dec import nova.api import nova.api.openstack.auth +import nova.auth.manager from nova import auth from nova.tests.api.openstack import fakes @@ -26,7 +27,7 @@ class Test(unittest.TestCase): def test_authorize_user(self): f = fakes.FakeAuthManager() - f.add_user('derp', { 'uid': 1, 'name':'herp' } ) + f.add_user('derp', nova.auth.manager.User(1, 'herp', None, None, None)) req = webob.Request.blank('/v1.0/') req.headers['X-Auth-User'] = 'herp' @@ -40,7 +41,7 @@ class Test(unittest.TestCase): def test_authorize_token(self): f = fakes.FakeAuthManager() - f.add_user('derp', { 'uid': 1, 'name':'herp' } ) + f.add_user('derp', nova.auth.manager.User(1, 'herp', None, None, None)) req = webob.Request.blank('/v1.0/') req.headers['X-Auth-User'] = 'herp' @@ -71,8 +72,9 @@ class Test(unittest.TestCase): self.destroy_called = True def bad_token(meh, context, token_hash): - return { 'token_hash':token_hash, - 'created_at':datetime.datetime(1990, 1, 1) } + return fakes.FakeToken( + token_hash=token_hash, + created_at=datetime.datetime(1990, 1, 1)) self.stubs.Set(fakes.FakeAuthDatabase, 'auth_destroy_token', destroy_token_mock) diff --git a/nova/tests/api_unittest.py b/nova/tests/api_unittest.py index 7ab27e000..414db1e11 100644 --- a/nova/tests/api_unittest.py +++ b/nova/tests/api_unittest.py @@ -25,6 +25,7 @@ import random import StringIO import webob +from nova import context from nova import flags from nova import test from nova import api @@ -131,7 +132,7 @@ class ApiEc2TestCase(test.BaseTestCase): user = self.manager.create_user('fake', 'fake', 'fake') project = self.manager.create_project('fake', 'fake', 'fake') # NOTE(vish): create depends on pool, so call helper directly - cloud._gen_key(None, user.id, keyname) + cloud._gen_key(context.get_admin_context(), user.id, keyname) rv = self.ec2.get_all_key_pairs() results = [k for k in rv if k.name == keyname] diff --git a/nova/tests/auth_unittest.py b/nova/tests/auth_unittest.py index 99f7ab599..97d22d702 100644 --- a/nova/tests/auth_unittest.py +++ b/nova/tests/auth_unittest.py @@ -80,7 +80,7 @@ class AuthManagerTestCase(object): FLAGS.auth_driver = self.auth_driver super(AuthManagerTestCase, self).setUp() self.flags(connection_type='fake') - self.manager = manager.AuthManager() + self.manager = manager.AuthManager(new=True) def test_create_and_find_user(self): with user_generator(self.manager): @@ -117,7 +117,7 @@ class AuthManagerTestCase(object): self.assert_(filter(lambda u: u.id == 'test1', users)) self.assert_(filter(lambda u: u.id == 'test2', users)) self.assert_(not filter(lambda u: u.id == 'test3', users)) - + def test_can_add_and_remove_user_role(self): with user_generator(self.manager): self.assertFalse(self.manager.has_role('test1', 'itsec')) @@ -324,6 +324,19 @@ class AuthManagerTestCase(object): class AuthManagerLdapTestCase(AuthManagerTestCase, test.TrialTestCase): auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver' + def __init__(self, *args, **kwargs): + AuthManagerTestCase.__init__(self) + test.TrialTestCase.__init__(self, *args, **kwargs) + import nova.auth.fakeldap as fakeldap + FLAGS.redis_db = 8 + if FLAGS.flush_db: + logging.info("Flushing redis datastore") + try: + r = fakeldap.Redis.instance() + r.flushdb() + except: + self.skip = True + class AuthManagerDbTestCase(AuthManagerTestCase, test.TrialTestCase): auth_driver = 'nova.auth.dbdriver.DbDriver' diff --git a/nova/tests/cloud_unittest.py b/nova/tests/cloud_unittest.py index ff466135d..20099069c 100644 --- a/nova/tests/cloud_unittest.py +++ b/nova/tests/cloud_unittest.py @@ -30,6 +30,7 @@ from twisted.internet import defer import unittest from xml.etree import ElementTree +from nova import context from nova import crypto from nova import db from nova import flags @@ -38,7 +39,6 @@ from nova import test from nova import utils from nova.auth import manager from nova.compute import power_state -from nova.api.ec2 import context from nova.api.ec2 import cloud from nova.objectstore import image @@ -78,7 +78,7 @@ class CloudTestCase(test.TrialTestCase): self.manager = manager.AuthManager() self.user = self.manager.create_user('admin', 'admin', 'admin', True) self.project = self.manager.create_project('proj', 'admin', 'proj') - self.context = context.APIRequestContext(user=self.user, + self.context = context.RequestContext(user=self.user, project=self.project) def tearDown(self): @@ -243,34 +243,34 @@ class CloudTestCase(test.TrialTestCase): self.assertEqual('', img.metadata['description']) def test_update_of_instance_display_fields(self): - inst = db.instance_create({}, {}) + inst = db.instance_create(self.context, {}) ec2_id = cloud.internal_id_to_ec2_id(inst['internal_id']) self.cloud.update_instance(self.context, ec2_id, display_name='c00l 1m4g3') - inst = db.instance_get({}, inst['id']) + inst = db.instance_get(self.context, inst['id']) self.assertEqual('c00l 1m4g3', inst['display_name']) - db.instance_destroy({}, inst['id']) + db.instance_destroy(self.context, inst['id']) def test_update_of_instance_wont_update_private_fields(self): - inst = db.instance_create({}, {}) + inst = db.instance_create(self.context, {}) self.cloud.update_instance(self.context, inst['id'], mac_address='DE:AD:BE:EF') - inst = db.instance_get({}, inst['id']) + inst = db.instance_get(self.context, inst['id']) self.assertEqual(None, inst['mac_address']) - db.instance_destroy({}, inst['id']) + db.instance_destroy(self.context, inst['id']) def test_update_of_volume_display_fields(self): - vol = db.volume_create({}, {}) + vol = db.volume_create(self.context, {}) self.cloud.update_volume(self.context, vol['id'], display_name='c00l v0lum3') - vol = db.volume_get({}, vol['id']) + vol = db.volume_get(self.context, vol['id']) self.assertEqual('c00l v0lum3', vol['display_name']) - db.volume_destroy({}, vol['id']) + db.volume_destroy(self.context, vol['id']) def test_update_of_volume_wont_update_private_fields(self): - vol = db.volume_create({}, {}) + vol = db.volume_create(self.context, {}) self.cloud.update_volume(self.context, vol['id'], mountpoint='/not/here') - vol = db.volume_get({}, vol['id']) + vol = db.volume_get(self.context, vol['id']) self.assertEqual(None, vol['mountpoint']) - db.volume_destroy({}, vol['id']) + db.volume_destroy(self.context, vol['id']) diff --git a/nova/tests/compute_unittest.py b/nova/tests/compute_unittest.py index 5a7f170f3..01e1bcd30 100644 --- a/nova/tests/compute_unittest.py +++ b/nova/tests/compute_unittest.py @@ -24,13 +24,13 @@ import logging from twisted.internet import defer +from nova import context from nova import db from nova import exception from nova import flags from nova import test from nova import utils from nova.auth import manager -from nova.api import context FLAGS = flags.FLAGS @@ -46,7 +46,7 @@ class ComputeTestCase(test.TrialTestCase): self.manager = manager.AuthManager() self.user = self.manager.create_user('fake', 'fake', 'fake') self.project = self.manager.create_project('fake', 'fake', 'fake') - self.context = None + self.context = context.get_admin_context() def tearDown(self): # pylint: disable-msg=C0103 self.manager.delete_user(self.user) @@ -73,13 +73,13 @@ class ComputeTestCase(test.TrialTestCase): yield self.compute.run_instance(self.context, instance_id) - instances = db.instance_get_all(None) + instances = db.instance_get_all(context.get_admin_context()) logging.info("Running instances: %s", instances) self.assertEqual(len(instances), 1) yield self.compute.terminate_instance(self.context, instance_id) - instances = db.instance_get_all(None) + instances = db.instance_get_all(context.get_admin_context()) logging.info("After terminating instances: %s", instances) self.assertEqual(len(instances), 0) @@ -97,8 +97,7 @@ class ComputeTestCase(test.TrialTestCase): self.assertEqual(instance_ref['deleted_at'], None) terminate = datetime.datetime.utcnow() yield self.compute.terminate_instance(self.context, instance_id) - self.context = context.get_admin_context(user=self.user, - read_deleted=True) + self.context = self.context.elevated(True) instance_ref = db.instance_get(self.context, instance_id) self.assert_(instance_ref['launched_at'] < terminate) self.assert_(instance_ref['deleted_at'] > terminate) diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index 3afb4d19e..e8dd2624f 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -22,13 +22,13 @@ import IPy import os import logging +from nova import context from nova import db from nova import exception from nova import flags from nova import test from nova import utils from nova.auth import manager -from nova.api.ec2 import context FLAGS = flags.FLAGS @@ -49,13 +49,13 @@ class NetworkTestCase(test.TrialTestCase): self.user = self.manager.create_user('netuser', 'netuser', 'netuser') self.projects = [] self.network = utils.import_object(FLAGS.network_manager) - self.context = context.APIRequestContext(project=None, user=self.user) + self.context = context.RequestContext(project=None, user=self.user) for i in range(5): name = 'project%s' % i project = self.manager.create_project(name, 'netuser', name) self.projects.append(project) # create the necessary network data for the project - user_context = context.APIRequestContext(project=self.projects[i], + user_context = context.RequestContext(project=self.projects[i], user=self.user) network_ref = self.network.get_network(user_context) self.network.set_network_host(context.get_admin_context(), @@ -69,8 +69,8 @@ class NetworkTestCase(test.TrialTestCase): super(NetworkTestCase, self).tearDown() # TODO(termie): this should really be instantiating clean datastores # in between runs, one failure kills all the tests - db.instance_destroy(None, self.instance_id) - db.instance_destroy(None, self.instance2_id) + db.instance_destroy(context.get_admin_context(), self.instance_id) + db.instance_destroy(context.get_admin_context(), self.instance2_id) for project in self.projects: self.manager.delete_project(project) self.manager.delete_user(self.user) @@ -79,7 +79,8 @@ class NetworkTestCase(test.TrialTestCase): if not mac: mac = utils.generate_mac() project = self.projects[project_num] - self.context.project = project + self.context._project = project + self.context.project_id = project.id return db.instance_create(self.context, {'project_id': project.id, 'mac_address': mac}) @@ -88,35 +89,39 @@ class NetworkTestCase(test.TrialTestCase): """Create an address in given project num""" if instance_id is None: instance_id = self.instance_id - self.context.project = self.projects[project_num] + self.context._project = self.projects[project_num] + self.context.project_id = self.projects[project_num].id return self.network.allocate_fixed_ip(self.context, instance_id) def _deallocate_address(self, project_num, address): - self.context.project = self.projects[project_num] + self.context._project = self.projects[project_num] + self.context.project_id = self.projects[project_num].id self.network.deallocate_fixed_ip(self.context, address) def test_public_network_association(self): """Makes sure that we can allocaate a public ip""" # TODO(vish): better way of adding floating ips - self.context.project = self.projects[0] + self.context._project = self.projects[0] + self.context.project_id = self.projects[0].id pubnet = IPy.IP(flags.FLAGS.floating_range) address = str(pubnet[0]) try: - db.floating_ip_get_by_address(None, address) + db.floating_ip_get_by_address(context.get_admin_context(), address) except exception.NotFound: - db.floating_ip_create(None, {'address': address, - 'host': FLAGS.host}) + db.floating_ip_create(context.get_admin_context(), + {'address': address, + 'host': FLAGS.host}) float_addr = self.network.allocate_floating_ip(self.context, self.projects[0].id) fix_addr = self._create_address(0) lease_ip(fix_addr) self.assertEqual(float_addr, str(pubnet[0])) self.network.associate_floating_ip(self.context, float_addr, fix_addr) - address = db.instance_get_floating_address(None, self.instance_id) + address = db.instance_get_floating_address(context.get_admin_context(), self.instance_id) self.assertEqual(address, float_addr) self.network.disassociate_floating_ip(self.context, float_addr) - address = db.instance_get_floating_address(None, self.instance_id) + address = db.instance_get_floating_address(context.get_admin_context(), self.instance_id) self.assertEqual(address, None) self.network.deallocate_floating_ip(self.context, float_addr) self.network.deallocate_fixed_ip(self.context, fix_addr) @@ -178,7 +183,8 @@ class NetworkTestCase(test.TrialTestCase): lease_ip(address) lease_ip(address2) lease_ip(address3) - self.context.project = self.projects[i] + self.context._project = self.projects[i] + self.context.project_id = self.projects[i].id self.assertFalse(is_allocated_in_project(address, self.projects[0].id)) self.assertFalse(is_allocated_in_project(address2, @@ -192,8 +198,9 @@ class NetworkTestCase(test.TrialTestCase): release_ip(address2) release_ip(address3) for instance_id in instance_ids: - db.instance_destroy(None, instance_id) - self.context.project = self.projects[0] + db.instance_destroy(context.get_admin_context(), instance_id) + self.context._project = self.projects[0] + self.context.project_id = self.projects[0].id self.network.deallocate_fixed_ip(self.context, first) self._deallocate_address(0, first) release_ip(first) @@ -208,16 +215,17 @@ class NetworkTestCase(test.TrialTestCase): def test_too_many_networks(self): """Ensure error is raised if we run out of networks""" projects = [] - networks_left = FLAGS.num_networks - db.network_count(None) + networks_left = (FLAGS.num_networks - + db.network_count(context.get_admin_context())) for i in range(networks_left): project = self.manager.create_project('many%s' % i, self.user) projects.append(project) - db.project_get_network(None, project.id) + db.project_get_network(context.get_admin_context(), project.id) project = self.manager.create_project('last', self.user) projects.append(project) self.assertRaises(db.NoMoreNetworks, db.project_get_network, - None, + context.get_admin_context(), project.id) for project in projects: self.manager.delete_project(project) @@ -246,18 +254,18 @@ class NetworkTestCase(test.TrialTestCase): There are ips reserved at the bottom and top of the range. services (network, gateway, CloudPipe, broadcast) """ - network = db.project_get_network(None, self.projects[0].id) + network = db.project_get_network(context.get_admin_context(), self.projects[0].id) net_size = flags.FLAGS.network_size - total_ips = (db.network_count_available_ips(None, network['id']) + - db.network_count_reserved_ips(None, network['id']) + - db.network_count_allocated_ips(None, network['id'])) + total_ips = (db.network_count_available_ips(context.get_admin_context(), network['id']) + + db.network_count_reserved_ips(context.get_admin_context(), network['id']) + + db.network_count_allocated_ips(context.get_admin_context(), network['id'])) self.assertEqual(total_ips, net_size) def test_too_many_addresses(self): """Test for a NoMoreAddresses exception when all fixed ips are used. """ - network = db.project_get_network(None, self.projects[0].id) - num_available_ips = db.network_count_available_ips(None, + network = db.project_get_network(context.get_admin_context(), self.projects[0].id) + num_available_ips = db.network_count_available_ips(context.get_admin_context(), network['id']) addresses = [] instance_ids = [] @@ -268,7 +276,7 @@ class NetworkTestCase(test.TrialTestCase): addresses.append(address) lease_ip(address) - self.assertEqual(db.network_count_available_ips(None, + self.assertEqual(db.network_count_available_ips(context.get_admin_context(), network['id']), 0) self.assertRaises(db.NoMoreAddresses, self.network.allocate_fixed_ip, @@ -278,17 +286,17 @@ class NetworkTestCase(test.TrialTestCase): for i in range(num_available_ips): self.network.deallocate_fixed_ip(self.context, addresses[i]) release_ip(addresses[i]) - db.instance_destroy(None, instance_ids[i]) - self.assertEqual(db.network_count_available_ips(None, + db.instance_destroy(context.get_admin_context(), instance_ids[i]) + self.assertEqual(db.network_count_available_ips(context.get_admin_context(), network['id']), num_available_ips) def is_allocated_in_project(address, project_id): """Returns true if address is in specified project""" - project_net = db.project_get_network(None, project_id) - network = db.fixed_ip_get_network(None, address) - instance = db.fixed_ip_get_instance(None, address) + project_net = db.project_get_network(context.get_admin_context(), project_id) + network = db.fixed_ip_get_network(context.get_admin_context(), address) + instance = db.fixed_ip_get_instance(context.get_admin_context(), address) # instance exists until release return instance is not None and network['id'] == project_net['id'] @@ -300,8 +308,8 @@ def binpath(script): def lease_ip(private_ip): """Run add command on dhcpbridge""" - network_ref = db.fixed_ip_get_network(None, private_ip) - instance_ref = db.fixed_ip_get_instance(None, private_ip) + network_ref = db.fixed_ip_get_network(context.get_admin_context(), private_ip) + instance_ref = db.fixed_ip_get_instance(context.get_admin_context(), private_ip) cmd = "%s add %s %s fake" % (binpath('nova-dhcpbridge'), instance_ref['mac_address'], private_ip) @@ -314,8 +322,8 @@ def lease_ip(private_ip): def release_ip(private_ip): """Run del command on dhcpbridge""" - network_ref = db.fixed_ip_get_network(None, private_ip) - instance_ref = db.fixed_ip_get_instance(None, private_ip) + network_ref = db.fixed_ip_get_network(context.get_admin_context(), private_ip) + instance_ref = db.fixed_ip_get_instance(context.get_admin_context(), private_ip) cmd = "%s del %s %s fake" % (binpath('nova-dhcpbridge'), instance_ref['mac_address'], private_ip) diff --git a/nova/tests/objectstore_unittest.py b/nova/tests/objectstore_unittest.py index 872f1ab23..f096ac6fe 100644 --- a/nova/tests/objectstore_unittest.py +++ b/nova/tests/objectstore_unittest.py @@ -32,6 +32,7 @@ from boto.s3.connection import S3Connection, OrdinaryCallingFormat from twisted.internet import reactor, threads, defer from twisted.web import http, server +from nova import context from nova import flags from nova import objectstore from nova import test @@ -70,13 +71,7 @@ class ObjectStoreTestCase(test.TrialTestCase): self.auth_manager.create_user('admin_user', admin=True) self.auth_manager.create_project('proj1', 'user1', 'a proj', ['user1']) self.auth_manager.create_project('proj2', 'user2', 'a proj', ['user2']) - - class Context(object): - """Dummy context for running tests.""" - user = None - project = None - - self.context = Context() + self.context = context.RequestContext('user1', 'proj1') def tearDown(self): # pylint: disable-msg=C0103 """Tear down users and projects.""" @@ -89,8 +84,6 @@ class ObjectStoreTestCase(test.TrialTestCase): def test_buckets(self): """Test the bucket API.""" - self.context.user = self.auth_manager.get_user('user1') - self.context.project = self.auth_manager.get_project('proj1') objectstore.bucket.Bucket.create('new_bucket', self.context) bucket = objectstore.bucket.Bucket('new_bucket') @@ -98,14 +91,12 @@ class ObjectStoreTestCase(test.TrialTestCase): self.assert_(bucket.is_authorized(self.context)) # another user is not authorized - self.context.user = self.auth_manager.get_user('user2') - self.context.project = self.auth_manager.get_project('proj2') - self.assertFalse(bucket.is_authorized(self.context)) + context2 = context.RequestContext('user2', 'proj2') + self.assertFalse(bucket.is_authorized(context2)) # admin is authorized to use bucket - self.context.user = self.auth_manager.get_user('admin_user') - self.context.project = None - self.assertTrue(bucket.is_authorized(self.context)) + admin_context = context.RequestContext('admin_user', None) + self.assertTrue(bucket.is_authorized(admin_context)) # new buckets are empty self.assertTrue(bucket.list_keys()['Contents'] == []) @@ -143,8 +134,6 @@ class ObjectStoreTestCase(test.TrialTestCase): def do_test_images(self, manifest_file, expect_kernel_and_ramdisk, image_bucket, image_name): "Test the image API." - self.context.user = self.auth_manager.get_user('user1') - self.context.project = self.auth_manager.get_project('proj1') # create a bucket for our bundle objectstore.bucket.Bucket.create(image_bucket, self.context) @@ -179,9 +168,8 @@ class ObjectStoreTestCase(test.TrialTestCase): self.assertFalse('ramdiskId' in my_img.metadata) # verify image permissions - self.context.user = self.auth_manager.get_user('user2') - self.context.project = self.auth_manager.get_project('proj2') - self.assertFalse(my_img.is_authorized(self.context)) + context2 = context.RequestContext('user2', 'proj2') + self.assertFalse(my_img.is_authorized(context2)) # change user-editable fields my_img.update_user_editable_fields({'display_name': 'my cool image'}) diff --git a/nova/tests/quota_unittest.py b/nova/tests/quota_unittest.py index 370ccd506..72e44bf52 100644 --- a/nova/tests/quota_unittest.py +++ b/nova/tests/quota_unittest.py @@ -18,6 +18,7 @@ import logging +from nova import context from nova import db from nova import exception from nova import flags @@ -26,7 +27,6 @@ from nova import test from nova import utils from nova.auth import manager from nova.api.ec2 import cloud -from nova.api.ec2 import context FLAGS = flags.FLAGS @@ -48,8 +48,8 @@ class QuotaTestCase(test.TrialTestCase): self.user = self.manager.create_user('admin', 'admin', 'admin', True) self.project = self.manager.create_project('admin', 'admin', 'admin') self.network = utils.import_object(FLAGS.network_manager) - self.context = context.APIRequestContext(project=self.project, - user=self.user) + self.context = context.RequestContext(project=self.project, + user=self.user) def tearDown(self): # pylint: disable-msg=C0103 manager.AuthManager().delete_project(self.project) @@ -94,7 +94,7 @@ class QuotaTestCase(test.TrialTestCase): for i in range(FLAGS.quota_instances): instance_id = self._create_instance() instance_ids.append(instance_id) - self.assertRaises(cloud.QuotaError, self.cloud.run_instances, + self.assertRaises(cloud.QuotaError, self.cloud.run_instances, self.context, min_count=1, max_count=1, @@ -106,7 +106,7 @@ class QuotaTestCase(test.TrialTestCase): instance_ids = [] instance_id = self._create_instance(cores=4) instance_ids.append(instance_id) - self.assertRaises(cloud.QuotaError, self.cloud.run_instances, + self.assertRaises(cloud.QuotaError, self.cloud.run_instances, self.context, min_count=1, max_count=1, @@ -139,9 +139,9 @@ class QuotaTestCase(test.TrialTestCase): def test_too_many_addresses(self): address = '192.168.0.100' try: - db.floating_ip_get_by_address(None, address) + db.floating_ip_get_by_address(context.get_admin_context(), address) except exception.NotFound: - db.floating_ip_create(None, {'address': address, + db.floating_ip_create(context.get_admin_context(), {'address': address, 'host': FLAGS.host}) float_addr = self.network.allocate_floating_ip(self.context, self.project.id) diff --git a/nova/tests/rpc_unittest.py b/nova/tests/rpc_unittest.py index 9652841f2..5d2bb1046 100644 --- a/nova/tests/rpc_unittest.py +++ b/nova/tests/rpc_unittest.py @@ -22,6 +22,7 @@ import logging from twisted.internet import defer +from nova import context from nova import flags from nova import rpc from nova import test @@ -40,14 +41,24 @@ class RpcTestCase(test.TrialTestCase): topic='test', proxy=self.receiver) self.consumer.attach_to_twisted() + self.context= context.get_admin_context() def test_call_succeed(self): """Get a value through rpc call""" value = 42 - result = yield rpc.call_twisted('test', {"method": "echo", + result = yield rpc.call_twisted(self.context, + 'test', {"method": "echo", "args": {"value": value}}) self.assertEqual(value, result) + def test_context_passed(self): + """Makes sure a context is passed through rpc call""" + value = 42 + result = yield rpc.call_twisted(self.context, + 'test', {"method": "context", + "args": {"value": value}}) + self.assertEqual(self.context.to_dict(), result) + def test_call_exception(self): """Test that exception gets passed back properly @@ -56,11 +67,13 @@ class RpcTestCase(test.TrialTestCase): to an int in the test. """ value = 42 - self.assertFailure(rpc.call_twisted('test', {"method": "fail", + self.assertFailure(rpc.call_twisted(self.context, + 'test', {"method": "fail", "args": {"value": value}}), rpc.RemoteError) try: - yield rpc.call_twisted('test', {"method": "fail", + yield rpc.call_twisted(self.context, + 'test', {"method": "fail", "args": {"value": value}}) self.fail("should have thrown rpc.RemoteError") except rpc.RemoteError as exc: @@ -73,12 +86,19 @@ class TestReceiver(object): Uses static methods because we aren't actually storing any state""" @staticmethod - def echo(value): + def echo(context, value): """Simply returns whatever value is sent in""" logging.debug("Received %s", value) return defer.succeed(value) @staticmethod - def fail(value): + def context(context, value): + """Returns dictionary version of context""" + logging.debug("Received %s", context) + return defer.succeed(context.to_dict()) + + @staticmethod + def fail(context, value): """Raises an exception with the value sent in""" raise Exception(value) + diff --git a/nova/tests/scheduler_unittest.py b/nova/tests/scheduler_unittest.py index 80100fc2f..379f8cdc8 100644 --- a/nova/tests/scheduler_unittest.py +++ b/nova/tests/scheduler_unittest.py @@ -19,6 +19,7 @@ Tests For Scheduler """ +from nova import context from nova import db from nova import flags from nova import service @@ -50,22 +51,24 @@ class SchedulerTestCase(test.TrialTestCase): def test_fallback(self): scheduler = manager.SchedulerManager() self.mox.StubOutWithMock(rpc, 'cast', use_mock_anything=True) - rpc.cast('topic.fallback_host', + ctxt = context.get_admin_context() + rpc.cast(ctxt, + 'topic.fallback_host', {'method': 'noexist', - 'args': {'context': None, - 'num': 7}}) + 'args': {'num': 7}}) self.mox.ReplayAll() - scheduler.noexist(None, 'topic', num=7) + scheduler.noexist(ctxt, 'topic', num=7) def test_named_method(self): scheduler = manager.SchedulerManager() self.mox.StubOutWithMock(rpc, 'cast', use_mock_anything=True) - rpc.cast('topic.named_host', + ctxt = context.get_admin_context() + rpc.cast(ctxt, + 'topic.named_host', {'method': 'named_method', - 'args': {'context': None, - 'num': 7}}) + 'args': {'num': 7}}) self.mox.ReplayAll() - scheduler.named_method(None, 'topic', num=7) + scheduler.named_method(ctxt, 'topic', num=7) class SimpleDriverTestCase(test.TrialTestCase): @@ -79,11 +82,10 @@ class SimpleDriverTestCase(test.TrialTestCase): volume_driver='nova.volume.driver.FakeAOEDriver', scheduler_driver='nova.scheduler.simple.SimpleScheduler') self.scheduler = manager.SchedulerManager() - self.context = None self.manager = auth_manager.AuthManager() self.user = self.manager.create_user('fake', 'fake', 'fake') self.project = self.manager.create_project('fake', 'fake', 'fake') - self.context = None + self.context = context.get_admin_context() def tearDown(self): # pylint: disable-msg=C0103 self.manager.delete_user(self.user) diff --git a/nova/tests/service_unittest.py b/nova/tests/service_unittest.py index 6afeec377..61db52742 100644 --- a/nova/tests/service_unittest.py +++ b/nova/tests/service_unittest.py @@ -24,6 +24,7 @@ import mox from twisted.application.app import startApplication +from nova import context from nova import exception from nova import flags from nova import rpc @@ -47,6 +48,7 @@ class ServiceTestCase(test.BaseTestCase): def setUp(self): # pylint: disable=C0103 super(ServiceTestCase, self).setUp() self.mox.StubOutWithMock(service, 'db') + self.context = context.get_admin_context() def test_create(self): host = 'foo' @@ -90,10 +92,10 @@ class ServiceTestCase(test.BaseTestCase): 'report_count': 0, 'id': 1} - service.db.service_get_by_args(None, + service.db.service_get_by_args(mox.IgnoreArg(), host, binary).AndRaise(exception.NotFound()) - service.db.service_create(None, + service.db.service_create(mox.IgnoreArg(), service_create).AndReturn(service_ref) self.mox.ReplayAll() @@ -113,10 +115,10 @@ class ServiceTestCase(test.BaseTestCase): 'report_count': 0, 'id': 1} service.db.__getattr__('report_state') - service.db.service_get_by_args(None, + service.db.service_get_by_args(self.context, host, binary).AndReturn(service_ref) - service.db.service_update(None, service_ref['id'], + service.db.service_update(self.context, service_ref['id'], mox.ContainsKeyValue('report_count', 1)) self.mox.ReplayAll() @@ -135,13 +137,13 @@ class ServiceTestCase(test.BaseTestCase): 'id': 1} service.db.__getattr__('report_state') - service.db.service_get_by_args(None, + service.db.service_get_by_args(self.context, host, binary).AndRaise(exception.NotFound()) - service.db.service_create(None, + service.db.service_create(self.context, service_create).AndReturn(service_ref) - service.db.service_get(None, service_ref['id']).AndReturn(service_ref) - service.db.service_update(None, service_ref['id'], + service.db.service_get(self.context, service_ref['id']).AndReturn(service_ref) + service.db.service_update(self.context, service_ref['id'], mox.ContainsKeyValue('report_count', 1)) self.mox.ReplayAll() @@ -157,7 +159,7 @@ class ServiceTestCase(test.BaseTestCase): 'id': 1} service.db.__getattr__('report_state') - service.db.service_get_by_args(None, + service.db.service_get_by_args(self.context, host, binary).AndRaise(Exception()) @@ -176,10 +178,10 @@ class ServiceTestCase(test.BaseTestCase): 'id': 1} service.db.__getattr__('report_state') - service.db.service_get_by_args(None, + service.db.service_get_by_args(self.context, host, binary).AndReturn(service_ref) - service.db.service_update(None, service_ref['id'], + service.db.service_update(self.context, service_ref['id'], mox.ContainsKeyValue('report_count', 1)) self.mox.ReplayAll() diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py index edcdba425..76af5cabd 100644 --- a/nova/tests/virt_unittest.py +++ b/nova/tests/virt_unittest.py @@ -17,11 +17,11 @@ from xml.etree.ElementTree import fromstring as xml_to_tree from xml.dom.minidom import parseString as xml_to_dom +from nova import context from nova import db from nova import flags from nova import test from nova import utils -from nova.api import context from nova.api.ec2 import cloud from nova.auth import manager from nova.virt import libvirt_conn @@ -51,9 +51,9 @@ class LibvirtConnTestCase(test.TrialTestCase): 'bridge' : 'br101', 'instance_type' : 'm1.small'} - instance_ref = db.instance_create(None, instance) - user_context = context.APIRequestContext(project=self.project, - user=self.user) + user_context = context.RequestContext(project=self.project, + user=self.user) + instance_ref = db.instance_create(user_context, instance) network_ref = self.network.get_network(user_context) self.network.set_network_host(context.get_admin_context(), network_ref['id']) @@ -61,9 +61,10 @@ class LibvirtConnTestCase(test.TrialTestCase): fixed_ip = { 'address' : ip, 'network_id' : network_ref['id'] } - fixed_ip_ref = db.fixed_ip_create(None, fixed_ip) - db.fixed_ip_update(None, ip, { 'allocated' : True, - 'instance_id' : instance_ref['id'] }) + ctxt = context.get_admin_context() + fixed_ip_ref = db.fixed_ip_create(ctxt, fixed_ip) + db.fixed_ip_update(ctxt, ip, {'allocated': True, + 'instance_id': instance_ref['id'] }) type_uri_map = { 'qemu' : ('qemu:///system', [(lambda t: t.find('.').get('type'), 'qemu'), @@ -132,7 +133,7 @@ class NWFilterTestCase(test.TrialTestCase): self.manager = manager.AuthManager() self.user = self.manager.create_user('fake', 'fake', 'fake', admin=True) self.project = self.manager.create_project('fake', 'fake', 'fake') - self.context = context.APIRequestContext(self.user, self.project) + self.context = context.RequestContext(self.user, self.project) self.fake_libvirt_connection = Mock() diff --git a/nova/tests/volume_unittest.py b/nova/tests/volume_unittest.py index 1d665b502..8e2fa11c1 100644 --- a/nova/tests/volume_unittest.py +++ b/nova/tests/volume_unittest.py @@ -22,6 +22,7 @@ import logging from twisted.internet import defer +from nova import context from nova import exception from nova import db from nova import flags @@ -39,7 +40,7 @@ class VolumeTestCase(test.TrialTestCase): self.compute = utils.import_object(FLAGS.compute_manager) self.flags(connection_type='fake') self.volume = utils.import_object(FLAGS.volume_manager) - self.context = None + self.context = context.get_admin_context() @staticmethod def _create_volume(size='0'): @@ -51,19 +52,19 @@ class VolumeTestCase(test.TrialTestCase): vol['availability_zone'] = FLAGS.storage_availability_zone vol['status'] = "creating" vol['attach_status'] = "detached" - return db.volume_create(None, vol)['id'] + return db.volume_create(context.get_admin_context(), vol)['id'] @defer.inlineCallbacks def test_create_delete_volume(self): """Test volume can be created and deleted""" volume_id = self._create_volume() yield self.volume.create_volume(self.context, volume_id) - self.assertEqual(volume_id, db.volume_get(None, volume_id).id) + self.assertEqual(volume_id, db.volume_get(context.get_admin_context(), volume_id).id) yield self.volume.delete_volume(self.context, volume_id) self.assertRaises(exception.NotFound, db.volume_get, - None, + self.context, volume_id) @defer.inlineCallbacks @@ -92,7 +93,7 @@ class VolumeTestCase(test.TrialTestCase): self.assertFailure(self.volume.create_volume(self.context, volume_id), db.NoMoreBlades) - db.volume_destroy(None, volume_id) + db.volume_destroy(context.get_admin_context(), volume_id) for volume_id in vols: yield self.volume.delete_volume(self.context, volume_id) @@ -113,12 +114,13 @@ class VolumeTestCase(test.TrialTestCase): volume_id = self._create_volume() yield self.volume.create_volume(self.context, volume_id) if FLAGS.fake_tests: - db.volume_attached(None, volume_id, instance_id, mountpoint) + db.volume_attached(self.context, volume_id, instance_id, mountpoint) else: - yield self.compute.attach_volume(instance_id, + yield self.compute.attach_volume(self.context, + instance_id, volume_id, mountpoint) - vol = db.volume_get(None, volume_id) + vol = db.volume_get(context.get_admin_context(), volume_id) self.assertEqual(vol['status'], "in-use") self.assertEqual(vol['attach_status'], "attached") self.assertEqual(vol['mountpoint'], mountpoint) @@ -128,17 +130,18 @@ class VolumeTestCase(test.TrialTestCase): self.assertFailure(self.volume.delete_volume(self.context, volume_id), exception.Error) if FLAGS.fake_tests: - db.volume_detached(None, volume_id) + db.volume_detached(self.context, volume_id) else: - yield self.compute.detach_volume(instance_id, + yield self.compute.detach_volume(self.context, + instance_id, volume_id) - vol = db.volume_get(None, volume_id) + vol = db.volume_get(self.context, volume_id) self.assertEqual(vol['status'], "available") yield self.volume.delete_volume(self.context, volume_id) self.assertRaises(exception.Error, db.volume_get, - None, + self.context, volume_id) db.instance_destroy(self.context, instance_id) @@ -151,7 +154,7 @@ class VolumeTestCase(test.TrialTestCase): def _check(volume_id): """Make sure blades aren't duplicated""" volume_ids.append(volume_id) - (shelf_id, blade_id) = db.volume_get_shelf_and_blade(None, + (shelf_id, blade_id) = db.volume_get_shelf_and_blade(context.get_admin_context(), volume_id) shelf_blade = '%s.%s' % (shelf_id, blade_id) self.assert_(shelf_blade not in shelf_blades) diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index ce97ef1eb..d8d36ff65 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -30,6 +30,7 @@ from twisted.internet import defer from twisted.internet import task from twisted.internet import threads +from nova import context from nova import db from nova import exception from nova import flags @@ -152,12 +153,13 @@ class LibvirtConnection(object): def _wait_for_shutdown(): try: state = self.get_info(instance['name'])['state'] - db.instance_set_state(None, instance['id'], state) + db.instance_set_state(context.get_admin_context(), + instance['id'], state) if state == power_state.SHUTDOWN: timer.stop() d.callback(None) except Exception: - db.instance_set_state(None, + db.instance_set_state(context.get_admin_context(), instance['id'], power_state.SHUTDOWN) timer.stop() @@ -202,14 +204,15 @@ class LibvirtConnection(object): def _wait_for_reboot(): try: state = self.get_info(instance['name'])['state'] - db.instance_set_state(None, instance['id'], state) + db.instance_set_state(context.get_admin_context(), + instance['id'], state) if state == power_state.RUNNING: logging.debug('instance %s: rebooted', instance['name']) timer.stop() d.callback(None) except Exception, exn: logging.error('_wait_for_reboot failed: %s', exn) - db.instance_set_state(None, + db.instance_set_state(context.get_admin_context(), instance['id'], power_state.SHUTDOWN) timer.stop() @@ -222,7 +225,7 @@ class LibvirtConnection(object): @exception.wrap_exception def spawn(self, instance): xml = self.to_xml(instance) - db.instance_set_state(None, + db.instance_set_state(context.get_admin_context(), instance['id'], power_state.NOSTATE, 'launching') @@ -238,7 +241,8 @@ class LibvirtConnection(object): def _wait_for_boot(): try: state = self.get_info(instance['name'])['state'] - db.instance_set_state(None, instance['id'], state) + db.instance_set_state(context.get_admin_context(), + instance['id'], state) if state == power_state.RUNNING: logging.debug('instance %s: booted', instance['name']) timer.stop() @@ -246,7 +250,7 @@ class LibvirtConnection(object): except: logging.exception('instance %s: failed to boot', instance['name']) - db.instance_set_state(None, + db.instance_set_state(context.get_admin_context(), instance['id'], power_state.SHUTDOWN) timer.stop() @@ -272,7 +276,7 @@ class LibvirtConnection(object): fp = open(fpath, 'a+') fp.write(data) return fpath - + def _dump_file(self, fpath): fp = open(fpath, 'r+') contents = fp.read() @@ -333,9 +337,11 @@ class LibvirtConnection(object): key = str(inst['key_data']) net = None - network_ref = db.network_get_by_instance(None, inst['id']) + network_ref = db.network_get_by_instance(context.get_admin_context(), + inst['id']) if network_ref['injected']: - address = db.instance_get_fixed_address(None, inst['id']) + address = db.instance_get_fixed_address(context.get_admin_context(), + inst['id']) with open(FLAGS.injected_network_template) as f: net = f.read() % {'address': address, 'netmask': network_ref['netmask'], @@ -366,11 +372,12 @@ class LibvirtConnection(object): def to_xml(self, instance): # TODO(termie): cache? logging.debug('instance %s: starting toXML method', instance['name']) - network = db.project_get_network(None, + network = db.project_get_network(context.get_admin_context(), instance['project_id']) # FIXME(vish): stick this in db instance_type = instance_types.INSTANCE_TYPES[instance['instance_type']] - ip_address = db.instance_get_fixed_address({}, instance['id']) + ip_address = db.instance_get_fixed_address(context.get_admin_context(), + instance['id']) # Assume that the gateway also acts as the dhcp server. dhcp_server = network['gateway'] xml_info = {'type': FLAGS.libvirt_type, @@ -642,7 +649,8 @@ class NWFilterFirewall(object): ) % instance['name'] if FLAGS.allow_project_net_traffic: - network_ref = db.project_get_network({}, instance['project_id']) + network_ref = db.project_get_network(context.get_admin_context(), + instance['project_id']) net, mask = self._get_net_and_mask(network_ref['cidr']) project_filter = self.nova_project_filter(instance['project_id'], net, mask) @@ -667,7 +675,8 @@ class NWFilterFirewall(object): def security_group_to_nwfilter_xml(self, security_group_id): - security_group = db.security_group_get({}, security_group_id) + security_group = db.security_group_get(context.get_admin_context(), + security_group_id) rule_xml = "" for rule in security_group.rules: rule_xml += "<rule action='accept' direction='in' priority='300'>" diff --git a/nova/volume/driver.py b/nova/volume/driver.py index 4604b85d5..cca619550 100644 --- a/nova/volume/driver.py +++ b/nova/volume/driver.py @@ -24,6 +24,7 @@ import logging from twisted.internet import defer +from nova import exception from nova import flags from nova import process @@ -33,6 +34,8 @@ flags.DEFINE_string('volume_group', 'nova-volumes', 'Name for the VG that will contain exported volumes') flags.DEFINE_string('aoe_eth_dev', 'eth0', 'Which device to export the volumes on') +flags.DEFINE_string('num_shell_tries', 3, + 'number of times to attempt to run flakey shell commands') class AOEDriver(object): @@ -41,6 +44,25 @@ class AOEDriver(object): self._execute = execute @defer.inlineCallbacks + def _try_execute(self, command): + # NOTE(vish): Volume commands can partially fail due to timing, but + # running them a second time on failure will usually + # recover nicely. + tries = 0 + while True: + try: + yield self._execute(command) + defer.returnValue(True) + except exception.ProcessExecutionError: + tries = tries + 1 + if tries >= FLAGS.num_shell_tries: + raise + logging.exception("Recovering from a failed execute." + "Try number %s", tries) + yield self._execute("sleep %s" % tries ** 2) + + + @defer.inlineCallbacks def create_volume(self, volume_name, size): """Creates a logical volume""" # NOTE(vish): makes sure that the volume group exists @@ -49,22 +71,22 @@ class AOEDriver(object): sizestr = '100M' else: sizestr = '%sG' % size - yield self._execute( - "sudo lvcreate -L %s -n %s %s" % (sizestr, - volume_name, - FLAGS.volume_group)) + yield self._try_execute("sudo lvcreate -L %s -n %s %s" % + (sizestr, + volume_name, + FLAGS.volume_group)) @defer.inlineCallbacks def delete_volume(self, volume_name): """Deletes a logical volume""" - yield self._execute( - "sudo lvremove -f %s/%s" % (FLAGS.volume_group, - volume_name)) + yield self._try_execute("sudo lvremove -f %s/%s" % + (FLAGS.volume_group, + volume_name)) @defer.inlineCallbacks def create_export(self, volume_name, shelf_id, blade_id): """Creates an export for a logical volume""" - yield self._execute( + yield self._try_execute( "sudo vblade-persist setup %s %s %s /dev/%s/%s" % (shelf_id, blade_id, @@ -81,16 +103,22 @@ class AOEDriver(object): @defer.inlineCallbacks def remove_export(self, _volume_name, shelf_id, blade_id): """Removes an export for a logical volume""" - yield self._execute( - "sudo vblade-persist stop %s %s" % (shelf_id, blade_id)) - yield self._execute( - "sudo vblade-persist destroy %s %s" % (shelf_id, blade_id)) + yield self._try_execute("sudo vblade-persist stop %s %s" % + (shelf_id, blade_id)) + yield self._try_execute("sudo vblade-persist destroy %s %s" % + (shelf_id, blade_id)) @defer.inlineCallbacks def ensure_exports(self): """Runs all existing exports""" - # NOTE(ja): wait for blades to appear - yield self._execute("sleep 5") + # NOTE(vish): The standard _try_execute does not work here + # because these methods throw errors if other + # volumes on this host are in the process of + # being created. The good news is the command + # still works for the other volumes, so we + # just wait a bit for the current volume to + # be ready and ignore any errors. + yield self._execute("sleep 2") yield self._execute("sudo vblade-persist auto all", check_exit_code=False) yield self._execute("sudo vblade-persist start all", diff --git a/nova/volume/manager.py b/nova/volume/manager.py index 8508f27b2..2874459f9 100644 --- a/nova/volume/manager.py +++ b/nova/volume/manager.py @@ -62,11 +62,12 @@ class AOEManager(manager.Manager): for shelf_id in xrange(FLAGS.num_shelves): for blade_id in xrange(FLAGS.blades_per_shelf): dev = {'shelf_id': shelf_id, 'blade_id': blade_id} - self.db.export_device_create(context, dev) + self.db.export_device_create_safe(context, dev) @defer.inlineCallbacks def create_volume(self, context, volume_id): """Creates and exports the volume""" + context = context.elevated() logging.info("volume %s: creating", volume_id) volume_ref = self.db.volume_get(context, volume_id) @@ -95,20 +96,22 @@ class AOEManager(manager.Manager): yield self.driver.ensure_exports() now = datetime.datetime.utcnow() - self.db.volume_update(context, volume_id, {'status': 'available', - 'launched_at': now}) + self.db.volume_update(context, + volume_ref['id'], {'status': 'available', + 'launched_at': now}) logging.debug("volume %s: created successfully", volume_id) defer.returnValue(volume_id) @defer.inlineCallbacks def delete_volume(self, context, volume_id): """Deletes and unexports volume""" - logging.debug("Deleting volume with id of: %s", volume_id) + context = context.elevated() volume_ref = self.db.volume_get(context, volume_id) if volume_ref['attach_status'] == "attached": raise exception.Error("Volume is still attached") if volume_ref['host'] != self.host: raise exception.Error("Volume is not local to this node") + logging.debug("Deleting volume with id of: %s", volume_id) shelf_id, blade_id = self.db.volume_get_shelf_and_blade(context, volume_id) yield self.driver.remove_export(volume_ref['ec2_id'], @@ -124,6 +127,7 @@ class AOEManager(manager.Manager): Returns path to device. """ + context = context.elevated() volume_ref = self.db.volume_get(context, volume_id) yield self.driver.discover_volume(volume_ref['ec2_id']) shelf_id, blade_id = self.db.volume_get_shelf_and_blade(context, diff --git a/run_tests.py b/run_tests.py index 0b27ec6cf..b1a3f1d66 100644 --- a/run_tests.py +++ b/run_tests.py @@ -45,7 +45,6 @@ import sys from twisted.scripts import trial as trial_script -from nova import datastore from nova import flags from nova import twistd @@ -86,12 +85,6 @@ if __name__ == '__main__': # TODO(termie): these should make a call instead of doing work on import if FLAGS.fake_tests: from nova.tests.fake_flags import * - # use db 8 for fake tests - FLAGS.redis_db = 8 - if FLAGS.flush_db: - logging.info("Flushing redis datastore") - r = datastore.Redis.instance() - r.flushdb() else: from nova.tests.real_flags import * diff --git a/run_tests.sh b/run_tests.sh index ec727d094..a11dcd7cc 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -55,7 +55,7 @@ else else echo -e "No virtual environment found...create one? (Y/n) \c" read use_ve - if [ "x$use_ve" = "xY" ]; then + if [ "x$use_ve" = "xY" -o "x$use_ve" = "x" -o "x$use_ve" = "xy" ]; then # Install the virtualenv and run the test suite in it python tools/install_venv.py else |
