diff options
| author | Ryan Lane <rlane@wikimedia.org> | 2011-03-03 23:04:11 +0000 |
|---|---|---|
| committer | Ryan Lane <rlane@wikimedia.org> | 2011-03-03 23:04:11 +0000 |
| commit | df3a65793ec7bb9d85d2a3da47fbbfb9e97d03d4 (patch) | |
| tree | aab853348a2df9f5cdb26a77dd836d4f2083f119 /nova | |
| parent | 4c50ddee48971c76f0f6252295747b89de5d3697 (diff) | |
| parent | 7ca1669603132e3afd14606dda3f95ccbce08a41 (diff) | |
| download | nova-df3a65793ec7bb9d85d2a3da47fbbfb9e97d03d4.tar.gz nova-df3a65793ec7bb9d85d2a3da47fbbfb9e97d03d4.tar.xz nova-df3a65793ec7bb9d85d2a3da47fbbfb9e97d03d4.zip | |
Merge from trunk
Diffstat (limited to 'nova')
122 files changed, 8194 insertions, 1843 deletions
diff --git a/nova/__init__.py b/nova/__init__.py index 8745617bc..256db55a9 100644 --- a/nova/__init__.py +++ b/nova/__init__.py @@ -30,5 +30,3 @@ .. moduleauthor:: Manish Singh <yosh@gimp.org> .. moduleauthor:: Andy Smith <andy@anarkystic.com> """ - -from exception import * diff --git a/nova/adminclient.py b/nova/adminclient.py index c614b274c..fc3c5c5fe 100644 --- a/nova/adminclient.py +++ b/nova/adminclient.py @@ -23,6 +23,8 @@ import base64 import boto import boto.exception import httplib +import re +import string from boto.ec2.regioninfo import RegionInfo @@ -165,19 +167,20 @@ class HostInfo(object): **Fields Include** - * Disk stats - * Running Instances - * Memory stats - * CPU stats - * Network address info - * Firewall info - * Bridge and devices - + * Hostname + * Compute service status + * Volume service status + * Instance count + * Volume count """ def __init__(self, connection=None): self.connection = connection self.hostname = None + self.compute = None + self.volume = None + self.instance_count = 0 + self.volume_count = 0 def __repr__(self): return 'Host:%s' % self.hostname @@ -188,7 +191,39 @@ class HostInfo(object): # this is needed by the sax parser, so ignore the ugly name def endElement(self, name, value, connection): - setattr(self, name, value) + fixed_name = string.lower(re.sub(r'([A-Z])', r'_\1', name)) + setattr(self, fixed_name, value) + + +class Vpn(object): + """ + Information about a Vpn, as parsed through SAX + + **Fields Include** + + * instance_id + * project_id + * public_ip + * public_port + * created_at + * internal_ip + * state + """ + + def __init__(self, connection=None): + self.connection = connection + self.instance_id = None + self.project_id = None + + def __repr__(self): + return 'Vpn:%s:%s' % (self.project_id, self.instance_id) + + def startElement(self, name, attrs, connection): + return None + + def endElement(self, name, value, connection): + fixed_name = string.lower(re.sub(r'([A-Z])', r'_\1', name)) + setattr(self, fixed_name, value) class InstanceType(object): @@ -422,6 +457,16 @@ class NovaAdminClient(object): zip = self.apiconn.get_object('GenerateX509ForUser', params, UserInfo) return zip.file + def start_vpn(self, project): + """ + Starts the vpn for a user + """ + return self.apiconn.get_object('StartVpn', {'Project': project}, Vpn) + + def get_vpns(self): + """Return a list of vpn with project name""" + return self.apiconn.get_list('DescribeVpns', {}, [('item', Vpn)]) + def get_hosts(self): return self.apiconn.get_list('DescribeHosts', {}, [('item', HostInfo)]) diff --git a/nova/api/direct.py b/nova/api/direct.py index 208b6d086..dfca250e0 100644 --- a/nova/api/direct.py +++ b/nova/api/direct.py @@ -187,7 +187,7 @@ class ServiceWrapper(wsgi.Controller): def __init__(self, service_handle): self.service_handle = service_handle - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=wsgi.Request) def __call__(self, req): arg_dict = req.environ['wsgiorg.routing_args'][1] action = arg_dict['action'] @@ -206,7 +206,7 @@ class ServiceWrapper(wsgi.Controller): params = dict([(str(k), v) for (k, v) in params.iteritems()]) result = method(context, **params) if type(result) is dict or type(result) is list: - return self._serialize(result, req) + return self._serialize(result, req.best_match_content_type()) else: return result @@ -218,7 +218,7 @@ class Proxy(object): self.prefix = prefix def __do_request(self, path, context, **kwargs): - req = webob.Request.blank(path) + req = wsgi.Request.blank(path) req.method = 'POST' req.body = urllib.urlencode({'json': utils.dumps(kwargs)}) req.environ['openstack.context'] = context diff --git a/nova/api/ec2/__init__.py b/nova/api/ec2/__init__.py index ddcdc673c..fccebca5d 100644 --- a/nova/api/ec2/__init__.py +++ b/nova/api/ec2/__init__.py @@ -20,8 +20,6 @@ Starting point for routing EC2 requests. """ -import datetime -import routes import webob import webob.dec import webob.exc @@ -55,25 +53,22 @@ flags.DEFINE_list('lockout_memcached_servers', None, class RequestLogging(wsgi.Middleware): """Access-Log akin logging for all EC2 API requests.""" - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=wsgi.Request) def __call__(self, req): + start = utils.utcnow() rv = req.get_response(self.application) - self.log_request_completion(rv, req) + self.log_request_completion(rv, req, start) return rv - def log_request_completion(self, response, request): + def log_request_completion(self, response, request, start): controller = request.environ.get('ec2.controller', None) if controller: controller = controller.__class__.__name__ action = request.environ.get('ec2.action', None) ctxt = request.environ.get('ec2.context', None) - seconds = 'X' - microseconds = 'X' - if ctxt: - delta = datetime.datetime.utcnow() - \ - ctxt.timestamp - seconds = delta.seconds - microseconds = delta.microseconds + delta = utils.utcnow() - start + seconds = delta.seconds + microseconds = delta.microseconds LOG.info( "%s.%ss %s %s %s %s:%s %s [%s] %s %s", seconds, @@ -117,7 +112,7 @@ class Lockout(wsgi.Middleware): debug=0) super(Lockout, self).__init__(application) - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=wsgi.Request) def __call__(self, req): access_key = str(req.params['AWSAccessKeyId']) failures_key = "authfailures-%s" % access_key @@ -146,7 +141,7 @@ class Authenticate(wsgi.Middleware): """Authenticate an EC2 request and add 'ec2.context' to WSGI environ.""" - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=wsgi.Request) def __call__(self, req): # Read request signature and access id. try: @@ -195,7 +190,7 @@ class Requestify(wsgi.Middleware): super(Requestify, self).__init__(app) self.controller = utils.import_class(controller)() - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=wsgi.Request) def __call__(self, req): non_args = ['Action', 'Signature', 'AWSAccessKeyId', 'SignatureMethod', 'SignatureVersion', 'Version', 'Timestamp'] @@ -203,6 +198,12 @@ class Requestify(wsgi.Middleware): try: # Raise KeyError if omitted action = req.params['Action'] + # Fix bug lp:720157 for older (version 1) clients + version = req.params['SignatureVersion'] + if int(version) == 1: + non_args.remove('SignatureMethod') + if 'SignatureMethod' in args: + args.pop('SignatureMethod') for non_arg in non_args: # Remove, but raise KeyError if omitted args.pop(non_arg) @@ -233,7 +234,7 @@ class Authorizer(wsgi.Middleware): super(Authorizer, self).__init__(application) self.action_roles = { 'CloudController': { - 'DescribeAvailabilityzones': ['all'], + 'DescribeAvailabilityZones': ['all'], 'DescribeRegions': ['all'], 'DescribeSnapshots': ['all'], 'DescribeKeyPairs': ['all'], @@ -274,7 +275,7 @@ class Authorizer(wsgi.Middleware): }, } - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=wsgi.Request) def __call__(self, req): context = req.environ['ec2.context'] controller = req.environ['ec2.request'].controller.__class__.__name__ @@ -295,7 +296,7 @@ class Authorizer(wsgi.Middleware): return True if 'none' in roles: return False - return any(context.project.has_role(context.user.id, role) + return any(context.project.has_role(context.user_id, role) for role in roles) @@ -308,7 +309,7 @@ class Executor(wsgi.Application): response, or a 400 upon failure. """ - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=wsgi.Request) def __call__(self, req): context = req.environ['ec2.context'] api_request = req.environ['ec2.request'] @@ -370,7 +371,7 @@ class Executor(wsgi.Application): class Versions(wsgi.Application): - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=wsgi.Request) def __call__(self, req): """Respond to a request for all EC2 versions.""" # available api versions diff --git a/nova/api/ec2/admin.py b/nova/api/ec2/admin.py index 735951082..d9a4ef999 100644 --- a/nova/api/ec2/admin.py +++ b/nova/api/ec2/admin.py @@ -21,14 +21,17 @@ Admin API controller, exposed through http via the api worker. """ import base64 +import datetime from nova import db from nova import exception +from nova import flags from nova import log as logging +from nova import utils from nova.auth import manager -from nova.compute import instance_types +FLAGS = flags.FLAGS LOG = logging.getLogger('nova.api.ec2.admin') @@ -55,22 +58,54 @@ def project_dict(project): return {} -def host_dict(host): +def host_dict(host, compute_service, instances, volume_service, volumes, now): """Convert a host model object to a result dict""" - if host: - return host.state - else: - return {} + rv = {'hostanme': host, 'instance_count': len(instances), + 'volume_count': len(volumes)} + if compute_service: + latest = compute_service['updated_at'] or compute_service['created_at'] + delta = now - latest + if delta.seconds <= FLAGS.service_down_time: + rv['compute'] = 'up' + else: + rv['compute'] = 'down' + if volume_service: + latest = volume_service['updated_at'] or volume_service['created_at'] + delta = now - latest + if delta.seconds <= FLAGS.service_down_time: + rv['volume'] = 'up' + else: + rv['volume'] = 'down' + return rv -def instance_dict(name, inst): - return {'name': name, +def instance_dict(inst): + return {'name': inst['name'], 'memory_mb': inst['memory_mb'], 'vcpus': inst['vcpus'], 'disk_gb': inst['local_gb'], 'flavor_id': inst['flavorid']} +def vpn_dict(project, vpn_instance): + rv = {'project_id': project.id, + 'public_ip': project.vpn_ip, + 'public_port': project.vpn_port} + if vpn_instance: + rv['instance_id'] = vpn_instance['ec2_id'] + rv['created_at'] = utils.isotime(vpn_instance['created_at']) + address = vpn_instance.get('fixed_ip', None) + if address: + rv['internal_ip'] = address['address'] + if utils.vpn_ping(project.vpn_ip, project.vpn_port): + rv['state'] = 'running' + else: + rv['state'] = 'down' + else: + rv['state'] = 'pending' + return rv + + class AdminController(object): """ API Controller for users, hosts, nodes, and workers. @@ -79,9 +114,9 @@ class AdminController(object): def __str__(self): return 'AdminController' - def describe_instance_types(self, _context, **_kwargs): - return {'instanceTypeSet': [instance_dict(n, v) for n, v in - instance_types.INSTANCE_TYPES.iteritems()]} + def describe_instance_types(self, context, **_kwargs): + """Returns all active instance types data (vcpus, memory, etc.)""" + return {'instanceTypeSet': [db.instance_type_get_all(context)]} def describe_user(self, _context, name, **_kwargs): """Returns user data, including access and secret keys.""" @@ -223,19 +258,68 @@ class AdminController(object): raise exception.ApiError(_('operation must be add or remove')) return True + def _vpn_for(self, context, project_id): + """Get the VPN instance for a project ID.""" + for instance in db.instance_get_all_by_project(context, project_id): + if (instance['image_id'] == FLAGS.vpn_image_id + and not instance['state_description'] in + ['shutting_down', 'shutdown']): + return instance + + def start_vpn(self, context, project): + instance = self._vpn_for(context, project) + if not instance: + # NOTE(vish) import delayed because of __init__.py + from nova.cloudpipe import pipelib + pipe = pipelib.CloudPipe() + try: + pipe.launch_vpn_instance(project) + except db.NoMoreNetworks: + raise exception.ApiError("Unable to claim IP for VPN instance" + ", ensure it isn't running, and try " + "again in a few minutes") + instance = self._vpn_for(context, project) + return {'instance_id': instance['ec2_id']} + + def describe_vpns(self, context): + vpns = [] + for project in manager.AuthManager().get_projects(): + instance = self._vpn_for(context, project.id) + vpns.append(vpn_dict(project, instance)) + return {'items': vpns} + # FIXME(vish): these host commands don't work yet, perhaps some of the # required data can be retrieved from service objects? - def describe_hosts(self, _context, **_kwargs): + def describe_hosts(self, context, **_kwargs): """Returns status info for all nodes. Includes: - * Disk Space - * Instance List - * RAM used - * CPU used - * DHCP servers running - * Iptables / bridges + * Hostname + * Compute (up, down, None) + * Instance count + * Volume (up, down, None) + * Volume Count """ - return {'hostSet': [host_dict(h) for h in db.host_get_all()]} + services = db.service_get_all(context) + now = datetime.datetime.utcnow() + hosts = [] + rv = [] + for host in [service['host'] for service in services]: + if not host in hosts: + hosts.append(host) + for host in hosts: + compute = [s for s in services if s['host'] == host \ + and s['binary'] == 'nova-compute'] + if compute: + compute = compute[0] + instances = db.instance_get_all_by_host(context, host) + volume = [s for s in services if s['host'] == host \ + and s['binary'] == 'nova-volume'] + if volume: + volume = volume[0] + volumes = db.volume_get_all_by_host(context, host) + rv.append(host_dict(host, compute, instances, volume, volumes, + now)) + return {'hosts': rv} def describe_host(self, _context, name, **_kwargs): """Returns status info for single node.""" diff --git a/nova/api/ec2/apirequest.py b/nova/api/ec2/apirequest.py index 7e72d67fb..d7ad08d2f 100644 --- a/nova/api/ec2/apirequest.py +++ b/nova/api/ec2/apirequest.py @@ -20,6 +20,7 @@ APIRequest class """ +import datetime import re # TODO(termie): replace minidom with etree from xml.dom import minidom @@ -45,8 +46,29 @@ def _underscore_to_xmlcase(str): return res[:1].lower() + res[1:] +def _database_to_isoformat(datetimeobj): + """Return a xs:dateTime parsable string from datatime""" + return datetimeobj.strftime("%Y-%m-%dT%H:%M:%SZ") + + def _try_convert(value): - """Return a non-string if possible""" + """Return a non-string from a string or unicode, if possible. + + ============= ===================================================== + When value is returns + ============= ===================================================== + zero-length '' + 'None' None + 'True' True + 'False' False + '0', '-0' 0 + 0xN, -0xN int from hex (postitive) (N is any number) + 0bN, -0bN int from binary (positive) (N is any number) + * try conversion to int, float, complex, fallback value + + """ + if len(value) == 0: + return '' if value == 'None': return None if value == 'True': @@ -171,6 +193,9 @@ class APIRequest(object): self._render_dict(xml, data_el, data.__dict__) elif isinstance(data, bool): data_el.appendChild(xml.createTextNode(str(data).lower())) + elif isinstance(data, datetime.datetime): + data_el.appendChild( + xml.createTextNode(_database_to_isoformat(data))) elif data != None: data_el.appendChild(xml.createTextNode(str(data))) diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 00d044e95..b1917e9ea 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -39,7 +39,9 @@ from nova import log as logging from nova import network from nova import utils from nova import volume +from nova.api.ec2 import ec2utils from nova.compute import instance_types +from nova.image import s3 FLAGS = flags.FLAGS @@ -73,30 +75,19 @@ def _gen_key(context, user_id, key_name): return {'private_key': private_key, 'fingerprint': fingerprint} -def ec2_id_to_id(ec2_id): - """Convert an ec2 ID (i-[base 16 number]) to an instance id (int)""" - return int(ec2_id.split('-')[-1], 16) - - -def id_to_ec2_id(instance_id, template='i-%08x'): - """Convert an instance ID (int) to an ec2 ID (i-[base 16 number])""" - return template % instance_id - - class CloudController(object): """ CloudController provides the critical dispatch between inbound API calls through the endpoint and messages sent to the other nodes. """ def __init__(self): - self.image_service = utils.import_object(FLAGS.image_service) + self.image_service = s3.S3ImageService() self.network_api = network.API() self.volume_api = volume.API() self.compute_api = compute.API( network_api=self.network_api, - image_service=self.image_service, volume_api=self.volume_api, - hostname_factory=id_to_ec2_id) + hostname_factory=ec2utils.id_to_ec2_id) self.setup() def __str__(self): @@ -115,7 +106,7 @@ class CloudController(object): start = os.getcwd() os.chdir(FLAGS.ca_path) # TODO(vish): Do this with M2Crypto instead - utils.runthis(_("Generating root CA: %s"), "sh genrootca.sh") + utils.runthis(_("Generating root CA: %s"), "sh", "genrootca.sh") os.chdir(start) def _get_mpi_data(self, context, project_id): @@ -154,11 +145,14 @@ class CloudController(object): availability_zone = self._get_availability_zone_by_host(ctxt, host) floating_ip = db.instance_get_floating_address(ctxt, instance_ref['id']) - ec2_id = id_to_ec2_id(instance_ref['id']) + ec2_id = ec2utils.id_to_ec2_id(instance_ref['id']) + image_ec2_id = self._image_ec2_id(instance_ref['image_id'], 'machine') + k_ec2_id = self._image_ec2_id(instance_ref['kernel_id'], 'kernel') + r_ec2_id = self._image_ec2_id(instance_ref['ramdisk_id'], 'ramdisk') data = { 'user-data': base64.b64decode(instance_ref['user_data']), 'meta-data': { - 'ami-id': instance_ref['image_id'], + 'ami-id': image_ec2_id, 'ami-launch-index': instance_ref['launch_index'], 'ami-manifest-path': 'FIXME', 'block-device-mapping': { @@ -173,12 +167,12 @@ class CloudController(object): 'instance-type': instance_ref['instance_type'], 'local-hostname': hostname, 'local-ipv4': address, - 'kernel-id': instance_ref['kernel_id'], + 'kernel-id': k_ec2_id, + 'ramdisk-id': r_ec2_id, 'placement': {'availability-zone': availability_zone}, 'public-hostname': hostname, 'public-ipv4': floating_ip or '', 'public-keys': keys, - 'ramdisk-id': instance_ref['ramdisk_id'], 'reservation-id': instance_ref['reservation_id'], 'security-groups': '', 'mpi': mpi}} @@ -198,8 +192,9 @@ class CloudController(object): return self._describe_availability_zones(context, **kwargs) def _describe_availability_zones(self, context, **kwargs): - enabled_services = db.service_get_all(context) - disabled_services = db.service_get_all(context, True) + ctxt = context.elevated() + enabled_services = db.service_get_all(ctxt) + disabled_services = db.service_get_all(ctxt, True) available_zones = [] for zone in [service.availability_zone for service in enabled_services]: @@ -282,7 +277,7 @@ class CloudController(object): 'description': 'fixme'}]} def describe_key_pairs(self, context, key_name=None, **kwargs): - key_pairs = db.key_pair_get_all_by_user(context, context.user.id) + key_pairs = db.key_pair_get_all_by_user(context, context.user_id) if not key_name is None: key_pairs = [x for x in key_pairs if x['name'] in key_name] @@ -290,18 +285,18 @@ class CloudController(object): for key_pair in key_pairs: # filter out the vpn keys suffix = FLAGS.vpn_key_suffix - if context.user.is_admin() or \ + if context.is_admin or \ not key_pair['name'].endswith(suffix): result.append({ 'keyName': key_pair['name'], 'keyFingerprint': key_pair['fingerprint'], }) - return {'keypairsSet': result} + return {'keySet': result} def create_key_pair(self, context, key_name, **kwargs): LOG.audit(_("Create key pair %s"), key_name, context=context) - data = _gen_key(context, context.user.id, key_name) + data = _gen_key(context, context.user_id, key_name) return {'keyName': key_name, 'keyFingerprint': data['fingerprint'], 'keyMaterial': data['private_key']} @@ -310,7 +305,7 @@ class CloudController(object): def delete_key_pair(self, context, key_name, **kwargs): LOG.audit(_("Delete key pair %s"), key_name, context=context) try: - db.key_pair_destroy(context, context.user.id, key_name) + db.key_pair_destroy(context, context.user_id, key_name) except exception.NotFound: # aws returns true even if the key doesn't exist pass @@ -318,16 +313,23 @@ class CloudController(object): def describe_security_groups(self, context, group_name=None, **kwargs): self.compute_api.ensure_default_security_group(context) - if context.user.is_admin(): + if group_name: + groups = [] + for name in group_name: + group = db.security_group_get_by_name(context, + context.project_id, + name) + groups.append(group) + elif context.is_admin: groups = db.security_group_get_all(context) else: groups = db.security_group_get_by_project(context, context.project_id) groups = [self._format_security_group(context, g) for g in groups] - if not group_name is None: - groups = [g for g in groups if g.name in group_name] - return {'securityGroupInfo': groups} + return {'securityGroupInfo': + list(sorted(groups, + key=lambda k: (k['ownerId'], k['groupName'])))} def _format_security_group(self, context, group): g = {} @@ -492,7 +494,7 @@ class CloudController(object): if db.security_group_exists(context, context.project_id, group_name): raise exception.ApiError(_('group %s already exists') % group_name) - group = {'user_id': context.user.id, + group = {'user_id': context.user_id, 'project_id': context.project_id, 'name': group_name, 'description': group_description} @@ -512,9 +514,12 @@ class CloudController(object): def get_console_output(self, context, instance_id, **kwargs): LOG.audit(_("Get console output for instance %s"), instance_id, context=context) - # instance_id is passed in as a list of instances - ec2_id = instance_id[0] - instance_id = ec2_id_to_id(ec2_id) + # instance_id may be passed in as a list of instances + if type(instance_id) == list: + ec2_id = instance_id[0] + else: + ec2_id = instance_id + instance_id = ec2utils.ec2_id_to_id(ec2_id) output = self.compute_api.get_console_output( context, instance_id=instance_id) now = datetime.datetime.utcnow() @@ -524,14 +529,15 @@ class CloudController(object): def get_ajax_console(self, context, instance_id, **kwargs): ec2_id = instance_id[0] - internal_id = ec2_id_to_id(ec2_id) - return self.compute_api.get_ajax_console(context, internal_id) + instance_id = ec2utils.ec2_id_to_id(ec2_id) + return self.compute_api.get_ajax_console(context, + instance_id=instance_id) def describe_volumes(self, context, volume_id=None, **kwargs): if volume_id: volumes = [] for ec2_id in volume_id: - internal_id = ec2_id_to_id(ec2_id) + internal_id = ec2utils.ec2_id_to_id(ec2_id) volume = self.volume_api.get(context, internal_id) volumes.append(volume) else: @@ -544,11 +550,11 @@ class CloudController(object): instance_data = None if volume.get('instance', None): instance_id = volume['instance']['id'] - instance_ec2_id = id_to_ec2_id(instance_id) + instance_ec2_id = ec2utils.id_to_ec2_id(instance_id) instance_data = '%s[%s]' % (instance_ec2_id, volume['instance']['host']) v = {} - v['volumeId'] = id_to_ec2_id(volume['id'], 'vol-%08x') + v['volumeId'] = ec2utils.id_to_ec2_id(volume['id'], 'vol-%08x') v['status'] = volume['status'] v['size'] = volume['size'] v['availabilityZone'] = volume['availability_zone'] @@ -566,8 +572,7 @@ class CloudController(object): 'device': volume['mountpoint'], 'instanceId': instance_ec2_id, 'status': 'attached', - 'volumeId': id_to_ec2_id(volume['id'], - 'vol-%08x')}] + 'volumeId': v['volumeId']}] else: v['attachmentSet'] = [{}] @@ -586,12 +591,12 @@ class CloudController(object): return {'volumeSet': [self._format_volume(context, dict(volume))]} def delete_volume(self, context, volume_id, **kwargs): - volume_id = ec2_id_to_id(volume_id) + volume_id = ec2utils.ec2_id_to_id(volume_id) self.volume_api.delete(context, volume_id=volume_id) return True def update_volume(self, context, volume_id, **kwargs): - volume_id = ec2_id_to_id(volume_id) + volume_id = ec2utils.ec2_id_to_id(volume_id) updatable_fields = ['display_name', 'display_description'] changes = {} for field in updatable_fields: @@ -602,8 +607,8 @@ class CloudController(object): return True def attach_volume(self, context, volume_id, instance_id, device, **kwargs): - volume_id = ec2_id_to_id(volume_id) - instance_id = ec2_id_to_id(instance_id) + volume_id = ec2utils.ec2_id_to_id(volume_id) + instance_id = ec2utils.ec2_id_to_id(instance_id) msg = _("Attach volume %(volume_id)s to instance %(instance_id)s" " at %(device)s") % locals() LOG.audit(msg, context=context) @@ -614,22 +619,22 @@ class CloudController(object): volume = self.volume_api.get(context, volume_id) return {'attachTime': volume['attach_time'], 'device': volume['mountpoint'], - 'instanceId': id_to_ec2_id(instance_id), + 'instanceId': ec2utils.id_to_ec2_id(instance_id), 'requestId': context.request_id, 'status': volume['attach_status'], - 'volumeId': id_to_ec2_id(volume_id, 'vol-%08x')} + 'volumeId': ec2utils.id_to_ec2_id(volume_id, 'vol-%08x')} def detach_volume(self, context, volume_id, **kwargs): - volume_id = ec2_id_to_id(volume_id) + volume_id = ec2utils.ec2_id_to_id(volume_id) LOG.audit(_("Detach volume %s"), volume_id, context=context) volume = self.volume_api.get(context, volume_id) instance = self.compute_api.detach_volume(context, volume_id=volume_id) return {'attachTime': volume['attach_time'], 'device': volume['mountpoint'], - 'instanceId': id_to_ec2_id(instance['id']), + 'instanceId': ec2utils.id_to_ec2_id(instance['id']), 'requestId': context.request_id, 'status': volume['attach_status'], - 'volumeId': id_to_ec2_id(volume_id, 'vol-%08x')} + 'volumeId': ec2utils.id_to_ec2_id(volume_id, 'vol-%08x')} def _convert_to_set(self, lst, label): if lst == None or lst == []: @@ -663,20 +668,21 @@ class CloudController(object): if instance_id: instances = [] for ec2_id in instance_id: - internal_id = ec2_id_to_id(ec2_id) - instance = self.compute_api.get(context, internal_id) + internal_id = ec2utils.ec2_id_to_id(ec2_id) + instance = self.compute_api.get(context, + instance_id=internal_id) instances.append(instance) else: instances = self.compute_api.get_all(context, **kwargs) for instance in instances: - if not context.user.is_admin(): + if not context.is_admin: if instance['image_id'] == FLAGS.vpn_image_id: continue i = {} instance_id = instance['id'] - ec2_id = id_to_ec2_id(instance_id) + ec2_id = ec2utils.id_to_ec2_id(instance_id) i['instanceId'] = ec2_id - i['imageId'] = instance['image_id'] + i['imageId'] = self._image_ec2_id(instance['image_id']) i['instanceState'] = { 'code': instance['state'], 'name': instance['state_description']} @@ -697,7 +703,7 @@ class CloudController(object): i['dnsName'] = i['publicDnsName'] or i['privateDnsName'] i['keyName'] = instance['key_name'] - if context.user.is_admin(): + if context.is_admin: i['keyName'] = '%s (%s, %s)' % (i['keyName'], instance['project_id'], instance['host']) @@ -731,7 +737,7 @@ class CloudController(object): def format_addresses(self, context): addresses = [] - if context.user.is_admin(): + if context.is_admin: iterator = db.floating_ip_get_all(context) else: iterator = db.floating_ip_get_all_by_project(context, @@ -742,10 +748,10 @@ class CloudController(object): if (floating_ip_ref['fixed_ip'] and floating_ip_ref['fixed_ip']['instance']): instance_id = floating_ip_ref['fixed_ip']['instance']['id'] - ec2_id = id_to_ec2_id(instance_id) + ec2_id = ec2utils.id_to_ec2_id(instance_id) address_rv = {'public_ip': address, 'instance_id': ec2_id} - if context.user.is_admin(): + if context.is_admin: details = "%s (%s)" % (address_rv['instance_id'], floating_ip_ref['project_id']) address_rv['instance_id'] = details @@ -765,7 +771,7 @@ class CloudController(object): def associate_address(self, context, instance_id, public_ip, **kwargs): LOG.audit(_("Associate address %(public_ip)s to" " instance %(instance_id)s") % locals(), context=context) - instance_id = ec2_id_to_id(instance_id) + instance_id = ec2utils.ec2_id_to_id(instance_id) self.compute_api.associate_floating_ip(context, instance_id=instance_id, address=public_ip) @@ -778,13 +784,19 @@ class CloudController(object): def run_instances(self, context, **kwargs): max_count = int(kwargs.get('max_count', 1)) + if kwargs.get('kernel_id'): + kernel = self._get_image(context, kwargs['kernel_id']) + kwargs['kernel_id'] = kernel['id'] + if kwargs.get('ramdisk_id'): + ramdisk = self._get_image(context, kwargs['ramdisk_id']) + kwargs['ramdisk_id'] = ramdisk['id'] instances = self.compute_api.create(context, instance_type=instance_types.get_by_type( kwargs.get('instance_type', None)), - image_id=kwargs['image_id'], + image_id=self._get_image(context, kwargs['image_id'])['id'], min_count=int(kwargs.get('min_count', max_count)), max_count=max_count, - kernel_id=kwargs.get('kernel_id', None), + kernel_id=kwargs.get('kernel_id'), ramdisk_id=kwargs.get('ramdisk_id'), display_name=kwargs.get('display_name'), display_description=kwargs.get('display_description'), @@ -801,7 +813,7 @@ class CloudController(object): instance_id is a kwarg so its name cannot be modified.""" LOG.debug(_("Going to start terminating instances")) for ec2_id in instance_id: - instance_id = ec2_id_to_id(ec2_id) + instance_id = ec2utils.ec2_id_to_id(ec2_id) self.compute_api.delete(context, instance_id=instance_id) return True @@ -809,49 +821,103 @@ class CloudController(object): """instance_id is a list of instance ids""" LOG.audit(_("Reboot instance %r"), instance_id, context=context) for ec2_id in instance_id: - instance_id = ec2_id_to_id(ec2_id) + instance_id = ec2utils.ec2_id_to_id(ec2_id) self.compute_api.reboot(context, instance_id=instance_id) return True def rescue_instance(self, context, instance_id, **kwargs): """This is an extension to the normal ec2_api""" - instance_id = ec2_id_to_id(instance_id) + instance_id = ec2utils.ec2_id_to_id(instance_id) self.compute_api.rescue(context, instance_id=instance_id) return True def unrescue_instance(self, context, instance_id, **kwargs): """This is an extension to the normal ec2_api""" - instance_id = ec2_id_to_id(instance_id) + instance_id = ec2utils.ec2_id_to_id(instance_id) self.compute_api.unrescue(context, instance_id=instance_id) return True - def update_instance(self, context, ec2_id, **kwargs): + def update_instance(self, context, instance_id, **kwargs): updatable_fields = ['display_name', 'display_description'] changes = {} for field in updatable_fields: if field in kwargs: changes[field] = kwargs[field] if changes: - instance_id = ec2_id_to_id(ec2_id) + instance_id = ec2utils.ec2_id_to_id(instance_id) self.compute_api.update(context, instance_id=instance_id, **kwargs) return True + _type_prefix_map = {'machine': 'ami', + 'kernel': 'aki', + 'ramdisk': 'ari'} + + def _image_ec2_id(self, image_id, image_type='machine'): + prefix = self._type_prefix_map[image_type] + template = prefix + '-%08x' + return ec2utils.id_to_ec2_id(int(image_id), template=template) + + def _get_image(self, context, ec2_id): + try: + internal_id = ec2utils.ec2_id_to_id(ec2_id) + return self.image_service.show(context, internal_id) + except exception.NotFound: + return self.image_service.show_by_name(context, ec2_id) + + def _format_image(self, image): + """Convert from format defined by BaseImageService to S3 format.""" + i = {} + image_type = image['properties'].get('type') + ec2_id = self._image_ec2_id(image.get('id'), image_type) + name = image.get('name') + if name: + i['imageId'] = "%s (%s)" % (ec2_id, name) + else: + i['imageId'] = ec2_id + kernel_id = image['properties'].get('kernel_id') + if kernel_id: + i['kernelId'] = self._image_ec2_id(kernel_id, 'kernel') + ramdisk_id = image['properties'].get('ramdisk_id') + if ramdisk_id: + i['ramdiskId'] = self._image_ec2_id(ramdisk_id, 'ramdisk') + i['imageOwnerId'] = image['properties'].get('owner_id') + i['imageLocation'] = image['properties'].get('image_location') + i['imageState'] = image['properties'].get('image_state') + i['type'] = image_type + i['isPublic'] = str(image['properties'].get('is_public', '')) == 'True' + i['architecture'] = image['properties'].get('architecture') + return i + def describe_images(self, context, image_id=None, **kwargs): - # Note: image_id is a list! - images = self.image_service.index(context) + # NOTE: image_id is a list! if image_id: - images = filter(lambda x: x['imageId'] in image_id, images) + images = [] + for ec2_id in image_id: + try: + image = self._get_image(context, ec2_id) + except exception.NotFound: + raise exception.NotFound(_('Image %s not found') % + ec2_id) + images.append(image) + else: + images = self.image_service.detail(context) + images = [self._format_image(i) for i in images] return {'imagesSet': images} def deregister_image(self, context, image_id, **kwargs): LOG.audit(_("De-registering image %s"), image_id, context=context) - self.image_service.deregister(context, image_id) + image = self._get_image(context, image_id) + internal_id = image['id'] + self.image_service.delete(context, internal_id) return {'imageId': image_id} def register_image(self, context, image_location=None, **kwargs): if image_location is None and 'name' in kwargs: image_location = kwargs['name'] - image_id = self.image_service.register(context, image_location) + metadata = {'properties': {'image_location': image_location}} + image = self.image_service.create(context, metadata) + image_id = self._image_ec2_id(image['id'], + image['properties']['type']) msg = _("Registered image %(image_location)s with" " id %(image_id)s") % locals() LOG.audit(msg, context=context) @@ -862,11 +928,11 @@ class CloudController(object): raise exception.ApiError(_('attribute not supported: %s') % attribute) try: - image = self.image_service.show(context, image_id) - except IndexError: - raise exception.ApiError(_('invalid id: %s') % image_id) - result = {'image_id': image_id, 'launchPermission': []} - if image['isPublic']: + image = self._get_image(context, image_id) + except exception.NotFound: + raise exception.NotFound(_('Image %s not found') % image_id) + result = {'imageId': image_id, 'launchPermission': []} + if image['properties']['is_public']: result['launchPermission'].append({'group': 'all'}) return result @@ -883,8 +949,18 @@ class CloudController(object): if not operation_type in ['add', 'remove']: raise exception.ApiError(_('operation_type must be add or remove')) LOG.audit(_("Updating image %s publicity"), image_id, context=context) - return self.image_service.modify(context, image_id, operation_type) + + try: + image = self._get_image(context, image_id) + except exception.NotFound: + raise exception.NotFound(_('Image %s not found') % image_id) + internal_id = image['id'] + del(image['id']) + raise Exception(image) + image['properties']['is_public'] = (operation_type == 'add') + return self.image_service.update(context, internal_id, image) def update_image(self, context, image_id, **kwargs): - result = self.image_service.update(context, image_id, dict(kwargs)) + internal_id = ec2utils.ec2_id_to_id(image_id) + result = self.image_service.update(context, internal_id, dict(kwargs)) return result diff --git a/nova/api/ec2/ec2utils.py b/nova/api/ec2/ec2utils.py new file mode 100644 index 000000000..3b34f6ea5 --- /dev/null +++ b/nova/api/ec2/ec2utils.py @@ -0,0 +1,32 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from nova import exception + + +def ec2_id_to_id(ec2_id): + """Convert an ec2 ID (i-[base 16 number]) to an instance id (int)""" + try: + return int(ec2_id.split('-')[-1], 16) + except ValueError: + raise exception.NotFound(_("Id %s Not Found") % ec2_id) + + +def id_to_ec2_id(instance_id, template='i-%08x'): + """Convert an instance ID (int) to an ec2 ID (i-[base 16 number])""" + return template % instance_id diff --git a/nova/api/ec2/metadatarequesthandler.py b/nova/api/ec2/metadatarequesthandler.py index 6fb441656..28f99b0ef 100644 --- a/nova/api/ec2/metadatarequesthandler.py +++ b/nova/api/ec2/metadatarequesthandler.py @@ -65,7 +65,7 @@ class MetadataRequestHandler(wsgi.Application): data = data[item] return data - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=wsgi.Request) def __call__(self, req): cc = cloud.CloudController() remote_address = req.remote_addr diff --git a/nova/api/openstack/__init__.py b/nova/api/openstack/__init__.py index 056c7dd27..ab9dbb780 100644 --- a/nova/api/openstack/__init__.py +++ b/nova/api/openstack/__init__.py @@ -34,6 +34,7 @@ from nova.api.openstack import flavors from nova.api.openstack import images from nova.api.openstack import servers from nova.api.openstack import shared_ip_groups +from nova.api.openstack import zones LOG = logging.getLogger('nova.api.openstack') @@ -46,7 +47,7 @@ flags.DEFINE_bool('allow_admin_api', class FaultWrapper(wsgi.Middleware): """Calls down the middleware stack, making exceptions into faults.""" - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=wsgi.Request) def __call__(self, req): try: return req.get_response(self.application) @@ -73,12 +74,20 @@ class APIRouter(wsgi.Router): server_members = {'action': 'POST'} if FLAGS.allow_admin_api: LOG.debug(_("Including admin operations in API.")) + server_members['pause'] = 'POST' server_members['unpause'] = 'POST' - server_members["diagnostics"] = "GET" - server_members["actions"] = "GET" + server_members['diagnostics'] = 'GET' + server_members['actions'] = 'GET' server_members['suspend'] = 'POST' server_members['resume'] = 'POST' + server_members['rescue'] = 'POST' + server_members['unrescue'] = 'POST' + server_members['reset_network'] = 'POST' + server_members['inject_network_info'] = 'POST' + + mapper.resource("zone", "zones", controller=zones.Controller(), + collection={'detail': 'GET', 'info': 'GET'}), mapper.resource("server", "servers", controller=servers.Controller(), collection={'detail': 'GET'}, @@ -106,7 +115,7 @@ class APIRouter(wsgi.Router): class Versions(wsgi.Application): - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=wsgi.Request) def __call__(self, req): """Respond to a request for all OpenStack API versions.""" response = { @@ -115,4 +124,6 @@ class Versions(wsgi.Application): metadata = { "application/xml": { "attributes": dict(version=["status", "id"])}} - return wsgi.Serializer(req.environ, metadata).to_content_type(response) + + content_type = req.best_match_content_type() + return wsgi.Serializer(metadata).serialize(response, content_type) diff --git a/nova/api/openstack/auth.py b/nova/api/openstack/auth.py index 1dfdd5318..de8905f46 100644 --- a/nova/api/openstack/auth.py +++ b/nova/api/openstack/auth.py @@ -26,6 +26,7 @@ import webob.dec from nova import auth from nova import context from nova import db +from nova import exception from nova import flags from nova import manager from nova import utils @@ -45,7 +46,7 @@ class AuthMiddleware(wsgi.Middleware): self.auth = auth.manager.AuthManager() super(AuthMiddleware, self).__init__(application) - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=wsgi.Request) def __call__(self, req): if not self.has_authentication(req): return self.authenticate(req) @@ -103,11 +104,14 @@ class AuthMiddleware(wsgi.Middleware): 2 days ago. """ ctxt = context.get_admin_context() - token = self.db.auth_get_token(ctxt, token_hash) + try: + token = self.db.auth_token_get(ctxt, token_hash) + except exception.NotFound: + return None if token: delta = datetime.datetime.now() - token.created_at if delta.days >= 2: - self.db.auth_destroy_token(ctxt, token) + self.db.auth_token_destroy(ctxt, token.token_hash) else: return self.auth.get_user(token.user_id) return None @@ -117,7 +121,7 @@ class AuthMiddleware(wsgi.Middleware): username - string key - string API key - req - webob.Request object + req - wsgi.Request object """ ctxt = context.get_admin_context() user = self.auth.get_user_from_access_key(key) @@ -131,6 +135,6 @@ class AuthMiddleware(wsgi.Middleware): token_dict['server_management_url'] = req.url token_dict['storage_url'] = '' token_dict['user_id'] = user.id - token = self.db.auth_create_token(ctxt, token_dict) + token = self.db.auth_token_create(ctxt, token_dict) return token, user return None, None diff --git a/nova/api/openstack/backup_schedules.py b/nova/api/openstack/backup_schedules.py index 197125d86..7abb5f884 100644 --- a/nova/api/openstack/backup_schedules.py +++ b/nova/api/openstack/backup_schedules.py @@ -15,7 +15,6 @@ # License for the specific language governing permissions and limitations # under the License. -import logging import time from webob import exc diff --git a/nova/api/openstack/common.py b/nova/api/openstack/common.py index 6d2fa16e8..74ac21024 100644 --- a/nova/api/openstack/common.py +++ b/nova/api/openstack/common.py @@ -15,25 +15,41 @@ # License for the specific language governing permissions and limitations # under the License. -from nova import exception +import webob.exc +from nova import exception -def limited(items, req): - """Return a slice of items according to requested offset and limit. - items - a sliceable - req - wobob.Request possibly containing offset and limit GET variables. - offset is where to start in the list, and limit is the maximum number - of items to return. +def limited(items, request, max_limit=1000): + """ + Return a slice of items according to requested offset and limit. - If limit is not specified, 0, or > 1000, defaults to 1000. + @param items: A sliceable entity + @param request: `wsgi.Request` possibly containing 'offset' and 'limit' + GET variables. 'offset' is where to start in the list, + and 'limit' is the maximum number of items to return. If + 'limit' is not specified, 0, or > max_limit, we default + to max_limit. Negative values for either offset or limit + will cause exc.HTTPBadRequest() exceptions to be raised. + @kwarg max_limit: The maximum number of items to return from 'items' """ + try: + offset = int(request.GET.get('offset', 0)) + except ValueError: + raise webob.exc.HTTPBadRequest(_('offset param must be an integer')) + + try: + limit = int(request.GET.get('limit', max_limit)) + except ValueError: + raise webob.exc.HTTPBadRequest(_('limit param must be an integer')) + + if limit < 0: + raise webob.exc.HTTPBadRequest(_('limit param must be positive')) + + if offset < 0: + raise webob.exc.HTTPBadRequest(_('offset param must be positive')) - offset = int(req.GET.get('offset', 0)) - limit = int(req.GET.get('limit', 0)) - if not limit: - limit = 1000 - limit = min(1000, limit) + limit = min(max_limit, limit or max_limit) range_end = offset + limit return items[offset:range_end] diff --git a/nova/api/openstack/consoles.py b/nova/api/openstack/consoles.py index 9ebdbe710..8c291c2eb 100644 --- a/nova/api/openstack/consoles.py +++ b/nova/api/openstack/consoles.py @@ -65,7 +65,7 @@ class Controller(wsgi.Controller): def create(self, req, server_id): """Creates a new console""" - #info = self._deserialize(req.body, req) + #info = self._deserialize(req.body, req.get_content_type()) self.console_api.create_console( req.environ['nova.context'], int(server_id)) diff --git a/nova/api/openstack/faults.py b/nova/api/openstack/faults.py index 224a7ef0b..2fd733299 100644 --- a/nova/api/openstack/faults.py +++ b/nova/api/openstack/faults.py @@ -42,7 +42,7 @@ class Fault(webob.exc.HTTPException): """Create a Fault for the given webob.exc.exception.""" self.wrapped_exc = exception - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=wsgi.Request) def __call__(self, req): """Generate a WSGI response based on the exception passed to ctor.""" # Replace the body with fault details. @@ -57,6 +57,7 @@ class Fault(webob.exc.HTTPException): fault_data[fault_name]['retryAfter'] = retry # 'code' is an attribute on the fault tag itself metadata = {'application/xml': {'attributes': {fault_name: 'code'}}} - serializer = wsgi.Serializer(req.environ, metadata) - self.wrapped_exc.body = serializer.to_content_type(fault_data) + serializer = wsgi.Serializer(metadata) + content_type = req.best_match_content_type() + self.wrapped_exc.body = serializer.serialize(fault_data, content_type) return self.wrapped_exc diff --git a/nova/api/openstack/flavors.py b/nova/api/openstack/flavors.py index f620d4107..f3d040ba3 100644 --- a/nova/api/openstack/flavors.py +++ b/nova/api/openstack/flavors.py @@ -17,6 +17,8 @@ from webob import exc +from nova import db +from nova import context from nova.api.openstack import faults from nova.api.openstack import common from nova.compute import instance_types @@ -39,19 +41,19 @@ class Controller(wsgi.Controller): def detail(self, req): """Return all flavors in detail.""" - items = [self.show(req, id)['flavor'] for id in self._all_ids()] - items = common.limited(items, req) + items = [self.show(req, id)['flavor'] for id in self._all_ids(req)] return dict(flavors=items) def show(self, req, id): """Return data about the given flavor id.""" - for name, val in instance_types.INSTANCE_TYPES.iteritems(): - if val['flavorid'] == int(id): - item = dict(ram=val['memory_mb'], disk=val['local_gb'], - id=val['flavorid'], name=name) - return dict(flavor=item) + ctxt = req.environ['nova.context'] + values = db.instance_type_get_by_flavor_id(ctxt, id) + return dict(flavor=values) raise faults.Fault(exc.HTTPNotFound()) - def _all_ids(self): + def _all_ids(self, req): """Return the list of all flavorids.""" - return [i['flavorid'] for i in instance_types.INSTANCE_TYPES.values()] + ctxt = req.environ['nova.context'] + inst_types = db.instance_type_get_all(ctxt) + flavor_ids = [inst_types[i]['flavorid'] for i in inst_types.keys()] + return sorted(flavor_ids) diff --git a/nova/api/openstack/images.py b/nova/api/openstack/images.py index 9d56bc508..98f0dd96b 100644 --- a/nova/api/openstack/images.py +++ b/nova/api/openstack/images.py @@ -15,8 +15,6 @@ # License for the specific language governing permissions and limitations # under the License. -import logging - from webob import exc from nova import compute @@ -153,7 +151,7 @@ class Controller(wsgi.Controller): def create(self, req): context = req.environ['nova.context'] - env = self._deserialize(req.body, req) + env = self._deserialize(req.body, req.get_content_type()) instance_id = env["image"]["serverId"] name = env["image"]["name"] diff --git a/nova/api/openstack/ratelimiting/__init__.py b/nova/api/openstack/ratelimiting/__init__.py index cbb4b897e..88ffc3246 100644 --- a/nova/api/openstack/ratelimiting/__init__.py +++ b/nova/api/openstack/ratelimiting/__init__.py @@ -57,7 +57,7 @@ class RateLimitingMiddleware(wsgi.Middleware): self.limiter = WSGIAppProxy(service_host) super(RateLimitingMiddleware, self).__init__(application) - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=wsgi.Request) def __call__(self, req): """Rate limit the request. @@ -183,7 +183,7 @@ class WSGIApp(object): """Create the WSGI application using the given Limiter instance.""" self.limiter = limiter - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=wsgi.Request) def __call__(self, req): parts = req.path_info.split('/') # format: /limiter/<username>/<urlencoded action> diff --git a/nova/api/openstack/servers.py b/nova/api/openstack/servers.py index 17c5519a1..dc28a0782 100644 --- a/nova/api/openstack/servers.py +++ b/nova/api/openstack/servers.py @@ -1,5 +1,3 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - # Copyright 2010 OpenStack LLC. # All Rights Reserved. # @@ -15,6 +13,7 @@ # License for the specific language governing permissions and limitations # under the License. +import hashlib import json import traceback @@ -35,7 +34,6 @@ import nova.api.openstack LOG = logging.getLogger('server') -LOG.setLevel(logging.DEBUG) FLAGS = flags.FLAGS @@ -53,7 +51,8 @@ def _translate_detail_keys(inst): power_state.PAUSED: 'paused', power_state.SHUTDOWN: 'active', power_state.SHUTOFF: 'active', - power_state.CRASHED: 'error'} + power_state.CRASHED: 'error', + power_state.FAILED: 'error'} inst_dict = {} mapped_keys = dict(status='state', imageId='image_id', @@ -64,8 +63,24 @@ def _translate_detail_keys(inst): inst_dict['status'] = power_mapping[inst_dict['status']] inst_dict['addresses'] = dict(public=[], private=[]) - inst_dict['metadata'] = {} + + # grab single private fixed ip + private_ips = utils.get_from_path(inst, 'fixed_ip/address') + inst_dict['addresses']['private'] = private_ips + + # grab all public floating ips + public_ips = utils.get_from_path(inst, 'fixed_ip/floating_ips/address') + inst_dict['addresses']['public'] = public_ips + + # Return the metadata as a dictionary + metadata = {} + for item in inst['metadata']: + metadata[item['key']] = item['value'] + inst_dict['metadata'] = metadata + inst_dict['hostId'] = '' + if inst['host']: + inst_dict['hostId'] = hashlib.sha224(inst['host']).hexdigest() return dict(server=inst_dict) @@ -83,7 +98,7 @@ class Controller(wsgi.Controller): 'application/xml': { "attributes": { "server": ["id", "imageId", "name", "flavorId", "hostId", - "status", "progress"]}}} + "status", "progress", "adminPass"]}}} def __init__(self): self.compute_api = compute.API() @@ -124,38 +139,35 @@ class Controller(wsgi.Controller): return faults.Fault(exc.HTTPNotFound()) return exc.HTTPAccepted() - def _get_kernel_ramdisk_from_image(self, req, image_id): - """ - Machine images are associated with Kernels and Ramdisk images via - metadata stored in Glance as 'image_properties' - """ - def lookup(param): - _image_id = image_id - try: - return image['properties'][param] - except KeyError: - raise exception.NotFound( - _("%(param)s property not found for image %(_image_id)s") % - locals()) - - image_id = str(image_id) - image = self._image_service.show(req.environ['nova.context'], image_id) - return lookup('kernel_id'), lookup('ramdisk_id') - def create(self, req): """ Creates a new server for a given user """ - env = self._deserialize(req.body, req) + env = self._deserialize(req.body, req.get_content_type()) if not env: return faults.Fault(exc.HTTPUnprocessableEntity()) - key_pair = auth_manager.AuthManager.get_key_pairs( - req.environ['nova.context'])[0] + context = req.environ['nova.context'] + key_pairs = auth_manager.AuthManager.get_key_pairs(context) + if not key_pairs: + raise exception.NotFound(_("No keypairs defined")) + key_pair = key_pairs[0] + image_id = common.get_image_id_from_image_hash(self._image_service, - req.environ['nova.context'], env['server']['imageId']) + context, env['server']['imageId']) kernel_id, ramdisk_id = self._get_kernel_ramdisk_from_image( req, image_id) + + # Metadata is a list, not a Dictionary, because we allow duplicate keys + # (even though JSON can't encode this) + # In future, we may not allow duplicate keys. + # However, the CloudServers API is not definitive on this front, + # and we want to be compatible. + metadata = [] + if env['server'].get('metadata'): + for k, v in env['server']['metadata'].items(): + metadata.append({'key': k, 'value': v}) + instances = self.compute_api.create( - req.environ['nova.context'], + context, instance_types.get_by_flavor_id(env['server']['flavorId']), image_id, kernel_id=kernel_id, @@ -163,12 +175,24 @@ class Controller(wsgi.Controller): display_name=env['server']['name'], display_description=env['server']['name'], key_name=key_pair['name'], - key_data=key_pair['public_key']) - return _translate_keys(instances[0]) + key_data=key_pair['public_key'], + metadata=metadata, + onset_files=env.get('onset_files', [])) + + server = _translate_keys(instances[0]) + password = "%s%s" % (server['server']['name'][:4], + utils.generate_password(12)) + server['server']['adminPass'] = password + self.compute_api.set_admin_password(context, server['server']['id'], + password) + return server def update(self, req, id): """ Updates the server name or password """ - inst_dict = self._deserialize(req.body, req) + if len(req.body) == 0: + raise exc.HTTPUnprocessableEntity() + + inst_dict = self._deserialize(req.body, req.get_content_type()) if not inst_dict: return faults.Fault(exc.HTTPUnprocessableEntity()) @@ -189,10 +213,58 @@ class Controller(wsgi.Controller): return exc.HTTPNoContent() def action(self, req, id): - """ Multi-purpose method used to reboot, rebuild, and - resize a server """ - input_dict = self._deserialize(req.body, req) - #TODO(sandy): rebuild/resize not supported. + """Multi-purpose method used to reboot, rebuild, or + resize a server""" + + actions = { + 'reboot': self._action_reboot, + 'resize': self._action_resize, + 'confirmResize': self._action_confirm_resize, + 'revertResize': self._action_revert_resize, + 'rebuild': self._action_rebuild, + } + + input_dict = self._deserialize(req.body, req.get_content_type()) + for key in actions.keys(): + if key in input_dict: + return actions[key](input_dict, req, id) + return faults.Fault(exc.HTTPNotImplemented()) + + def _action_confirm_resize(self, input_dict, req, id): + try: + self.compute_api.confirm_resize(req.environ['nova.context'], id) + except Exception, e: + LOG.exception(_("Error in confirm-resize %s"), e) + return faults.Fault(exc.HTTPBadRequest()) + return exc.HTTPNoContent() + + def _action_revert_resize(self, input_dict, req, id): + try: + self.compute_api.revert_resize(req.environ['nova.context'], id) + except Exception, e: + LOG.exception(_("Error in revert-resize %s"), e) + return faults.Fault(exc.HTTPBadRequest()) + return exc.HTTPAccepted() + + def _action_rebuild(self, input_dict, req, id): + return faults.Fault(exc.HTTPNotImplemented()) + + def _action_resize(self, input_dict, req, id): + """ Resizes a given instance to the flavor size requested """ + try: + if 'resize' in input_dict and 'flavorId' in input_dict['resize']: + flavor_id = input_dict['resize']['flavorId'] + self.compute_api.resize(req.environ['nova.context'], id, + flavor_id) + else: + LOG.exception(_("Missing arguments for resize")) + return faults.Fault(exc.HTTPUnprocessableEntity()) + except Exception, e: + LOG.exception(_("Error in resize %s"), e) + return faults.Fault(exc.HTTPBadRequest()) + return faults.Fault(exc.HTTPAccepted()) + + def _action_reboot(self, input_dict, req, id): try: reboot_type = input_dict['reboot']['type'] except Exception: @@ -249,6 +321,34 @@ class Controller(wsgi.Controller): return faults.Fault(exc.HTTPUnprocessableEntity()) return exc.HTTPAccepted() + def reset_network(self, req, id): + """ + Reset networking on an instance (admin only). + + """ + context = req.environ['nova.context'] + try: + self.compute_api.reset_network(context, id) + except: + readable = traceback.format_exc() + LOG.exception(_("Compute.api::reset_network %s"), readable) + return faults.Fault(exc.HTTPUnprocessableEntity()) + return exc.HTTPAccepted() + + def inject_network_info(self, req, id): + """ + Inject network info for an instance (admin only). + + """ + context = req.environ['nova.context'] + try: + self.compute_api.inject_network_info(context, id) + except: + readable = traceback.format_exc() + LOG.exception(_("Compute.api::inject_network_info %s"), readable) + return faults.Fault(exc.HTTPUnprocessableEntity()) + return exc.HTTPAccepted() + def pause(self, req, id): """ Permit Admins to Pause the server. """ ctxt = req.environ['nova.context'] @@ -293,6 +393,28 @@ class Controller(wsgi.Controller): return faults.Fault(exc.HTTPUnprocessableEntity()) return exc.HTTPAccepted() + def rescue(self, req, id): + """Permit users to rescue the server.""" + context = req.environ["nova.context"] + try: + self.compute_api.rescue(context, id) + except: + readable = traceback.format_exc() + LOG.exception(_("compute.api::rescue %s"), readable) + return faults.Fault(exc.HTTPUnprocessableEntity()) + return exc.HTTPAccepted() + + def unrescue(self, req, id): + """Permit users to unrescue the server.""" + context = req.environ["nova.context"] + try: + self.compute_api.unrescue(context, id) + except: + readable = traceback.format_exc() + LOG.exception(_("compute.api::unrescue %s"), readable) + return faults.Fault(exc.HTTPUnprocessableEntity()) + return exc.HTTPAccepted() + def get_ajax_console(self, req, id): """ Returns a url to an instance's ajaxterm console. """ try: @@ -320,3 +442,37 @@ class Controller(wsgi.Controller): action=item.action, error=item.error)) return dict(actions=actions) + + def _get_kernel_ramdisk_from_image(self, req, image_id): + """Retrevies kernel and ramdisk IDs from Glance + + Only 'machine' (ami) type use kernel and ramdisk outside of the + image. + """ + # FIXME(sirp): Since we're retrieving the kernel_id from an + # image_property, this means only Glance is supported. + # The BaseImageService needs to expose a consistent way of accessing + # kernel_id and ramdisk_id + image = self._image_service.show(req.environ['nova.context'], image_id) + + if image['status'] != 'active': + raise exception.Invalid( + _("Cannot build from image %(image_id)s, status not active") % + locals()) + + if image['disk_format'] != 'ami': + return None, None + + try: + kernel_id = image['properties']['kernel_id'] + except KeyError: + raise exception.NotFound( + _("Kernel not found for image %(image_id)s") % locals()) + + try: + ramdisk_id = image['properties']['ramdisk_id'] + except KeyError: + raise exception.NotFound( + _("Ramdisk not found for image %(image_id)s") % locals()) + + return kernel_id, ramdisk_id diff --git a/nova/api/openstack/shared_ip_groups.py b/nova/api/openstack/shared_ip_groups.py index bd3cc23a8..5d78f9377 100644 --- a/nova/api/openstack/shared_ip_groups.py +++ b/nova/api/openstack/shared_ip_groups.py @@ -15,8 +15,6 @@ # License for the specific language governing permissions and limitations # under the License. -import logging - from webob import exc from nova import wsgi diff --git a/nova/api/openstack/zones.py b/nova/api/openstack/zones.py new file mode 100644 index 000000000..8fe84275a --- /dev/null +++ b/nova/api/openstack/zones.py @@ -0,0 +1,95 @@ +# Copyright 2011 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import common + +from nova import flags +from nova import wsgi +from nova import db +from nova.scheduler import api + + +FLAGS = flags.FLAGS + + +def _filter_keys(item, keys): + """ + Filters all model attributes except for keys + item is a dict + + """ + return dict((k, v) for k, v in item.iteritems() if k in keys) + + +def _exclude_keys(item, keys): + return dict((k, v) for k, v in item.iteritems() if k not in keys) + + +def _scrub_zone(zone): + return _filter_keys(zone, ('id', 'api_url')) + + +class Controller(wsgi.Controller): + + _serialization_metadata = { + 'application/xml': { + "attributes": { + "zone": ["id", "api_url", "name", "capabilities"]}}} + + def index(self, req): + """Return all zones in brief""" + # Ask the ZoneManager in the Scheduler for most recent data, + # or fall-back to the database ... + items = api.API().get_zone_list(req.environ['nova.context']) + if not items: + items = db.zone_get_all(req.environ['nova.context']) + + items = common.limited(items, req) + items = [_exclude_keys(item, ['username', 'password']) + for item in items] + return dict(zones=items) + + def detail(self, req): + """Return all zones in detail""" + return self.index(req) + + def info(self, req): + """Return name and capabilities for this zone.""" + return dict(zone=dict(name=FLAGS.zone_name, + capabilities=FLAGS.zone_capabilities)) + + def show(self, req, id): + """Return data about the given zone id""" + zone_id = int(id) + zone = db.zone_get(req.environ['nova.context'], zone_id) + return dict(zone=_scrub_zone(zone)) + + def delete(self, req, id): + zone_id = int(id) + db.zone_delete(req.environ['nova.context'], zone_id) + return {} + + def create(self, req): + context = req.environ['nova.context'] + env = self._deserialize(req.body, req.get_content_type()) + zone = db.zone_create(context, env["zone"]) + return dict(zone=_scrub_zone(zone)) + + def update(self, req, id): + context = req.environ['nova.context'] + env = self._deserialize(req.body, req.get_content_type()) + zone_id = int(id) + zone = db.zone_update(context, zone_id, env["zone"]) + return dict(zone=_scrub_zone(zone)) diff --git a/nova/auth/ldapdriver.py b/nova/auth/ldapdriver.py index e652f1caa..5da7751a0 100644 --- a/nova/auth/ldapdriver.py +++ b/nova/auth/ldapdriver.py @@ -74,6 +74,25 @@ LOG = logging.getLogger("nova.ldapdriver") # in which we may want to change the interface a bit more. +def _clean(attr): + """Clean attr for insertion into ldap""" + if attr is None: + return None + if type(attr) is unicode: + return str(attr) + return attr + + +def sanitize(fn): + """Decorator to sanitize all args""" + def _wrapped(self, *args, **kwargs): + args = [_clean(x) for x in args] + kwargs = dict((k, _clean(v)) for (k, v) in kwargs) + return fn(self, *args, **kwargs) + _wrapped.func_name = fn.func_name + return _wrapped + + class LdapDriver(object): """Ldap Auth driver @@ -106,23 +125,27 @@ class LdapDriver(object): self.conn.unbind_s() return False + @sanitize def get_user(self, uid): """Retrieve user by id""" attr = self.__get_ldap_user(uid) return self.__to_user(attr) + @sanitize def get_user_from_access_key(self, access): """Retrieve user by access key""" query = '(accessKey=%s)' % access dn = FLAGS.ldap_user_subtree return self.__to_user(self.__find_object(dn, query)) + @sanitize def get_project(self, pid): """Retrieve project by id""" dn = self.__project_to_dn(pid) attr = self.__find_object(dn, LdapDriver.project_pattern) return self.__to_project(attr) + @sanitize def get_users(self): """Retrieve list of users""" attrs = self.__find_objects(FLAGS.ldap_user_subtree, @@ -134,6 +157,7 @@ class LdapDriver(object): users.append(user) return users + @sanitize def get_projects(self, uid=None): """Retrieve list of projects""" pattern = LdapDriver.project_pattern @@ -143,6 +167,7 @@ class LdapDriver(object): pattern) return [self.__to_project(attr) for attr in attrs] + @sanitize def create_user(self, name, access_key, secret_key, is_admin): """Create a user""" if self.__user_exists(name): @@ -196,6 +221,7 @@ class LdapDriver(object): self.conn.add_s(self.__uid_to_dn(name), attr) return self.__to_user(dict(attr)) + @sanitize def create_project(self, name, manager_uid, description=None, member_uids=None): """Create a project""" @@ -231,6 +257,7 @@ class LdapDriver(object): self.conn.add_s(dn, attr) return self.__to_project(dict(attr)) + @sanitize def modify_project(self, project_id, manager_uid=None, description=None): """Modify an existing project""" if not manager_uid and not description: @@ -249,21 +276,25 @@ class LdapDriver(object): dn = self.__project_to_dn(project_id) self.conn.modify_s(dn, attr) + @sanitize def add_to_project(self, uid, project_id): """Add user to project""" dn = self.__project_to_dn(project_id) return self.__add_to_group(uid, dn) + @sanitize def remove_from_project(self, uid, project_id): """Remove user from project""" dn = self.__project_to_dn(project_id) return self.__remove_from_group(uid, dn) + @sanitize def is_in_project(self, uid, project_id): """Check if user is in project""" dn = self.__project_to_dn(project_id) return self.__is_in_group(uid, dn) + @sanitize def has_role(self, uid, role, project_id=None): """Check if user has role @@ -273,6 +304,7 @@ class LdapDriver(object): role_dn = self.__role_to_dn(role, project_id) return self.__is_in_group(uid, role_dn) + @sanitize def add_role(self, uid, role, project_id=None): """Add role for user (or user and project)""" role_dn = self.__role_to_dn(role, project_id) @@ -283,11 +315,13 @@ class LdapDriver(object): else: return self.__add_to_group(uid, role_dn) + @sanitize def remove_role(self, uid, role, project_id=None): """Remove role for user (or user and project)""" role_dn = self.__role_to_dn(role, project_id) return self.__remove_from_group(uid, role_dn) + @sanitize def get_user_roles(self, uid, project_id=None): """Retrieve list of roles for user (or user and project)""" if project_id is None: @@ -307,6 +341,7 @@ class LdapDriver(object): roles = self.__find_objects(project_dn, query) return [role['cn'][0] for role in roles] + @sanitize def delete_user(self, uid): """Delete a user""" if not self.__user_exists(uid): @@ -332,12 +367,14 @@ class LdapDriver(object): # Delete entry self.conn.delete_s(self.__uid_to_dn(uid)) + @sanitize def delete_project(self, project_id): """Delete a project""" project_dn = self.__project_to_dn(project_id) self.__delete_roles(project_dn) self.__delete_group(project_dn) + @sanitize def modify_user(self, uid, access_key=None, secret_key=None, admin=None): """Modify an existing user""" if not access_key and not secret_key and admin is None: diff --git a/nova/auth/novarc.template b/nova/auth/novarc.template index c53a4acdc..cda2ecc28 100644 --- a/nova/auth/novarc.template +++ b/nova/auth/novarc.template @@ -10,7 +10,6 @@ export NOVA_CERT=${NOVA_KEY_DIR}/%(nova)s export EUCALYPTUS_CERT=${NOVA_CERT} # euca-bundle-image seems to require this set alias ec2-bundle-image="ec2-bundle-image --cert ${EC2_CERT} --privatekey ${EC2_PRIVATE_KEY} --user 42 --ec2cert ${NOVA_CERT}" alias ec2-upload-bundle="ec2-upload-bundle -a ${EC2_ACCESS_KEY} -s ${EC2_SECRET_KEY} --url ${S3_URL} --ec2cert ${NOVA_CERT}" -export CLOUD_SERVERS_API_KEY="%(access)s" -export CLOUD_SERVERS_USERNAME="%(user)s" -export CLOUD_SERVERS_URL="%(os)s" - +export NOVA_API_KEY="%(access)s" +export NOVA_USERNAME="%(user)s" +export NOVA_URL="%(os)s" diff --git a/nova/compute/api.py b/nova/compute/api.py index ac02dbcfa..f5638ba0b 100644 --- a/nova/compute/api.py +++ b/nova/compute/api.py @@ -67,10 +67,10 @@ class API(base.Base): """Get the network topic for an instance.""" try: instance = self.get(context, instance_id) - except exception.NotFound as e: + except exception.NotFound: LOG.warning(_("Instance %d was not found in get_network_topic"), instance_id) - raise e + raise host = instance['host'] if not host: @@ -85,11 +85,12 @@ class API(base.Base): min_count=1, max_count=1, display_name='', display_description='', key_name=None, key_data=None, security_group='default', - availability_zone=None, user_data=None): + availability_zone=None, user_data=None, metadata=[], + onset_files=None): """Create the number of instances requested if quota and other arguments check out ok.""" - type_data = instance_types.INSTANCE_TYPES[instance_type] + type_data = instance_types.get_instance_type(instance_type) num_instances = quota.allowed_instances(context, max_count, type_data) if num_instances < min_count: pid = context.project_id @@ -99,25 +100,48 @@ class API(base.Base): "run %s more instances of this type.") % num_instances, "InstanceLimitExceeded") - is_vpn = image_id == FLAGS.vpn_image_id - if not is_vpn: - image = self.image_service.show(context, image_id) - if kernel_id is None: - kernel_id = image.get('kernelId', None) - if ramdisk_id is None: - ramdisk_id = image.get('ramdiskId', None) - # No kernel and ramdisk for raw images - if kernel_id == str(FLAGS.null_kernel): - kernel_id = None - ramdisk_id = None - LOG.debug(_("Creating a raw instance")) - # Make sure we have access to kernel and ramdisk (if not raw) - logging.debug("Using Kernel=%s, Ramdisk=%s" % - (kernel_id, ramdisk_id)) - if kernel_id: - self.image_service.show(context, kernel_id) - if ramdisk_id: - self.image_service.show(context, ramdisk_id) + num_metadata = len(metadata) + quota_metadata = quota.allowed_metadata_items(context, num_metadata) + if quota_metadata < num_metadata: + pid = context.project_id + msg = (_("Quota exceeeded for %(pid)s," + " tried to set %(num_metadata)s metadata properties") + % locals()) + LOG.warn(msg) + raise quota.QuotaError(msg, "MetadataLimitExceeded") + + # Because metadata is stored in the DB, we hard-code the size limits + # In future, we may support more variable length strings, so we act + # as if this is quota-controlled for forwards compatibility + for metadata_item in metadata: + k = metadata_item['key'] + v = metadata_item['value'] + if len(k) > 255 or len(v) > 255: + pid = context.project_id + msg = (_("Quota exceeeded for %(pid)s," + " metadata property key or value too long") + % locals()) + LOG.warn(msg) + raise quota.QuotaError(msg, "MetadataLimitExceeded") + + image = self.image_service.show(context, image_id) + if kernel_id is None: + kernel_id = image['properties'].get('kernel_id', None) + if ramdisk_id is None: + ramdisk_id = image['properties'].get('ramdisk_id', None) + # FIXME(sirp): is there a way we can remove null_kernel? + # No kernel and ramdisk for raw images + if kernel_id == str(FLAGS.null_kernel): + kernel_id = None + ramdisk_id = None + LOG.debug(_("Creating a raw instance")) + # Make sure we have access to kernel and ramdisk (if not raw) + logging.debug("Using Kernel=%s, Ramdisk=%s" % + (kernel_id, ramdisk_id)) + if kernel_id: + self.image_service.show(context, kernel_id) + if ramdisk_id: + self.image_service.show(context, ramdisk_id) if security_group is None: security_group = ['default'] @@ -141,6 +165,7 @@ class API(base.Base): 'image_id': image_id, 'kernel_id': kernel_id or '', 'ramdisk_id': ramdisk_id or '', + 'state': 0, 'state_description': 'scheduling', 'user_id': context.user_id, 'project_id': context.project_id, @@ -155,8 +180,8 @@ class API(base.Base): 'key_name': key_name, 'key_data': key_data, 'locked': False, + 'metadata': metadata, 'availability_zone': availability_zone} - elevated = context.elevated() instances = [] LOG.debug(_("Going to run %s instances..."), num_instances) @@ -193,7 +218,8 @@ class API(base.Base): {"method": "run_instance", "args": {"topic": FLAGS.compute_topic, "instance_id": instance_id, - "availability_zone": availability_zone}}) + "availability_zone": availability_zone, + "onset_files": onset_files}}) for group_id in security_groups: self.trigger_security_group_members_refresh(elevated, group_id) @@ -293,13 +319,13 @@ class API(base.Base): LOG.debug(_("Going to try to terminate %s"), instance_id) try: instance = self.get(context, instance_id) - except exception.NotFound as e: - LOG.warning(_("Instance %d was not found during terminate"), + except exception.NotFound: + LOG.warning(_("Instance %s was not found during terminate"), instance_id) - raise e + raise if (instance['state_description'] == 'terminating'): - LOG.warning(_("Instance %d is already being terminated"), + LOG.warning(_("Instance %s is already being terminated"), instance_id) return @@ -379,6 +405,10 @@ class API(base.Base): kwargs = {'method': method, 'args': params} return rpc.call(context, queue, kwargs) + def _cast_scheduler_message(self, context, args): + """Generic handler for RPC calls to the scheduler""" + rpc.cast(context, FLAGS.scheduler_topic, args) + def snapshot(self, context, instance_id, name): """Snapshot the given instance. @@ -395,6 +425,45 @@ class API(base.Base): """Reboot the given instance.""" self._cast_compute_message('reboot_instance', context, instance_id) + def revert_resize(self, context, instance_id): + """Reverts a resize, deleting the 'new' instance in the process""" + context = context.elevated() + migration_ref = self.db.migration_get_by_instance_and_status(context, + instance_id, 'finished') + if not migration_ref: + raise exception.NotFound(_("No finished migrations found for " + "instance")) + + params = {'migration_id': migration_ref['id']} + self._cast_compute_message('revert_resize', context, instance_id, + migration_ref['dest_compute'], params=params) + + def confirm_resize(self, context, instance_id): + """Confirms a migration/resize, deleting the 'old' instance in the + process.""" + context = context.elevated() + migration_ref = self.db.migration_get_by_instance_and_status(context, + instance_id, 'finished') + if not migration_ref: + raise exception.NotFound(_("No finished migrations found for " + "instance")) + instance_ref = self.db.instance_get(context, instance_id) + params = {'migration_id': migration_ref['id']} + self._cast_compute_message('confirm_resize', context, instance_id, + migration_ref['source_compute'], params=params) + + self.db.migration_update(context, migration_id, + {'status': 'confirmed'}) + self.db.instance_update(context, instance_id, + {'host': migration_ref['dest_compute'], }) + + def resize(self, context, instance_id, flavor): + """Resize a running instance.""" + self._cast_scheduler_message(context, + {"method": "prep_resize", + "args": {"topic": FLAGS.compute_topic, + "instance_id": instance_id, }},) + def pause(self, context, instance_id): """Pause the given instance.""" self._cast_compute_message('pause_instance', context, instance_id) @@ -430,9 +499,14 @@ class API(base.Base): """Unrescue the given instance.""" self._cast_compute_message('unrescue_instance', context, instance_id) - def set_admin_password(self, context, instance_id): + def set_admin_password(self, context, instance_id, password=None): """Set the root/admin password for the given instance.""" - self._cast_compute_message('set_admin_password', context, instance_id) + self._cast_compute_message('set_admin_password', context, instance_id, + password) + + def inject_file(self, context, instance_id): + """Write a file to the given instance.""" + self._cast_compute_message('inject_file', context, instance_id) def get_ajax_console(self, context, instance_id): """Get a url to an AJAX Console""" @@ -444,7 +518,7 @@ class API(base.Base): {'method': 'authorize_ajax_console', 'args': {'token': output['token'], 'host': output['host'], 'port': output['port']}}) - return {'url': '%s?token=%s' % (FLAGS.ajax_console_proxy_url, + return {'url': '%s/?token=%s' % (FLAGS.ajax_console_proxy_url, output['token'])} def get_console_output(self, context, instance_id): @@ -466,6 +540,20 @@ class API(base.Base): instance = self.get(context, instance_id) return instance['locked'] + def reset_network(self, context, instance_id): + """ + Reset networking on the instance. + + """ + self._cast_compute_message('reset_network', context, instance_id) + + def inject_network_info(self, context, instance_id): + """ + Inject network info for the instance. + + """ + self._cast_compute_message('inject_network_info', context, instance_id) + def attach_volume(self, context, instance_id, volume_id, device): if not re.match("^/dev/[a-z]d[a-z]+$", device): raise exception.ApiError(_("Invalid device specified: %s. " diff --git a/nova/compute/instance_types.py b/nova/compute/instance_types.py index 196d6a8df..fa02a5dfa 100644 --- a/nova/compute/instance_types.py +++ b/nova/compute/instance_types.py @@ -4,6 +4,7 @@ # Administrator of the National Aeronautics and Space Administration. # All Rights Reserved. # Copyright (c) 2010 Citrix Systems, Inc. +# Copyright 2011 Ken Pepple # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain @@ -21,30 +22,120 @@ The built-in instance properties. """ -from nova import flags +from nova import context +from nova import db from nova import exception +from nova import flags +from nova import log as logging FLAGS = flags.FLAGS -INSTANCE_TYPES = { - 'm1.tiny': dict(memory_mb=512, vcpus=1, local_gb=0, flavorid=1), - 'm1.small': dict(memory_mb=2048, vcpus=1, local_gb=20, flavorid=2), - 'm1.medium': dict(memory_mb=4096, vcpus=2, local_gb=40, flavorid=3), - 'm1.large': dict(memory_mb=8192, vcpus=4, local_gb=80, flavorid=4), - 'm1.xlarge': dict(memory_mb=16384, vcpus=8, local_gb=160, flavorid=5)} +LOG = logging.getLogger('nova.instance_types') + + +def create(name, memory, vcpus, local_gb, flavorid, swap=0, + rxtx_quota=0, rxtx_cap=0): + """Creates instance types / flavors + arguments: name memory vcpus local_gb flavorid swap rxtx_quota rxtx_cap + """ + for option in [memory, vcpus, local_gb, flavorid]: + try: + int(option) + except ValueError: + raise exception.InvalidInputException( + _("create arguments must be positive integers")) + if (int(memory) <= 0) or (int(vcpus) <= 0) or (int(local_gb) < 0): + raise exception.InvalidInputException( + _("create arguments must be positive integers")) + + try: + db.instance_type_create( + context.get_admin_context(), + dict(name=name, + memory_mb=memory, + vcpus=vcpus, + local_gb=local_gb, + flavorid=flavorid, + swap=swap, + rxtx_quota=rxtx_quota, + rxtx_cap=rxtx_cap)) + except exception.DBError, e: + LOG.exception(_('DB error: %s' % e)) + raise exception.ApiError(_("Cannot create instance type: %s" % name)) + + +def destroy(name): + """Marks instance types / flavors as deleted + arguments: name""" + if name == None: + raise exception.InvalidInputException(_("No instance type specified")) + else: + try: + db.instance_type_destroy(context.get_admin_context(), name) + except exception.NotFound: + LOG.exception(_('Instance type %s not found for deletion' % name)) + raise exception.ApiError(_("Unknown instance type: %s" % name)) + + +def purge(name): + """Removes instance types / flavors from database + arguments: name""" + if name == None: + raise exception.InvalidInputException(_("No instance type specified")) + else: + try: + db.instance_type_purge(context.get_admin_context(), name) + except exception.NotFound: + LOG.exception(_('Instance type %s not found for purge' % name)) + raise exception.ApiError(_("Unknown instance type: %s" % name)) + + +def get_all_types(inactive=0): + """Retrieves non-deleted instance_types. + Pass true as argument if you want deleted instance types returned also.""" + return db.instance_type_get_all(context.get_admin_context(), inactive) + + +def get_all_flavors(): + """retrieves non-deleted flavors. alias for instance_types.get_all_types(). + Pass true as argument if you want deleted instance types returned also.""" + return get_all_types(context.get_admin_context()) + + +def get_instance_type(name): + """Retrieves single instance type by name""" + if name is None: + return FLAGS.default_instance_type + try: + ctxt = context.get_admin_context() + inst_type = db.instance_type_get_by_name(ctxt, name) + return inst_type + except exception.DBError: + raise exception.ApiError(_("Unknown instance type: %s" % name)) def get_by_type(instance_type): - """Build instance data structure and save it to the data store.""" + """retrieve instance type name""" if instance_type is None: return FLAGS.default_instance_type - if instance_type not in INSTANCE_TYPES: - raise exception.ApiError(_("Unknown instance type: %s"), - instance_type) - return instance_type + + try: + ctxt = context.get_admin_context() + inst_type = db.instance_type_get_by_name(ctxt, instance_type) + return inst_type['name'] + except exception.DBError, e: + LOG.exception(_('DB error: %s' % e)) + raise exception.ApiError(_("Unknown instance type: %s" %\ + instance_type)) def get_by_flavor_id(flavor_id): - for instance_type, details in INSTANCE_TYPES.iteritems(): - if details['flavorid'] == flavor_id: - return instance_type - return FLAGS.default_instance_type + """retrieve instance type's name by flavor_id""" + if flavor_id is None: + return FLAGS.default_instance_type + try: + ctxt = context.get_admin_context() + flavor = db.instance_type_get_by_flavor_id(ctxt, flavor_id) + return flavor['name'] + except exception.DBError, e: + LOG.exception(_('DB error: %s' % e)) + raise exception.ApiError(_("Unknown flavor: %s" % flavor_id)) diff --git a/nova/compute/manager.py b/nova/compute/manager.py index f4418af26..b35216dd3 100644 --- a/nova/compute/manager.py +++ b/nova/compute/manager.py @@ -34,6 +34,7 @@ terminating it. :func:`nova.utils.import_object` """ +import base64 import datetime import random import string @@ -127,10 +128,10 @@ class ComputeManager(manager.Manager): info = self.driver.get_info(instance_ref['name']) state = info['state'] except exception.NotFound: - state = power_state.NOSTATE + state = power_state.FAILED self.db.instance_set_state(context, instance_id, state) - def get_console_topic(self, context, **_kwargs): + def get_console_topic(self, context, **kwargs): """Retrieves the console host for a project on this host Currently this is just set in the flags for each compute host.""" @@ -139,7 +140,7 @@ class ComputeManager(manager.Manager): FLAGS.console_topic, FLAGS.console_host) - def get_network_topic(self, context, **_kwargs): + def get_network_topic(self, context, **kwargs): """Retrieves the network host for a project on this host""" # TODO(vish): This method should be memoized. This will make # the call to get_network_host cheaper, so that @@ -158,21 +159,22 @@ class ComputeManager(manager.Manager): @exception.wrap_exception def refresh_security_group_rules(self, context, - security_group_id, **_kwargs): + security_group_id, **kwargs): """This call passes straight through to the virtualization driver.""" return self.driver.refresh_security_group_rules(security_group_id) @exception.wrap_exception def refresh_security_group_members(self, context, - security_group_id, **_kwargs): + security_group_id, **kwargs): """This call passes straight through to the virtualization driver.""" return self.driver.refresh_security_group_members(security_group_id) @exception.wrap_exception - def run_instance(self, context, instance_id, **_kwargs): + def run_instance(self, context, instance_id, **kwargs): """Launch a new instance with specified options.""" context = context.elevated() instance_ref = self.db.instance_get(context, instance_id) + instance_ref.onset_files = kwargs.get('onset_files', []) if instance_ref['name'] in self.driver.list_instances(): raise exception.Error(_("Instance has already been created")) LOG.audit(_("instance %s: starting..."), instance_id, @@ -323,28 +325,43 @@ class ComputeManager(manager.Manager): """Set the root/admin password for an instance on this server.""" context = context.elevated() instance_ref = self.db.instance_get(context, instance_id) - if instance_ref['state'] != power_state.RUNNING: - logging.warn('trying to reset the password on a non-running ' - 'instance: %s (state: %s expected: %s)', - instance_ref['id'], - instance_ref['state'], - power_state.RUNNING) - - logging.debug('instance %s: setting admin password', + instance_id = instance_ref['id'] + instance_state = instance_ref['state'] + expected_state = power_state.RUNNING + if instance_state != expected_state: + LOG.warn(_('trying to reset the password on a non-running ' + 'instance: %(instance_id)s (state: %(instance_state)s ' + 'expected: %(expected_state)s)') % locals()) + LOG.audit(_('instance %s: setting admin password'), instance_ref['name']) if new_pass is None: # Generate a random password - new_pass = self._generate_password(FLAGS.password_length) - + new_pass = utils.generate_password(FLAGS.password_length) self.driver.set_admin_password(instance_ref, new_pass) self._update_state(context, instance_id) - def _generate_password(self, length=20): - """Generate a random sequence of letters and digits - to be used as a password. - """ - chrs = string.letters + string.digits - return "".join([random.choice(chrs) for i in xrange(length)]) + @exception.wrap_exception + @checks_instance_lock + def inject_file(self, context, instance_id, path, file_contents): + """Write a file to the specified path on an instance on this server""" + context = context.elevated() + instance_ref = self.db.instance_get(context, instance_id) + instance_id = instance_ref['id'] + instance_state = instance_ref['state'] + expected_state = power_state.RUNNING + if instance_state != expected_state: + LOG.warn(_('trying to inject a file into a non-running ' + 'instance: %(instance_id)s (state: %(instance_state)s ' + 'expected: %(expected_state)s)') % locals()) + # Files/paths *should* be base64-encoded at this point, but + # double-check to make sure. + b64_path = utils.ensure_b64_encoding(path) + b64_contents = utils.ensure_b64_encoding(file_contents) + plain_path = base64.b64decode(b64_path) + nm = instance_ref['name'] + msg = _('instance %(nm)s: injecting file to %(plain_path)s') % locals() + LOG.audit(msg) + self.driver.inject_file(instance_ref, b64_path, b64_contents) @exception.wrap_exception @checks_instance_lock @@ -353,12 +370,19 @@ class ComputeManager(manager.Manager): context = context.elevated() instance_ref = self.db.instance_get(context, instance_id) LOG.audit(_('instance %s: rescuing'), instance_id, context=context) - self.db.instance_set_state(context, - instance_id, - power_state.NOSTATE, - 'rescuing') + self.db.instance_set_state( + context, + instance_id, + power_state.NOSTATE, + 'rescuing') self.network_manager.setup_compute_network(context, instance_id) - self.driver.rescue(instance_ref) + self.driver.rescue( + instance_ref, + lambda result: self._update_state_callback( + self, + context, + instance_id, + result)) self._update_state(context, instance_id) @exception.wrap_exception @@ -368,11 +392,18 @@ class ComputeManager(manager.Manager): context = context.elevated() instance_ref = self.db.instance_get(context, instance_id) LOG.audit(_('instance %s: unrescuing'), instance_id, context=context) - self.db.instance_set_state(context, - instance_id, - power_state.NOSTATE, - 'unrescuing') - self.driver.unrescue(instance_ref) + self.db.instance_set_state( + context, + instance_id, + power_state.NOSTATE, + 'unrescuing') + self.driver.unrescue( + instance_ref, + lambda result: self._update_state_callback( + self, + context, + instance_id, + result)) self._update_state(context, instance_id) @staticmethod @@ -382,6 +413,110 @@ class ComputeManager(manager.Manager): @exception.wrap_exception @checks_instance_lock + def confirm_resize(self, context, instance_id, migration_id): + """Destroys the source instance""" + context = context.elevated() + instance_ref = self.db.instance_get(context, instance_id) + migration_ref = self.db.migration_get(context, migration_id) + self.driver.destroy(instance_ref) + + @exception.wrap_exception + @checks_instance_lock + def revert_resize(self, context, instance_id, migration_id): + """Destroys the new instance on the destination machine, + reverts the model changes, and powers on the old + instance on the source machine""" + instance_ref = self.db.instance_get(context, instance_id) + migration_ref = self.db.migration_get(context, migration_id) + + #TODO(mdietz): we may want to split these into separate methods. + if migration_ref['source_compute'] == FLAGS.host: + self.driver._start(instance_ref) + self.db.migration_update(context, migration_id, + {'status': 'reverted'}) + else: + self.driver.destroy(instance_ref) + topic = self.db.queue_get_for(context, FLAGS.compute_topic, + instance_ref['host']) + rpc.cast(context, topic, + {'method': 'revert_resize', + 'args': { + 'migration_id': migration_ref['id'], + 'instance_id': instance_id, }, + }) + + @exception.wrap_exception + @checks_instance_lock + def prep_resize(self, context, instance_id): + """Initiates the process of moving a running instance to another + host, possibly changing the RAM and disk size in the process""" + context = context.elevated() + instance_ref = self.db.instance_get(context, instance_id) + if instance_ref['host'] == FLAGS.host: + raise exception.Error(_( + 'Migration error: destination same as source!')) + + migration_ref = self.db.migration_create(context, + {'instance_id': instance_id, + 'source_compute': instance_ref['host'], + 'dest_compute': FLAGS.host, + 'dest_host': self.driver.get_host_ip_addr(), + 'status': 'pre-migrating'}) + LOG.audit(_('instance %s: migrating to '), instance_id, + context=context) + topic = self.db.queue_get_for(context, FLAGS.compute_topic, + instance_ref['host']) + rpc.cast(context, topic, + {'method': 'resize_instance', + 'args': { + 'migration_id': migration_ref['id'], + 'instance_id': instance_id, }, + }) + + @exception.wrap_exception + @checks_instance_lock + def resize_instance(self, context, instance_id, migration_id): + """Starts the migration of a running instance to another host""" + migration_ref = self.db.migration_get(context, migration_id) + instance_ref = self.db.instance_get(context, instance_id) + self.db.migration_update(context, migration_id, + {'status': 'migrating', }) + + disk_info = self.driver.migrate_disk_and_power_off(instance_ref, + migration_ref['dest_host']) + self.db.migration_update(context, migration_id, + {'status': 'post-migrating', }) + + #TODO(mdietz): This is where we would update the VM record + #after resizing + service = self.db.service_get_by_host_and_topic(context, + migration_ref['dest_compute'], FLAGS.compute_topic) + topic = self.db.queue_get_for(context, FLAGS.compute_topic, + migration_ref['dest_compute']) + rpc.cast(context, topic, + {'method': 'finish_resize', + 'args': { + 'migration_id': migration_id, + 'instance_id': instance_id, + 'disk_info': disk_info, }, + }) + + @exception.wrap_exception + @checks_instance_lock + def finish_resize(self, context, instance_id, migration_id, disk_info): + """Completes the migration process by setting up the newly transferred + disk and turning on the instance on its new host machine""" + migration_ref = self.db.migration_get(context, migration_id) + instance_ref = self.db.instance_get(context, + migration_ref['instance_id']) + + self.driver.finish_resize(instance_ref, disk_info) + + self.db.migration_update(context, migration_id, + {'status': 'finished', }) + + @exception.wrap_exception + @checks_instance_lock def pause_instance(self, context, instance_id): """Pause an instance on this server.""" context = context.elevated() @@ -498,6 +633,30 @@ class ComputeManager(manager.Manager): instance_ref = self.db.instance_get(context, instance_id) return instance_ref['locked'] + @checks_instance_lock + def reset_network(self, context, instance_id): + """ + Reset networking on the instance. + + """ + context = context.elevated() + instance_ref = self.db.instance_get(context, instance_id) + LOG.debug(_('instance %s: reset network'), instance_id, + context=context) + self.driver.reset_network(instance_ref) + + @checks_instance_lock + def inject_network_info(self, context, instance_id): + """ + Inject network info for the instance. + + """ + context = context.elevated() + instance_ref = self.db.instance_get(context, instance_id) + LOG.debug(_('instance %s: inject network info'), instance_id, + context=context) + self.driver.inject_network_info(instance_ref) + @exception.wrap_exception def get_console_output(self, context, instance_id): """Send the console output for an instance.""" @@ -511,7 +670,7 @@ class ComputeManager(manager.Manager): def get_ajax_console(self, context, instance_id): """Return connection information for an ajax console""" context = context.elevated() - logging.debug(_("instance %s: getting ajax console"), instance_id) + LOG.debug(_("instance %s: getting ajax console"), instance_id) instance_ref = self.db.instance_get(context, instance_id) return self.driver.get_ajax_console(instance_ref) diff --git a/nova/compute/power_state.py b/nova/compute/power_state.py index 37039d2ec..adfc2dff0 100644 --- a/nova/compute/power_state.py +++ b/nova/compute/power_state.py @@ -27,6 +27,7 @@ SHUTDOWN = 0x04 SHUTOFF = 0x05 CRASHED = 0x06 SUSPENDED = 0x07 +FAILED = 0x08 def name(code): @@ -38,5 +39,6 @@ def name(code): SHUTDOWN: 'shutdown', SHUTOFF: 'shutdown', CRASHED: 'crashed', - SUSPENDED: 'suspended'} + SUSPENDED: 'suspended', + FAILED: 'failed to spawn'} return d[code] diff --git a/nova/console/manager.py b/nova/console/manager.py index 5697e7cb1..57c75cf4f 100644 --- a/nova/console/manager.py +++ b/nova/console/manager.py @@ -20,11 +20,11 @@ Console Proxy Service """ import functools -import logging import socket from nova import exception from nova import flags +from nova import log as logging from nova import manager from nova import rpc from nova import utils diff --git a/nova/console/xvp.py b/nova/console/xvp.py index ee66dac46..68d8c8565 100644 --- a/nova/console/xvp.py +++ b/nova/console/xvp.py @@ -20,7 +20,6 @@ XVP (Xenserver VNC Proxy) driver. """ import fcntl -import logging import os import signal import subprocess @@ -31,6 +30,7 @@ from nova import context from nova import db from nova import exception from nova import flags +from nova import log as logging from nova import utils flags.DEFINE_string('console_xvp_conf_template', @@ -133,10 +133,10 @@ class XVPConsoleProxy(object): return logging.debug(_("Starting xvp")) try: - utils.execute('xvp -p %s -c %s -l %s' % - (FLAGS.console_xvp_pid, - FLAGS.console_xvp_conf, - FLAGS.console_xvp_log)) + utils.execute('xvp', + '-p', FLAGS.console_xvp_pid, + '-c', FLAGS.console_xvp_conf, + '-l', FLAGS.console_xvp_log) except exception.ProcessExecutionError, err: logging.error(_("Error starting xvp: %s") % err) @@ -190,5 +190,5 @@ class XVPConsoleProxy(object): flag = '-x' #xvp will blow up on passwords that are too long (mdragon) password = password[:maxlen] - out, err = utils.execute('xvp %s' % flag, process_input=password) + out, err = utils.execute('xvp', flag, process_input=password) return out.strip() diff --git a/nova/context.py b/nova/context.py index f2669c9f1..0256bf448 100644 --- a/nova/context.py +++ b/nova/context.py @@ -28,7 +28,6 @@ from nova import utils class RequestContext(object): - def __init__(self, user, project, is_admin=None, read_deleted=False, remote_address=None, timestamp=None, request_id=None): if hasattr(user, 'id'): @@ -53,7 +52,7 @@ class RequestContext(object): self.read_deleted = read_deleted self.remote_address = remote_address if not timestamp: - timestamp = datetime.datetime.utcnow() + timestamp = utils.utcnow() if isinstance(timestamp, str) or isinstance(timestamp, unicode): timestamp = utils.parse_isotime(timestamp) self.timestamp = timestamp @@ -101,7 +100,7 @@ class RequestContext(object): return cls(**values) def elevated(self, read_deleted=False): - """Return a version of this context with admin flag set""" + """Return a version of this context with admin flag set.""" return RequestContext(self.user_id, self.project_id, True, diff --git a/nova/crypto.py b/nova/crypto.py index a34b940f5..2a8d4abca 100644 --- a/nova/crypto.py +++ b/nova/crypto.py @@ -105,8 +105,10 @@ def generate_key_pair(bits=1024): tmpdir = tempfile.mkdtemp() keyfile = os.path.join(tmpdir, 'temp') - utils.execute('ssh-keygen -q -b %d -N "" -f %s' % (bits, keyfile)) - (out, err) = utils.execute('ssh-keygen -q -l -f %s.pub' % (keyfile)) + utils.execute('ssh-keygen', '-q', '-b', bits, '-N', '', + '-f', keyfile) + (out, err) = utils.execute('ssh-keygen', '-q', '-l', '-f', + '%s.pub' % (keyfile)) fingerprint = out.split(' ')[1] private_key = open(keyfile).read() public_key = open(keyfile + '.pub').read() @@ -118,7 +120,8 @@ def generate_key_pair(bits=1024): # bio = M2Crypto.BIO.MemoryBuffer() # key.save_pub_key_bio(bio) # public_key = bio.read() - # public_key, err = execute('ssh-keygen -y -f /dev/stdin', private_key) + # public_key, err = execute('ssh-keygen', '-y', '-f', + # '/dev/stdin', private_key) return (private_key, public_key, fingerprint) @@ -143,9 +146,10 @@ def revoke_cert(project_id, file_name): start = os.getcwd() os.chdir(ca_folder(project_id)) # NOTE(vish): potential race condition here - utils.execute("openssl ca -config ./openssl.cnf -revoke '%s'" % file_name) - utils.execute("openssl ca -gencrl -config ./openssl.cnf -out '%s'" % - FLAGS.crl_file) + utils.execute('openssl', 'ca', '-config', './openssl.cnf', '-revoke', + file_name) + utils.execute('openssl', 'ca', '-gencrl', '-config', './openssl.cnf', + '-out', FLAGS.crl_file) os.chdir(start) @@ -193,9 +197,9 @@ def generate_x509_cert(user_id, project_id, bits=1024): tmpdir = tempfile.mkdtemp() keyfile = os.path.abspath(os.path.join(tmpdir, 'temp.key')) csrfile = os.path.join(tmpdir, 'temp.csr') - utils.execute("openssl genrsa -out %s %s" % (keyfile, bits)) - utils.execute("openssl req -new -key %s -out %s -batch -subj %s" % - (keyfile, csrfile, subject)) + utils.execute('openssl', 'genrsa', '-out', keyfile, str(bits)) + utils.execute('openssl', 'req', '-new', '-key', keyfile, '-out', csrfile, + '-batch', '-subj', subject) private_key = open(keyfile).read() csr = open(csrfile).read() shutil.rmtree(tmpdir) @@ -212,8 +216,8 @@ def _ensure_project_folder(project_id): if not os.path.exists(ca_path(project_id)): start = os.getcwd() os.chdir(ca_folder()) - utils.execute("sh geninter.sh %s %s" % - (project_id, _project_cert_subject(project_id))) + utils.execute('sh', 'geninter.sh', project_id, + _project_cert_subject(project_id)) os.chdir(start) @@ -228,8 +232,8 @@ def generate_vpn_files(project_id): start = os.getcwd() os.chdir(ca_folder()) # TODO(vish): the shell scripts could all be done in python - utils.execute("sh genvpn.sh %s %s" % - (project_id, _vpn_cert_subject(project_id))) + utils.execute('sh', 'genvpn.sh', + project_id, _vpn_cert_subject(project_id)) with open(csr_fn, "r") as csrfile: csr_text = csrfile.read() (serial, signed_csr) = sign_csr(csr_text, project_id) @@ -259,9 +263,10 @@ def _sign_csr(csr_text, ca_folder): start = os.getcwd() # Change working dir to CA os.chdir(ca_folder) - utils.execute("openssl ca -batch -out %s -config " - "./openssl.cnf -infiles %s" % (outbound, inbound)) - out, _err = utils.execute("openssl x509 -in %s -serial -noout" % outbound) + utils.execute('openssl', 'ca', '-batch', '-out', outbound, '-config', + './openssl.cnf', '-infiles', inbound) + out, _err = utils.execute('openssl', 'x509', '-in', outbound, + '-serial', '-noout') serial = out.rpartition("=")[2] os.chdir(start) with open(outbound, "r") as crtfile: diff --git a/nova/db/api.py b/nova/db/api.py index 789cb8ebb..aa86f0af1 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -80,13 +80,18 @@ def service_destroy(context, instance_id): def service_get(context, service_id): - """Get an service or raise if it does not exist.""" + """Get a service or raise if it does not exist.""" return IMPL.service_get(context, service_id) +def service_get_by_host_and_topic(context, host, topic): + """Get a service by host it's on and topic it listens to""" + return IMPL.service_get_by_host_and_topic(context, host, topic) + + def service_get_all(context, disabled=False): - """Get all service.""" - return IMPL.service_get_all(context, None, disabled) + """Get all services.""" + return IMPL.service_get_all(context, disabled) def service_get_all_by_topic(context, topic): @@ -254,6 +259,28 @@ def floating_ip_get_by_address(context, address): #################### +def migration_update(context, id, values): + """Update a migration instance""" + return IMPL.migration_update(context, id, values) + + +def migration_create(context, values): + """Create a migration record""" + return IMPL.migration_create(context, values) + + +def migration_get(context, migration_id): + """Finds a migration by the id""" + return IMPL.migration_get(context, migration_id) + + +def migration_get_by_instance_and_status(context, instance_id, status): + """Finds a migration by the instance id its migrating""" + return IMPL.migration_get_by_instance_and_status(context, instance_id, + status) + +#################### + def fixed_ip_associate(context, address, instance_id): """Associate fixed ip to instance. @@ -288,11 +315,21 @@ def fixed_ip_disassociate_all_by_timeout(context, host, time): return IMPL.fixed_ip_disassociate_all_by_timeout(context, host, time) +def fixed_ip_get_all(context): + """Get all defined fixed ips.""" + return IMPL.fixed_ip_get_all(context) + + def fixed_ip_get_by_address(context, address): """Get a fixed ip by address or raise if it does not exist.""" return IMPL.fixed_ip_get_by_address(context, address) +def fixed_ip_get_all_by_instance(context, instance_id): + """Get fixed ips by instance or raise if none exist.""" + return IMPL.fixed_ip_get_all_by_instance(context, instance_id) + + def fixed_ip_get_instance(context, address): """Get an instance for a fixed ip by address.""" return IMPL.fixed_ip_get_instance(context, address) @@ -480,6 +517,13 @@ def network_create_safe(context, values): return IMPL.network_create_safe(context, values) +def network_delete_safe(context, network_id): + """Delete network with key network_id. + This method assumes that the network is not associated with any project + """ + return IMPL.network_delete_safe(context, network_id) + + def network_create_fixed_ips(context, network_id, num_vpn_clients): """Create the ips for the network, reserving sepecified ips.""" return IMPL.network_create_fixed_ips(context, network_id, num_vpn_clients) @@ -500,6 +544,11 @@ def network_get(context, network_id): return IMPL.network_get(context, network_id) +def network_get_all(context): + """Return all defined networks.""" + return IMPL.network_get_all(context) + + # pylint: disable-msg=C0103 def network_get_associated_fixed_ips(context, network_id): """Get all network's ips that have been associated.""" @@ -511,11 +560,21 @@ def network_get_by_bridge(context, bridge): return IMPL.network_get_by_bridge(context, bridge) +def network_get_by_cidr(context, cidr): + """Get a network by cidr or raise if it does not exist""" + return IMPL.network_get_by_cidr(context, cidr) + + def network_get_by_instance(context, instance_id): """Get a network by instance id or raise if it does not exist.""" return IMPL.network_get_by_instance(context, instance_id) +def network_get_all_by_instance(context, instance_id): + """Get all networks by instance id or raise if none exist.""" + return IMPL.network_get_all_by_instance(context, instance_id) + + def network_get_index(context, network_id): """Get non-conflicting index for network.""" return IMPL.network_get_index(context, network_id) @@ -556,7 +615,7 @@ def project_get_network(context, project_id, associate=True): """ - return IMPL.project_get_network(context, project_id) + return IMPL.project_get_network(context, project_id, associate) def project_get_network_v6(context, project_id): @@ -610,19 +669,24 @@ def iscsi_target_create_safe(context, values): ############### -def auth_destroy_token(context, token): +def auth_token_destroy(context, token_id): """Destroy an auth token.""" - return IMPL.auth_destroy_token(context, token) + return IMPL.auth_token_destroy(context, token_id) -def auth_get_token(context, token_hash): +def auth_token_get(context, token_hash): """Retrieves a token given the hash representing it.""" - return IMPL.auth_get_token(context, token_hash) + return IMPL.auth_token_get(context, token_hash) -def auth_create_token(context, token): +def auth_token_update(context, token_hash, values): + """Updates a token given the hash representing it.""" + return IMPL.auth_token_update(context, token_hash, values) + + +def auth_token_create(context, token): """Creates a new token.""" - return IMPL.auth_create_token(context, token) + return IMPL.auth_token_create(context, token) ################### @@ -980,3 +1044,66 @@ def console_get_all_by_instance(context, instance_id): def console_get(context, console_id, instance_id=None): """Get a specific console (possibly on a given instance).""" return IMPL.console_get(context, console_id, instance_id) + + + ################## + + +def instance_type_create(context, values): + """Create a new instance type""" + return IMPL.instance_type_create(context, values) + + +def instance_type_get_all(context, inactive=0): + """Get all instance types""" + return IMPL.instance_type_get_all(context, inactive) + + +def instance_type_get_by_name(context, name): + """Get instance type by name""" + return IMPL.instance_type_get_by_name(context, name) + + +def instance_type_get_by_flavor_id(context, id): + """Get instance type by name""" + return IMPL.instance_type_get_by_flavor_id(context, id) + + +def instance_type_destroy(context, name): + """Delete a instance type""" + return IMPL.instance_type_destroy(context, name) + + +def instance_type_purge(context, name): + """Purges (removes) an instance type from DB + Use instance_type_destroy for most cases + """ + return IMPL.instance_type_purge(context, name) + + +#################### + + +def zone_create(context, values): + """Create a new child Zone entry.""" + return IMPL.zone_create(context, values) + + +def zone_update(context, zone_id, values): + """Update a child Zone entry.""" + return IMPL.zone_update(context, values) + + +def zone_delete(context, zone_id): + """Delete a child Zone.""" + return IMPL.zone_delete(context, zone_id) + + +def zone_get(context, zone_id): + """Get a specific child Zone.""" + return IMPL.zone_get(context, zone_id) + + +def zone_get_all(context): + """Get all child Zones.""" + return IMPL.zone_get_all(context) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 85250d56e..3e94082df 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -136,15 +136,12 @@ def service_get(context, service_id, session=None): @require_admin_context -def service_get_all(context, session=None, disabled=False): - if not session: - session = get_session() - - result = session.query(models.Service).\ +def service_get_all(context, disabled=False): + session = get_session() + return session.query(models.Service).\ filter_by(deleted=can_read_deleted(context)).\ filter_by(disabled=disabled).\ all() - return result @require_admin_context @@ -158,6 +155,17 @@ def service_get_all_by_topic(context, topic): @require_admin_context +def service_get_by_host_and_topic(context, host, topic): + session = get_session() + return session.query(models.Service).\ + filter_by(deleted=False).\ + filter_by(disabled=False).\ + filter_by(host=host).\ + filter_by(topic=topic).\ + first() + + +@require_admin_context def service_get_all_by_host(context, host): session = get_session() return session.query(models.Service).\ @@ -579,10 +587,21 @@ def fixed_ip_disassociate_all_by_timeout(_context, host, time): 'AND instance_id IS NOT NULL ' 'AND allocated = 0', {'host': host, - 'time': time.isoformat()}) + 'time': time}) return result.rowcount +@require_admin_context +def fixed_ip_get_all(context, session=None): + if not session: + session = get_session() + result = session.query(models.FixedIp).all() + if not result: + raise exception.NotFound(_('No fixed ips defined')) + + return result + + @require_context def fixed_ip_get_by_address(context, address, session=None): if not session: @@ -609,6 +628,17 @@ def fixed_ip_get_instance(context, address): @require_context +def fixed_ip_get_all_by_instance(context, instance_id): + session = get_session() + rv = session.query(models.FixedIp).\ + filter_by(instance_id=instance_id).\ + filter_by(deleted=False) + if not rv: + raise exception.NotFound(_('No address for instance %s') % instance_id) + return rv + + +@require_context def fixed_ip_get_instance_v6(context, address): session = get_session() mac = utils.to_mac(address) @@ -693,6 +723,7 @@ def instance_get(context, instance_id, session=None): options(joinedload_all('security_groups.rules')).\ options(joinedload('volumes')).\ options(joinedload_all('fixed_ip.network')).\ + options(joinedload('metadata')).\ filter_by(id=instance_id).\ filter_by(deleted=can_read_deleted(context)).\ first() @@ -701,6 +732,7 @@ def instance_get(context, instance_id, session=None): options(joinedload_all('fixed_ip.floating_ips')).\ options(joinedload_all('security_groups.rules')).\ options(joinedload('volumes')).\ + options(joinedload('metadata')).\ filter_by(project_id=context.project_id).\ filter_by(id=instance_id).\ filter_by(deleted=False).\ @@ -719,6 +751,7 @@ def instance_get_all(context): return session.query(models.Instance).\ options(joinedload_all('fixed_ip.floating_ips')).\ options(joinedload('security_groups')).\ + options(joinedload_all('fixed_ip.network')).\ filter_by(deleted=can_read_deleted(context)).\ all() @@ -729,6 +762,7 @@ def instance_get_all_by_user(context, user_id): return session.query(models.Instance).\ options(joinedload_all('fixed_ip.floating_ips')).\ options(joinedload('security_groups')).\ + options(joinedload_all('fixed_ip.network')).\ filter_by(deleted=can_read_deleted(context)).\ filter_by(user_id=user_id).\ all() @@ -740,6 +774,7 @@ def instance_get_all_by_host(context, host): return session.query(models.Instance).\ options(joinedload_all('fixed_ip.floating_ips')).\ options(joinedload('security_groups')).\ + options(joinedload_all('fixed_ip.network')).\ filter_by(host=host).\ filter_by(deleted=can_read_deleted(context)).\ all() @@ -753,6 +788,7 @@ def instance_get_all_by_project(context, project_id): return session.query(models.Instance).\ options(joinedload_all('fixed_ip.floating_ips')).\ options(joinedload('security_groups')).\ + options(joinedload_all('fixed_ip.network')).\ filter_by(project_id=project_id).\ filter_by(deleted=can_read_deleted(context)).\ all() @@ -766,6 +802,7 @@ def instance_get_all_by_reservation(context, reservation_id): return session.query(models.Instance).\ options(joinedload_all('fixed_ip.floating_ips')).\ options(joinedload('security_groups')).\ + options(joinedload_all('fixed_ip.network')).\ filter_by(reservation_id=reservation_id).\ filter_by(deleted=can_read_deleted(context)).\ all() @@ -773,6 +810,7 @@ def instance_get_all_by_reservation(context, reservation_id): return session.query(models.Instance).\ options(joinedload_all('fixed_ip.floating_ips')).\ options(joinedload('security_groups')).\ + options(joinedload_all('fixed_ip.network')).\ filter_by(project_id=context.project_id).\ filter_by(reservation_id=reservation_id).\ filter_by(deleted=False).\ @@ -1017,8 +1055,18 @@ def network_create_safe(context, values): @require_admin_context +def network_delete_safe(context, network_id): + session = get_session() + with session.begin(): + network_ref = network_get(context, network_id=network_id, \ + session=session) + session.delete(network_ref) + + +@require_admin_context def network_disassociate(context, network_id): - network_update(context, network_id, {'project_id': None}) + network_update(context, network_id, {'project_id': None, + 'host': None}) @require_admin_context @@ -1050,6 +1098,15 @@ def network_get(context, network_id, session=None): return result +@require_admin_context +def network_get_all(context): + session = get_session() + result = session.query(models.Network) + if not result: + raise exception.NotFound(_('No networks defined')) + return result + + # NOTE(vish): pylint complains because of the long method name, but # it fits with the names of the rest of the methods # pylint: disable-msg=C0103 @@ -1080,6 +1137,18 @@ def network_get_by_bridge(context, bridge): @require_admin_context +def network_get_by_cidr(context, cidr): + session = get_session() + result = session.query(models.Network).\ + filter_by(cidr=cidr).first() + + if not result: + raise exception.NotFound(_('Network with cidr %s does not exist') % + cidr) + return result + + +@require_admin_context def network_get_by_instance(_context, instance_id): session = get_session() rv = session.query(models.Network).\ @@ -1094,6 +1163,19 @@ def network_get_by_instance(_context, instance_id): @require_admin_context +def network_get_all_by_instance(_context, instance_id): + session = get_session() + rv = session.query(models.Network).\ + filter_by(deleted=False).\ + join(models.Network.fixed_ips).\ + filter_by(instance_id=instance_id).\ + filter_by(deleted=False) + if not rv: + raise exception.NotFound(_('No network for instance %s') % instance_id) + return rv + + +@require_admin_context def network_set_host(context, network_id, host_id): session = get_session() with session.begin(): @@ -1212,16 +1294,20 @@ def iscsi_target_create_safe(context, values): @require_admin_context -def auth_destroy_token(_context, token): +def auth_token_destroy(context, token_id): session = get_session() - session.delete(token) + with session.begin(): + token_ref = auth_token_get(context, token_id, session=session) + token_ref.delete(session=session) @require_admin_context -def auth_get_token(_context, token_hash): - session = get_session() +def auth_token_get(context, token_hash, session=None): + if session is None: + session = get_session() tk = session.query(models.AuthToken).\ filter_by(token_hash=token_hash).\ + filter_by(deleted=can_read_deleted(context)).\ first() if not tk: raise exception.NotFound(_('Token %s does not exist') % token_hash) @@ -1229,7 +1315,16 @@ def auth_get_token(_context, token_hash): @require_admin_context -def auth_create_token(_context, token): +def auth_token_update(context, token_hash, values): + session = get_session() + with session.begin(): + token_ref = auth_token_get(context, token_hash, session=session) + token_ref.update(values) + token_ref.save(session=session) + + +@require_admin_context +def auth_token_create(_context, token): tk = models.AuthToken() tk.update(token) tk.save() @@ -1909,6 +2004,51 @@ def host_get_networks(context, host): all() +################### + + +@require_admin_context +def migration_create(context, values): + migration = models.Migration() + migration.update(values) + migration.save() + return migration + + +@require_admin_context +def migration_update(context, id, values): + session = get_session() + with session.begin(): + migration = migration_get(context, id, session=session) + migration.update(values) + migration.save(session=session) + return migration + + +@require_admin_context +def migration_get(context, id, session=None): + if not session: + session = get_session() + result = session.query(models.Migration).\ + filter_by(id=id).first() + if not result: + raise exception.NotFound(_("No migration found with id %s") + % migration_id) + return result + + +@require_admin_context +def migration_get_by_instance_and_status(context, instance_id, status): + session = get_session() + result = session.query(models.Migration).\ + filter_by(instance_id=instance_id).\ + filter_by(status=status).first() + if not result: + raise exception.NotFound(_("No migration found with instance id %s") + % migration_id) + return result + + ################## @@ -2008,3 +2148,139 @@ def console_get(context, console_id, instance_id=None): raise exception.NotFound(_("No console with id %(console_id)s" " %(idesc)s") % locals()) return result + + + ################## + + +@require_admin_context +def instance_type_create(_context, values): + try: + instance_type_ref = models.InstanceTypes() + instance_type_ref.update(values) + instance_type_ref.save() + except: + raise exception.DBError + return instance_type_ref + + +@require_context +def instance_type_get_all(context, inactive=0): + """ + Returns a dict describing all instance_types with name as key. + """ + session = get_session() + if inactive: + inst_types = session.query(models.InstanceTypes).\ + order_by("name").\ + all() + else: + inst_types = session.query(models.InstanceTypes).\ + filter_by(deleted=inactive).\ + order_by("name").\ + all() + if inst_types: + inst_dict = {} + for i in inst_types: + inst_dict[i['name']] = dict(i) + return inst_dict + else: + raise exception.NotFound + + +@require_context +def instance_type_get_by_name(context, name): + """Returns a dict describing specific instance_type""" + session = get_session() + inst_type = session.query(models.InstanceTypes).\ + filter_by(name=name).\ + first() + if not inst_type: + raise exception.NotFound(_("No instance type with name %s") % name) + else: + return dict(inst_type) + + +@require_context +def instance_type_get_by_flavor_id(context, id): + """Returns a dict describing specific flavor_id""" + session = get_session() + inst_type = session.query(models.InstanceTypes).\ + filter_by(flavorid=int(id)).\ + first() + if not inst_type: + raise exception.NotFound(_("No flavor with name %s") % id) + else: + return dict(inst_type) + + +@require_admin_context +def instance_type_destroy(context, name): + """ Marks specific instance_type as deleted""" + session = get_session() + instance_type_ref = session.query(models.InstanceTypes).\ + filter_by(name=name) + records = instance_type_ref.update(dict(deleted=1)) + if records == 0: + raise exception.NotFound + else: + return instance_type_ref + + +@require_admin_context +def instance_type_purge(context, name): + """ Removes specific instance_type from DB + Usually instance_type_destroy should be used + """ + session = get_session() + instance_type_ref = session.query(models.InstanceTypes).\ + filter_by(name=name) + records = instance_type_ref.delete() + if records == 0: + raise exception.NotFound + else: + return instance_type_ref + + +#################### + + +@require_admin_context +def zone_create(context, values): + zone = models.Zone() + zone.update(values) + zone.save() + return zone + + +@require_admin_context +def zone_update(context, zone_id, values): + zone = session.query(models.Zone).filter_by(id=zone_id).first() + if not zone: + raise exception.NotFound(_("No zone with id %(zone_id)s") % locals()) + zone.update(values) + zone.save() + return zone + + +@require_admin_context +def zone_delete(context, zone_id): + session = get_session() + with session.begin(): + session.execute('delete from zones ' + 'where id=:id', {'id': zone_id}) + + +@require_admin_context +def zone_get(context, zone_id): + session = get_session() + result = session.query(models.Zone).filter_by(id=zone_id).first() + if not result: + raise exception.NotFound(_("No zone with id %(zone_id)s") % locals()) + return result + + +@require_admin_context +def zone_get_all(context): + session = get_session() + return session.query(models.Zone).all() diff --git a/nova/db/sqlalchemy/migrate_repo/versions/001_austin.py b/nova/db/sqlalchemy/migrate_repo/versions/001_austin.py index 366944591..9e7ab3554 100644 --- a/nova/db/sqlalchemy/migrate_repo/versions/001_austin.py +++ b/nova/db/sqlalchemy/migrate_repo/versions/001_austin.py @@ -508,17 +508,19 @@ def upgrade(migrate_engine): # bind migrate_engine to your metadata meta.bind = migrate_engine - for table in (auth_tokens, export_devices, fixed_ips, floating_ips, - instances, key_pairs, networks, - projects, quotas, security_groups, security_group_inst_assoc, - security_group_rules, services, users, - user_project_association, user_project_role_association, - user_role_association, volumes): + tables = [auth_tokens, + instances, key_pairs, networks, fixed_ips, floating_ips, + quotas, security_groups, security_group_inst_assoc, + security_group_rules, services, users, projects, + user_project_association, user_project_role_association, + user_role_association, volumes, export_devices] + for table in tables: try: table.create() except Exception: logging.info(repr(table)) logging.exception('Exception while creating table') + meta.drop_all(tables=tables) raise diff --git a/nova/db/sqlalchemy/migrate_repo/versions/002_bexar.py b/nova/db/sqlalchemy/migrate_repo/versions/002_bexar.py index 699b837f8..413536a59 100644 --- a/nova/db/sqlalchemy/migrate_repo/versions/002_bexar.py +++ b/nova/db/sqlalchemy/migrate_repo/versions/002_bexar.py @@ -209,13 +209,16 @@ def upgrade(migrate_engine): # Upgrade operations go here. Don't create your own engine; # bind migrate_engine to your metadata meta.bind = migrate_engine - for table in (certificates, consoles, console_pools, instance_actions, - iscsi_targets): + + tables = [certificates, console_pools, consoles, instance_actions, + iscsi_targets] + for table in tables: try: table.create() except Exception: logging.info(repr(table)) logging.exception('Exception while creating table') + meta.drop_all(tables=tables) raise auth_tokens.c.user_id.alter(type=String(length=255, diff --git a/nova/db/sqlalchemy/migrate_repo/versions/003_add_label_to_networks.py b/nova/db/sqlalchemy/migrate_repo/versions/003_add_label_to_networks.py new file mode 100644 index 000000000..5ba7910f1 --- /dev/null +++ b/nova/db/sqlalchemy/migrate_repo/versions/003_add_label_to_networks.py @@ -0,0 +1,51 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack LLC +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from sqlalchemy import * +from migrate import * + +from nova import log as logging + + +meta = MetaData() + + +networks = Table('networks', meta, + Column('id', Integer(), primary_key=True, nullable=False), + ) + + +# +# New Tables +# + + +# +# Tables to alter +# + +networks_label = Column( + 'label', + String(length=255, convert_unicode=False, assert_unicode=None, + unicode_error=None, _warn_on_bytestring=False)) + + +def upgrade(migrate_engine): + # Upgrade operations go here. Don't create your own engine; + # bind migrate_engine to your metadata + meta.bind = migrate_engine + networks.create_column(networks_label) diff --git a/nova/db/sqlalchemy/migrate_repo/versions/004_add_zone_tables.py b/nova/db/sqlalchemy/migrate_repo/versions/004_add_zone_tables.py new file mode 100644 index 000000000..ade981687 --- /dev/null +++ b/nova/db/sqlalchemy/migrate_repo/versions/004_add_zone_tables.py @@ -0,0 +1,61 @@ +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from sqlalchemy import * +from migrate import * + +from nova import log as logging + + +meta = MetaData() + + +# +# New Tables +# +zones = Table('zones', meta, + Column('created_at', DateTime(timezone=False)), + Column('updated_at', DateTime(timezone=False)), + Column('deleted_at', DateTime(timezone=False)), + Column('deleted', Boolean(create_constraint=True, name=None)), + Column('id', Integer(), primary_key=True, nullable=False), + Column('api_url', + String(length=255, convert_unicode=False, assert_unicode=None, + unicode_error=None, _warn_on_bytestring=False)), + Column('username', + String(length=255, convert_unicode=False, assert_unicode=None, + unicode_error=None, _warn_on_bytestring=False)), + Column('password', + String(length=255, convert_unicode=False, assert_unicode=None, + unicode_error=None, _warn_on_bytestring=False)), + ) + + +# +# Tables to alter +# + +# (none currently) + + +def upgrade(migrate_engine): + # Upgrade operations go here. Don't create your own engine; + # bind migrate_engine to your metadata + meta.bind = migrate_engine + for table in (zones, ): + try: + table.create() + except Exception: + logging.info(repr(table)) diff --git a/nova/db/sqlalchemy/migrate_repo/versions/005_add_instance_metadata.py b/nova/db/sqlalchemy/migrate_repo/versions/005_add_instance_metadata.py new file mode 100644 index 000000000..4cb07e0d8 --- /dev/null +++ b/nova/db/sqlalchemy/migrate_repo/versions/005_add_instance_metadata.py @@ -0,0 +1,78 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 Justin Santa Barbara +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from sqlalchemy import * +from migrate import * + +from nova import log as logging + + +meta = MetaData() + + +# Just for the ForeignKey and column creation to succeed, these are not the +# actual definitions of instances or services. +instances = Table('instances', meta, + Column('id', Integer(), primary_key=True, nullable=False), + ) + +quotas = Table('quotas', meta, + Column('id', Integer(), primary_key=True, nullable=False), + ) + + +# +# New Tables +# + +instance_metadata_table = Table('instance_metadata', meta, + Column('created_at', DateTime(timezone=False)), + Column('updated_at', DateTime(timezone=False)), + Column('deleted_at', DateTime(timezone=False)), + Column('deleted', Boolean(create_constraint=True, name=None)), + Column('id', Integer(), primary_key=True, nullable=False), + Column('instance_id', + Integer(), + ForeignKey('instances.id'), + nullable=False), + Column('key', + String(length=255, convert_unicode=False, assert_unicode=None, + unicode_error=None, _warn_on_bytestring=False)), + Column('value', + String(length=255, convert_unicode=False, assert_unicode=None, + unicode_error=None, _warn_on_bytestring=False))) + + +# +# New columns +# +quota_metadata_items = Column('metadata_items', Integer()) + + +def upgrade(migrate_engine): + # Upgrade operations go here. Don't create your own engine; + # bind migrate_engine to your metadata + meta.bind = migrate_engine + for table in (instance_metadata_table, ): + try: + table.create() + except Exception: + logging.info(repr(table)) + logging.exception('Exception while creating table') + raise + + quotas.create_column(quota_metadata_items) diff --git a/nova/db/sqlalchemy/migrate_repo/versions/006_add_provider_data_to_volumes.py b/nova/db/sqlalchemy/migrate_repo/versions/006_add_provider_data_to_volumes.py new file mode 100644 index 000000000..705fc8ff3 --- /dev/null +++ b/nova/db/sqlalchemy/migrate_repo/versions/006_add_provider_data_to_volumes.py @@ -0,0 +1,72 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 Justin Santa Barbara. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from sqlalchemy import * +from migrate import * + +from nova import log as logging + + +meta = MetaData() + + +# Table stub-definitions +# Just for the ForeignKey and column creation to succeed, these are not the +# actual definitions of instances or services. +# +volumes = Table('volumes', meta, + Column('id', Integer(), primary_key=True, nullable=False), + ) + + +# +# New Tables +# +# None + +# +# Tables to alter +# +# None + +# +# Columns to add to existing tables +# + +volumes_provider_location = Column('provider_location', + String(length=256, + convert_unicode=False, + assert_unicode=None, + unicode_error=None, + _warn_on_bytestring=False)) + +volumes_provider_auth = Column('provider_auth', + String(length=256, + convert_unicode=False, + assert_unicode=None, + unicode_error=None, + _warn_on_bytestring=False)) + + +def upgrade(migrate_engine): + # Upgrade operations go here. Don't create your own engine; + # bind migrate_engine to your metadata + meta.bind = migrate_engine + + # Add columns to existing tables + volumes.create_column(volumes_provider_location) + volumes.create_column(volumes_provider_auth) diff --git a/nova/db/sqlalchemy/migrate_repo/versions/007_add_ipv6_to_fixed_ips.py b/nova/db/sqlalchemy/migrate_repo/versions/007_add_ipv6_to_fixed_ips.py new file mode 100644 index 000000000..427934d53 --- /dev/null +++ b/nova/db/sqlalchemy/migrate_repo/versions/007_add_ipv6_to_fixed_ips.py @@ -0,0 +1,90 @@ +# Copyright 2011 OpenStack LLC +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from sqlalchemy import * +from migrate import * + +from nova import log as logging + + +meta = MetaData() + + +# Table stub-definitions +# Just for the ForeignKey and column creation to succeed, these are not the +# actual definitions of instances or services. +# +fixed_ips = Table( + "fixed_ips", + meta, + Column( + "id", + Integer(), + primary_key=True, + nullable=False)) + +# +# New Tables +# +# None + +# +# Tables to alter +# +# None + +# +# Columns to add to existing tables +# + +fixed_ips_addressV6 = Column( + "addressV6", + String( + length=255, + convert_unicode=False, + assert_unicode=None, + unicode_error=None, + _warn_on_bytestring=False)) + + +fixed_ips_netmaskV6 = Column( + "netmaskV6", + String( + length=3, + convert_unicode=False, + assert_unicode=None, + unicode_error=None, + _warn_on_bytestring=False)) + + +fixed_ips_gatewayV6 = Column( + "gatewayV6", + String( + length=255, + convert_unicode=False, + assert_unicode=None, + unicode_error=None, + _warn_on_bytestring=False)) + + +def upgrade(migrate_engine): + # Upgrade operations go here. Don't create your own engine; + # bind migrate_engine to your metadata + meta.bind = migrate_engine + + # Add columns to existing tables + fixed_ips.create_column(fixed_ips_addressV6) + fixed_ips.create_column(fixed_ips_netmaskV6) + fixed_ips.create_column(fixed_ips_gatewayV6) diff --git a/nova/db/sqlalchemy/migrate_repo/versions/008_add_instance_types.py b/nova/db/sqlalchemy/migrate_repo/versions/008_add_instance_types.py new file mode 100644 index 000000000..66609054e --- /dev/null +++ b/nova/db/sqlalchemy/migrate_repo/versions/008_add_instance_types.py @@ -0,0 +1,87 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 Ken Pepple +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from sqlalchemy import * +from migrate import * + +from nova import api +from nova import db +from nova import log as logging + +import datetime + +meta = MetaData() + + +# +# New Tables +# +instance_types = Table('instance_types', meta, + Column('created_at', DateTime(timezone=False)), + Column('updated_at', DateTime(timezone=False)), + Column('deleted_at', DateTime(timezone=False)), + Column('deleted', Boolean(create_constraint=True, name=None)), + Column('name', + String(length=255, convert_unicode=False, assert_unicode=None, + unicode_error=None, _warn_on_bytestring=False), + unique=True), + Column('id', Integer(), primary_key=True, nullable=False), + Column('memory_mb', Integer(), nullable=False), + Column('vcpus', Integer(), nullable=False), + Column('local_gb', Integer(), nullable=False), + Column('flavorid', Integer(), nullable=False, unique=True), + Column('swap', Integer(), nullable=False, default=0), + Column('rxtx_quota', Integer(), nullable=False, default=0), + Column('rxtx_cap', Integer(), nullable=False, default=0)) + + +def upgrade(migrate_engine): + # Upgrade operations go here + # Don't create your own engine; bind migrate_engine + # to your metadata + meta.bind = migrate_engine + try: + instance_types.create() + except Exception: + logging.info(repr(table)) + logging.exception('Exception while creating instance_types table') + raise + + # Here are the old static instance types + INSTANCE_TYPES = { + 'm1.tiny': dict(memory_mb=512, vcpus=1, local_gb=0, flavorid=1), + 'm1.small': dict(memory_mb=2048, vcpus=1, local_gb=20, flavorid=2), + 'm1.medium': dict(memory_mb=4096, vcpus=2, local_gb=40, flavorid=3), + 'm1.large': dict(memory_mb=8192, vcpus=4, local_gb=80, flavorid=4), + 'm1.xlarge': dict(memory_mb=16384, vcpus=8, local_gb=160, flavorid=5)} + try: + i = instance_types.insert() + for name, values in INSTANCE_TYPES.iteritems(): + # FIXME(kpepple) should we be seeding created_at / updated_at ? + # now = datetime.datatime.utcnow() + i.execute({'name': name, 'memory_mb': values["memory_mb"], + 'vcpus': values["vcpus"], 'deleted': 0, + 'local_gb': values["local_gb"], + 'flavorid': values["flavorid"]}) + except Exception: + logging.info(repr(table)) + logging.exception('Exception while seeding instance_types table') + raise + + +def downgrade(migrate_engine): + # Operations to reverse the above upgrade go here. + for table in (instance_types): + table.drop() diff --git a/nova/db/sqlalchemy/migrate_repo/versions/009_add_instance_migrations.py b/nova/db/sqlalchemy/migrate_repo/versions/009_add_instance_migrations.py new file mode 100644 index 000000000..4fda525f1 --- /dev/null +++ b/nova/db/sqlalchemy/migrate_repo/versions/009_add_instance_migrations.py @@ -0,0 +1,61 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License.from sqlalchemy import * + +from sqlalchemy import * +from migrate import * + +from nova import log as logging + + +meta = MetaData() + +# Just for the ForeignKey and column creation to succeed, these are not the +# actual definitions of instances or services. +instances = Table('instances', meta, + Column('id', Integer(), primary_key=True, nullable=False), + ) + +# +# New Tables +# + +migrations = Table('migrations', meta, + Column('created_at', DateTime(timezone=False)), + Column('updated_at', DateTime(timezone=False)), + Column('deleted_at', DateTime(timezone=False)), + Column('deleted', Boolean(create_constraint=True, name=None)), + Column('id', Integer(), primary_key=True, nullable=False), + Column('source_compute', String(255)), + Column('dest_compute', String(255)), + Column('dest_host', String(255)), + Column('instance_id', Integer, ForeignKey('instances.id'), + nullable=True), + Column('status', String(255)), + ) + + +def upgrade(migrate_engine): + # Upgrade operations go here. Don't create your own engine; + # bind migrate_engine to your metadata + meta.bind = migrate_engine + for table in (migrations, ): + try: + table.create() + except Exception: + logging.info(repr(table)) + logging.exception('Exception while creating table') + raise diff --git a/nova/db/sqlalchemy/migration.py b/nova/db/sqlalchemy/migration.py index 2a13c5466..d9e303599 100644 --- a/nova/db/sqlalchemy/migration.py +++ b/nova/db/sqlalchemy/migration.py @@ -17,12 +17,22 @@ # under the License. import os +import sys from nova import flags import sqlalchemy from migrate.versioning import api as versioning_api -from migrate.versioning import exceptions as versioning_exceptions + +try: + from migrate.versioning import exceptions as versioning_exceptions +except ImportError: + try: + # python-migration changed location of exceptions after 1.6.3 + # See LP Bug #717467 + from migrate import exceptions as versioning_exceptions + except ImportError: + sys.exit(_("python-migrate is not installed. Exiting.")) FLAGS = flags.FLAGS @@ -45,12 +55,12 @@ def db_version(): engine = sqlalchemy.create_engine(FLAGS.sql_connection, echo=False) meta.reflect(bind=engine) try: - for table in ('auth_tokens', 'export_devices', 'fixed_ips', - 'floating_ips', 'instances', + for table in ('auth_tokens', 'zones', 'export_devices', + 'fixed_ips', 'floating_ips', 'instances', 'key_pairs', 'networks', 'projects', 'quotas', 'security_group_instance_association', 'security_group_rules', 'security_groups', - 'services', + 'services', 'migrations', 'users', 'user_project_association', 'user_project_role_association', 'user_role_association', diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 7efb36c0e..6ef284e65 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -126,11 +126,16 @@ class Certificate(BASE, NovaBase): class Instance(BASE, NovaBase): """Represents a guest vm.""" __tablename__ = 'instances' + onset_files = [] + id = Column(Integer, primary_key=True, autoincrement=True) @property def name(self): - return FLAGS.instance_name_template % self.id + base_name = FLAGS.instance_name_template % self.id + if getattr(self, '_rescue', False): + base_name += "-rescue" + return base_name admin_pass = Column(String(255)) user_id = Column(String(255)) @@ -210,6 +215,20 @@ class InstanceActions(BASE, NovaBase): error = Column(Text) +class InstanceTypes(BASE, NovaBase): + """Represent possible instance_types or flavor of VM offered""" + __tablename__ = "instance_types" + id = Column(Integer, primary_key=True) + name = Column(String(255), unique=True) + memory_mb = Column(Integer) + vcpus = Column(Integer) + local_gb = Column(Integer) + flavorid = Column(Integer, unique=True) + swap = Column(Integer, nullable=False, default=0) + rxtx_quota = Column(Integer, nullable=False, default=0) + rxtx_cap = Column(Integer, nullable=False, default=0) + + class Volume(BASE, NovaBase): """Represents a block storage device that can be attached to a vm.""" __tablename__ = 'volumes' @@ -243,6 +262,9 @@ class Volume(BASE, NovaBase): display_name = Column(String(255)) display_description = Column(String(255)) + provider_location = Column(String(255)) + provider_auth = Column(String(255)) + class Quota(BASE, NovaBase): """Represents quota overrides for a project.""" @@ -256,6 +278,7 @@ class Quota(BASE, NovaBase): volumes = Column(Integer) gigabytes = Column(Integer) floating_ips = Column(Integer) + metadata_items = Column(Integer) class ExportDevice(BASE, NovaBase): @@ -366,6 +389,18 @@ class KeyPair(BASE, NovaBase): public_key = Column(Text) +class Migration(BASE, NovaBase): + """Represents a running host-to-host migration.""" + __tablename__ = 'migrations' + id = Column(Integer, primary_key=True, nullable=False) + source_compute = Column(String(255)) + dest_compute = Column(String(255)) + dest_host = Column(String(255)) + instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True) + #TODO(_cerberus_): enum + status = Column(String(255)) + + class Network(BASE, NovaBase): """Represents a network.""" __tablename__ = 'networks' @@ -373,6 +408,7 @@ class Network(BASE, NovaBase): "vpn_public_port"), {'mysql_engine': 'InnoDB'}) id = Column(Integer, primary_key=True) + label = Column(String(255)) injected = Column(Boolean, default=False) cidr = Column(String(255), unique=True) @@ -432,6 +468,9 @@ class FixedIp(BASE, NovaBase): allocated = Column(Boolean, default=False) leased = Column(Boolean, default=False) reserved = Column(Boolean, default=False) + addressV6 = Column(String(255)) + netmaskV6 = Column(String(3)) + gatewayV6 = Column(String(255)) class User(BASE, NovaBase): @@ -535,6 +574,29 @@ class Console(BASE, NovaBase): pool = relationship(ConsolePool, backref=backref('consoles')) +class InstanceMetadata(BASE, NovaBase): + """Represents a metadata key/value pair for an instance""" + __tablename__ = 'instance_metadata' + id = Column(Integer, primary_key=True) + key = Column(String(255)) + value = Column(String(255)) + instance_id = Column(Integer, ForeignKey('instances.id'), nullable=False) + instance = relationship(Instance, backref="metadata", + foreign_keys=instance_id, + primaryjoin='and_(' + 'InstanceMetadata.instance_id == Instance.id,' + 'InstanceMetadata.deleted == False)') + + +class Zone(BASE, NovaBase): + """Represents a child zone of this zone.""" + __tablename__ = 'zones' + id = Column(Integer, primary_key=True) + api_url = Column(String(255)) + username = Column(String(255)) + password = Column(String(255)) + + def register_models(): """Register Models and create metadata. @@ -543,11 +605,12 @@ def register_models(): connection is lost and needs to be reestablished. """ from sqlalchemy import create_engine - models = (Service, Instance, InstanceActions, + models = (Service, Instance, InstanceActions, InstanceTypes, Volume, ExportDevice, IscsiTarget, FixedIp, FloatingIp, Network, SecurityGroup, SecurityGroupIngressRule, SecurityGroupInstanceAssociation, AuthToken, User, - Project, Certificate, ConsolePool, Console) # , Image, Host + Project, Certificate, ConsolePool, Console, Zone, + InstanceMetadata, Migration) engine = create_engine(FLAGS.sql_connection, echo=False) for model in models: model.metadata.create_all(engine) diff --git a/nova/db/sqlalchemy/session.py b/nova/db/sqlalchemy/session.py index dc885f138..4a9a28f43 100644 --- a/nova/db/sqlalchemy/session.py +++ b/nova/db/sqlalchemy/session.py @@ -20,6 +20,7 @@ Session Handling for SQLAlchemy backend """ from sqlalchemy import create_engine +from sqlalchemy import pool from sqlalchemy.orm import sessionmaker from nova import exception @@ -37,9 +38,14 @@ def get_session(autocommit=True, expire_on_commit=False): global _MAKER if not _MAKER: if not _ENGINE: + kwargs = {'pool_recycle': FLAGS.sql_idle_timeout, + 'echo': False} + + if FLAGS.sql_connection.startswith('sqlite'): + kwargs['poolclass'] = pool.NullPool + _ENGINE = create_engine(FLAGS.sql_connection, - pool_recycle=FLAGS.sql_idle_timeout, - echo=False) + **kwargs) _MAKER = (sessionmaker(bind=_ENGINE, autocommit=autocommit, expire_on_commit=expire_on_commit)) diff --git a/nova/exception.py b/nova/exception.py index 7d65bd6a5..93c5fe3d7 100644 --- a/nova/exception.py +++ b/nova/exception.py @@ -88,6 +88,10 @@ class InvalidInputException(Error): pass +class InvalidContentType(Error): + pass + + class TimeoutException(Error): pass diff --git a/nova/fakerabbit.py b/nova/fakerabbit.py index dd82a9366..a7dee8caf 100644 --- a/nova/fakerabbit.py +++ b/nova/fakerabbit.py @@ -48,7 +48,6 @@ class Exchange(object): nm = self.name LOG.debug(_('(%(nm)s) publish (key: %(routing_key)s)' ' %(message)s') % locals()) - routing_key = routing_key.split('.')[0] if routing_key in self._routes: for f in self._routes[routing_key]: LOG.debug(_('Publishing to route %s'), f) diff --git a/nova/flags.py b/nova/flags.py index 43bc174d2..9123e9ac7 100644 --- a/nova/flags.py +++ b/nova/flags.py @@ -160,9 +160,45 @@ class StrWrapper(object): raise KeyError(name) -FLAGS = FlagValues() -gflags.FLAGS = FLAGS -gflags.DEFINE_flag(gflags.HelpFlag(), FLAGS) +# Copied from gflags with small mods to get the naming correct. +# Originally gflags checks for the first module that is not gflags that is +# in the call chain, we want to check for the first module that is not gflags +# and not this module. +def _GetCallingModule(): + """Returns the name of the module that's calling into this module. + + We generally use this function to get the name of the module calling a + DEFINE_foo... function. + """ + # Walk down the stack to find the first globals dict that's not ours. + for depth in range(1, sys.getrecursionlimit()): + if not sys._getframe(depth).f_globals is globals(): + module_name = __GetModuleName(sys._getframe(depth).f_globals) + if module_name == 'gflags': + continue + if module_name is not None: + return module_name + raise AssertionError("No module was found") + + +# Copied from gflags because it is a private function +def __GetModuleName(globals_dict): + """Given a globals dict, returns the name of the module that defines it. + + Args: + globals_dict: A dictionary that should correspond to an environment + providing the values of the globals. + + Returns: + A string (the name of the module) or None (if the module could not + be identified. + """ + for name, module in sys.modules.iteritems(): + if getattr(module, '__dict__', None) is globals_dict: + if name == '__main__': + return sys.argv[0] + return name + return None def _wrapper(func): @@ -173,6 +209,11 @@ def _wrapper(func): return _wrapped +FLAGS = FlagValues() +gflags.FLAGS = FLAGS +gflags._GetCallingModule = _GetCallingModule + + DEFINE = _wrapper(gflags.DEFINE) DEFINE_string = _wrapper(gflags.DEFINE_string) DEFINE_integer = _wrapper(gflags.DEFINE_integer) @@ -185,8 +226,6 @@ DEFINE_spaceseplist = _wrapper(gflags.DEFINE_spaceseplist) DEFINE_multistring = _wrapper(gflags.DEFINE_multistring) DEFINE_multi_int = _wrapper(gflags.DEFINE_multi_int) DEFINE_flag = _wrapper(gflags.DEFINE_flag) - - HelpFlag = gflags.HelpFlag HelpshortFlag = gflags.HelpshortFlag HelpXMLFlag = gflags.HelpXMLFlag @@ -208,7 +247,7 @@ def _get_my_ip(): (addr, port) = csock.getsockname() csock.close() return addr - except socket.gaierror as ex: + except socket.error as ex: return "127.0.0.1" @@ -282,12 +321,17 @@ DEFINE_integer('auth_token_ttl', 3600, 'Seconds for auth tokens to linger') DEFINE_string('state_path', os.path.join(os.path.dirname(__file__), '../'), "Top-level directory for maintaining nova's state") +DEFINE_string('lock_path', os.path.join(os.path.dirname(__file__), '../'), + "Directory for lock files") +DEFINE_string('logdir', None, 'output to a per-service log file in named ' + 'directory') +DEFINE_string('sqlite_db', 'nova.sqlite', 'file name for sqlite') DEFINE_string('sql_connection', - 'sqlite:///$state_path/nova.sqlite', + 'sqlite:///$state_path/$sqlite_db', 'connection string for sql database') -DEFINE_string('sql_idle_timeout', - '3600', +DEFINE_integer('sql_idle_timeout', + 3600, 'timeout for idle sql database connections') DEFINE_integer('sql_max_retries', 12, 'sql connection attempts') DEFINE_integer('sql_retry_interval', 10, 'sql connection retry interval') @@ -304,7 +348,7 @@ DEFINE_string('scheduler_manager', 'nova.scheduler.manager.SchedulerManager', 'Manager for scheduler') # The service to use for image search and retrieval -DEFINE_string('image_service', 'nova.image.s3.S3ImageService', +DEFINE_string('image_service', 'nova.image.local.LocalImageService', 'The service to use for retrieving and searching for images.') DEFINE_string('host', socket.gethostname(), @@ -312,3 +356,7 @@ DEFINE_string('host', socket.gethostname(), DEFINE_string('node_availability_zone', 'nova', 'availability zone of this node') + +DEFINE_string('zone_name', 'nova', 'name of this zone') +DEFINE_string('zone_capabilities', 'kypervisor:xenserver;os:linux', + 'Key/Value tags which represent capabilities of this zone') diff --git a/nova/image/glance.py b/nova/image/glance.py index 593c4bce6..15fca69b8 100644 --- a/nova/image/glance.py +++ b/nova/image/glance.py @@ -17,9 +17,8 @@ """Implementation of an image service that uses Glance as the backend""" from __future__ import absolute_import -import httplib -import json -import urlparse + +from glance.common import exception as glance_exception from nova import exception from nova import flags @@ -53,31 +52,64 @@ class GlanceImageService(service.BaseImageService): """ return self.client.get_images_detailed() - def show(self, context, id): + def show(self, context, image_id): """ Returns a dict containing image data for the given opaque image id. """ - image = self.client.get_image_meta(id) - if image: - return image - raise exception.NotFound + try: + image = self.client.get_image_meta(image_id) + except glance_exception.NotFound: + raise exception.NotFound + return image - def create(self, context, data): + def show_by_name(self, context, name): + """ + Returns a dict containing image data for the given name. + """ + # TODO(vish): replace this with more efficient call when glance + # supports it. + images = self.detail(context) + image = None + for cantidate in images: + if name == cantidate.get('name'): + image = cantidate + break + if image is None: + raise exception.NotFound + return image + + def get(self, context, image_id, data): + """ + Calls out to Glance for metadata and data and writes data. + """ + try: + metadata, image_chunks = self.client.get_image(image_id) + except glance_exception.NotFound: + raise exception.NotFound + for chunk in image_chunks: + data.write(chunk) + return metadata + + def create(self, context, metadata, data=None): """ Store the image data and return the new image id. :raises AlreadyExists if the image already exist. """ - return self.client.add_image(image_meta=data) + return self.client.add_image(metadata, data) - def update(self, context, image_id, data): + def update(self, context, image_id, metadata, data=None): """Replace the contents of the given image with the new data. :raises NotFound if the image does not exist. """ - return self.client.update_image(image_id, data) + try: + result = self.client.update_image(image_id, metadata, data) + except glance_exception.NotFound: + raise exception.NotFound + return result def delete(self, context, image_id): """ @@ -86,7 +118,11 @@ class GlanceImageService(service.BaseImageService): :raises NotFound if the image does not exist. """ - return self.client.delete_image(image_id) + try: + result = self.client.delete_image(image_id) + except glance_exception.NotFound: + raise exception.NotFound + return result def delete_all(self): """ diff --git a/nova/image/local.py b/nova/image/local.py index f78b9aa89..c4ac3baaa 100644 --- a/nova/image/local.py +++ b/nova/image/local.py @@ -15,57 +15,110 @@ # License for the specific language governing permissions and limitations # under the License. -import cPickle as pickle +import json import os.path import random -import tempfile +import shutil +from nova import flags from nova import exception from nova.image import service -class LocalImageService(service.BaseImageService): +FLAGS = flags.FLAGS +flags.DEFINE_string('images_path', '$state_path/images', + 'path to decrypted images') + +class LocalImageService(service.BaseImageService): """Image service storing images to local disk. + It assumes that image_ids are integers. """ def __init__(self): - self._path = tempfile.mkdtemp() + self._path = FLAGS.images_path - def _path_to(self, image_id): - return os.path.join(self._path, str(image_id)) + def _path_to(self, image_id, fname='info.json'): + if fname: + return os.path.join(self._path, '%08x' % int(image_id), fname) + return os.path.join(self._path, '%08x' % int(image_id)) def _ids(self): """The list of all image ids.""" - return [int(i) for i in os.listdir(self._path)] + return [int(i, 16) for i in os.listdir(self._path)] def index(self, context): - return [dict(id=i['id'], name=i['name']) for i in self.detail(context)] + return [dict(image_id=i['id'], name=i.get('name')) + for i in self.detail(context)] def detail(self, context): - return [self.show(context, id) for id in self._ids()] + images = [] + for image_id in self._ids(): + try: + image = self.show(context, image_id) + images.append(image) + except exception.NotFound: + continue + return images + + def show(self, context, image_id): + try: + with open(self._path_to(image_id)) as metadata_file: + return json.load(metadata_file) + except (IOError, ValueError): + raise exception.NotFound - def show(self, context, id): + def show_by_name(self, context, name): + """Returns a dict containing image data for the given name.""" + # NOTE(vish): Not very efficient, but the local image service + # is for testing so it should be fine. + images = self.detail(context) + image = None + for cantidate in images: + if name == cantidate.get('name'): + image = cantidate + break + if image == None: + raise exception.NotFound + return image + + def get(self, context, image_id, data): + """Get image and metadata.""" try: - return pickle.load(open(self._path_to(id))) - except IOError: + with open(self._path_to(image_id)) as metadata_file: + metadata = json.load(metadata_file) + with open(self._path_to(image_id, 'image')) as image_file: + shutil.copyfileobj(image_file, data) + except (IOError, ValueError): raise exception.NotFound + return metadata - def create(self, context, data): - """Store the image data and return the new image id.""" - id = random.randint(0, 2 ** 31 - 1) - data['id'] = id - self.update(context, id, data) - return id + def create(self, context, metadata, data=None): + """Store the image data and return the new image.""" + image_id = random.randint(0, 2 ** 31 - 1) + image_path = self._path_to(image_id, None) + if not os.path.exists(image_path): + os.mkdir(image_path) + return self.update(context, image_id, metadata, data) - def update(self, context, image_id, data): + def update(self, context, image_id, metadata, data=None): """Replace the contents of the given image with the new data.""" + metadata['id'] = image_id try: - pickle.dump(data, open(self._path_to(image_id), 'w')) - except IOError: + if data: + location = self._path_to(image_id, 'image') + with open(location, 'w') as image_file: + shutil.copyfileobj(data, image_file) + # NOTE(vish): update metadata similarly to glance + metadata['status'] = 'active' + metadata['location'] = location + with open(self._path_to(image_id), 'w') as metadata_file: + json.dump(metadata, metadata_file) + except (IOError, ValueError): raise exception.NotFound + return metadata def delete(self, context, image_id): """Delete the given image. @@ -73,18 +126,11 @@ class LocalImageService(service.BaseImageService): """ try: - os.unlink(self._path_to(image_id)) - except IOError: + shutil.rmtree(self._path_to(image_id, None)) + except (IOError, ValueError): raise exception.NotFound def delete_all(self): """Clears out all images in local directory.""" - for id in self._ids(): - os.unlink(self._path_to(id)) - - def delete_imagedir(self): - """Deletes the local directory. - Raises OSError if directory is not empty. - - """ - os.rmdir(self._path) + for image_id in self._ids(): + shutil.rmtree(self._path_to(image_id, None)) diff --git a/nova/image/s3.py b/nova/image/s3.py index 08a40f191..85a2c651c 100644 --- a/nova/image/s3.py +++ b/nova/image/s3.py @@ -21,8 +21,13 @@ Proxy AMI-related calls from the cloud controller, to the running objectstore service. """ -import json -import urllib +import binascii +import eventlet +import os +import shutil +import tarfile +import tempfile +from xml.etree import ElementTree import boto.s3.connection @@ -31,74 +36,78 @@ from nova import flags from nova import utils from nova.auth import manager from nova.image import service +from nova.api.ec2 import ec2utils FLAGS = flags.FLAGS +flags.DEFINE_string('image_decryption_dir', '/tmp', + 'parent dir for tempdir used for image decryption') class S3ImageService(service.BaseImageService): + def __init__(self, service=None, *args, **kwargs): + if service == None: + service = utils.import_object(FLAGS.image_service) + self.service = service + self.service.__init__(*args, **kwargs) - def modify(self, context, image_id, operation): - self._conn(context).make_request( - method='POST', - bucket='_images', - query_args=self._qs({'image_id': image_id, - 'operation': operation})) - return True - - def update(self, context, image_id, attributes): - """update an image's attributes / info.json""" - attributes.update({"image_id": image_id}) - self._conn(context).make_request( - method='POST', - bucket='_images', - query_args=self._qs(attributes)) - return True - - def register(self, context, image_location): - """ rpc call to register a new image based from a manifest """ - image_id = utils.generate_uid('ami') - self._conn(context).make_request( - method='PUT', - bucket='_images', - query_args=self._qs({'image_location': image_location, - 'image_id': image_id})) - return image_id - - def _fix_image_id(self, images): - """S3 has imageId but OpenStack wants id""" - for image in images: - if 'imageId' in image: - image['id'] = image['imageId'] - return images + def create(self, context, metadata, data=None): + """metadata['properties'] should contain image_location""" + image = self._s3_create(context, metadata) + return image + + def delete(self, context, image_id): + # FIXME(vish): call to show is to check filter + self.show(context, image_id) + self.service.delete(context, image_id) + + def update(self, context, image_id, metadata, data=None): + # FIXME(vish): call to show is to check filter + self.show(context, image_id) + image = self.service.update(context, image_id, metadata, data) + return image def index(self, context): - """Return a list of all images that a user can see.""" - response = self._conn(context).make_request( - method='GET', - bucket='_images') - return self._fix_image_id(json.loads(response.read())) + images = self.service.index(context) + # FIXME(vish): index doesn't filter so we do it manually + return self._filter(context, images) + + def detail(self, context): + images = self.service.detail(context) + # FIXME(vish): detail doesn't filter so we do it manually + return self._filter(context, images) + + @classmethod + def _is_visible(cls, context, image): + return (context.is_admin + or context.project_id == image['properties']['owner_id'] + or image['properties']['is_public'] == 'True') + + @classmethod + def _filter(cls, context, images): + filtered = [] + for image in images: + if not cls._is_visible(context, image): + continue + filtered.append(image) + return filtered def show(self, context, image_id): - """return a image object if the context has permissions""" - if FLAGS.connection_type == 'fake': - return {'imageId': 'bar'} - result = self.index(context) - result = [i for i in result if i['imageId'] == image_id] - if not result: - raise exception.NotFound(_('Image %s could not be found') - % image_id) - image = result[0] + image = self.service.show(context, image_id) + if not self._is_visible(context, image): + raise exception.NotFound return image - def deregister(self, context, image_id): - """ unregister an image """ - self._conn(context).make_request( - method='DELETE', - bucket='_images', - query_args=self._qs({'image_id': image_id})) + def show_by_name(self, context, name): + image = self.service.show_by_name(context, name) + if not self._is_visible(context, image): + raise exception.NotFound + return image - def _conn(self, context): + @staticmethod + def _conn(context): + # TODO(vish): is there a better way to get creds to sign + # for the user? access = manager.AuthManager().get_access_key(context.user, context.project) secret = str(context.user.secret) @@ -110,8 +119,159 @@ class S3ImageService(service.BaseImageService): port=FLAGS.s3_port, host=FLAGS.s3_host) - def _qs(self, params): - pairs = [] - for key in params.keys(): - pairs.append(key + '=' + urllib.quote(params[key])) - return '&'.join(pairs) + @staticmethod + def _download_file(bucket, filename, local_dir): + key = bucket.get_key(filename) + local_filename = os.path.join(local_dir, filename) + key.get_contents_to_filename(local_filename) + return local_filename + + def _s3_create(self, context, metadata): + """Gets a manifext from s3 and makes an image""" + + image_path = tempfile.mkdtemp(dir=FLAGS.image_decryption_dir) + + image_location = metadata['properties']['image_location'] + bucket_name = image_location.split("/")[0] + manifest_path = image_location[len(bucket_name) + 1:] + bucket = self._conn(context).get_bucket(bucket_name) + key = bucket.get_key(manifest_path) + manifest = key.get_contents_as_string() + + manifest = ElementTree.fromstring(manifest) + image_format = 'ami' + image_type = 'machine' + + try: + kernel_id = manifest.find("machine_configuration/kernel_id").text + if kernel_id == 'true': + image_format = 'aki' + image_type = 'kernel' + kernel_id = None + except Exception: + kernel_id = None + + try: + ramdisk_id = manifest.find("machine_configuration/ramdisk_id").text + if ramdisk_id == 'true': + image_format = 'ari' + image_type = 'ramdisk' + ramdisk_id = None + except Exception: + ramdisk_id = None + + try: + arch = manifest.find("machine_configuration/architecture").text + except Exception: + arch = 'x86_64' + + properties = metadata['properties'] + properties['owner_id'] = context.project_id + properties['architecture'] = arch + + if kernel_id: + properties['kernel_id'] = ec2utils.ec2_id_to_id(kernel_id) + + if ramdisk_id: + properties['ramdisk_id'] = ec2utils.ec2_id_to_id(ramdisk_id) + + properties['is_public'] = False + properties['type'] = image_type + metadata.update({'disk_format': image_format, + 'container_format': image_format, + 'status': 'queued', + 'is_public': True, + 'properties': properties}) + metadata['properties']['image_state'] = 'pending' + image = self.service.create(context, metadata) + image_id = image['id'] + + def delayed_create(): + """This handles the fetching and decrypting of the part files.""" + parts = [] + for fn_element in manifest.find("image").getiterator("filename"): + part = self._download_file(bucket, fn_element.text, image_path) + parts.append(part) + + # NOTE(vish): this may be suboptimal, should we use cat? + encrypted_filename = os.path.join(image_path, 'image.encrypted') + with open(encrypted_filename, 'w') as combined: + for filename in parts: + with open(filename) as part: + shutil.copyfileobj(part, combined) + + metadata['properties']['image_state'] = 'decrypting' + self.service.update(context, image_id, metadata) + + hex_key = manifest.find("image/ec2_encrypted_key").text + encrypted_key = binascii.a2b_hex(hex_key) + hex_iv = manifest.find("image/ec2_encrypted_iv").text + encrypted_iv = binascii.a2b_hex(hex_iv) + + # FIXME(vish): grab key from common service so this can run on + # any host. + cloud_pk = os.path.join(FLAGS.ca_path, "private/cakey.pem") + + decrypted_filename = os.path.join(image_path, 'image.tar.gz') + self._decrypt_image(encrypted_filename, encrypted_key, + encrypted_iv, cloud_pk, decrypted_filename) + + metadata['properties']['image_state'] = 'untarring' + self.service.update(context, image_id, metadata) + + unz_filename = self._untarzip_image(image_path, decrypted_filename) + + metadata['properties']['image_state'] = 'uploading' + with open(unz_filename) as image_file: + self.service.update(context, image_id, metadata, image_file) + metadata['properties']['image_state'] = 'available' + self.service.update(context, image_id, metadata) + + shutil.rmtree(image_path) + + eventlet.spawn_n(delayed_create) + + return image + + @staticmethod + def _decrypt_image(encrypted_filename, encrypted_key, encrypted_iv, + cloud_private_key, decrypted_filename): + key, err = utils.execute('openssl', + 'rsautl', + '-decrypt', + '-inkey', '%s' % cloud_private_key, + process_input=encrypted_key, + check_exit_code=False) + if err: + raise exception.Error(_("Failed to decrypt private key: %s") + % err) + iv, err = utils.execute('openssl', + 'rsautl', + '-decrypt', + '-inkey', '%s' % cloud_private_key, + process_input=encrypted_iv, + check_exit_code=False) + if err: + raise exception.Error(_("Failed to decrypt initialization " + "vector: %s") % err) + + _out, err = utils.execute('openssl', 'enc', + '-d', '-aes-128-cbc', + '-in', '%s' % (encrypted_filename,), + '-K', '%s' % (key,), + '-iv', '%s' % (iv,), + '-out', '%s' % (decrypted_filename,), + check_exit_code=False) + if err: + raise exception.Error(_("Failed to decrypt image file " + "%(image_file)s: %(err)s") % + {'image_file': encrypted_filename, + 'err': err}) + + @staticmethod + def _untarzip_image(path, filename): + tar_file = tarfile.open(filename, "r|gz") + tar_file.extractall(path) + image_file = tar_file.getnames()[0] + tar_file.close() + return os.path.join(path, image_file) diff --git a/nova/image/service.py b/nova/image/service.py index ebee2228d..c09052cab 100644 --- a/nova/image/service.py +++ b/nova/image/service.py @@ -56,9 +56,9 @@ class BaseImageService(object): """ raise NotImplementedError - def show(self, context, id): + def show(self, context, image_id): """ - Returns a dict containing image data for the given opaque image id. + Returns a dict containing image metadata for the given opaque image id. :retval a mapping with the following signature: @@ -76,17 +76,27 @@ class BaseImageService(object): """ raise NotImplementedError - def create(self, context, data): + def get(self, context, data): """ - Store the image data and return the new image id. + Returns a dict containing image metadata and writes image data to data. + + :param data: a file-like object to hold binary image data + + :raises NotFound if the image does not exist + """ + raise NotImplementedError + + def create(self, context, metadata, data=None): + """ + Store the image metadata and data and return the new image id. :raises AlreadyExists if the image already exist. """ raise NotImplementedError - def update(self, context, image_id, data): - """Replace the contents of the given image with the new data. + def update(self, context, image_id, metadata, data=None): + """Update the given image with the new metadata and data. :raises NotFound if the image does not exist. diff --git a/nova/log.py b/nova/log.py index b541488bd..d194ab8f0 100644 --- a/nova/log.py +++ b/nova/log.py @@ -28,9 +28,11 @@ It also allows setting of formatting information through flags. import cStringIO +import inspect import json import logging import logging.handlers +import os import sys import traceback @@ -52,7 +54,7 @@ flags.DEFINE_string('logging_default_format_string', 'format string to use for log messages without context') flags.DEFINE_string('logging_debug_format_suffix', - 'from %(processName)s (pid=%(process)d) %(funcName)s' + 'from (pid=%(process)d) %(funcName)s' ' %(pathname)s:%(lineno)d', 'data to append to log format when level is DEBUG') @@ -63,6 +65,7 @@ flags.DEFINE_string('logging_exception_prefix', flags.DEFINE_list('default_log_levels', ['amqplib=WARN', 'sqlalchemy=WARN', + 'boto=WARN', 'eventlet.wsgi.server=WARN'], 'list of logger=LEVEL pairs') @@ -92,7 +95,7 @@ critical = logging.critical log = logging.log # handlers StreamHandler = logging.StreamHandler -FileHandler = logging.FileHandler +WatchedFileHandler = logging.handlers.WatchedFileHandler # logging.SysLogHandler is nicer than logging.logging.handler.SysLogHandler. SysLogHandler = logging.handlers.SysLogHandler @@ -111,22 +114,16 @@ def _dictify_context(context): return context -def basicConfig(): - logging.basicConfig() - for handler in logging.root.handlers: - handler.setFormatter(_formatter) - if FLAGS.verbose: - logging.root.setLevel(logging.DEBUG) - else: - logging.root.setLevel(logging.INFO) - if FLAGS.use_syslog: - syslog = SysLogHandler(address='/dev/log') - syslog.setFormatter(_formatter) - logging.root.addHandler(syslog) +def _get_binary_name(): + return os.path.basename(inspect.stack()[-1][1]) + + +def _get_log_file_path(binary=None): if FLAGS.logfile: - logfile = FileHandler(FLAGS.logfile) - logfile.setFormatter(_formatter) - logging.root.addHandler(logfile) + return FLAGS.logfile + if FLAGS.logdir: + binary = binary or _get_binary_name() + return '%s.log' % (os.path.join(FLAGS.logdir, binary),) class NovaLogger(logging.Logger): @@ -136,23 +133,19 @@ class NovaLogger(logging.Logger): This becomes the class that is instanciated by logging.getLogger. """ def __init__(self, name, level=NOTSET): - level_name = self._get_level_from_flags(name, FLAGS) - level = globals()[level_name] logging.Logger.__init__(self, name, level) + self.setup_from_flags() - def _get_level_from_flags(self, name, FLAGS): - # if exactly "nova", or a child logger, honor the verbose flag - if (name == "nova" or name.startswith("nova.")) and FLAGS.verbose: - return 'DEBUG' + def setup_from_flags(self): + """Setup logger from flags""" + level = NOTSET for pair in FLAGS.default_log_levels: - logger, _sep, level = pair.partition('=') + logger, _sep, level_name = pair.partition('=') # NOTE(todd): if we set a.b, we want a.b.c to have the same level # (but not a.bc, so we check the dot) - if name == logger: - return level - if name.startswith(logger) and name[len(logger)] == '.': - return level - return 'INFO' + if self.name == logger or self.name.startswith("%s." % logger): + level = globals()[level_name] + self.setLevel(level) def _log(self, level, msg, args, exc_info=None, extra=None, context=None): """Extract context from any log call""" @@ -161,12 +154,12 @@ class NovaLogger(logging.Logger): if context: extra.update(_dictify_context(context)) extra.update({"nova_version": version.version_string_with_vcs()}) - logging.Logger._log(self, level, msg, args, exc_info, extra) + return logging.Logger._log(self, level, msg, args, exc_info, extra) def addHandler(self, handler): """Each handler gets our custom formatter""" handler.setFormatter(_formatter) - logging.Logger.addHandler(self, handler) + return logging.Logger.addHandler(self, handler) def audit(self, msg, *args, **kwargs): """Shortcut for our AUDIT level""" @@ -193,23 +186,6 @@ class NovaLogger(logging.Logger): self.error(message, **kwargs) -def handle_exception(type, value, tb): - logging.root.critical(str(value), exc_info=(type, value, tb)) - - -sys.excepthook = handle_exception -logging.setLoggerClass(NovaLogger) - - -class NovaRootLogger(NovaLogger): - pass - -if not isinstance(logging.root, NovaRootLogger): - logging.root = NovaRootLogger("nova.root", WARNING) - NovaLogger.root = logging.root - NovaLogger.manager.root = logging.root - - class NovaFormatter(logging.Formatter): """ A nova.context.RequestContext aware formatter configured through flags. @@ -256,8 +232,76 @@ class NovaFormatter(logging.Formatter): _formatter = NovaFormatter() +class NovaRootLogger(NovaLogger): + def __init__(self, name, level=NOTSET): + self.logpath = None + self.filelog = None + self.streamlog = StreamHandler() + self.syslog = None + NovaLogger.__init__(self, name, level) + + def setup_from_flags(self): + """Setup logger from flags""" + global _filelog + if FLAGS.use_syslog: + self.syslog = SysLogHandler(address='/dev/log') + self.addHandler(self.syslog) + elif self.syslog: + self.removeHandler(self.syslog) + logpath = _get_log_file_path() + if logpath: + self.removeHandler(self.streamlog) + if logpath != self.logpath: + self.removeHandler(self.filelog) + self.filelog = WatchedFileHandler(logpath) + self.addHandler(self.filelog) + self.logpath = logpath + else: + self.removeHandler(self.filelog) + self.addHandler(self.streamlog) + if FLAGS.verbose: + self.setLevel(DEBUG) + else: + self.setLevel(INFO) + + +def handle_exception(type, value, tb): + extra = {} + if FLAGS.verbose: + extra['exc_info'] = (type, value, tb) + logging.root.critical(str(value), **extra) + + +def reset(): + """Resets logging handlers. Should be called if FLAGS changes.""" + for logger in NovaLogger.manager.loggerDict.itervalues(): + if isinstance(logger, NovaLogger): + logger.setup_from_flags() + + +def setup(): + """Setup nova logging.""" + if not isinstance(logging.root, NovaRootLogger): + logging._acquireLock() + for handler in logging.root.handlers: + logging.root.removeHandler(handler) + logging.root = NovaRootLogger("nova") + NovaLogger.root = logging.root + NovaLogger.manager.root = logging.root + for logger in NovaLogger.manager.loggerDict.itervalues(): + logger.root = logging.root + if isinstance(logger, logging.Logger): + NovaLogger.manager._fixupParents(logger) + NovaLogger.manager.loggerDict["nova"] = logging.root + logging._releaseLock() + sys.excepthook = handle_exception + reset() + + +root = logging.root +logging.setLoggerClass(NovaLogger) + + def audit(msg, *args, **kwargs): """Shortcut for logging to root log with sevrity 'AUDIT'.""" - if len(logging.root.handlers) == 0: - basicConfig() logging.root.log(AUDIT, msg, *args, **kwargs) diff --git a/nova/network/api.py b/nova/network/api.py index bf43acb51..4ee1148cb 100644 --- a/nova/network/api.py +++ b/nova/network/api.py @@ -21,6 +21,7 @@ Handles all requests relating to instances (guest vms). """ from nova import db +from nova import exception from nova import flags from nova import log as logging from nova import quota diff --git a/nova/network/linux_net.py b/nova/network/linux_net.py index de0e488ae..9f9d282b6 100644 --- a/nova/network/linux_net.py +++ b/nova/network/linux_net.py @@ -17,14 +17,17 @@ Implements vlans, bridges, and iptables rules using linux utilities. """ +import inspect import os +from eventlet import semaphore + from nova import db +from nova import exception from nova import flags from nova import log as logging from nova import utils - LOG = logging.getLogger("nova.linux_net") @@ -43,7 +46,7 @@ flags.DEFINE_string('dhcp_domain', flags.DEFINE_string('networks_path', '$state_path/networks', 'Location to keep network config files') -flags.DEFINE_string('public_interface', 'vlan1', +flags.DEFINE_string('public_interface', 'eth0', 'Interface for public IP addresses') flags.DEFINE_string('vlan_interface', 'eth0', 'network device for vlans') @@ -51,8 +54,8 @@ flags.DEFINE_string('dhcpbridge', _bin_file('nova-dhcpbridge'), 'location of nova-dhcpbridge') flags.DEFINE_string('routing_source_ip', '$my_ip', 'Public IP of network host') -flags.DEFINE_bool('use_nova_chains', False, - 'use the nova_ routing chains instead of default') +flags.DEFINE_string('input_chain', 'INPUT', + 'chain to add nova_input to') flags.DEFINE_string('dns_server', None, 'if set, uses specific dns server for dnsmasq') @@ -60,111 +63,379 @@ flags.DEFINE_string('dmz_cidr', '10.128.0.0/24', 'dmz range that should be accepted') +binary_name = os.path.basename(inspect.stack()[-1][1]) + + +class IptablesRule(object): + """An iptables rule + + You shouldn't need to use this class directly, it's only used by + IptablesManager + """ + def __init__(self, chain, rule, wrap=True, top=False): + self.chain = chain + self.rule = rule + self.wrap = wrap + self.top = top + + def __eq__(self, other): + return ((self.chain == other.chain) and + (self.rule == other.rule) and + (self.top == other.top) and + (self.wrap == other.wrap)) + + def __ne__(self, other): + return not self == other + + def __str__(self): + if self.wrap: + chain = '%s-%s' % (binary_name, self.chain) + else: + chain = self.chain + return '-A %s %s' % (chain, self.rule) + + +class IptablesTable(object): + """An iptables table""" + + def __init__(self): + self.rules = [] + self.chains = set() + self.unwrapped_chains = set() + + def add_chain(self, name, wrap=True): + """Adds a named chain to the table + + The chain name is wrapped to be unique for the component creating + it, so different components of Nova can safely create identically + named chains without interfering with one another. + + At the moment, its wrapped name is <binary name>-<chain name>, + so if nova-compute creates a chain named "OUTPUT", it'll actually + end up named "nova-compute-OUTPUT". + """ + if wrap: + self.chains.add(name) + else: + self.unwrapped_chains.add(name) + + def remove_chain(self, name, wrap=True): + """Remove named chain + + This removal "cascades". All rule in the chain are removed, as are + all rules in other chains that jump to it. + + If the chain is not found, this is merely logged. + """ + if wrap: + chain_set = self.chains + else: + chain_set = self.unwrapped_chains + + if name not in chain_set: + LOG.debug(_("Attempted to remove chain %s which doesn't exist"), + name) + return + + chain_set.remove(name) + self.rules = filter(lambda r: r.chain != name, self.rules) + + if wrap: + jump_snippet = '-j %s-%s' % (binary_name, name) + else: + jump_snippet = '-j %s' % (name,) + + self.rules = filter(lambda r: jump_snippet not in r.rule, self.rules) + + def add_rule(self, chain, rule, wrap=True, top=False): + """Add a rule to the table + + This is just like what you'd feed to iptables, just without + the "-A <chain name>" bit at the start. + + However, if you need to jump to one of your wrapped chains, + prepend its name with a '$' which will ensure the wrapping + is applied correctly. + """ + if wrap and chain not in self.chains: + raise ValueError(_("Unknown chain: %r") % chain) + + if '$' in rule: + rule = ' '.join(map(self._wrap_target_chain, rule.split(' '))) + + self.rules.append(IptablesRule(chain, rule, wrap, top)) + + def _wrap_target_chain(self, s): + if s.startswith('$'): + return '%s-%s' % (binary_name, s[1:]) + return s + + def remove_rule(self, chain, rule, wrap=True, top=False): + """Remove a rule from a chain + + Note: The rule must be exactly identical to the one that was added. + You cannot switch arguments around like you can with the iptables + CLI tool. + """ + try: + self.rules.remove(IptablesRule(chain, rule, wrap, top)) + except ValueError: + LOG.debug(_("Tried to remove rule that wasn't there:" + " %(chain)r %(rule)r %(wrap)r %(top)r"), + {'chain': chain, 'rule': rule, + 'top': top, 'wrap': wrap}) + + +class IptablesManager(object): + """Wrapper for iptables + + See IptablesTable for some usage docs + + A number of chains are set up to begin with. + + First, nova-filter-top. It's added at the top of FORWARD and OUTPUT. Its + name is not wrapped, so it's shared between the various nova workers. It's + intended for rules that need to live at the top of the FORWARD and OUTPUT + chains. It's in both the ipv4 and ipv6 set of tables. + + For ipv4 and ipv6, the builtin INPUT, OUTPUT, and FORWARD filter chains are + wrapped, meaning that the "real" INPUT chain has a rule that jumps to the + wrapped INPUT chain, etc. Additionally, there's a wrapped chain named + "local" which is jumped to from nova-filter-top. + + For ipv4, the builtin PREROUTING, OUTPUT, and POSTROUTING nat chains are + wrapped in the same was as the builtin filter chains. Additionally, there's + a snat chain that is applied after the POSTROUTING chain. + """ + def __init__(self, execute=None): + if not execute: + if FLAGS.fake_network: + self.execute = lambda *args, **kwargs: ('', '') + else: + self.execute = utils.execute + else: + self.execute = execute + + self.ipv4 = {'filter': IptablesTable(), + 'nat': IptablesTable()} + self.ipv6 = {'filter': IptablesTable()} + + # Add a nova-filter-top chain. It's intended to be shared + # among the various nova components. It sits at the very top + # of FORWARD and OUTPUT. + for tables in [self.ipv4, self.ipv6]: + tables['filter'].add_chain('nova-filter-top', wrap=False) + tables['filter'].add_rule('FORWARD', '-j nova-filter-top', + wrap=False, top=True) + tables['filter'].add_rule('OUTPUT', '-j nova-filter-top', + wrap=False, top=True) + + tables['filter'].add_chain('local') + tables['filter'].add_rule('nova-filter-top', '-j $local', + wrap=False) + + # Wrap the builtin chains + builtin_chains = {4: {'filter': ['INPUT', 'OUTPUT', 'FORWARD'], + 'nat': ['PREROUTING', 'OUTPUT', 'POSTROUTING']}, + 6: {'filter': ['INPUT', 'OUTPUT', 'FORWARD']}} + + for ip_version in builtin_chains: + if ip_version == 4: + tables = self.ipv4 + elif ip_version == 6: + tables = self.ipv6 + + for table, chains in builtin_chains[ip_version].iteritems(): + for chain in chains: + tables[table].add_chain(chain) + tables[table].add_rule(chain, '-j $%s' % (chain,), + wrap=False) + + # Add a nova-postrouting-bottom chain. It's intended to be shared + # among the various nova components. We set it as the last chain + # of POSTROUTING chain. + self.ipv4['nat'].add_chain('nova-postrouting-bottom', wrap=False) + self.ipv4['nat'].add_rule('POSTROUTING', '-j nova-postrouting-bottom', + wrap=False) + + # We add a snat chain to the shared nova-postrouting-bottom chain + # so that it's applied last. + self.ipv4['nat'].add_chain('snat') + self.ipv4['nat'].add_rule('nova-postrouting-bottom', '-j $snat', + wrap=False) + + # And then we add a floating-snat chain and jump to first thing in + # the snat chain. + self.ipv4['nat'].add_chain('floating-snat') + self.ipv4['nat'].add_rule('snat', '-j $floating-snat') + + self.semaphore = semaphore.Semaphore() + + @utils.synchronized('iptables') + def apply(self): + """Apply the current in-memory set of iptables rules + + This will blow away any rules left over from previous runs of the + same component of Nova, and replace them with our current set of + rules. This happens atomically, thanks to iptables-restore. + + We wrap the call in a semaphore lock, so that we don't race with + ourselves. In the event of a race with another component running + an iptables-* command at the same time, we retry up to 5 times. + """ + with self.semaphore: + s = [('iptables', self.ipv4)] + if FLAGS.use_ipv6: + s += [('ip6tables', self.ipv6)] + + for cmd, tables in s: + for table in tables: + current_table, _ = self.execute('sudo', + '%s-save' % (cmd,), + '-t', '%s' % (table,), + attempts=5) + current_lines = current_table.split('\n') + new_filter = self._modify_rules(current_lines, + tables[table]) + self.execute('sudo', '%s-restore' % (cmd,), + process_input='\n'.join(new_filter), + attempts=5) + + def _modify_rules(self, current_lines, table, binary=None): + unwrapped_chains = table.unwrapped_chains + chains = table.chains + rules = table.rules + + # Remove any trace of our rules + new_filter = filter(lambda line: binary_name not in line, + current_lines) + + seen_chains = False + rules_index = 0 + for rules_index, rule in enumerate(new_filter): + if not seen_chains: + if rule.startswith(':'): + seen_chains = True + else: + if not rule.startswith(':'): + break + + our_rules = [] + for rule in rules: + rule_str = str(rule) + if rule.top: + # rule.top == True means we want this rule to be at the top. + # Further down, we weed out duplicates from the bottom of the + # list, so here we remove the dupes ahead of time. + new_filter = filter(lambda s: s.strip() != rule_str.strip(), + new_filter) + our_rules += [rule_str] + + new_filter[rules_index:rules_index] = our_rules + + new_filter[rules_index:rules_index] = [':%s - [0:0]' % \ + (name,) \ + for name in unwrapped_chains] + new_filter[rules_index:rules_index] = [':%s-%s - [0:0]' % \ + (binary_name, name,) \ + for name in chains] + + seen_lines = set() + + def _weed_out_duplicates(line): + line = line.strip() + if line in seen_lines: + return False + else: + seen_lines.add(line) + return True + + # We filter duplicates, letting the *last* occurrence take + # precendence. + new_filter.reverse() + new_filter = filter(_weed_out_duplicates, new_filter) + new_filter.reverse() + return new_filter + + +iptables_manager = IptablesManager() + + def metadata_forward(): """Create forwarding rule for metadata""" - _confirm_rule("PREROUTING", "-t nat -s 0.0.0.0/0 " - "-d 169.254.169.254/32 -p tcp -m tcp --dport 80 -j DNAT " - "--to-destination %s:%s" % (FLAGS.ec2_dmz_host, FLAGS.ec2_port)) + iptables_manager.ipv4['nat'].add_rule("PREROUTING", + "-s 0.0.0.0/0 -d 169.254.169.254/32 " + "-p tcp -m tcp --dport 80 -j DNAT " + "--to-destination %s:%s" % \ + (FLAGS.ec2_dmz_host, FLAGS.ec2_port)) + iptables_manager.apply() def init_host(): """Basic networking setup goes here""" - - if FLAGS.use_nova_chains: - _execute("sudo iptables -N nova_input", check_exit_code=False) - _execute("sudo iptables -D %s -j nova_input" % FLAGS.input_chain, - check_exit_code=False) - _execute("sudo iptables -A %s -j nova_input" % FLAGS.input_chain) - - _execute("sudo iptables -N nova_forward", check_exit_code=False) - _execute("sudo iptables -D FORWARD -j nova_forward", - check_exit_code=False) - _execute("sudo iptables -A FORWARD -j nova_forward") - - _execute("sudo iptables -N nova_output", check_exit_code=False) - _execute("sudo iptables -D OUTPUT -j nova_output", - check_exit_code=False) - _execute("sudo iptables -A OUTPUT -j nova_output") - - _execute("sudo iptables -t nat -N nova_prerouting", - check_exit_code=False) - _execute("sudo iptables -t nat -D PREROUTING -j nova_prerouting", - check_exit_code=False) - _execute("sudo iptables -t nat -A PREROUTING -j nova_prerouting") - - _execute("sudo iptables -t nat -N nova_postrouting", - check_exit_code=False) - _execute("sudo iptables -t nat -D POSTROUTING -j nova_postrouting", - check_exit_code=False) - _execute("sudo iptables -t nat -A POSTROUTING -j nova_postrouting") - - _execute("sudo iptables -t nat -N nova_snatting", - check_exit_code=False) - _execute("sudo iptables -t nat -D POSTROUTING -j nova_snatting", - check_exit_code=False) - _execute("sudo iptables -t nat -A POSTROUTING -j nova_snatting") - - _execute("sudo iptables -t nat -N nova_output", check_exit_code=False) - _execute("sudo iptables -t nat -D OUTPUT -j nova_output", - check_exit_code=False) - _execute("sudo iptables -t nat -A OUTPUT -j nova_output") - else: - # NOTE(vish): This makes it easy to ensure snatting rules always - # come after the accept rules in the postrouting chain - _execute("sudo iptables -t nat -N SNATTING", - check_exit_code=False) - _execute("sudo iptables -t nat -D POSTROUTING -j SNATTING", - check_exit_code=False) - _execute("sudo iptables -t nat -A POSTROUTING -j SNATTING") - # NOTE(devcamcar): Cloud public SNAT entries and the default # SNAT rule for outbound traffic. - _confirm_rule("SNATTING", "-t nat -s %s " - "-j SNAT --to-source %s" - % (FLAGS.fixed_range, FLAGS.routing_source_ip), append=True) + iptables_manager.ipv4['nat'].add_rule("snat", + "-s %s -j SNAT --to-source %s" % \ + (FLAGS.fixed_range, + FLAGS.routing_source_ip)) + + iptables_manager.ipv4['nat'].add_rule("POSTROUTING", + "-s %s -d %s -j ACCEPT" % \ + (FLAGS.fixed_range, FLAGS.dmz_cidr)) - _confirm_rule("POSTROUTING", "-t nat -s %s -d %s -j ACCEPT" % - (FLAGS.fixed_range, FLAGS.dmz_cidr)) - _confirm_rule("POSTROUTING", "-t nat -s %(range)s -d %(range)s -j ACCEPT" % - {'range': FLAGS.fixed_range}) + iptables_manager.ipv4['nat'].add_rule("POSTROUTING", + "-s %(range)s -d %(range)s " + "-j ACCEPT" % \ + {'range': FLAGS.fixed_range}) + iptables_manager.apply() def bind_floating_ip(floating_ip, check_exit_code=True): """Bind ip to public interface""" - _execute("sudo ip addr add %s dev %s" % (floating_ip, - FLAGS.public_interface), + _execute('sudo', 'ip', 'addr', 'add', floating_ip, + 'dev', FLAGS.public_interface, check_exit_code=check_exit_code) def unbind_floating_ip(floating_ip): """Unbind a public ip from public interface""" - _execute("sudo ip addr del %s dev %s" % (floating_ip, - FLAGS.public_interface)) + _execute('sudo', 'ip', 'addr', 'del', floating_ip, + 'dev', FLAGS.public_interface) def ensure_vlan_forward(public_ip, port, private_ip): """Sets up forwarding rules for vlan""" - _confirm_rule("FORWARD", "-d %s -p udp --dport 1194 -j ACCEPT" % - private_ip) - _confirm_rule("PREROUTING", - "-t nat -d %s -p udp --dport %s -j DNAT --to %s:1194" - % (public_ip, port, private_ip)) + iptables_manager.ipv4['filter'].add_rule("FORWARD", + "-d %s -p udp " + "--dport 1194 " + "-j ACCEPT" % private_ip) + iptables_manager.ipv4['nat'].add_rule("PREROUTING", + "-d %s -p udp " + "--dport %s -j DNAT --to %s:1194" % + (public_ip, port, private_ip)) + iptables_manager.apply() def ensure_floating_forward(floating_ip, fixed_ip): """Ensure floating ip forwarding rule""" - _confirm_rule("PREROUTING", "-t nat -d %s -j DNAT --to %s" - % (floating_ip, fixed_ip)) - _confirm_rule("SNATTING", "-t nat -s %s -j SNAT --to %s" - % (fixed_ip, floating_ip)) + for chain, rule in floating_forward_rules(floating_ip, fixed_ip): + iptables_manager.ipv4['nat'].add_rule(chain, rule) + iptables_manager.apply() def remove_floating_forward(floating_ip, fixed_ip): """Remove forwarding for floating ip""" - _remove_rule("PREROUTING", "-t nat -d %s -j DNAT --to %s" - % (floating_ip, fixed_ip)) - _remove_rule("SNATTING", "-t nat -s %s -j SNAT --to %s" - % (fixed_ip, floating_ip)) + for chain, rule in floating_forward_rules(floating_ip, fixed_ip): + iptables_manager.ipv4['nat'].remove_rule(chain, rule) + iptables_manager.apply() + + +def floating_forward_rules(floating_ip, fixed_ip): + return [("PREROUTING", "-d %s -j DNAT --to %s" % (floating_ip, fixed_ip)), + ("OUTPUT", "-d %s -j DNAT --to %s" % (floating_ip, fixed_ip)), + ("floating-snat", + "-s %s -j SNAT --to %s" % (fixed_ip, floating_ip))] def ensure_vlan_bridge(vlan_num, bridge, net_attrs=None): @@ -178,47 +449,90 @@ def ensure_vlan(vlan_num): interface = "vlan%s" % vlan_num if not _device_exists(interface): LOG.debug(_("Starting VLAN inteface %s"), interface) - _execute("sudo vconfig set_name_type VLAN_PLUS_VID_NO_PAD") - _execute("sudo vconfig add %s %s" % (FLAGS.vlan_interface, vlan_num)) - _execute("sudo ifconfig %s up" % interface) + _execute('sudo', 'vconfig', 'set_name_type', 'VLAN_PLUS_VID_NO_PAD') + _execute('sudo', 'vconfig', 'add', FLAGS.vlan_interface, vlan_num) + _execute('sudo', 'ip', 'link', 'set', interface, 'up') return interface def ensure_bridge(bridge, interface, net_attrs=None): - """Create a bridge unless it already exists""" + """Create a bridge unless it already exists. + + :param interface: the interface to create the bridge on. + :param net_attrs: dictionary with attributes used to create the bridge. + + If net_attrs is set, it will add the net_attrs['gateway'] to the bridge + using net_attrs['broadcast'] and net_attrs['cidr']. It will also add + the ip_v6 address specified in net_attrs['cidr_v6'] if use_ipv6 is set. + + The code will attempt to move any ips that already exist on the interface + onto the bridge and reset the default gateway if necessary. + """ if not _device_exists(bridge): LOG.debug(_("Starting Bridge interface for %s"), interface) - _execute("sudo brctl addbr %s" % bridge) - _execute("sudo brctl setfd %s 0" % bridge) + _execute('sudo', 'brctl', 'addbr', bridge) + _execute('sudo', 'brctl', 'setfd', bridge, 0) # _execute("sudo brctl setageing %s 10" % bridge) - _execute("sudo brctl stp %s off" % bridge) - if interface: - _execute("sudo brctl addif %s %s" % (bridge, interface)) + _execute('sudo', 'brctl', 'stp', bridge, 'off') + _execute('sudo', 'ip', 'link', 'set', bridge, 'up') if net_attrs: - _execute("sudo ifconfig %s %s broadcast %s netmask %s up" % \ - (bridge, - net_attrs['gateway'], - net_attrs['broadcast'], - net_attrs['netmask'])) + # NOTE(vish): The ip for dnsmasq has to be the first address on the + # bridge for it to respond to reqests properly + suffix = net_attrs['cidr'].rpartition('/')[2] + out, err = _execute('sudo', 'ip', 'addr', 'add', + "%s/%s" % + (net_attrs['gateway'], suffix), + 'brd', + net_attrs['broadcast'], + 'dev', + bridge, + check_exit_code=False) + if err and err != "RTNETLINK answers: File exists\n": + raise exception.Error("Failed to add ip: %s" % err) if(FLAGS.use_ipv6): - _execute("sudo ip -f inet6 addr change %s dev %s" % - (net_attrs['cidr_v6'], bridge)) - _execute("sudo ifconfig %s up" % bridge) - else: - _execute("sudo ifconfig %s up" % bridge) - if FLAGS.use_nova_chains: - (out, err) = _execute("sudo iptables -N nova_forward", - check_exit_code=False) - if err != 'iptables: Chain already exists.\n': - # NOTE(vish): chain didn't exist link chain - _execute("sudo iptables -D FORWARD -j nova_forward", - check_exit_code=False) - _execute("sudo iptables -A FORWARD -j nova_forward") - - _confirm_rule("FORWARD", "--in-interface %s -j ACCEPT" % bridge) - _confirm_rule("FORWARD", "--out-interface %s -j ACCEPT" % bridge) - _execute("sudo iptables -N nova-local", check_exit_code=False) - _confirm_rule("FORWARD", "-j nova-local") + _execute('sudo', 'ip', '-f', 'inet6', 'addr', + 'change', net_attrs['cidr_v6'], + 'dev', bridge) + # NOTE(vish): If the public interface is the same as the + # bridge, then the bridge has to be in promiscuous + # to forward packets properly. + if(FLAGS.public_interface == bridge): + _execute('sudo', 'ip', 'link', 'set', + 'dev', bridge, 'promisc', 'on') + if interface: + # NOTE(vish): This will break if there is already an ip on the + # interface, so we move any ips to the bridge + gateway = None + out, err = _execute('sudo', 'route', '-n') + for line in out.split("\n"): + fields = line.split() + if fields and fields[0] == "0.0.0.0" and fields[-1] == interface: + gateway = fields[1] + out, err = _execute('sudo', 'ip', 'addr', 'show', 'dev', interface, + 'scope', 'global') + for line in out.split("\n"): + fields = line.split() + if fields and fields[0] == "inet": + params = ' '.join(fields[1:-1]) + _execute('sudo', 'ip', 'addr', + 'del', params, 'dev', fields[-1]) + _execute('sudo', 'ip', 'addr', + 'add', params, 'dev', bridge) + if gateway: + _execute('sudo', 'route', 'add', '0.0.0.0', 'gw', gateway) + out, err = _execute('sudo', 'brctl', 'addif', bridge, interface, + check_exit_code=False) + + if (err and err != "device %s is already a member of a bridge; can't " + "enslave it to bridge %s.\n" % (interface, bridge)): + raise exception.Error("Failed to add interface: %s" % err) + + iptables_manager.ipv4['filter'].add_rule("FORWARD", + "--in-interface %s -j ACCEPT" % \ + bridge) + iptables_manager.ipv4['filter'].add_rule("FORWARD", + "--out-interface %s -j ACCEPT" % \ + bridge) def get_dhcp_hosts(context, network_id): @@ -252,11 +566,11 @@ def update_dhcp(context, network_id): # if dnsmasq is already running, then tell it to reload if pid: - out, _err = _execute('cat /proc/%d/cmdline' % pid, + out, _err = _execute('cat', "/proc/%d/cmdline" % pid, check_exit_code=False) if conffile in out: try: - _execute('sudo kill -HUP %d' % pid) + _execute('sudo', 'kill', '-HUP', pid) return except Exception as exc: # pylint: disable-msg=W0703 LOG.debug(_("Hupping dnsmasq threw %s"), exc) @@ -267,7 +581,7 @@ def update_dhcp(context, network_id): env = {'FLAGFILE': FLAGS.dhcpbridge_flagfile, 'DNSMASQ_INTERFACE': network_ref['bridge']} command = _dnsmasq_cmd(network_ref) - _execute(command, addl_env=env) + _execute(*command, addl_env=env) def update_ra(context, network_id): @@ -297,17 +611,17 @@ interface %s # if radvd is already running, then tell it to reload if pid: - out, _err = _execute('cat /proc/%d/cmdline' + out, _err = _execute('cat', '/proc/%d/cmdline' % pid, check_exit_code=False) if conffile in out: try: - _execute('sudo kill %d' % pid) + _execute('sudo', 'kill', pid) except Exception as exc: # pylint: disable-msg=W0703 LOG.debug(_("killing radvd threw %s"), exc) else: LOG.debug(_("Pid %d is stale, relaunching radvd"), pid) command = _ra_cmd(network_ref) - _execute(command) + _execute(*command) db.network_update(context, network_id, {"ra_server": utils.get_my_linklocal(network_ref['bridge'])}) @@ -322,67 +636,48 @@ def _host_dhcp(fixed_ip_ref): fixed_ip_ref['address']) -def _execute(cmd, *args, **kwargs): +def _execute(*cmd, **kwargs): """Wrapper around utils._execute for fake_network""" if FLAGS.fake_network: - LOG.debug("FAKE NET: %s", cmd) + LOG.debug("FAKE NET: %s", " ".join(map(str, cmd))) return "fake", 0 else: - return utils.execute(cmd, *args, **kwargs) + return utils.execute(*cmd, **kwargs) def _device_exists(device): """Check if ethernet device exists""" - (_out, err) = _execute("ifconfig %s" % device, check_exit_code=False) + (_out, err) = _execute('ip', 'link', 'show', 'dev', device, + check_exit_code=False) return not err -def _confirm_rule(chain, cmd, append=False): - """Delete and re-add iptables rule""" - if FLAGS.use_nova_chains: - chain = "nova_%s" % chain.lower() - if append: - loc = "-A" - else: - loc = "-I" - _execute("sudo iptables --delete %s %s" % (chain, cmd), - check_exit_code=False) - _execute("sudo iptables %s %s %s" % (loc, chain, cmd)) - - -def _remove_rule(chain, cmd): - """Remove iptables rule""" - if FLAGS.use_nova_chains: - chain = "%s" % chain.lower() - _execute("sudo iptables --delete %s %s" % (chain, cmd)) - - def _dnsmasq_cmd(net): """Builds dnsmasq command""" - cmd = ['sudo -E dnsmasq', - ' --strict-order', - ' --bind-interfaces', - ' --conf-file=', - ' --domain=%s' % FLAGS.dhcp_domain, - ' --pid-file=%s' % _dhcp_file(net['bridge'], 'pid'), - ' --listen-address=%s' % net['gateway'], - ' --except-interface=lo', - ' --dhcp-range=%s,static,120s' % net['dhcp_start'], - ' --dhcp-hostsfile=%s' % _dhcp_file(net['bridge'], 'conf'), - ' --dhcp-script=%s' % FLAGS.dhcpbridge, - ' --leasefile-ro'] + cmd = ['sudo', '-E', 'dnsmasq', + '--strict-order', + '--bind-interfaces', + '--conf-file=', + '--domain=%s' % FLAGS.dhcp_domain, + '--pid-file=%s' % _dhcp_file(net['bridge'], 'pid'), + '--listen-address=%s' % net['gateway'], + '--except-interface=lo', + '--dhcp-range=%s,static,120s' % net['dhcp_start'], + '--dhcp-hostsfile=%s' % _dhcp_file(net['bridge'], 'conf'), + '--dhcp-script=%s' % FLAGS.dhcpbridge, + '--leasefile-ro'] if FLAGS.dns_server: - cmd.append(' -h -R --server=%s' % FLAGS.dns_server) - return ''.join(cmd) + cmd += ['-h', '-R', '--server=%s' % FLAGS.dns_server] + return cmd def _ra_cmd(net): """Builds radvd command""" - cmd = ['sudo -E radvd', -# ' -u nobody', - ' -C %s' % _ra_file(net['bridge'], 'conf'), - ' -p %s' % _ra_file(net['bridge'], 'pid')] - return ''.join(cmd) + cmd = ['sudo', '-E', 'radvd', +# '-u', 'nobody', + '-C', '%s' % _ra_file(net['bridge'], 'conf'), + '-p', '%s' % _ra_file(net['bridge'], 'pid')] + return cmd def _stop_dnsmasq(network): @@ -391,7 +686,7 @@ def _stop_dnsmasq(network): if pid: try: - _execute('sudo kill -TERM %d' % pid) + _execute('sudo', 'kill', '-TERM', pid) except Exception as exc: # pylint: disable-msg=W0703 LOG.debug(_("Killing dnsmasq threw %s"), exc) diff --git a/nova/network/manager.py b/nova/network/manager.py index fbcbea131..3dfc48934 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -110,6 +110,7 @@ class NetworkManager(manager.Manager): This class must be subclassed to support specific topologies. """ + timeout_fixed_ips = True def __init__(self, network_driver=None, *args, **kwargs): if not network_driver: @@ -118,6 +119,10 @@ class NetworkManager(manager.Manager): super(NetworkManager, self).__init__(*args, **kwargs) def init_host(self): + """Do any initialization that needs to be run if this is a + standalone service. + """ + self.driver.init_host() # Set up networking for the projects for which we're already # the designated network host. ctxt = context.get_admin_context() @@ -134,6 +139,19 @@ class NetworkManager(manager.Manager): self.driver.ensure_floating_forward(floating_ip['address'], fixed_address) + def periodic_tasks(self, context=None): + """Tasks to be run at a periodic interval.""" + super(NetworkManager, self).periodic_tasks(context) + if self.timeout_fixed_ips: + now = utils.utcnow() + timeout = FLAGS.fixed_ip_disassociate_timeout + time = now - datetime.timedelta(seconds=timeout) + num = self.db.fixed_ip_disassociate_all_by_timeout(context, + self.host, + time) + if num: + LOG.debug(_("Dissassociated %s stale fixed ip(s)"), num) + def set_network_host(self, context, network_id): """Safely sets the host of the network.""" LOG.debug(_("setting network host"), context=context) @@ -145,11 +163,22 @@ class NetworkManager(manager.Manager): def allocate_fixed_ip(self, context, instance_id, *args, **kwargs): """Gets a fixed ip from the pool.""" - raise NotImplementedError() + # TODO(vish): when this is called by compute, we can associate compute + # with a network, or a cluster of computes with a network + # and use that network here with a method like + # network_get_by_compute_host + network_ref = self.db.network_get_by_bridge(context, + FLAGS.flat_network_bridge) + address = self.db.fixed_ip_associate_pool(context.elevated(), + network_ref['id'], + instance_id) + self.db.fixed_ip_update(context, address, {'allocated': True}) + return address def deallocate_fixed_ip(self, context, address, *args, **kwargs): """Returns a fixed ip to the pool.""" - raise NotImplementedError() + self.db.fixed_ip_update(context, address, {'allocated': False}) + self.db.fixed_ip_disassociate(context.elevated(), address) def setup_fixed_ip(self, context, address): """Sets up rules for fixed ip.""" @@ -239,12 +268,58 @@ class NetworkManager(manager.Manager): def get_network_host(self, context): """Get the network host for the current context.""" - raise NotImplementedError() + network_ref = self.db.network_get_by_bridge(context, + FLAGS.flat_network_bridge) + # NOTE(vish): If the network has no host, use the network_host flag. + # This could eventually be a a db lookup of some sort, but + # a flag is easy to handle for now. + host = network_ref['host'] + if not host: + topic = self.db.queue_get_for(context, + FLAGS.network_topic, + FLAGS.network_host) + if FLAGS.fake_call: + return self.set_network_host(context, network_ref['id']) + host = rpc.call(context, + FLAGS.network_topic, + {"method": "set_network_host", + "args": {"network_id": network_ref['id']}}) + return host def create_networks(self, context, cidr, num_networks, network_size, - cidr_v6, *args, **kwargs): + cidr_v6, label, *args, **kwargs): """Create networks based on parameters.""" - raise NotImplementedError() + fixed_net = IPy.IP(cidr) + fixed_net_v6 = IPy.IP(cidr_v6) + significant_bits_v6 = 64 + count = 1 + for index in range(num_networks): + start = index * network_size + significant_bits = 32 - int(math.log(network_size, 2)) + cidr = "%s/%s" % (fixed_net[start], significant_bits) + project_net = IPy.IP(cidr) + net = {} + net['bridge'] = FLAGS.flat_network_bridge + net['dns'] = FLAGS.flat_network_dns + net['cidr'] = cidr + net['netmask'] = str(project_net.netmask()) + net['gateway'] = str(project_net[1]) + net['broadcast'] = str(project_net.broadcast()) + net['dhcp_start'] = str(project_net[2]) + if num_networks > 1: + net['label'] = "%s_%d" % (label, count) + else: + net['label'] = label + count += 1 + + if(FLAGS.use_ipv6): + cidr_v6 = "%s/%s" % (fixed_net_v6[0], significant_bits_v6) + net['cidr_v6'] = cidr_v6 + + network_ref = self.db.network_create_safe(context, net) + + if network_ref: + self._create_fixed_ips(context, network_ref['id']) @property def _bottom_reserved_ips(self): # pylint: disable-msg=R0201 @@ -302,78 +377,22 @@ class FlatManager(NetworkManager): not do any setup in this mode, it must be done manually. Requests to 169.254.169.254 port 80 will need to be forwarded to the api server. """ + timeout_fixed_ips = False - def allocate_fixed_ip(self, context, instance_id, *args, **kwargs): - """Gets a fixed ip from the pool.""" - # TODO(vish): when this is called by compute, we can associate compute - # with a network, or a cluster of computes with a network - # and use that network here with a method like - # network_get_by_compute_host - network_ref = self.db.network_get_by_bridge(context, - FLAGS.flat_network_bridge) - address = self.db.fixed_ip_associate_pool(context.elevated(), - network_ref['id'], - instance_id) - self.db.fixed_ip_update(context, address, {'allocated': True}) - return address - - def deallocate_fixed_ip(self, context, address, *args, **kwargs): - """Returns a fixed ip to the pool.""" - self.db.fixed_ip_update(context, address, {'allocated': False}) - self.db.fixed_ip_disassociate(context.elevated(), address) + def init_host(self): + """Do any initialization that needs to be run if this is a + standalone service. + """ + #Fix for bug 723298 - do not call init_host on superclass + #Following code has been copied for NetworkManager.init_host + ctxt = context.get_admin_context() + for network in self.db.host_get_networks(ctxt, self.host): + self._on_set_network_host(ctxt, network['id']) def setup_compute_network(self, context, instance_id): """Network is created manually.""" pass - def create_networks(self, context, cidr, num_networks, network_size, - cidr_v6, *args, **kwargs): - """Create networks based on parameters.""" - fixed_net = IPy.IP(cidr) - fixed_net_v6 = IPy.IP(cidr_v6) - significant_bits_v6 = 64 - for index in range(num_networks): - start = index * network_size - significant_bits = 32 - int(math.log(network_size, 2)) - cidr = "%s/%s" % (fixed_net[start], significant_bits) - project_net = IPy.IP(cidr) - net = {} - net['bridge'] = FLAGS.flat_network_bridge - net['cidr'] = cidr - net['netmask'] = str(project_net.netmask()) - net['gateway'] = str(project_net[1]) - net['broadcast'] = str(project_net.broadcast()) - net['dhcp_start'] = str(project_net[2]) - - if(FLAGS.use_ipv6): - cidr_v6 = "%s/%s" % (fixed_net_v6[0], significant_bits_v6) - net['cidr_v6'] = cidr_v6 - - network_ref = self.db.network_create_safe(context, net) - - if network_ref: - self._create_fixed_ips(context, network_ref['id']) - - def get_network_host(self, context): - """Get the network host for the current context.""" - network_ref = self.db.network_get_by_bridge(context, - FLAGS.flat_network_bridge) - # NOTE(vish): If the network has no host, use the network_host flag. - # This could eventually be a a db lookup of some sort, but - # a flag is easy to handle for now. - host = network_ref['host'] - if not host: - topic = self.db.queue_get_for(context, - FLAGS.network_topic, - FLAGS.network_host) - if FLAGS.fake_call: - return self.set_network_host(context, network_ref['id']) - host = rpc.call(context, - FLAGS.network_topic, - {"method": "set_network_host", - "args": {"network_id": network_ref['id']}}) - return host - def _on_set_network_host(self, context, network_id): """Called when this host becomes the host for a network.""" net = {} @@ -381,8 +400,24 @@ class FlatManager(NetworkManager): net['dns'] = FLAGS.flat_network_dns self.db.network_update(context, network_id, net) + def allocate_floating_ip(self, context, project_id): + #Fix for bug 723298 + raise NotImplementedError() + + def associate_floating_ip(self, context, floating_address, fixed_address): + #Fix for bug 723298 + raise NotImplementedError() + + def disassociate_floating_ip(self, context, floating_address): + #Fix for bug 723298 + raise NotImplementedError() + + def deallocate_floating_ip(self, context, floating_address): + #Fix for bug 723298 + raise NotImplementedError() + -class FlatDHCPManager(FlatManager): +class FlatDHCPManager(NetworkManager): """Flat networking with dhcp. FlatDHCPManager will start up one dhcp server to give out addresses. @@ -395,7 +430,6 @@ class FlatDHCPManager(FlatManager): standalone service. """ super(FlatDHCPManager, self).init_host() - self.driver.init_host() self.driver.metadata_forward() def setup_compute_network(self, context, instance_id): @@ -448,24 +482,11 @@ class VlanManager(NetworkManager): instances in its subnet. """ - def periodic_tasks(self, context=None): - """Tasks to be run at a periodic interval.""" - super(VlanManager, self).periodic_tasks(context) - now = datetime.datetime.utcnow() - timeout = FLAGS.fixed_ip_disassociate_timeout - time = now - datetime.timedelta(seconds=timeout) - num = self.db.fixed_ip_disassociate_all_by_timeout(context, - self.host, - time) - if num: - LOG.debug(_("Dissassociated %s stale fixed ip(s)"), num) - def init_host(self): """Do any initialization that needs to be run if this is a standalone service. """ super(VlanManager, self).init_host() - self.driver.init_host() self.driver.metadata_forward() def allocate_fixed_ip(self, context, instance_id, *args, **kwargs): @@ -501,9 +522,20 @@ class VlanManager(NetworkManager): network_ref['bridge']) def create_networks(self, context, cidr, num_networks, network_size, - cidr_v6, vlan_start, vpn_start): + cidr_v6, vlan_start, vpn_start, **kwargs): """Create networks based on parameters.""" + # Check that num_networks + vlan_start is not > 4094, fixes lp708025 + if num_networks + vlan_start > 4094: + raise ValueError(_('The sum between the number of networks and' + ' the vlan start cannot be greater' + ' than 4094')) + fixed_net = IPy.IP(cidr) + if fixed_net.len() < num_networks * network_size: + raise ValueError(_('The network range is not big enough to fit ' + '%(num_networks)s. Network size is %(network_size)s' % + locals())) + fixed_net_v6 = IPy.IP(cidr_v6) network_size_v6 = 1 << 64 significant_bits_v6 = 64 @@ -531,6 +563,16 @@ class VlanManager(NetworkManager): # NOTE(vish): This makes ports unique accross the cloud, a more # robust solution would be to make them unique per ip net['vpn_public_port'] = vpn_start + index + network_ref = None + try: + network_ref = db.network_get_by_cidr(context, cidr) + except exception.NotFound: + pass + + if network_ref is not None: + raise ValueError(_('Network with cidr %s already exists' % + cidr)) + network_ref = self.db.network_create_safe(context, net) if network_ref: self._create_fixed_ips(context, network_ref['id']) diff --git a/nova/objectstore/bucket.py b/nova/objectstore/bucket.py index 82767e52f..b213e18e8 100644 --- a/nova/objectstore/bucket.py +++ b/nova/objectstore/bucket.py @@ -107,7 +107,7 @@ class Bucket(object): def is_authorized(self, context): try: - return context.user.is_admin() or \ + return context.is_admin or \ self.owner_id == context.project_id except Exception, e: return False diff --git a/nova/objectstore/image.py b/nova/objectstore/image.py index 41e0abd80..c90b5b54b 100644 --- a/nova/objectstore/image.py +++ b/nova/objectstore/image.py @@ -37,8 +37,7 @@ from nova.objectstore import bucket FLAGS = flags.FLAGS -flags.DEFINE_string('images_path', '$state_path/images', - 'path to decrypted images') +flags.DECLARE('images_path', 'nova.image.local') class Image(object): @@ -69,7 +68,7 @@ class Image(object): # but only modified by admin or owner. try: return (self.metadata['isPublic'] and readonly) or \ - context.user.is_admin() or \ + context.is_admin or \ self.metadata['imageOwnerId'] == context.project_id except: return False @@ -254,25 +253,34 @@ class Image(object): @staticmethod def decrypt_image(encrypted_filename, encrypted_key, encrypted_iv, cloud_private_key, decrypted_filename): - key, err = utils.execute( - 'openssl rsautl -decrypt -inkey %s' % cloud_private_key, - process_input=encrypted_key, - check_exit_code=False) + key, err = utils.execute('openssl', + 'rsautl', + '-decrypt', + '-inkey', '%s' % cloud_private_key, + process_input=encrypted_key, + check_exit_code=False) if err: raise exception.Error(_("Failed to decrypt private key: %s") % err) - iv, err = utils.execute( - 'openssl rsautl -decrypt -inkey %s' % cloud_private_key, - process_input=encrypted_iv, - check_exit_code=False) + iv, err = utils.execute('openssl', + 'rsautl', + '-decrypt', + '-inkey', '%s' % cloud_private_key, + process_input=encrypted_iv, + check_exit_code=False) if err: raise exception.Error(_("Failed to decrypt initialization " "vector: %s") % err) - _out, err = utils.execute( - 'openssl enc -d -aes-128-cbc -in %s -K %s -iv %s -out %s' - % (encrypted_filename, key, iv, decrypted_filename), - check_exit_code=False) + _out, err = utils.execute('openssl', + 'enc', + '-d', + '-aes-128-cbc', + '-in', '%s' % (encrypted_filename,), + '-K', '%s' % (key,), + '-iv', '%s' % (iv,), + '-out', '%s' % (decrypted_filename,), + check_exit_code=False) if err: raise exception.Error(_("Failed to decrypt image file " "%(image_file)s: %(err)s") % diff --git a/nova/quota.py b/nova/quota.py index 3884eb308..6b52a97fa 100644 --- a/nova/quota.py +++ b/nova/quota.py @@ -35,6 +35,8 @@ flags.DEFINE_integer('quota_gigabytes', 1000, 'number of volume gigabytes allowed per project') flags.DEFINE_integer('quota_floating_ips', 10, 'number of floating ips allowed per project') +flags.DEFINE_integer('quota_metadata_items', 128, + 'number of metadata items allowed per instance') def get_quota(context, project_id): @@ -42,7 +44,8 @@ def get_quota(context, project_id): 'cores': FLAGS.quota_cores, 'volumes': FLAGS.quota_volumes, 'gigabytes': FLAGS.quota_gigabytes, - 'floating_ips': FLAGS.quota_floating_ips} + 'floating_ips': FLAGS.quota_floating_ips, + 'metadata_items': FLAGS.quota_metadata_items} try: quota = db.quota_get(context, project_id) for key in rval.keys(): @@ -94,6 +97,15 @@ def allowed_floating_ips(context, num_floating_ips): return min(num_floating_ips, allowed_floating_ips) +def allowed_metadata_items(context, num_metadata_items): + """Check quota; return min(num_metadata_items,allowed_metadata_items)""" + project_id = context.project_id + context = context.elevated() + quota = get_quota(context, project_id) + num_allowed_metadata_items = quota['metadata_items'] + return min(num_metadata_items, num_allowed_metadata_items) + + class QuotaError(exception.ApiError): """Quota Exceeeded""" pass diff --git a/nova/rpc.py b/nova/rpc.py index 01fc6d44b..fbb90299b 100644 --- a/nova/rpc.py +++ b/nova/rpc.py @@ -29,6 +29,7 @@ import uuid from carrot import connection as carrot_connection from carrot import messaging +from eventlet import greenpool from eventlet import greenthread from nova import context @@ -42,11 +43,13 @@ from nova import utils FLAGS = flags.FLAGS LOG = logging.getLogger('nova.rpc') +flags.DEFINE_integer('rpc_thread_pool_size', 1024, 'Size of RPC thread pool') + class Connection(carrot_connection.BrokerConnection): """Connection instance object""" @classmethod - def instance(cls, new=False): + def instance(cls, new=True): """Returns the instance""" if new or not hasattr(cls, '_instance'): params = dict(hostname=FLAGS.rabbit_host, @@ -88,18 +91,19 @@ class Consumer(messaging.Consumer): super(Consumer, self).__init__(*args, **kwargs) self.failed_connection = False break - except: # Catching all because carrot sucks + except Exception as e: # Catching all because carrot sucks fl_host = FLAGS.rabbit_host fl_port = FLAGS.rabbit_port fl_intv = FLAGS.rabbit_retry_interval - LOG.exception(_("AMQP server on %(fl_host)s:%(fl_port)d is" - " unreachable. Trying again in %(fl_intv)d seconds.") + LOG.error(_("AMQP server on %(fl_host)s:%(fl_port)d is" + " unreachable: %(e)s. Trying again in %(fl_intv)d" + " seconds.") % locals()) self.failed_connection = True if self.failed_connection: - LOG.exception(_("Unable to connect to AMQP server " - "after %d tries. Shutting down."), - FLAGS.rabbit_max_retries) + LOG.error(_("Unable to connect to AMQP server " + "after %d tries. Shutting down."), + FLAGS.rabbit_max_retries) sys.exit(1) def fetch(self, no_ack=None, auto_ack=None, enable_callbacks=False): @@ -119,7 +123,7 @@ class Consumer(messaging.Consumer): LOG.error(_("Reconnected to queue")) self.failed_connection = False # NOTE(vish): This is catching all errors because we really don't - # exceptions to be logged 10 times a second if some + # want exceptions to be logged 10 times a second if some # persistent failure occurs. except Exception: # pylint: disable-msg=W0703 if not self.failed_connection: @@ -155,11 +159,15 @@ class AdapterConsumer(TopicConsumer): def __init__(self, connection=None, topic="broadcast", proxy=None): LOG.debug(_('Initing the Adapter Consumer for %s') % topic) self.proxy = proxy + self.pool = greenpool.GreenPool(FLAGS.rpc_thread_pool_size) super(AdapterConsumer, self).__init__(connection=connection, topic=topic) + def receive(self, *args, **kwargs): + self.pool.spawn_n(self._receive, *args, **kwargs) + @exception.wrap_exception - def receive(self, message_data, message): + def _receive(self, message_data, message): """Magically looks for a method on the proxy object and calls it Message data should be a dictionary with two keys: @@ -246,7 +254,7 @@ def msg_reply(msg_id, reply=None, failure=None): LOG.error(_("Returning exception %s to caller"), message) LOG.error(tb) failure = (failure[0].__name__, str(failure[1]), tb) - conn = Connection.instance(True) + conn = Connection.instance() publisher = DirectPublisher(connection=conn, msg_id=msg_id) try: publisher.send({'result': reply, 'failure': failure}) @@ -319,7 +327,7 @@ def call(context, topic, msg): self.result = data['result'] wait_msg = WaitMessage() - conn = Connection.instance(True) + conn = Connection.instance() consumer = DirectConsumer(connection=conn, msg_id=msg_id) consumer.register_callback(wait_msg) diff --git a/nova/scheduler/api.py b/nova/scheduler/api.py new file mode 100644 index 000000000..2405f1343 --- /dev/null +++ b/nova/scheduler/api.py @@ -0,0 +1,49 @@ +# Copyright (c) 2011 Openstack, LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Handles all requests relating to schedulers. +""" + +from nova import flags +from nova import log as logging +from nova import rpc + +FLAGS = flags.FLAGS +LOG = logging.getLogger('nova.scheduler.api') + + +class API(object): + """API for interacting with the scheduler.""" + + def _call_scheduler(self, method, context, params=None): + """Generic handler for RPC calls to the scheduler. + + :param params: Optional dictionary of arguments to be passed to the + scheduler worker + + :retval: Result returned by scheduler worker + """ + if not params: + params = {} + queue = FLAGS.scheduler_topic + kwargs = {'method': method, 'args': params} + return rpc.call(context, queue, kwargs) + + def get_zone_list(self, context): + items = self._call_scheduler('get_zone_list', context) + for item in items: + item['api_url'] = item['api_url'].replace('\\/', '/') + return items diff --git a/nova/scheduler/manager.py b/nova/scheduler/manager.py index e9b47512e..c94397210 100644 --- a/nova/scheduler/manager.py +++ b/nova/scheduler/manager.py @@ -29,6 +29,7 @@ from nova import log as logging from nova import manager from nova import rpc from nova import utils +from nova.scheduler import zone_manager LOG = logging.getLogger('nova.scheduler.manager') FLAGS = flags.FLAGS @@ -43,12 +44,21 @@ class SchedulerManager(manager.Manager): if not scheduler_driver: scheduler_driver = FLAGS.scheduler_driver self.driver = utils.import_object(scheduler_driver) + self.zone_manager = zone_manager.ZoneManager() super(SchedulerManager, self).__init__(*args, **kwargs) def __getattr__(self, key): """Converts all method calls to use the schedule method""" return functools.partial(self._schedule, key) + def periodic_tasks(self, context=None): + """Poll child zones periodically to get status.""" + self.zone_manager.ping(context) + + def get_zone_list(self, context=None): + """Get a list of zones from the ZoneManager.""" + return self.zone_manager.get_zone_list() + def _schedule(self, method, context, topic, *args, **kwargs): """Tries to call schedule_* method on the driver to retrieve host. diff --git a/nova/scheduler/zone_manager.py b/nova/scheduler/zone_manager.py new file mode 100644 index 000000000..edf9000cc --- /dev/null +++ b/nova/scheduler/zone_manager.py @@ -0,0 +1,143 @@ +# Copyright (c) 2011 Openstack, LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +ZoneManager oversees all communications with child Zones. +""" + +import novaclient +import thread +import traceback + +from datetime import datetime +from eventlet import greenpool + +from nova import db +from nova import flags +from nova import log as logging + +FLAGS = flags.FLAGS +flags.DEFINE_integer('zone_db_check_interval', 60, + 'Seconds between getting fresh zone info from db.') +flags.DEFINE_integer('zone_failures_to_offline', 3, + 'Number of consecutive errors before marking zone offline') + + +class ZoneState(object): + """Holds the state of all connected child zones.""" + def __init__(self): + self.is_active = True + self.name = None + self.capabilities = None + self.attempt = 0 + self.last_seen = datetime.min + self.last_exception = None + self.last_exception_time = None + + def update_credentials(self, zone): + """Update zone credentials from db""" + self.zone_id = zone.id + self.api_url = zone.api_url + self.username = zone.username + self.password = zone.password + + def update_metadata(self, zone_metadata): + """Update zone metadata after successful communications with + child zone.""" + self.last_seen = datetime.now() + self.attempt = 0 + self.name = zone_metadata["name"] + self.capabilities = zone_metadata["capabilities"] + self.is_active = True + + def to_dict(self): + return dict(name=self.name, capabilities=self.capabilities, + is_active=self.is_active, api_url=self.api_url, + id=self.zone_id) + + def log_error(self, exception): + """Something went wrong. Check to see if zone should be + marked as offline.""" + self.last_exception = exception + self.last_exception_time = datetime.now() + api_url = self.api_url + logging.warning(_("'%(exception)s' error talking to " + "zone %(api_url)s") % locals()) + + max_errors = FLAGS.zone_failures_to_offline + self.attempt += 1 + if self.attempt >= max_errors: + self.is_active = False + logging.error(_("No answer from zone %(api_url)s " + "after %(max_errors)d " + "attempts. Marking inactive.") % locals()) + + +def _call_novaclient(zone): + """Call novaclient. Broken out for testing purposes.""" + client = novaclient.OpenStack(zone.username, zone.password, zone.api_url) + return client.zones.info()._info + + +def _poll_zone(zone): + """Eventlet worker to poll a zone.""" + logging.debug(_("Polling zone: %s") % zone.api_url) + try: + zone.update_metadata(_call_novaclient(zone)) + except Exception, e: + zone.log_error(traceback.format_exc()) + + +class ZoneManager(object): + """Keeps the zone states updated.""" + def __init__(self): + self.last_zone_db_check = datetime.min + self.zone_states = {} + self.green_pool = greenpool.GreenPool() + + def get_zone_list(self): + """Return the list of zones we know about.""" + return [zone.to_dict() for zone in self.zone_states.values()] + + def _refresh_from_db(self, context): + """Make our zone state map match the db.""" + # Add/update existing zones ... + zones = db.zone_get_all(context) + existing = self.zone_states.keys() + db_keys = [] + for zone in zones: + db_keys.append(zone.id) + if zone.id not in existing: + self.zone_states[zone.id] = ZoneState() + self.zone_states[zone.id].update_credentials(zone) + + # Cleanup zones removed from db ... + keys = self.zone_states.keys() # since we're deleting + for zone_id in keys: + if zone_id not in db_keys: + del self.zone_states[zone_id] + + def _poll_zones(self, context): + """Try to connect to each child zone and get update.""" + self.green_pool.imap(_poll_zone, self.zone_states.values()) + + def ping(self, context=None): + """Ping should be called periodically to update zone status.""" + diff = datetime.now() - self.last_zone_db_check + if diff.seconds >= FLAGS.zone_db_check_interval: + logging.debug(_("Updating zone cache from db.")) + self.last_zone_db_check = datetime.now() + self._refresh_from_db(context) + self._poll_zones(context) diff --git a/nova/service.py b/nova/service.py index 59648adf2..af20db01c 100644 --- a/nova/service.py +++ b/nova/service.py @@ -2,6 +2,7 @@ # Copyright 2010 United States Government as represented by the # Administrator of the National Aeronautics and Space Administration. +# Copyright 2011 Justin Santa Barbara # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -39,24 +40,24 @@ from nova import flags from nova import rpc from nova import utils from nova import version +from nova import wsgi FLAGS = flags.FLAGS flags.DEFINE_integer('report_interval', 10, 'seconds between nodes reporting state to datastore', lower_bound=1) - flags.DEFINE_integer('periodic_interval', 60, 'seconds between running periodic tasks', lower_bound=1) - -flags.DEFINE_string('pidfile', None, - 'pidfile to use for this service') - - -flags.DEFINE_flag(flags.HelpFlag()) -flags.DEFINE_flag(flags.HelpshortFlag()) -flags.DEFINE_flag(flags.HelpXMLFlag()) +flags.DEFINE_string('ec2_listen', "0.0.0.0", + 'IP address for EC2 API to listen') +flags.DEFINE_integer('ec2_listen_port', 8773, 'port for ec2 api to listen') +flags.DEFINE_string('osapi_listen', "0.0.0.0", + 'IP address for OpenStack API to listen') +flags.DEFINE_integer('osapi_listen_port', 8774, 'port for os api to listen') +flags.DEFINE_string('api_paste_config', "api-paste.ini", + 'File name for the paste.deploy config for nova-api') class Service(object): @@ -68,6 +69,8 @@ class Service(object): self.binary = binary self.topic = topic self.manager_class_name = manager + manager_class = utils.import_class(self.manager_class_name) + self.manager = manager_class(host=self.host, *args, **kwargs) self.report_interval = report_interval self.periodic_interval = periodic_interval super(Service, self).__init__(*args, **kwargs) @@ -75,9 +78,9 @@ class Service(object): self.timers = [] def start(self): - manager_class = utils.import_class(self.manager_class_name) - self.manager = manager_class(host=self.host, *self.saved_args, - **self.saved_kwargs) + vcs_string = version.version_string_with_vcs() + logging.audit(_("Starting %(topic)s node (version %(vcs_string)s)"), + {'topic': self.topic, 'vcs_string': vcs_string}) self.manager.init_host() self.model_disconnected = False ctxt = context.get_admin_context() @@ -157,9 +160,6 @@ class Service(object): report_interval = FLAGS.report_interval if not periodic_interval: periodic_interval = FLAGS.periodic_interval - vcs_string = version.version_string_with_vcs() - logging.audit(_("Starting %(topic)s node (version %(vcs_string)s)") - % locals()) service_obj = cls(host, binary, topic, manager, report_interval, periodic_interval) @@ -181,6 +181,13 @@ class Service(object): pass self.timers = [] + def wait(self): + for x in self.timers: + try: + x.wait() + except Exception: + pass + def periodic_tasks(self): """Tasks to be run at a periodic interval""" self.manager.periodic_tasks(context.get_admin_context()) @@ -213,12 +220,55 @@ class Service(object): logging.exception(_("model server went away")) -def serve(*services): - FLAGS(sys.argv) - logging.basicConfig() +class WsgiService(object): + """Base class for WSGI based services. + + For each api you define, you must also define these flags: + :<api>_listen: The address on which to listen + :<api>_listen_port: The port on which to listen + """ - if not services: - services = [Service.create()] + def __init__(self, conf, apis): + self.conf = conf + self.apis = apis + self.wsgi_app = None + + def start(self): + self.wsgi_app = _run_wsgi(self.conf, self.apis) + + def wait(self): + self.wsgi_app.wait() + + +class ApiService(WsgiService): + """Class for our nova-api service""" + @classmethod + def create(cls, conf=None): + if not conf: + conf = wsgi.paste_config_file(FLAGS.api_paste_config) + if not conf: + message = (_("No paste configuration found for: %s"), + FLAGS.api_paste_config) + raise exception.Error(message) + api_endpoints = ['ec2', 'osapi'] + service = cls(conf, api_endpoints) + return service + + +def serve(*services): + try: + if not services: + services = [Service.create()] + except Exception: + logging.exception('in Service.create()') + raise + finally: + # After we've loaded up all our dynamic bits, check + # whether we should print help + flags.DEFINE_flag(flags.HelpFlag()) + flags.DEFINE_flag(flags.HelpshortFlag()) + flags.DEFINE_flag(flags.HelpXMLFlag()) + FLAGS.ParseNewFlags() name = '_'.join(x.binary for x in services) logging.debug(_("Serving %s"), name) @@ -234,3 +284,46 @@ def serve(*services): def wait(): while True: greenthread.sleep(5) + + +def serve_wsgi(cls, conf=None): + try: + service = cls.create(conf) + except Exception: + logging.exception('in WsgiService.create()') + raise + finally: + # After we've loaded up all our dynamic bits, check + # whether we should print help + flags.DEFINE_flag(flags.HelpFlag()) + flags.DEFINE_flag(flags.HelpshortFlag()) + flags.DEFINE_flag(flags.HelpXMLFlag()) + FLAGS.ParseNewFlags() + + service.start() + + return service + + +def _run_wsgi(paste_config_file, apis): + logging.debug(_("Using paste.deploy config at: %s"), paste_config_file) + apps = [] + for api in apis: + config = wsgi.load_paste_configuration(paste_config_file, api) + if config is None: + logging.debug(_("No paste configuration for app: %s"), api) + continue + logging.debug(_("App Config: %(api)s\n%(config)r") % locals()) + logging.info(_("Running %s API"), api) + app = wsgi.load_paste_app(paste_config_file, api) + apps.append((app, getattr(FLAGS, "%s_listen_port" % api), + getattr(FLAGS, "%s_listen" % api))) + if len(apps) == 0: + logging.error(_("No known API applications configured in %s."), + paste_config_file) + return + + server = wsgi.Server() + for app in apps: + server.start(*app) + return server diff --git a/nova/test.py b/nova/test.py index a12cf9d32..d8a47464f 100644 --- a/nova/test.py +++ b/nova/test.py @@ -22,10 +22,15 @@ Allows overriding of flags for use of fakes, and some black magic for inline callbacks. """ + import datetime +import os +import shutil +import uuid import unittest import mox +import shutil import stubout from nova import context @@ -33,13 +38,12 @@ from nova import db from nova import fakerabbit from nova import flags from nova import rpc -from nova.network import manager as network_manager -from nova.tests import fake_flags +from nova import service FLAGS = flags.FLAGS -flags.DEFINE_bool('flush_db', True, - 'Flush the database before running fake tests') +flags.DEFINE_string('sqlite_clean_db', 'clean.sqlite', + 'File name of clean sqlite db') flags.DEFINE_bool('fake_tests', True, 'should we use everything for testing') @@ -64,15 +68,8 @@ class TestCase(unittest.TestCase): # now that we have some required db setup for the system # to work properly. self.start = datetime.datetime.utcnow() - ctxt = context.get_admin_context() - if db.network_count(ctxt) != 5: - network_manager.VlanManager().create_networks(ctxt, - FLAGS.fixed_range, - 5, 16, - FLAGS.fixed_range_v6, - FLAGS.vlan_start, - FLAGS.vpn_start, - ) + shutil.copyfile(os.path.join(FLAGS.state_path, FLAGS.sqlite_clean_db), + os.path.join(FLAGS.state_path, FLAGS.sqlite_db)) # emulate some of the mox stuff, we can't use the metaclass # because it screws with our generators @@ -80,6 +77,7 @@ class TestCase(unittest.TestCase): self.stubs = stubout.StubOutForTesting() self.flag_overrides = {} self.injected = [] + self._services = [] self._monkey_patch_attach() self._original_flags = FLAGS.FlagValuesDict() @@ -91,25 +89,31 @@ class TestCase(unittest.TestCase): self.stubs.UnsetAll() self.stubs.SmartUnsetAll() self.mox.VerifyAll() - # NOTE(vish): Clean up any ips associated during the test. - ctxt = context.get_admin_context() - db.fixed_ip_disassociate_all_by_timeout(ctxt, FLAGS.host, - self.start) - db.network_disassociate_all(ctxt) + super(TestCase, self).tearDown() + finally: + # Clean out fake_rabbit's queue if we used it + if FLAGS.fake_rabbit: + fakerabbit.reset_all() + + # Reset any overriden flags + self.reset_flags() + + # Reset our monkey-patches rpc.Consumer.attach_to_eventlet = self.originalAttach + + # Stop any timers for x in self.injected: try: x.stop() except AssertionError: pass - if FLAGS.fake_rabbit: - fakerabbit.reset_all() - - db.security_group_destroy_all(ctxt) - super(TestCase, self).tearDown() - finally: - self.reset_flags() + # Kill any services + for x in self._services: + try: + x.kill() + except Exception: + pass def flags(self, **kw): """Override flag variables for a test""" @@ -127,6 +131,15 @@ class TestCase(unittest.TestCase): for k, v in self._original_flags.iteritems(): setattr(FLAGS, k, v) + def start_service(self, name, host=None, **kwargs): + host = host and host or uuid.uuid4().hex + kwargs.setdefault('host', host) + kwargs.setdefault('binary', 'nova-%s' % name) + svc = service.Service.create(**kwargs) + svc.start() + self._services.append(svc) + return svc + def _monkey_patch_attach(self): self.originalAttach = rpc.Consumer.attach_to_eventlet diff --git a/nova/tests/__init__.py b/nova/tests/__init__.py index 592d5bea9..7fba02a93 100644 --- a/nova/tests/__init__.py +++ b/nova/tests/__init__.py @@ -37,5 +37,30 @@ setattr(__builtin__, '_', lambda x: x) def setup(): + import os + import shutil + + from nova import context + from nova import flags from nova.db import migration + from nova.network import manager as network_manager + from nova.tests import fake_flags + + FLAGS = flags.FLAGS + + testdb = os.path.join(FLAGS.state_path, FLAGS.sqlite_db) + if os.path.exists(testdb): + os.unlink(testdb) migration.db_sync() + ctxt = context.get_admin_context() + network_manager.VlanManager().create_networks(ctxt, + FLAGS.fixed_range, + FLAGS.num_networks, + FLAGS.network_size, + FLAGS.fixed_range_v6, + FLAGS.vlan_start, + FLAGS.vpn_start, + ) + + cleandb = os.path.join(FLAGS.state_path, FLAGS.sqlite_clean_db) + shutil.copyfile(testdb, cleandb) diff --git a/nova/tests/api/openstack/__init__.py b/nova/tests/api/openstack/__init__.py index 14eaaa62c..e18120285 100644 --- a/nova/tests/api/openstack/__init__.py +++ b/nova/tests/api/openstack/__init__.py @@ -16,7 +16,7 @@ # under the License. import webob.dec -import unittest +from nova import test from nova import context from nova import flags @@ -33,7 +33,7 @@ def simple_wsgi(req): return "" -class RateLimitingMiddlewareTest(unittest.TestCase): +class RateLimitingMiddlewareTest(test.TestCase): def test_get_action_name(self): middleware = RateLimitingMiddleware(simple_wsgi) @@ -92,31 +92,3 @@ class RateLimitingMiddlewareTest(unittest.TestCase): self.assertEqual(middleware.limiter.__class__.__name__, "Limiter") middleware = RateLimitingMiddleware(simple_wsgi, service_host='foobar') self.assertEqual(middleware.limiter.__class__.__name__, "WSGIAppProxy") - - -class LimiterTest(unittest.TestCase): - - def test_limiter(self): - items = range(2000) - req = Request.blank('/') - self.assertEqual(limited(items, req), items[:1000]) - req = Request.blank('/?offset=0') - self.assertEqual(limited(items, req), items[:1000]) - req = Request.blank('/?offset=3') - self.assertEqual(limited(items, req), items[3:1003]) - req = Request.blank('/?offset=2005') - self.assertEqual(limited(items, req), []) - req = Request.blank('/?limit=10') - self.assertEqual(limited(items, req), items[:10]) - req = Request.blank('/?limit=0') - self.assertEqual(limited(items, req), items[:1000]) - req = Request.blank('/?limit=3000') - self.assertEqual(limited(items, req), items[:1000]) - req = Request.blank('/?offset=1&limit=3') - self.assertEqual(limited(items, req), items[1:4]) - req = Request.blank('/?offset=3&limit=0') - self.assertEqual(limited(items, req), items[3:1003]) - req = Request.blank('/?offset=3&limit=1500') - self.assertEqual(limited(items, req), items[3:1003]) - req = Request.blank('/?offset=3000&limit=10') - self.assertEqual(limited(items, req), []) diff --git a/nova/tests/api/openstack/common.py b/nova/tests/api/openstack/common.py new file mode 100644 index 000000000..74bb8729a --- /dev/null +++ b/nova/tests/api/openstack/common.py @@ -0,0 +1,36 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import json + +import webob + + +def webob_factory(url): + """Factory for removing duplicate webob code from tests""" + + base_url = url + + def web_request(url, method=None, body=None): + req = webob.Request.blank("%s%s" % (base_url, url)) + if method: + req.content_type = "application/json" + req.method = method + if body: + req.body = json.dumps(body) + return req + return web_request diff --git a/nova/tests/api/openstack/fakes.py b/nova/tests/api/openstack/fakes.py index fb282f1c9..2c4e57246 100644 --- a/nova/tests/api/openstack/fakes.py +++ b/nova/tests/api/openstack/fakes.py @@ -25,6 +25,7 @@ import webob.dec from paste import urlmap from glance import client as glance_client +from glance.common import exception as glance_exc from nova import auth from nova import context @@ -149,25 +150,26 @@ def stub_out_glance(stubs, initial_fixtures=None): for f in self.fixtures: if f['id'] == image_id: return f - return None + raise glance_exc.NotFound - def fake_add_image(self, image_meta): + def fake_add_image(self, image_meta, data=None): id = ''.join(random.choice(string.letters) for _ in range(20)) image_meta['id'] = id self.fixtures.append(image_meta) - return id + return image_meta - def fake_update_image(self, image_id, image_meta): + def fake_update_image(self, image_id, image_meta, data=None): f = self.fake_get_image_meta(image_id) if not f: - raise exc.NotFound + raise glance_exc.NotFound f.update(image_meta) + return f def fake_delete_image(self, image_id): f = self.fake_get_image_meta(image_id) if not f: - raise exc.NotFound + raise glance_exc.NotFound self.fixtures.remove(f) @@ -188,7 +190,11 @@ def stub_out_glance(stubs, initial_fixtures=None): class FakeToken(object): + id = 0 + def __init__(self, **kwargs): + FakeToken.id += 1 + self.id = FakeToken.id for k, v in kwargs.iteritems(): setattr(self, k, v) @@ -203,19 +209,22 @@ class FakeAuthDatabase(object): data = {} @staticmethod - def auth_get_token(context, token_hash): + def auth_token_get(context, token_hash): return FakeAuthDatabase.data.get(token_hash, None) @staticmethod - def auth_create_token(context, token): + def auth_token_create(context, token): fake_token = FakeToken(created_at=datetime.datetime.now(), **token) FakeAuthDatabase.data[fake_token.token_hash] = fake_token + FakeAuthDatabase.data['id_%i' % fake_token.id] = fake_token return fake_token @staticmethod - def auth_destroy_token(context, token): - if token.token_hash in FakeAuthDatabase.data: - del FakeAuthDatabase.data['token_hash'] + def auth_token_destroy(context, token_id): + token = FakeAuthDatabase.data.get('id_%i' % token_id) + if token and token.token_hash in FakeAuthDatabase.data: + del FakeAuthDatabase.data[token.token_hash] + del FakeAuthDatabase.data['id_%i' % token_id] class FakeAuthManager(object): diff --git a/nova/tests/api/openstack/test_adminapi.py b/nova/tests/api/openstack/test_adminapi.py index 73120c31d..dfce1b127 100644 --- a/nova/tests/api/openstack/test_adminapi.py +++ b/nova/tests/api/openstack/test_adminapi.py @@ -15,13 +15,13 @@ # License for the specific language governing permissions and limitations # under the License. -import unittest import stubout import webob from paste import urlmap from nova import flags +from nova import test from nova.api import openstack from nova.api.openstack import ratelimiting from nova.api.openstack import auth @@ -30,9 +30,10 @@ from nova.tests.api.openstack import fakes FLAGS = flags.FLAGS -class AdminAPITest(unittest.TestCase): +class AdminAPITest(test.TestCase): def setUp(self): + super(AdminAPITest, self).setUp() self.stubs = stubout.StubOutForTesting() fakes.FakeAuthManager.auth_data = {} fakes.FakeAuthDatabase.data = {} @@ -44,6 +45,7 @@ class AdminAPITest(unittest.TestCase): def tearDown(self): self.stubs.UnsetAll() FLAGS.allow_admin_api = self.allow_admin + super(AdminAPITest, self).tearDown() def test_admin_enabled(self): FLAGS.allow_admin_api = True @@ -58,8 +60,5 @@ class AdminAPITest(unittest.TestCase): # We should still be able to access public operations. req = webob.Request.blank('/v1.0/flavors') res = req.get_response(fakes.wsgi_app()) - self.assertEqual(res.status_int, 200) # TODO: Confirm admin operations are unavailable. - -if __name__ == '__main__': - unittest.main() + self.assertEqual(res.status_int, 200) diff --git a/nova/tests/api/openstack/test_api.py b/nova/tests/api/openstack/test_api.py index db0fe1060..5112c486f 100644 --- a/nova/tests/api/openstack/test_api.py +++ b/nova/tests/api/openstack/test_api.py @@ -15,17 +15,17 @@ # License for the specific language governing permissions and limitations # under the License. -import unittest import webob.exc import webob.dec from webob import Request +from nova import test from nova.api import openstack from nova.api.openstack import faults -class APITest(unittest.TestCase): +class APITest(test.TestCase): def _wsgi_app(self, inner_app): # simpler version of the app than fakes.wsgi_app diff --git a/nova/tests/api/openstack/test_auth.py b/nova/tests/api/openstack/test_auth.py index 0dd65d321..ff8d42a14 100644 --- a/nova/tests/api/openstack/test_auth.py +++ b/nova/tests/api/openstack/test_auth.py @@ -16,7 +16,6 @@ # under the License. import datetime -import unittest import stubout import webob @@ -27,12 +26,15 @@ import nova.api.openstack.auth import nova.auth.manager from nova import auth from nova import context +from nova import db +from nova import test from nova.tests.api.openstack import fakes -class Test(unittest.TestCase): +class Test(test.TestCase): def setUp(self): + super(Test, self).setUp() self.stubs = stubout.StubOutForTesting() self.stubs.Set(nova.api.openstack.auth.AuthMiddleware, '__init__', fakes.fake_auth_init) @@ -45,6 +47,7 @@ class Test(unittest.TestCase): def tearDown(self): self.stubs.UnsetAll() fakes.fake_data_store = {} + super(Test, self).tearDown() def test_authorize_user(self): f = fakes.FakeAuthManager() @@ -97,10 +100,10 @@ class Test(unittest.TestCase): token_hash=token_hash, created_at=datetime.datetime(1990, 1, 1)) - self.stubs.Set(fakes.FakeAuthDatabase, 'auth_destroy_token', + self.stubs.Set(fakes.FakeAuthDatabase, 'auth_token_destroy', destroy_token_mock) - self.stubs.Set(fakes.FakeAuthDatabase, 'auth_get_token', + self.stubs.Set(fakes.FakeAuthDatabase, 'auth_token_get', bad_token) req = webob.Request.blank('/v1.0/') @@ -128,8 +131,36 @@ class Test(unittest.TestCase): self.assertEqual(result.status, '401 Unauthorized') -class TestLimiter(unittest.TestCase): +class TestFunctional(test.TestCase): + def test_token_expiry(self): + ctx = context.get_admin_context() + tok = db.auth_token_create(ctx, dict( + token_hash='bacon', + cdn_management_url='', + server_management_url='', + storage_url='', + user_id='ham', + )) + + db.auth_token_update(ctx, tok.token_hash, dict( + created_at=datetime.datetime(2000, 1, 1, 12, 0, 0), + )) + + req = webob.Request.blank('/v1.0/') + req.headers['X-Auth-Token'] = 'bacon' + result = req.get_response(fakes.wsgi_app()) + self.assertEqual(result.status, '401 Unauthorized') + + def test_token_doesnotexist(self): + req = webob.Request.blank('/v1.0/') + req.headers['X-Auth-Token'] = 'ham' + result = req.get_response(fakes.wsgi_app()) + self.assertEqual(result.status, '401 Unauthorized') + + +class TestLimiter(test.TestCase): def setUp(self): + super(TestLimiter, self).setUp() self.stubs = stubout.StubOutForTesting() self.stubs.Set(nova.api.openstack.auth.AuthMiddleware, '__init__', fakes.fake_auth_init) @@ -141,6 +172,7 @@ class TestLimiter(unittest.TestCase): def tearDown(self): self.stubs.UnsetAll() fakes.fake_data_store = {} + super(TestLimiter, self).tearDown() def test_authorize_token(self): f = fakes.FakeAuthManager() @@ -161,7 +193,3 @@ class TestLimiter(unittest.TestCase): result = req.get_response(fakes.wsgi_app()) self.assertEqual(result.status, '200 OK') self.assertEqual(result.headers['X-Test-Success'], 'True') - - -if __name__ == '__main__': - unittest.main() diff --git a/nova/tests/api/openstack/test_common.py b/nova/tests/api/openstack/test_common.py new file mode 100644 index 000000000..8f57c5b67 --- /dev/null +++ b/nova/tests/api/openstack/test_common.py @@ -0,0 +1,171 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Test suites for 'common' code used throughout the OpenStack HTTP API. +""" + +import webob.exc + +from webob import Request + +from nova import test +from nova.api.openstack.common import limited + + +class LimiterTest(test.TestCase): + """ + Unit tests for the `nova.api.openstack.common.limited` method which takes + in a list of items and, depending on the 'offset' and 'limit' GET params, + returns a subset or complete set of the given items. + """ + + def setUp(self): + """ + Run before each test. + """ + super(LimiterTest, self).setUp() + self.tiny = range(1) + self.small = range(10) + self.medium = range(1000) + self.large = range(10000) + + def test_limiter_offset_zero(self): + """ + Test offset key works with 0. + """ + req = Request.blank('/?offset=0') + self.assertEqual(limited(self.tiny, req), self.tiny) + self.assertEqual(limited(self.small, req), self.small) + self.assertEqual(limited(self.medium, req), self.medium) + self.assertEqual(limited(self.large, req), self.large[:1000]) + + def test_limiter_offset_medium(self): + """ + Test offset key works with a medium sized number. + """ + req = Request.blank('/?offset=10') + self.assertEqual(limited(self.tiny, req), []) + self.assertEqual(limited(self.small, req), self.small[10:]) + self.assertEqual(limited(self.medium, req), self.medium[10:]) + self.assertEqual(limited(self.large, req), self.large[10:1010]) + + def test_limiter_offset_over_max(self): + """ + Test offset key works with a number over 1000 (max_limit). + """ + req = Request.blank('/?offset=1001') + self.assertEqual(limited(self.tiny, req), []) + self.assertEqual(limited(self.small, req), []) + self.assertEqual(limited(self.medium, req), []) + self.assertEqual(limited(self.large, req), self.large[1001:2001]) + + def test_limiter_offset_blank(self): + """ + Test offset key works with a blank offset. + """ + req = Request.blank('/?offset=') + self.assertRaises(webob.exc.HTTPBadRequest, limited, self.tiny, req) + + def test_limiter_offset_bad(self): + """ + Test offset key works with a BAD offset. + """ + req = Request.blank(u'/?offset=\u0020aa') + self.assertRaises(webob.exc.HTTPBadRequest, limited, self.tiny, req) + + def test_limiter_nothing(self): + """ + Test request with no offset or limit + """ + req = Request.blank('/') + self.assertEqual(limited(self.tiny, req), self.tiny) + self.assertEqual(limited(self.small, req), self.small) + self.assertEqual(limited(self.medium, req), self.medium) + self.assertEqual(limited(self.large, req), self.large[:1000]) + + def test_limiter_limit_zero(self): + """ + Test limit of zero. + """ + req = Request.blank('/?limit=0') + self.assertEqual(limited(self.tiny, req), self.tiny) + self.assertEqual(limited(self.small, req), self.small) + self.assertEqual(limited(self.medium, req), self.medium) + self.assertEqual(limited(self.large, req), self.large[:1000]) + + def test_limiter_limit_medium(self): + """ + Test limit of 10. + """ + req = Request.blank('/?limit=10') + self.assertEqual(limited(self.tiny, req), self.tiny) + self.assertEqual(limited(self.small, req), self.small) + self.assertEqual(limited(self.medium, req), self.medium[:10]) + self.assertEqual(limited(self.large, req), self.large[:10]) + + def test_limiter_limit_over_max(self): + """ + Test limit of 3000. + """ + req = Request.blank('/?limit=3000') + self.assertEqual(limited(self.tiny, req), self.tiny) + self.assertEqual(limited(self.small, req), self.small) + self.assertEqual(limited(self.medium, req), self.medium) + self.assertEqual(limited(self.large, req), self.large[:1000]) + + def test_limiter_limit_and_offset(self): + """ + Test request with both limit and offset. + """ + items = range(2000) + req = Request.blank('/?offset=1&limit=3') + self.assertEqual(limited(items, req), items[1:4]) + req = Request.blank('/?offset=3&limit=0') + self.assertEqual(limited(items, req), items[3:1003]) + req = Request.blank('/?offset=3&limit=1500') + self.assertEqual(limited(items, req), items[3:1003]) + req = Request.blank('/?offset=3000&limit=10') + self.assertEqual(limited(items, req), []) + + def test_limiter_custom_max_limit(self): + """ + Test a max_limit other than 1000. + """ + items = range(2000) + req = Request.blank('/?offset=1&limit=3') + self.assertEqual(limited(items, req, max_limit=2000), items[1:4]) + req = Request.blank('/?offset=3&limit=0') + self.assertEqual(limited(items, req, max_limit=2000), items[3:]) + req = Request.blank('/?offset=3&limit=2500') + self.assertEqual(limited(items, req, max_limit=2000), items[3:]) + req = Request.blank('/?offset=3000&limit=10') + self.assertEqual(limited(items, req, max_limit=2000), []) + + def test_limiter_negative_limit(self): + """ + Test a negative limit. + """ + req = Request.blank('/?limit=-3000') + self.assertRaises(webob.exc.HTTPBadRequest, limited, self.tiny, req) + + def test_limiter_negative_offset(self): + """ + Test a negative offset. + """ + req = Request.blank('/?offset=-30') + self.assertRaises(webob.exc.HTTPBadRequest, limited, self.tiny, req) diff --git a/nova/tests/api/openstack/test_faults.py b/nova/tests/api/openstack/test_faults.py index fda2b5ede..7667753f4 100644 --- a/nova/tests/api/openstack/test_faults.py +++ b/nova/tests/api/openstack/test_faults.py @@ -15,15 +15,15 @@ # License for the specific language governing permissions and limitations # under the License. -import unittest import webob import webob.dec import webob.exc +from nova import test from nova.api.openstack import faults -class TestFaults(unittest.TestCase): +class TestFaults(test.TestCase): def test_fault_parts(self): req = webob.Request.blank('/.xml') diff --git a/nova/tests/api/openstack/test_flavors.py b/nova/tests/api/openstack/test_flavors.py index 1bdaea161..319767bb5 100644 --- a/nova/tests/api/openstack/test_flavors.py +++ b/nova/tests/api/openstack/test_flavors.py @@ -15,34 +15,38 @@ # License for the specific language governing permissions and limitations # under the License. -import unittest - import stubout import webob +from nova import test import nova.api +from nova import context +from nova import db from nova.api.openstack import flavors from nova.tests.api.openstack import fakes -class FlavorsTest(unittest.TestCase): +class FlavorsTest(test.TestCase): def setUp(self): + super(FlavorsTest, self).setUp() self.stubs = stubout.StubOutForTesting() fakes.FakeAuthManager.auth_data = {} fakes.FakeAuthDatabase.data = {} fakes.stub_out_networking(self.stubs) fakes.stub_out_rate_limiting(self.stubs) fakes.stub_out_auth(self.stubs) + self.context = context.get_admin_context() def tearDown(self): self.stubs.UnsetAll() + super(FlavorsTest, self).tearDown() def test_get_flavor_list(self): req = webob.Request.blank('/v1.0/flavors') res = req.get_response(fakes.wsgi_app()) + self.assertEqual(res.status_int, 200) def test_get_flavor_by_id(self): - pass - -if __name__ == '__main__': - unittest.main() + req = webob.Request.blank('/v1.0/flavors/1') + res = req.get_response(fakes.wsgi_app()) + self.assertEqual(res.status_int, 200) diff --git a/nova/tests/api/openstack/test_images.py b/nova/tests/api/openstack/test_images.py index 8ab4d7569..eb5039bdb 100644 --- a/nova/tests/api/openstack/test_images.py +++ b/nova/tests/api/openstack/test_images.py @@ -22,7 +22,8 @@ and as a WSGI layer import json import datetime -import unittest +import shutil +import tempfile import stubout import webob @@ -30,6 +31,7 @@ import webob from nova import context from nova import exception from nova import flags +from nova import test from nova import utils import nova.api.openstack from nova.api.openstack import images @@ -54,7 +56,7 @@ class BaseImageServiceTests(object): num_images = len(self.service.index(self.context)) - id = self.service.create(self.context, fixture) + id = self.service.create(self.context, fixture)['id'] self.assertNotEquals(None, id) self.assertEquals(num_images + 1, @@ -71,7 +73,7 @@ class BaseImageServiceTests(object): num_images = len(self.service.index(self.context)) - id = self.service.create(self.context, fixture) + id = self.service.create(self.context, fixture)['id'] self.assertNotEquals(None, id) @@ -89,7 +91,7 @@ class BaseImageServiceTests(object): 'instance_id': None, 'progress': None} - id = self.service.create(self.context, fixture) + id = self.service.create(self.context, fixture)['id'] fixture['status'] = 'in progress' @@ -118,7 +120,7 @@ class BaseImageServiceTests(object): ids = [] for fixture in fixtures: - new_id = self.service.create(self.context, fixture) + new_id = self.service.create(self.context, fixture)['id'] ids.append(new_id) num_images = len(self.service.index(self.context)) @@ -130,29 +132,33 @@ class BaseImageServiceTests(object): self.assertEquals(1, num_images) -class LocalImageServiceTest(unittest.TestCase, +class LocalImageServiceTest(test.TestCase, BaseImageServiceTests): """Tests the local image service""" def setUp(self): + super(LocalImageServiceTest, self).setUp() + self.tempdir = tempfile.mkdtemp() + self.flags(images_path=self.tempdir) self.stubs = stubout.StubOutForTesting() service_class = 'nova.image.local.LocalImageService' self.service = utils.import_object(service_class) self.context = context.RequestContext(None, None) def tearDown(self): - self.service.delete_all() - self.service.delete_imagedir() + shutil.rmtree(self.tempdir) self.stubs.UnsetAll() + super(LocalImageServiceTest, self).tearDown() -class GlanceImageServiceTest(unittest.TestCase, +class GlanceImageServiceTest(test.TestCase, BaseImageServiceTests): """Tests the local image service""" def setUp(self): + super(GlanceImageServiceTest, self).setUp() self.stubs = stubout.StubOutForTesting() fakes.stub_out_glance(self.stubs) fakes.stub_out_compute_api_snapshot(self.stubs) @@ -163,9 +169,10 @@ class GlanceImageServiceTest(unittest.TestCase, def tearDown(self): self.stubs.UnsetAll() + super(GlanceImageServiceTest, self).tearDown() -class ImageControllerWithGlanceServiceTest(unittest.TestCase): +class ImageControllerWithGlanceServiceTest(test.TestCase): """Test of the OpenStack API /images application controller""" @@ -194,6 +201,7 @@ class ImageControllerWithGlanceServiceTest(unittest.TestCase): 'image_type': 'ramdisk'}] def setUp(self): + super(ImageControllerWithGlanceServiceTest, self).setUp() self.orig_image_service = FLAGS.image_service FLAGS.image_service = 'nova.image.glance.GlanceImageService' self.stubs = stubout.StubOutForTesting() @@ -208,6 +216,7 @@ class ImageControllerWithGlanceServiceTest(unittest.TestCase): def tearDown(self): self.stubs.UnsetAll() FLAGS.image_service = self.orig_image_service + super(ImageControllerWithGlanceServiceTest, self).tearDown() def test_get_image_index(self): req = webob.Request.blank('/v1.0/images') diff --git a/nova/tests/api/openstack/test_ratelimiting.py b/nova/tests/api/openstack/test_ratelimiting.py index 4c9d6bc23..9ae90ee20 100644 --- a/nova/tests/api/openstack/test_ratelimiting.py +++ b/nova/tests/api/openstack/test_ratelimiting.py @@ -1,15 +1,16 @@ import httplib import StringIO import time -import unittest import webob +from nova import test import nova.api.openstack.ratelimiting as ratelimiting -class LimiterTest(unittest.TestCase): +class LimiterTest(test.TestCase): def setUp(self): + super(LimiterTest, self).setUp() self.limits = { 'a': (5, ratelimiting.PER_SECOND), 'b': (5, ratelimiting.PER_MINUTE), @@ -83,9 +84,10 @@ class FakeLimiter(object): return self._delay -class WSGIAppTest(unittest.TestCase): +class WSGIAppTest(test.TestCase): def setUp(self): + super(WSGIAppTest, self).setUp() self.limiter = FakeLimiter(self) self.app = ratelimiting.WSGIApp(self.limiter) @@ -206,7 +208,7 @@ def wire_HTTPConnection_to_WSGI(host, app): httplib.HTTPConnection = HTTPConnectionDecorator(httplib.HTTPConnection) -class WSGIAppProxyTest(unittest.TestCase): +class WSGIAppProxyTest(test.TestCase): def setUp(self): """Our WSGIAppProxy is going to call across an HTTPConnection to a @@ -218,6 +220,7 @@ class WSGIAppProxyTest(unittest.TestCase): at the WSGIApp. And the limiter isn't real -- it's a fake that behaves the way we tell it to. """ + super(WSGIAppProxyTest, self).setUp() self.limiter = FakeLimiter(self) app = ratelimiting.WSGIApp(self.limiter) wire_HTTPConnection_to_WSGI('100.100.100.100:80', app) @@ -238,7 +241,3 @@ class WSGIAppProxyTest(unittest.TestCase): self.limiter.mock('murder', 'brutus', None) self.proxy.perform('stab', 'brutus') self.assertRaises(AssertionError, shouldRaise) - - -if __name__ == '__main__': - unittest.main() diff --git a/nova/tests/api/openstack/test_servers.py b/nova/tests/api/openstack/test_servers.py index 724f14f19..c1e05b18a 100644 --- a/nova/tests/api/openstack/test_servers.py +++ b/nova/tests/api/openstack/test_servers.py @@ -1,6 +1,6 @@ # vim: tabstop=4 shiftwidth=4 softtabstop=4 -# Copyright 2010 OpenStack LLC. +# Copyright 2010-2011 OpenStack LLC. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -15,19 +15,23 @@ # License for the specific language governing permissions and limitations # under the License. +import datetime import json -import unittest import stubout import webob from nova import db from nova import flags +from nova import test import nova.api.openstack from nova.api.openstack import servers +import nova.compute.api import nova.db.api from nova.db.sqlalchemy.models import Instance +from nova.db.sqlalchemy.models import InstanceMetadata import nova.rpc +from nova.tests.api.openstack import common from nova.tests.api.openstack import fakes @@ -39,6 +43,13 @@ def return_server(context, id): return stub_instance(id) +def return_server_with_addresses(private, public): + def _return_server(context, id): + return stub_instance(id, private_address=private, + public_addresses=public) + return _return_server + + def return_servers(context, user_id=1): return [stub_instance(i, user_id) for i in xrange(5)] @@ -55,18 +66,59 @@ def instance_address(context, instance_id): return None -def stub_instance(id, user_id=1): - return Instance(id=id, state=0, image_id=10, user_id=user_id, - display_name='server%s' % id) +def stub_instance(id, user_id=1, private_address=None, public_addresses=None): + metadata = [] + metadata.append(InstanceMetadata(key='seq', value=id)) + + if public_addresses == None: + public_addresses = list() + + instance = { + "id": id, + "admin_pass": "", + "user_id": user_id, + "project_id": "", + "image_id": 10, + "kernel_id": "", + "ramdisk_id": "", + "launch_index": 0, + "key_name": "", + "key_data": "", + "state": 0, + "state_description": "", + "memory_mb": 0, + "vcpus": 0, + "local_gb": 0, + "hostname": "", + "host": None, + "instance_type": "", + "user_data": "", + "reservation_id": "", + "mac_address": "", + "scheduled_at": datetime.datetime.now(), + "launched_at": datetime.datetime.now(), + "terminated_at": datetime.datetime.now(), + "availability_zone": "", + "display_name": "server%s" % id, + "display_description": "", + "locked": False, + "metadata": metadata} + + instance["fixed_ip"] = { + "address": private_address, + "floating_ips": [{"address":ip} for ip in public_addresses]} + + return instance def fake_compute_api(cls, req, id): return True -class ServersTest(unittest.TestCase): +class ServersTest(test.TestCase): def setUp(self): + super(ServersTest, self).setUp() self.stubs = stubout.StubOutForTesting() fakes.FakeAuthManager.auth_data = {} fakes.FakeAuthDatabase.data = {} @@ -94,9 +146,12 @@ class ServersTest(unittest.TestCase): self.stubs.Set(nova.compute.API, "get_actions", fake_compute_api) self.allow_admin = FLAGS.allow_admin_api + self.webreq = common.webob_factory('/v1.0/servers') + def tearDown(self): self.stubs.UnsetAll() FLAGS.allow_admin_api = self.allow_admin + super(ServersTest, self).tearDown() def test_get_server_by_id(self): req = webob.Request.blank('/v1.0/servers/1') @@ -105,6 +160,22 @@ class ServersTest(unittest.TestCase): self.assertEqual(res_dict['server']['id'], '1') self.assertEqual(res_dict['server']['name'], 'server1') + def test_get_server_by_id_with_addresses(self): + private = "192.168.0.3" + public = ["1.2.3.4"] + new_return_server = return_server_with_addresses(private, public) + self.stubs.Set(nova.db.api, 'instance_get', new_return_server) + req = webob.Request.blank('/v1.0/servers/1') + res = req.get_response(fakes.wsgi_app()) + res_dict = json.loads(res.body) + self.assertEqual(res_dict['server']['id'], '1') + self.assertEqual(res_dict['server']['name'], 'server1') + addresses = res_dict['server']['addresses'] + self.assertEqual(len(addresses["public"]), len(public)) + self.assertEqual(addresses["public"][0], public[0]) + self.assertEqual(len(addresses["private"]), 1) + self.assertEqual(addresses["private"][0], private) + def test_get_server_list(self): req = webob.Request.blank('/v1.0/servers') res = req.get_response(fakes.wsgi_app()) @@ -117,9 +188,37 @@ class ServersTest(unittest.TestCase): self.assertEqual(s.get('imageId', None), None) i += 1 + def test_get_servers_with_limit(self): + req = webob.Request.blank('/v1.0/servers?limit=3') + res = req.get_response(fakes.wsgi_app()) + servers = json.loads(res.body)['servers'] + self.assertEqual([s['id'] for s in servers], [0, 1, 2]) + + req = webob.Request.blank('/v1.0/servers?limit=aaa') + res = req.get_response(fakes.wsgi_app()) + self.assertEqual(res.status_int, 400) + self.assertTrue('limit' in res.body) + + def test_get_servers_with_offset(self): + req = webob.Request.blank('/v1.0/servers?offset=2') + res = req.get_response(fakes.wsgi_app()) + servers = json.loads(res.body)['servers'] + self.assertEqual([s['id'] for s in servers], [2, 3, 4]) + + req = webob.Request.blank('/v1.0/servers?offset=aaa') + res = req.get_response(fakes.wsgi_app()) + self.assertEqual(res.status_int, 400) + self.assertTrue('offset' in res.body) + + def test_get_servers_with_limit_and_offset(self): + req = webob.Request.blank('/v1.0/servers?limit=2&offset=1') + res = req.get_response(fakes.wsgi_app()) + servers = json.loads(res.body)['servers'] + self.assertEqual([s['id'] for s in servers], [1, 2]) + def test_create_instance(self): def instance_create(context, inst): - return {'id': '1', 'display_name': ''} + return {'id': '1', 'display_name': 'server_test'} def server_update(context, id, params): return instance_create(context, id) @@ -154,14 +253,22 @@ class ServersTest(unittest.TestCase): "get_image_id_from_image_hash", image_id_from_hash) body = dict(server=dict( - name='server_test', imageId=2, flavorId=2, metadata={}, + name='server_test', imageId=2, flavorId=2, + metadata={'hello': 'world', 'open': 'stack'}, personality={})) req = webob.Request.blank('/v1.0/servers') req.method = 'POST' req.body = json.dumps(body) + req.headers["Content-Type"] = "application/json" res = req.get_response(fakes.wsgi_app()) + server = json.loads(res.body)['server'] + self.assertEqual('serv', server['adminPass'][:4]) + self.assertEqual(16, len(server['adminPass'])) + self.assertEqual('server_test', server['name']) + self.assertEqual('1', server['id']) + self.assertEqual(res.status_int, 200) def test_update_no_body(self): @@ -229,10 +336,45 @@ class ServersTest(unittest.TestCase): i = 0 for s in res_dict['servers']: self.assertEqual(s['id'], i) + self.assertEqual(s['hostId'], '') self.assertEqual(s['name'], 'server%d' % i) self.assertEqual(s['imageId'], 10) + self.assertEqual(s['metadata']['seq'], i) i += 1 + def test_get_all_server_details_with_host(self): + ''' + We want to make sure that if two instances are on the same host, then + they return the same hostId. If two instances are on different hosts, + they should return different hostId's. In this test, there are 5 + instances - 2 on one host and 3 on another. + ''' + + def stub_instance(id, user_id=1): + return Instance(id=id, state=0, image_id=10, user_id=user_id, + display_name='server%s' % id, host='host%s' % (id % 2)) + + def return_servers_with_host(context, user_id=1): + return [stub_instance(i) for i in xrange(5)] + + self.stubs.Set(nova.db.api, 'instance_get_all_by_user', + return_servers_with_host) + + req = webob.Request.blank('/v1.0/servers/detail') + res = req.get_response(fakes.wsgi_app()) + res_dict = json.loads(res.body) + + server_list = res_dict['servers'] + host_ids = [server_list[0]['hostId'], server_list[1]['hostId']] + self.assertTrue(host_ids[0] and host_ids[1]) + self.assertNotEqual(host_ids[0], host_ids[1]) + + for i, s in enumerate(res_dict['servers']): + self.assertEqual(s['id'], i) + self.assertEqual(s['hostId'], host_ids[i % 2]) + self.assertEqual(s['name'], 'server%d' % i) + self.assertEqual(s['imageId'], 10) + def test_server_pause(self): FLAGS.allow_admin_api = True body = dict(server=dict( @@ -281,6 +423,30 @@ class ServersTest(unittest.TestCase): res = req.get_response(fakes.wsgi_app()) self.assertEqual(res.status_int, 202) + def test_server_reset_network(self): + FLAGS.allow_admin_api = True + body = dict(server=dict( + name='server_test', imageId=2, flavorId=2, metadata={}, + personality={})) + req = webob.Request.blank('/v1.0/servers/1/reset_network') + req.method = 'POST' + req.content_type = 'application/json' + req.body = json.dumps(body) + res = req.get_response(fakes.wsgi_app()) + self.assertEqual(res.status_int, 202) + + def test_server_inject_network_info(self): + FLAGS.allow_admin_api = True + body = dict(server=dict( + name='server_test', imageId=2, flavorId=2, metadata={}, + personality={})) + req = webob.Request.blank('/v1.0/servers/1/inject_network_info') + req.method = 'POST' + req.content_type = 'application/json' + req.body = json.dumps(body) + res = req.get_response(fakes.wsgi_app()) + self.assertEqual(res.status_int, 202) + def test_server_diagnostics(self): req = webob.Request.blank("/v1.0/servers/1/diagnostics") req.method = "GET" @@ -339,6 +505,98 @@ class ServersTest(unittest.TestCase): self.assertEqual(res.status, '202 Accepted') self.assertEqual(self.server_delete_called, True) + def test_resize_server(self): + req = self.webreq('/1/action', 'POST', dict(resize=dict(flavorId=3))) + + self.resize_called = False + + def resize_mock(*args): + self.resize_called = True + + self.stubs.Set(nova.compute.api.API, 'resize', resize_mock) + + res = req.get_response(fakes.wsgi_app()) + self.assertEqual(res.status_int, 202) + self.assertEqual(self.resize_called, True) + + def test_resize_bad_flavor_fails(self): + req = self.webreq('/1/action', 'POST', dict(resize=dict(derp=3))) + + self.resize_called = False + + def resize_mock(*args): + self.resize_called = True + + self.stubs.Set(nova.compute.api.API, 'resize', resize_mock) + + res = req.get_response(fakes.wsgi_app()) + self.assertEqual(res.status_int, 422) + self.assertEqual(self.resize_called, False) + + def test_resize_raises_fails(self): + req = self.webreq('/1/action', 'POST', dict(resize=dict(flavorId=3))) + + def resize_mock(*args): + raise Exception('hurr durr') + + self.stubs.Set(nova.compute.api.API, 'resize', resize_mock) + + res = req.get_response(fakes.wsgi_app()) + self.assertEqual(res.status_int, 400) + + def test_confirm_resize_server(self): + req = self.webreq('/1/action', 'POST', dict(confirmResize=None)) + + self.resize_called = False + + def confirm_resize_mock(*args): + self.resize_called = True + + self.stubs.Set(nova.compute.api.API, 'confirm_resize', + confirm_resize_mock) + + res = req.get_response(fakes.wsgi_app()) + self.assertEqual(res.status_int, 204) + self.assertEqual(self.resize_called, True) + + def test_confirm_resize_server_fails(self): + req = self.webreq('/1/action', 'POST', dict(confirmResize=None)) + + def confirm_resize_mock(*args): + raise Exception('hurr durr') + + self.stubs.Set(nova.compute.api.API, 'confirm_resize', + confirm_resize_mock) + + res = req.get_response(fakes.wsgi_app()) + self.assertEqual(res.status_int, 400) + + def test_revert_resize_server(self): + req = self.webreq('/1/action', 'POST', dict(revertResize=None)) + + self.resize_called = False + + def revert_resize_mock(*args): + self.resize_called = True + + self.stubs.Set(nova.compute.api.API, 'revert_resize', + revert_resize_mock) + + res = req.get_response(fakes.wsgi_app()) + self.assertEqual(res.status_int, 202) + self.assertEqual(self.resize_called, True) + + def test_revert_resize_server_fails(self): + req = self.webreq('/1/action', 'POST', dict(revertResize=None)) + + def revert_resize_mock(*args): + raise Exception('hurr durr') + + self.stubs.Set(nova.compute.api.API, 'revert_resize', + revert_resize_mock) + + res = req.get_response(fakes.wsgi_app()) + self.assertEqual(res.status_int, 400) if __name__ == "__main__": unittest.main() diff --git a/nova/tests/api/openstack/test_shared_ip_groups.py b/nova/tests/api/openstack/test_shared_ip_groups.py index c2fc3a203..b4de2ef41 100644 --- a/nova/tests/api/openstack/test_shared_ip_groups.py +++ b/nova/tests/api/openstack/test_shared_ip_groups.py @@ -15,19 +15,20 @@ # License for the specific language governing permissions and limitations # under the License. -import unittest - import stubout +from nova import test from nova.api.openstack import shared_ip_groups -class SharedIpGroupsTest(unittest.TestCase): +class SharedIpGroupsTest(test.TestCase): def setUp(self): + super(SharedIpGroupsTest, self).setUp() self.stubs = stubout.StubOutForTesting() def tearDown(self): self.stubs.UnsetAll() + super(SharedIpGroupsTest, self).tearDown() def test_get_shared_ip_groups(self): pass diff --git a/nova/tests/api/openstack/test_zones.py b/nova/tests/api/openstack/test_zones.py new file mode 100644 index 000000000..4f4fabf12 --- /dev/null +++ b/nova/tests/api/openstack/test_zones.py @@ -0,0 +1,169 @@ +# Copyright 2011 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import stubout +import webob +import json + +import nova.db +from nova import context +from nova import flags +from nova import test +from nova.api.openstack import zones +from nova.tests.api.openstack import fakes +from nova.scheduler import api + + +FLAGS = flags.FLAGS +FLAGS.verbose = True + + +def zone_get(context, zone_id): + return dict(id=1, api_url='http://example.com', username='bob', + password='xxx') + + +def zone_create(context, values): + zone = dict(id=1) + zone.update(values) + return zone + + +def zone_update(context, zone_id, values): + zone = dict(id=zone_id, api_url='http://example.com', username='bob', + password='xxx') + zone.update(values) + return zone + + +def zone_delete(context, zone_id): + pass + + +def zone_get_all_scheduler(*args): + return [ + dict(id=1, api_url='http://example.com', username='bob', + password='xxx'), + dict(id=2, api_url='http://example.org', username='alice', + password='qwerty') + ] + + +def zone_get_all_scheduler_empty(*args): + return [] + + +def zone_get_all_db(context): + return [ + dict(id=1, api_url='http://example.com', username='bob', + password='xxx'), + dict(id=2, api_url='http://example.org', username='alice', + password='qwerty') + ] + + +class ZonesTest(test.TestCase): + def setUp(self): + super(ZonesTest, self).setUp() + self.stubs = stubout.StubOutForTesting() + fakes.FakeAuthManager.auth_data = {} + fakes.FakeAuthDatabase.data = {} + fakes.stub_out_networking(self.stubs) + fakes.stub_out_rate_limiting(self.stubs) + fakes.stub_out_auth(self.stubs) + + self.allow_admin = FLAGS.allow_admin_api + FLAGS.allow_admin_api = True + + self.stubs.Set(nova.db, 'zone_get', zone_get) + self.stubs.Set(nova.db, 'zone_update', zone_update) + self.stubs.Set(nova.db, 'zone_create', zone_create) + self.stubs.Set(nova.db, 'zone_delete', zone_delete) + + def tearDown(self): + self.stubs.UnsetAll() + FLAGS.allow_admin_api = self.allow_admin + super(ZonesTest, self).tearDown() + + def test_get_zone_list_scheduler(self): + self.stubs.Set(api.API, '_call_scheduler', zone_get_all_scheduler) + req = webob.Request.blank('/v1.0/zones') + res = req.get_response(fakes.wsgi_app()) + res_dict = json.loads(res.body) + + self.assertEqual(res.status_int, 200) + self.assertEqual(len(res_dict['zones']), 2) + + def test_get_zone_list_db(self): + self.stubs.Set(api.API, '_call_scheduler', + zone_get_all_scheduler_empty) + self.stubs.Set(nova.db, 'zone_get_all', zone_get_all_db) + req = webob.Request.blank('/v1.0/zones') + req.headers["Content-Type"] = "application/json" + res = req.get_response(fakes.wsgi_app()) + + self.assertEqual(res.status_int, 200) + res_dict = json.loads(res.body) + self.assertEqual(len(res_dict['zones']), 2) + + def test_get_zone_by_id(self): + req = webob.Request.blank('/v1.0/zones/1') + req.headers["Content-Type"] = "application/json" + res = req.get_response(fakes.wsgi_app()) + + self.assertEqual(res.status_int, 200) + res_dict = json.loads(res.body) + self.assertEqual(res_dict['zone']['id'], 1) + self.assertEqual(res_dict['zone']['api_url'], 'http://example.com') + self.assertFalse('password' in res_dict['zone']) + + def test_zone_delete(self): + req = webob.Request.blank('/v1.0/zones/1') + req.headers["Content-Type"] = "application/json" + res = req.get_response(fakes.wsgi_app()) + + self.assertEqual(res.status_int, 200) + + def test_zone_create(self): + body = dict(zone=dict(api_url='http://example.com', username='fred', + password='fubar')) + req = webob.Request.blank('/v1.0/zones') + req.headers["Content-Type"] = "application/json" + req.method = 'POST' + req.body = json.dumps(body) + + res = req.get_response(fakes.wsgi_app()) + + self.assertEqual(res.status_int, 200) + res_dict = json.loads(res.body) + self.assertEqual(res_dict['zone']['id'], 1) + self.assertEqual(res_dict['zone']['api_url'], 'http://example.com') + self.assertFalse('username' in res_dict['zone']) + + def test_zone_update(self): + body = dict(zone=dict(username='zeb', password='sneaky')) + req = webob.Request.blank('/v1.0/zones/1') + req.headers["Content-Type"] = "application/json" + req.method = 'PUT' + req.body = json.dumps(body) + + res = req.get_response(fakes.wsgi_app()) + + self.assertEqual(res.status_int, 200) + res_dict = json.loads(res.body) + self.assertEqual(res_dict['zone']['id'], 1) + self.assertEqual(res_dict['zone']['api_url'], 'http://example.com') + self.assertFalse('username' in res_dict['zone']) diff --git a/nova/tests/api/test_wsgi.py b/nova/tests/api/test_wsgi.py index 44e2d615c..b1a849cf9 100644 --- a/nova/tests/api/test_wsgi.py +++ b/nova/tests/api/test_wsgi.py @@ -21,15 +21,17 @@ Test WSGI basics and provide some helper functions for other WSGI tests. """ -import unittest +import json +from nova import test import routes import webob +from nova import exception from nova import wsgi -class Test(unittest.TestCase): +class Test(test.TestCase): def test_debug(self): @@ -66,63 +68,164 @@ class Test(unittest.TestCase): result = webob.Request.blank('/bad').get_response(Router()) self.assertNotEqual(result.body, "Router result") - def test_controller(self): - class Controller(wsgi.Controller): - """Test controller to call from router.""" - test = self +class ControllerTest(test.TestCase): - def show(self, req, id): # pylint: disable-msg=W0622,C0103 - """Default action called for requests with an ID.""" - self.test.assertEqual(req.path_info, '/tests/123') - self.test.assertEqual(id, '123') - return id - - class Router(wsgi.Router): - """Test router.""" - - def __init__(self): - mapper = routes.Mapper() - mapper.resource("test", "tests", controller=Controller()) - super(Router, self).__init__(mapper) + class TestRouter(wsgi.Router): - result = webob.Request.blank('/tests/123').get_response(Router()) - self.assertEqual(result.body, "123") - result = webob.Request.blank('/test/123').get_response(Router()) - self.assertNotEqual(result.body, "123") + class TestController(wsgi.Controller): + _serialization_metadata = { + 'application/xml': { + "attributes": { + "test": ["id"]}}} -class SerializerTest(unittest.TestCase): - - def match(self, url, accept, expect): + def show(self, req, id): # pylint: disable-msg=W0622,C0103 + return {"test": {"id": id}} + + def __init__(self): + mapper = routes.Mapper() + mapper.resource("test", "tests", controller=self.TestController()) + wsgi.Router.__init__(self, mapper) + + def test_show(self): + request = wsgi.Request.blank('/tests/123') + result = request.get_response(self.TestRouter()) + self.assertEqual(json.loads(result.body), {"test": {"id": "123"}}) + + def test_response_content_type_from_accept_xml(self): + request = webob.Request.blank('/tests/123') + request.headers["Accept"] = "application/xml" + result = request.get_response(self.TestRouter()) + self.assertEqual(result.headers["Content-Type"], "application/xml") + + def test_response_content_type_from_accept_json(self): + request = wsgi.Request.blank('/tests/123') + request.headers["Accept"] = "application/json" + result = request.get_response(self.TestRouter()) + self.assertEqual(result.headers["Content-Type"], "application/json") + + def test_response_content_type_from_query_extension_xml(self): + request = wsgi.Request.blank('/tests/123.xml') + result = request.get_response(self.TestRouter()) + self.assertEqual(result.headers["Content-Type"], "application/xml") + + def test_response_content_type_from_query_extension_json(self): + request = wsgi.Request.blank('/tests/123.json') + result = request.get_response(self.TestRouter()) + self.assertEqual(result.headers["Content-Type"], "application/json") + + def test_response_content_type_default_when_unsupported(self): + request = wsgi.Request.blank('/tests/123.unsupported') + request.headers["Accept"] = "application/unsupported1" + result = request.get_response(self.TestRouter()) + self.assertEqual(result.status_int, 200) + self.assertEqual(result.headers["Content-Type"], "application/json") + + +class RequestTest(test.TestCase): + + def test_request_content_type_missing(self): + request = wsgi.Request.blank('/tests/123') + request.body = "<body />" + self.assertRaises(webob.exc.HTTPBadRequest, request.get_content_type) + + def test_request_content_type_unsupported(self): + request = wsgi.Request.blank('/tests/123') + request.headers["Content-Type"] = "text/html" + request.body = "asdf<br />" + self.assertRaises(webob.exc.HTTPBadRequest, request.get_content_type) + + def test_content_type_from_accept_xml(self): + request = wsgi.Request.blank('/tests/123') + request.headers["Accept"] = "application/xml" + result = request.best_match_content_type() + self.assertEqual(result, "application/xml") + + request = wsgi.Request.blank('/tests/123') + request.headers["Accept"] = "application/json" + result = request.best_match_content_type() + self.assertEqual(result, "application/json") + + request = wsgi.Request.blank('/tests/123') + request.headers["Accept"] = "application/xml, application/json" + result = request.best_match_content_type() + self.assertEqual(result, "application/json") + + request = wsgi.Request.blank('/tests/123') + request.headers["Accept"] = \ + "application/json; q=0.3, application/xml; q=0.9" + result = request.best_match_content_type() + self.assertEqual(result, "application/xml") + + def test_content_type_from_query_extension(self): + request = wsgi.Request.blank('/tests/123.xml') + result = request.best_match_content_type() + self.assertEqual(result, "application/xml") + + request = wsgi.Request.blank('/tests/123.json') + result = request.best_match_content_type() + self.assertEqual(result, "application/json") + + request = wsgi.Request.blank('/tests/123.invalid') + result = request.best_match_content_type() + self.assertEqual(result, "application/json") + + def test_content_type_accept_and_query_extension(self): + request = wsgi.Request.blank('/tests/123.xml') + request.headers["Accept"] = "application/json" + result = request.best_match_content_type() + self.assertEqual(result, "application/xml") + + def test_content_type_accept_default(self): + request = wsgi.Request.blank('/tests/123.unsupported') + request.headers["Accept"] = "application/unsupported1" + result = request.best_match_content_type() + self.assertEqual(result, "application/json") + + +class SerializerTest(test.TestCase): + + def test_xml(self): input_dict = dict(servers=dict(a=(2, 3))) expected_xml = '<servers><a>(2,3)</a></servers>' + serializer = wsgi.Serializer() + result = serializer.serialize(input_dict, "application/xml") + result = result.replace('\n', '').replace(' ', '') + self.assertEqual(result, expected_xml) + + def test_json(self): + input_dict = dict(servers=dict(a=(2, 3))) expected_json = '{"servers":{"a":[2,3]}}' - req = webob.Request.blank(url, headers=dict(Accept=accept)) - result = wsgi.Serializer(req.environ).to_content_type(input_dict) + serializer = wsgi.Serializer() + result = serializer.serialize(input_dict, "application/json") result = result.replace('\n', '').replace(' ', '') - if expect == 'xml': - self.assertEqual(result, expected_xml) - elif expect == 'json': - self.assertEqual(result, expected_json) - else: - raise "Bad expect value" - - def test_basic(self): - self.match('/servers/4.json', None, expect='json') - self.match('/servers/4', 'application/json', expect='json') - self.match('/servers/4', 'application/xml', expect='xml') - self.match('/servers/4.xml', None, expect='xml') - - def test_defaults_to_json(self): - self.match('/servers/4', None, expect='json') - self.match('/servers/4', 'text/html', expect='json') - - def test_suffix_takes_precedence_over_accept_header(self): - self.match('/servers/4.xml', 'application/json', expect='xml') - self.match('/servers/4.xml.', 'application/json', expect='json') - - def test_deserialize(self): + self.assertEqual(result, expected_json) + + def test_unsupported_content_type(self): + serializer = wsgi.Serializer() + self.assertRaises(exception.InvalidContentType, serializer.serialize, + {}, "text/null") + + def test_deserialize_json(self): + data = """{"a": { + "a1": "1", + "a2": "2", + "bs": ["1", "2", "3", {"c": {"c1": "1"}}], + "d": {"e": "1"}, + "f": "1"}}""" + as_dict = dict(a={ + 'a1': '1', + 'a2': '2', + 'bs': ['1', '2', '3', {'c': dict(c1='1')}], + 'd': {'e': '1'}, + 'f': '1'}) + metadata = {} + serializer = wsgi.Serializer(metadata) + self.assertEqual(serializer.deserialize(data, "application/json"), + as_dict) + + def test_deserialize_xml(self): xml = """ <a a1="1" a2="2"> <bs><b>1</b><b>2</b><b>3</b><b><c c1="1"/></b></bs> @@ -137,11 +240,13 @@ class SerializerTest(unittest.TestCase): 'd': {'e': '1'}, 'f': '1'}) metadata = {'application/xml': dict(plurals={'bs': 'b', 'ts': 't'})} - serializer = wsgi.Serializer({}, metadata) - self.assertEqual(serializer.deserialize(xml), as_dict) + serializer = wsgi.Serializer(metadata) + self.assertEqual(serializer.deserialize(xml, "application/xml"), + as_dict) def test_deserialize_empty_xml(self): xml = """<a></a>""" as_dict = {"a": {}} - serializer = wsgi.Serializer({}) - self.assertEqual(serializer.deserialize(xml), as_dict) + serializer = wsgi.Serializer() + self.assertEqual(serializer.deserialize(xml, "application/xml"), + as_dict) diff --git a/nova/tests/db/fakes.py b/nova/tests/db/fakes.py index 05bdd172e..d760dc456 100644 --- a/nova/tests/db/fakes.py +++ b/nova/tests/db/fakes.py @@ -20,13 +20,22 @@ import time from nova import db +from nova import test from nova import utils -from nova.compute import instance_types def stub_out_db_instance_api(stubs): """ Stubs out the db API for creating Instances """ + INSTANCE_TYPES = { + 'm1.tiny': dict(memory_mb=512, vcpus=1, local_gb=0, flavorid=1), + 'm1.small': dict(memory_mb=2048, vcpus=1, local_gb=20, flavorid=2), + 'm1.medium': + dict(memory_mb=4096, vcpus=2, local_gb=40, flavorid=3), + 'm1.large': dict(memory_mb=8192, vcpus=4, local_gb=80, flavorid=4), + 'm1.xlarge': + dict(memory_mb=16384, vcpus=8, local_gb=160, flavorid=5)} + class FakeModel(object): """ Stubs out for model """ def __init__(self, values): @@ -41,10 +50,16 @@ def stub_out_db_instance_api(stubs): else: raise NotImplementedError() + def fake_instance_type_get_all(context, inactive=0): + return INSTANCE_TYPES + + def fake_instance_type_get_by_name(context, name): + return INSTANCE_TYPES[name] + def fake_instance_create(values): """ Stubs out the db.instance_create method """ - type_data = instance_types.INSTANCE_TYPES[values['instance_type']] + type_data = INSTANCE_TYPES[values['instance_type']] base_options = { 'name': values['name'], @@ -73,3 +88,5 @@ def stub_out_db_instance_api(stubs): stubs.Set(db, 'instance_create', fake_instance_create) stubs.Set(db, 'network_get_by_instance', fake_network_get_by_instance) + stubs.Set(db, 'instance_type_get_all', fake_instance_type_get_all) + stubs.Set(db, 'instance_type_get_by_name', fake_instance_type_get_by_name) diff --git a/nova/tests/fake_flags.py b/nova/tests/fake_flags.py index 1097488ec..5d7ca98b5 100644 --- a/nova/tests/fake_flags.py +++ b/nova/tests/fake_flags.py @@ -29,9 +29,10 @@ FLAGS.auth_driver = 'nova.auth.dbdriver.DbDriver' flags.DECLARE('network_size', 'nova.network.manager') flags.DECLARE('num_networks', 'nova.network.manager') flags.DECLARE('fake_network', 'nova.network.manager') -FLAGS.network_size = 16 -FLAGS.num_networks = 5 +FLAGS.network_size = 8 +FLAGS.num_networks = 2 FLAGS.fake_network = True +FLAGS.image_service = 'nova.image.local.LocalImageService' flags.DECLARE('num_shelves', 'nova.volume.driver') flags.DECLARE('blades_per_shelf', 'nova.volume.driver') flags.DECLARE('iscsi_num_targets', 'nova.volume.driver') @@ -39,5 +40,5 @@ FLAGS.num_shelves = 2 FLAGS.blades_per_shelf = 4 FLAGS.iscsi_num_targets = 8 FLAGS.verbose = True -FLAGS.sql_connection = 'sqlite:///nova.sqlite' +FLAGS.sqlite_db = "tests.sqlite" FLAGS.use_ipv6 = True diff --git a/nova/tests/glance/stubs.py b/nova/tests/glance/stubs.py index f182b857a..5872552ec 100644 --- a/nova/tests/glance/stubs.py +++ b/nova/tests/glance/stubs.py @@ -26,12 +26,45 @@ def stubout_glance_client(stubs, cls): class FakeGlance(object): + IMAGE_MACHINE = 1 + IMAGE_KERNEL = 2 + IMAGE_RAMDISK = 3 + IMAGE_RAW = 4 + IMAGE_VHD = 5 + + IMAGE_FIXTURES = { + IMAGE_MACHINE: { + 'image_meta': {'name': 'fakemachine', 'size': 0, + 'disk_format': 'ami', + 'container_format': 'ami'}, + 'image_data': StringIO.StringIO('')}, + IMAGE_KERNEL: { + 'image_meta': {'name': 'fakekernel', 'size': 0, + 'disk_format': 'aki', + 'container_format': 'aki'}, + 'image_data': StringIO.StringIO('')}, + IMAGE_RAMDISK: { + 'image_meta': {'name': 'fakeramdisk', 'size': 0, + 'disk_format': 'ari', + 'container_format': 'ari'}, + 'image_data': StringIO.StringIO('')}, + IMAGE_RAW: { + 'image_meta': {'name': 'fakeraw', 'size': 0, + 'disk_format': 'raw', + 'container_format': 'bare'}, + 'image_data': StringIO.StringIO('')}, + IMAGE_VHD: { + 'image_meta': {'name': 'fakevhd', 'size': 0, + 'disk_format': 'vhd', + 'container_format': 'ovf'}, + 'image_data': StringIO.StringIO('')}} + def __init__(self, host, port=None, use_ssl=False): pass - def get_image(self, image): - meta = { - 'size': 0, - } - image_file = StringIO.StringIO('') - return meta, image_file + def get_image_meta(self, image_id): + return self.IMAGE_FIXTURES[image_id]['image_meta'] + + def get_image(self, image_id): + image = self.IMAGE_FIXTURES[image_id] + return image['image_meta'], image['image_data'] diff --git a/nova/tests/objectstore_unittest.py b/nova/tests/objectstore_unittest.py index da86e6e11..5a1be08eb 100644 --- a/nova/tests/objectstore_unittest.py +++ b/nova/tests/objectstore_unittest.py @@ -311,4 +311,5 @@ class S3APITestCase(test.TestCase): self.auth_manager.delete_user('admin') self.auth_manager.delete_project('admin') stop_listening = defer.maybeDeferred(self.listening_port.stopListening) + super(S3APITestCase, self).tearDown() return defer.DeferredList([stop_listening]) diff --git a/nova/tests/test_api.py b/nova/tests/test_api.py index 2569e262b..d5c54a1c3 100644 --- a/nova/tests/test_api.py +++ b/nova/tests/test_api.py @@ -20,6 +20,7 @@ import boto from boto.ec2 import regioninfo +import datetime import httplib import random import StringIO @@ -127,6 +128,28 @@ class ApiEc2TestCase(test.TestCase): self.ec2.new_http_connection(host, is_secure).AndReturn(self.http) return self.http + def test_return_valid_isoformat(self): + """ + Ensure that the ec2 api returns datetime in xs:dateTime + (which apparently isn't datetime.isoformat()) + NOTE(ken-pepple): https://bugs.launchpad.net/nova/+bug/721297 + """ + conv = apirequest._database_to_isoformat + # sqlite database representation with microseconds + time_to_convert = datetime.datetime.strptime( + "2011-02-21 20:14:10.634276", + "%Y-%m-%d %H:%M:%S.%f") + self.assertEqual( + conv(time_to_convert), + '2011-02-21T20:14:10Z') + # mysqlite database representation + time_to_convert = datetime.datetime.strptime( + "2011-02-21 19:56:18", + "%Y-%m-%d %H:%M:%S") + self.assertEqual( + conv(time_to_convert), + '2011-02-21T19:56:18Z') + def test_xmlns_version_matches_request_version(self): self.expect_http(api_version='2010-10-30') self.mox.ReplayAll() @@ -248,16 +271,14 @@ class ApiEc2TestCase(test.TestCase): self.mox.ReplayAll() rv = self.ec2.get_all_security_groups() - # I don't bother checkng that we actually find it here, - # because the create/delete unit test further up should - # be good enough for that. - for group in rv: - if group.name == security_group_name: - self.assertEquals(len(group.rules), 1) - self.assertEquals(int(group.rules[0].from_port), 80) - self.assertEquals(int(group.rules[0].to_port), 81) - self.assertEquals(len(group.rules[0].grants), 1) - self.assertEquals(str(group.rules[0].grants[0]), '0.0.0.0/0') + + group = [grp for grp in rv if grp.name == security_group_name][0] + + self.assertEquals(len(group.rules), 1) + self.assertEquals(int(group.rules[0].from_port), 80) + self.assertEquals(int(group.rules[0].to_port), 81) + self.assertEquals(len(group.rules[0].grants), 1) + self.assertEquals(str(group.rules[0].grants[0]), '0.0.0.0/0') self.expect_http() self.mox.ReplayAll() @@ -314,16 +335,13 @@ class ApiEc2TestCase(test.TestCase): self.mox.ReplayAll() rv = self.ec2.get_all_security_groups() - # I don't bother checkng that we actually find it here, - # because the create/delete unit test further up should - # be good enough for that. - for group in rv: - if group.name == security_group_name: - self.assertEquals(len(group.rules), 1) - self.assertEquals(int(group.rules[0].from_port), 80) - self.assertEquals(int(group.rules[0].to_port), 81) - self.assertEquals(len(group.rules[0].grants), 1) - self.assertEquals(str(group.rules[0].grants[0]), '::/0') + + group = [grp for grp in rv if grp.name == security_group_name][0] + self.assertEquals(len(group.rules), 1) + self.assertEquals(int(group.rules[0].from_port), 80) + self.assertEquals(int(group.rules[0].to_port), 81) + self.assertEquals(len(group.rules[0].grants), 1) + self.assertEquals(str(group.rules[0].grants[0]), '::/0') self.expect_http() self.mox.ReplayAll() diff --git a/nova/tests/test_auth.py b/nova/tests/test_auth.py index 35ffffb67..2a7817032 100644 --- a/nova/tests/test_auth.py +++ b/nova/tests/test_auth.py @@ -327,15 +327,6 @@ class AuthManagerTestCase(object): class AuthManagerLdapTestCase(AuthManagerTestCase, test.TestCase): auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver' - def __init__(self, *args, **kwargs): - AuthManagerTestCase.__init__(self) - test.TestCase.__init__(self, *args, **kwargs) - import nova.auth.fakeldap as fakeldap - if FLAGS.flush_db: - LOG.info("Flushing datastore") - r = fakeldap.Store.instance() - r.flushdb() - class AuthManagerDbTestCase(AuthManagerTestCase, test.TestCase): auth_driver = 'nova.auth.dbdriver.DbDriver' diff --git a/nova/tests/test_cloud.py b/nova/tests/test_cloud.py index 445cc6e8b..cf8ee7eff 100644 --- a/nova/tests/test_cloud.py +++ b/nova/tests/test_cloud.py @@ -38,6 +38,8 @@ from nova import test from nova.auth import manager from nova.compute import power_state from nova.api.ec2 import cloud +from nova.api.ec2 import ec2utils +from nova.image import local from nova.objectstore import image @@ -65,18 +67,27 @@ class CloudTestCase(test.TestCase): self.cloud = cloud.CloudController() # set up services - self.compute = service.Service.create(binary='nova-compute') - self.compute.start() - self.network = service.Service.create(binary='nova-network') - self.network.start() + self.compute = self.start_service('compute') + self.scheduter = self.start_service('scheduler') + self.network = self.start_service('network') self.manager = manager.AuthManager() self.user = self.manager.create_user('admin', 'admin', 'admin', True) self.project = self.manager.create_project('proj', 'admin', 'proj') self.context = context.RequestContext(user=self.user, project=self.project) + host = self.network.get_network_host(self.context.elevated()) + + def fake_show(meh, context, id): + return {'id': 1, 'properties': {'kernel_id': 1, 'ramdisk_id': 1}} + + self.stubs.Set(local.LocalImageService, 'show', fake_show) + self.stubs.Set(local.LocalImageService, 'show_by_name', fake_show) def tearDown(self): + network_ref = db.project_get_network(self.context, + self.project.id) + db.network_disassociate(self.context, network_ref['id']) self.manager.delete_project(self.project) self.manager.delete_user(self.user) self.compute.kill() @@ -102,7 +113,7 @@ class CloudTestCase(test.TestCase): address = "10.10.10.10" db.floating_ip_create(self.context, {'address': address, - 'host': FLAGS.host}) + 'host': self.network.host}) self.cloud.allocate_address(self.context) self.cloud.describe_addresses(self.context) self.cloud.release_address(self.context, @@ -115,11 +126,11 @@ class CloudTestCase(test.TestCase): address = "10.10.10.10" db.floating_ip_create(self.context, {'address': address, - 'host': FLAGS.host}) + 'host': self.network.host}) self.cloud.allocate_address(self.context) - inst = db.instance_create(self.context, {'host': FLAGS.host}) + inst = db.instance_create(self.context, {'host': self.compute.host}) fixed = self.network.allocate_fixed_ip(self.context, inst['id']) - ec2_id = cloud.id_to_ec2_id(inst['id']) + ec2_id = ec2utils.id_to_ec2_id(inst['id']) self.cloud.associate_address(self.context, instance_id=ec2_id, public_ip=address) @@ -133,18 +144,34 @@ class CloudTestCase(test.TestCase): db.instance_destroy(self.context, inst['id']) db.floating_ip_destroy(self.context, address) + def test_describe_security_groups(self): + """Makes sure describe_security_groups works and filters results.""" + sec = db.security_group_create(self.context, + {'project_id': self.context.project_id, + 'name': 'test'}) + result = self.cloud.describe_security_groups(self.context) + # NOTE(vish): should have the default group as well + self.assertEqual(len(result['securityGroupInfo']), 2) + result = self.cloud.describe_security_groups(self.context, + group_name=[sec['name']]) + self.assertEqual(len(result['securityGroupInfo']), 1) + self.assertEqual( + result['securityGroupInfo'][0]['groupName'], + sec['name']) + db.security_group_destroy(self.context, sec['id']) + def test_describe_volumes(self): """Makes sure describe_volumes works and filters results.""" vol1 = db.volume_create(self.context, {}) vol2 = db.volume_create(self.context, {}) result = self.cloud.describe_volumes(self.context) self.assertEqual(len(result['volumeSet']), 2) - volume_id = cloud.id_to_ec2_id(vol2['id'], 'vol-%08x') + volume_id = ec2utils.id_to_ec2_id(vol2['id'], 'vol-%08x') result = self.cloud.describe_volumes(self.context, volume_id=[volume_id]) self.assertEqual(len(result['volumeSet']), 1) self.assertEqual( - cloud.ec2_id_to_id(result['volumeSet'][0]['volumeId']), + ec2utils.ec2_id_to_id(result['volumeSet'][0]['volumeId']), vol2['id']) db.volume_destroy(self.context, vol1['id']) db.volume_destroy(self.context, vol2['id']) @@ -169,8 +196,10 @@ class CloudTestCase(test.TestCase): def test_describe_instances(self): """Makes sure describe_instances works and filters results.""" inst1 = db.instance_create(self.context, {'reservation_id': 'a', + 'image_id': 1, 'host': 'host1'}) inst2 = db.instance_create(self.context, {'reservation_id': 'a', + 'image_id': 1, 'host': 'host2'}) comp1 = db.service_create(self.context, {'host': 'host1', 'availability_zone': 'zone1', @@ -181,7 +210,7 @@ class CloudTestCase(test.TestCase): result = self.cloud.describe_instances(self.context) result = result['reservationSet'][0] self.assertEqual(len(result['instancesSet']), 2) - instance_id = cloud.id_to_ec2_id(inst2['id']) + instance_id = ec2utils.id_to_ec2_id(inst2['id']) result = self.cloud.describe_instances(self.context, instance_id=[instance_id]) result = result['reservationSet'][0] @@ -196,34 +225,37 @@ class CloudTestCase(test.TestCase): db.service_destroy(self.context, comp2['id']) def test_console_output(self): - image_id = FLAGS.default_image instance_type = FLAGS.default_instance_type max_count = 1 - kwargs = {'image_id': image_id, + kwargs = {'image_id': 'ami-1', 'instance_type': instance_type, 'max_count': max_count} rv = self.cloud.run_instances(self.context, **kwargs) + greenthread.sleep(0.3) instance_id = rv['instancesSet'][0]['instanceId'] output = self.cloud.get_console_output(context=self.context, - instance_id=[instance_id]) + instance_id=[instance_id]) self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE OUTPUT') # TODO(soren): We need this until we can stop polling in the rpc code # for unit tests. greenthread.sleep(0.3) rv = self.cloud.terminate_instances(self.context, [instance_id]) + greenthread.sleep(0.3) def test_ajax_console(self): - kwargs = {'image_id': image_id} - rv = yield self.cloud.run_instances(self.context, **kwargs) + kwargs = {'image_id': 'ami-1'} + rv = self.cloud.run_instances(self.context, **kwargs) instance_id = rv['instancesSet'][0]['instanceId'] - output = yield self.cloud.get_console_output(context=self.context, - instance_id=[instance_id]) - self.assertEquals(b64decode(output['output']), - 'http://fakeajaxconsole.com/?token=FAKETOKEN') + greenthread.sleep(0.3) + output = self.cloud.get_ajax_console(context=self.context, + instance_id=[instance_id]) + self.assertEquals(output['url'], + '%s/?token=FAKETOKEN' % FLAGS.ajax_console_proxy_url) # TODO(soren): We need this until we can stop polling in the rpc code # for unit tests. greenthread.sleep(0.3) - rv = yield self.cloud.terminate_instances(self.context, [instance_id]) + rv = self.cloud.terminate_instances(self.context, [instance_id]) + greenthread.sleep(0.3) def test_key_generation(self): result = self._create_key('test') @@ -243,7 +275,7 @@ class CloudTestCase(test.TestCase): self._create_key('test1') self._create_key('test2') result = self.cloud.describe_key_pairs(self.context) - keys = result["keypairsSet"] + keys = result["keySet"] self.assertTrue(filter(lambda k: k['keyName'] == 'test1', keys)) self.assertTrue(filter(lambda k: k['keyName'] == 'test2', keys)) @@ -286,70 +318,6 @@ class CloudTestCase(test.TestCase): LOG.debug(_("Terminating instance %s"), instance_id) rv = self.compute.terminate_instance(instance_id) - def test_describe_instances(self): - """Makes sure describe_instances works.""" - instance1 = db.instance_create(self.context, {'host': 'host2'}) - comp1 = db.service_create(self.context, {'host': 'host2', - 'availability_zone': 'zone1', - 'topic': "compute"}) - result = self.cloud.describe_instances(self.context) - self.assertEqual(result['reservationSet'][0] - ['instancesSet'][0] - ['placement']['availabilityZone'], 'zone1') - db.instance_destroy(self.context, instance1['id']) - db.service_destroy(self.context, comp1['id']) - - def test_instance_update_state(self): - # TODO(termie): what is this code even testing? - def instance(num): - return { - 'reservation_id': 'r-1', - 'instance_id': 'i-%s' % num, - 'image_id': 'ami-%s' % num, - 'private_dns_name': '10.0.0.%s' % num, - 'dns_name': '10.0.0%s' % num, - 'ami_launch_index': str(num), - 'instance_type': 'fake', - 'availability_zone': 'fake', - 'key_name': None, - 'kernel_id': 'fake', - 'ramdisk_id': 'fake', - 'groups': ['default'], - 'product_codes': None, - 'state': 0x01, - 'user_data': ''} - rv = self.cloud._format_describe_instances(self.context) - logging.error(str(rv)) - self.assertEqual(len(rv['reservationSet']), 0) - - # simulate launch of 5 instances - # self.cloud.instances['pending'] = {} - #for i in xrange(5): - # inst = instance(i) - # self.cloud.instances['pending'][inst['instance_id']] = inst - - #rv = self.cloud._format_instances(self.admin) - #self.assert_(len(rv['reservationSet']) == 1) - #self.assert_(len(rv['reservationSet'][0]['instances_set']) == 5) - # report 4 nodes each having 1 of the instances - #for i in xrange(4): - # self.cloud.update_state('instances', - # {('node-%s' % i): {('i-%s' % i): - # instance(i)}}) - - # one instance should be pending still - #self.assert_(len(self.cloud.instances['pending'].keys()) == 1) - - # check that the reservations collapse - #rv = self.cloud._format_instances(self.admin) - #self.assert_(len(rv['reservationSet']) == 1) - #self.assert_(len(rv['reservationSet'][0]['instances_set']) == 5) - - # check that we can get metadata for each instance - #for i in xrange(4): - # data = self.cloud.get_metadata(instance(i)['private_dns_name']) - # self.assert_(data['meta-data']['ami-id'] == 'ami-%s' % i) - @staticmethod def _fake_set_image_description(ctxt, image_id, description): from nova.objectstore import handler @@ -387,7 +355,7 @@ class CloudTestCase(test.TestCase): def test_update_of_instance_display_fields(self): inst = db.instance_create(self.context, {}) - ec2_id = cloud.id_to_ec2_id(inst['id']) + ec2_id = ec2utils.id_to_ec2_id(inst['id']) self.cloud.update_instance(self.context, ec2_id, display_name='c00l 1m4g3') inst = db.instance_get(self.context, inst['id']) @@ -405,7 +373,7 @@ class CloudTestCase(test.TestCase): def test_update_of_volume_display_fields(self): vol = db.volume_create(self.context, {}) self.cloud.update_volume(self.context, - cloud.id_to_ec2_id(vol['id'], 'vol-%08x'), + ec2utils.id_to_ec2_id(vol['id'], 'vol-%08x'), display_name='c00l v0lum3') vol = db.volume_get(self.context, vol['id']) self.assertEqual('c00l v0lum3', vol['display_name']) @@ -414,7 +382,7 @@ class CloudTestCase(test.TestCase): def test_update_of_volume_wont_update_private_fields(self): vol = db.volume_create(self.context, {}) self.cloud.update_volume(self.context, - cloud.id_to_ec2_id(vol['id'], 'vol-%08x'), + ec2utils.id_to_ec2_id(vol['id'], 'vol-%08x'), mountpoint='/not/here') vol = db.volume_get(self.context, vol['id']) self.assertEqual(None, vol['mountpoint']) diff --git a/nova/tests/test_compute.py b/nova/tests/test_compute.py index 2aa0690e7..643b2e93a 100644 --- a/nova/tests/test_compute.py +++ b/nova/tests/test_compute.py @@ -30,7 +30,8 @@ from nova import log as logging from nova import test from nova import utils from nova.auth import manager - +from nova.compute import instance_types +from nova.image import local LOG = logging.getLogger('nova.tests.compute') FLAGS = flags.FLAGS @@ -51,15 +52,20 @@ class ComputeTestCase(test.TestCase): self.project = self.manager.create_project('fake', 'fake', 'fake') self.context = context.RequestContext('fake', 'fake', False) + def fake_show(meh, context, id): + return {'id': 1, 'properties': {'kernel_id': 1, 'ramdisk_id': 1}} + + self.stubs.Set(local.LocalImageService, 'show', fake_show) + def tearDown(self): self.manager.delete_user(self.user) self.manager.delete_project(self.project) super(ComputeTestCase, self).tearDown() - def _create_instance(self): + def _create_instance(self, params={}): """Create a test instance""" inst = {} - inst['image_id'] = 'ami-test' + inst['image_id'] = 1 inst['reservation_id'] = 'r-fakeres' inst['launch_time'] = '10' inst['user_id'] = self.user.id @@ -67,6 +73,7 @@ class ComputeTestCase(test.TestCase): inst['instance_type'] = 'm1.tiny' inst['mac_address'] = utils.generate_mac() inst['ami_launch_index'] = 0 + inst.update(params) return db.instance_create(self.context, inst)['id'] def _create_group(self): @@ -202,6 +209,14 @@ class ComputeTestCase(test.TestCase): self.compute.set_admin_password(self.context, instance_id) self.compute.terminate_instance(self.context, instance_id) + def test_inject_file(self): + """Ensure we can write a file to an instance""" + instance_id = self._create_instance() + self.compute.run_instance(self.context, instance_id) + self.compute.inject_file(self.context, instance_id, "/tmp/test", + "File Contents") + self.compute.terminate_instance(self.context, instance_id) + def test_snapshot(self): """Ensure instance can be snapshotted""" instance_id = self._create_instance() @@ -258,3 +273,31 @@ class ComputeTestCase(test.TestCase): self.assertEqual(ret_val, None) self.compute.terminate_instance(self.context, instance_id) + + def test_resize_instance(self): + """Ensure instance can be migrated/resized""" + instance_id = self._create_instance() + context = self.context.elevated() + self.compute.run_instance(self.context, instance_id) + db.instance_update(self.context, instance_id, {'host': 'foo'}) + self.compute.prep_resize(context, instance_id) + migration_ref = db.migration_get_by_instance_and_status(context, + instance_id, 'pre-migrating') + self.compute.resize_instance(context, instance_id, + migration_ref['id']) + self.compute.terminate_instance(context, instance_id) + + def test_get_by_flavor_id(self): + type = instance_types.get_by_flavor_id(1) + self.assertEqual(type, 'm1.tiny') + + def test_resize_same_source_fails(self): + """Ensure instance fails to migrate when source and destination are + the same host""" + instance_id = self._create_instance() + self.compute.run_instance(self.context, instance_id) + self.assertRaises(exception.Error, self.compute.prep_resize, + self.context, instance_id) + self.compute.terminate_instance(self.context, instance_id) + type = instance_types.get_by_flavor_id("1") + self.assertEqual(type, 'm1.tiny') diff --git a/nova/tests/test_console.py b/nova/tests/test_console.py index 85bf94458..d47c70d88 100644 --- a/nova/tests/test_console.py +++ b/nova/tests/test_console.py @@ -21,7 +21,6 @@ Tests For Console proxy. """ import datetime -import logging from nova import context from nova import db @@ -38,7 +37,6 @@ FLAGS = flags.FLAGS class ConsoleTestCase(test.TestCase): """Test case for console proxy""" def setUp(self): - logging.getLogger().setLevel(logging.DEBUG) super(ConsoleTestCase, self).setUp() self.flags(console_driver='nova.console.fake.FakeConsoleProxy', stub_compute=True) @@ -59,7 +57,7 @@ class ConsoleTestCase(test.TestCase): inst = {} #inst['host'] = self.host #inst['name'] = 'instance-1234' - inst['image_id'] = 'ami-test' + inst['image_id'] = 1 inst['reservation_id'] = 'r-fakeres' inst['launch_time'] = '10' inst['user_id'] = self.user.id diff --git a/nova/tests/test_direct.py b/nova/tests/test_direct.py index 8a74b2296..80e4d2e1f 100644 --- a/nova/tests/test_direct.py +++ b/nova/tests/test_direct.py @@ -19,7 +19,6 @@ """Tests for Direct API.""" import json -import logging import webob @@ -53,12 +52,14 @@ class DirectTestCase(test.TestCase): def tearDown(self): direct.ROUTES = {} + super(DirectTestCase, self).tearDown() def test_delegated_auth(self): req = webob.Request.blank('/fake/context') req.headers['X-OpenStack-User'] = 'user1' req.headers['X-OpenStack-Project'] = 'proj1' resp = req.get_response(self.auth_router) + self.assertEqual(resp.status_int, 200) data = json.loads(resp.body) self.assertEqual(data['user'], 'user1') self.assertEqual(data['project'], 'proj1') @@ -69,6 +70,7 @@ class DirectTestCase(test.TestCase): req.method = 'POST' req.body = 'json=%s' % json.dumps({'data': 'foo'}) resp = req.get_response(self.router) + self.assertEqual(resp.status_int, 200) resp_parsed = json.loads(resp.body) self.assertEqual(resp_parsed['data'], 'foo') @@ -78,6 +80,7 @@ class DirectTestCase(test.TestCase): req.method = 'POST' req.body = 'data=foo' resp = req.get_response(self.router) + self.assertEqual(resp.status_int, 200) resp_parsed = json.loads(resp.body) self.assertEqual(resp_parsed['data'], 'foo') @@ -90,8 +93,7 @@ class DirectTestCase(test.TestCase): class DirectCloudTestCase(test_cloud.CloudTestCase): def setUp(self): super(DirectCloudTestCase, self).setUp() - compute_handle = compute.API(image_service=self.cloud.image_service, - network_api=self.cloud.network_api, + compute_handle = compute.API(network_api=self.cloud.network_api, volume_api=self.cloud.volume_api) direct.register_service('compute', compute_handle) self.router = direct.JsonParamsMiddleware(direct.Router()) diff --git a/nova/tests/test_instance_types.py b/nova/tests/test_instance_types.py new file mode 100644 index 000000000..edc538879 --- /dev/null +++ b/nova/tests/test_instance_types.py @@ -0,0 +1,86 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 Ken Pepple +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +Unit Tests for instance types code +""" +import time + +from nova import context +from nova import db +from nova import exception +from nova import flags +from nova import log as logging +from nova import test +from nova import utils +from nova.compute import instance_types +from nova.db.sqlalchemy.session import get_session +from nova.db.sqlalchemy import models + +FLAGS = flags.FLAGS +LOG = logging.getLogger('nova.tests.compute') + + +class InstanceTypeTestCase(test.TestCase): + """Test cases for instance type code""" + def setUp(self): + super(InstanceTypeTestCase, self).setUp() + session = get_session() + max_flavorid = session.query(models.InstanceTypes).\ + order_by("flavorid desc").\ + first() + self.flavorid = max_flavorid["flavorid"] + 1 + self.name = str(int(time.time())) + + def test_instance_type_create_then_delete(self): + """Ensure instance types can be created""" + starting_inst_list = instance_types.get_all_types() + instance_types.create(self.name, 256, 1, 120, self.flavorid) + new = instance_types.get_all_types() + self.assertNotEqual(len(starting_inst_list), + len(new), + 'instance type was not created') + instance_types.destroy(self.name) + self.assertEqual(1, + instance_types.get_instance_type(self.name)["deleted"]) + self.assertEqual(starting_inst_list, instance_types.get_all_types()) + instance_types.purge(self.name) + self.assertEqual(len(starting_inst_list), + len(instance_types.get_all_types()), + 'instance type not purged') + + def test_get_all_instance_types(self): + """Ensures that all instance types can be retrieved""" + session = get_session() + total_instance_types = session.query(models.InstanceTypes).\ + count() + inst_types = instance_types.get_all_types() + self.assertEqual(total_instance_types, len(inst_types)) + + def test_invalid_create_args_should_fail(self): + """Ensures that instance type creation fails with invalid args""" + self.assertRaises( + exception.InvalidInputException, + instance_types.create, self.name, 0, 1, 120, self.flavorid) + self.assertRaises( + exception.InvalidInputException, + instance_types.create, self.name, 256, -1, 120, self.flavorid) + self.assertRaises( + exception.InvalidInputException, + instance_types.create, self.name, 256, 1, "aa", self.flavorid) + + def test_non_existant_inst_type_shouldnt_delete(self): + """Ensures that instance type creation fails with invalid args""" + self.assertRaises(exception.ApiError, + instance_types.destroy, "sfsfsdfdfs") diff --git a/nova/tests/test_localization.py b/nova/tests/test_localization.py index 6992773f5..393d71038 100644 --- a/nova/tests/test_localization.py +++ b/nova/tests/test_localization.py @@ -15,7 +15,6 @@ # under the License. import glob -import logging import os import re import sys diff --git a/nova/tests/test_log.py b/nova/tests/test_log.py index 868a5ead3..122351ff6 100644 --- a/nova/tests/test_log.py +++ b/nova/tests/test_log.py @@ -1,9 +1,12 @@ import cStringIO from nova import context +from nova import flags from nova import log from nova import test +FLAGS = flags.FLAGS + def _fake_context(): return context.RequestContext(1, 1) @@ -14,15 +17,11 @@ class RootLoggerTestCase(test.TestCase): super(RootLoggerTestCase, self).setUp() self.log = log.logging.root - def tearDown(self): - super(RootLoggerTestCase, self).tearDown() - log.NovaLogger.manager.loggerDict = {} - def test_is_nova_instance(self): self.assert_(isinstance(self.log, log.NovaLogger)) - def test_name_is_nova_root(self): - self.assertEqual("nova.root", self.log.name) + def test_name_is_nova(self): + self.assertEqual("nova", self.log.name) def test_handlers_have_nova_formatter(self): formatters = [] @@ -45,6 +44,38 @@ class RootLoggerTestCase(test.TestCase): log.audit("foo", context=_fake_context()) self.assert_(True) # didn't raise exception + def test_will_be_verbose_if_verbose_flag_set(self): + self.flags(verbose=True) + log.reset() + self.assertEqual(log.DEBUG, self.log.level) + + def test_will_not_be_verbose_if_verbose_flag_not_set(self): + self.flags(verbose=False) + log.reset() + self.assertEqual(log.INFO, self.log.level) + + +class LogHandlerTestCase(test.TestCase): + def test_log_path_logdir(self): + self.flags(logdir='/some/path', logfile=None) + self.assertEquals(log._get_log_file_path(binary='foo-bar'), + '/some/path/foo-bar.log') + + def test_log_path_logfile(self): + self.flags(logfile='/some/path/foo-bar.log') + self.assertEquals(log._get_log_file_path(binary='foo-bar'), + '/some/path/foo-bar.log') + + def test_log_path_none(self): + self.flags(logdir=None, logfile=None) + self.assertTrue(log._get_log_file_path(binary='foo-bar') is None) + + def test_log_path_logfile_overrides_logdir(self): + self.flags(logdir='/some/other/path', + logfile='/some/path/foo-bar.log') + self.assertEquals(log._get_log_file_path(binary='foo-bar'), + '/some/path/foo-bar.log') + class NovaFormatterTestCase(test.TestCase): def setUp(self): @@ -55,13 +86,15 @@ class NovaFormatterTestCase(test.TestCase): logging_debug_format_suffix="--DBG") self.log = log.logging.root self.stream = cStringIO.StringIO() - handler = log.StreamHandler(self.stream) - self.log.addHandler(handler) + self.handler = log.StreamHandler(self.stream) + self.log.addHandler(self.handler) + self.level = self.log.level self.log.setLevel(log.DEBUG) def tearDown(self): + self.log.setLevel(self.level) + self.log.removeHandler(self.handler) super(NovaFormatterTestCase, self).tearDown() - log.NovaLogger.manager.loggerDict = {} def test_uncontextualized_log(self): self.log.info("foo") @@ -81,30 +114,15 @@ class NovaFormatterTestCase(test.TestCase): class NovaLoggerTestCase(test.TestCase): def setUp(self): super(NovaLoggerTestCase, self).setUp() - self.flags(default_log_levels=["nova-test=AUDIT"], verbose=False) + levels = FLAGS.default_log_levels + levels.append("nova-test=AUDIT") + self.flags(default_log_levels=levels, + verbose=True) self.log = log.getLogger('nova-test') - def tearDown(self): - super(NovaLoggerTestCase, self).tearDown() - log.NovaLogger.manager.loggerDict = {} - def test_has_level_from_flags(self): self.assertEqual(log.AUDIT, self.log.level) def test_child_log_has_level_of_parent_flag(self): l = log.getLogger('nova-test.foo') self.assertEqual(log.AUDIT, l.level) - - -class VerboseLoggerTestCase(test.TestCase): - def setUp(self): - super(VerboseLoggerTestCase, self).setUp() - self.flags(default_log_levels=["nova.test=AUDIT"], verbose=True) - self.log = log.getLogger('nova.test') - - def tearDown(self): - super(VerboseLoggerTestCase, self).tearDown() - log.NovaLogger.manager.loggerDict = {} - - def test_will_be_verbose_if_named_nova_and_verbose_flag_set(self): - self.assertEqual(log.DEBUG, self.log.level) diff --git a/nova/tests/test_misc.py b/nova/tests/test_misc.py index 33c1777d5..a658e4978 100644 --- a/nova/tests/test_misc.py +++ b/nova/tests/test_misc.py @@ -14,10 +14,12 @@ # License for the specific language governing permissions and limitations # under the License. +import errno import os +import select from nova import test -from nova.utils import parse_mailmap, str_dict_replace +from nova.utils import parse_mailmap, str_dict_replace, synchronized class ProjectTestCase(test.TestCase): @@ -46,6 +48,8 @@ class ProjectTestCase(test.TestCase): missing = set() for contributor in contributors: + if contributor == 'nova-core': + continue if not contributor in authors_file: missing.add(contributor) @@ -53,3 +57,47 @@ class ProjectTestCase(test.TestCase): '%r not listed in Authors' % missing) finally: tree.unlock() + + +class LockTestCase(test.TestCase): + def test_synchronized_wrapped_function_metadata(self): + @synchronized('whatever') + def foo(): + """Bar""" + pass + self.assertEquals(foo.__doc__, 'Bar', "Wrapped function's docstring " + "got lost") + self.assertEquals(foo.__name__, 'foo', "Wrapped function's name " + "got mangled") + + def test_synchronized(self): + rpipe1, wpipe1 = os.pipe() + rpipe2, wpipe2 = os.pipe() + + @synchronized('testlock') + def f(rpipe, wpipe): + try: + os.write(wpipe, "foo") + except OSError, e: + self.assertEquals(e.errno, errno.EPIPE) + return + + rfds, _, __ = select.select([rpipe], [], [], 1) + self.assertEquals(len(rfds), 0, "The other process, which was" + " supposed to be locked, " + "wrote on its end of the " + "pipe") + os.close(rpipe) + + pid = os.fork() + if pid > 0: + os.close(wpipe1) + os.close(rpipe2) + + f(rpipe1, wpipe2) + else: + os.close(rpipe1) + os.close(wpipe2) + + f(rpipe2, wpipe1) + os._exit(0) diff --git a/nova/tests/test_network.py b/nova/tests/test_network.py index 00f9323f3..53e35ce7e 100644 --- a/nova/tests/test_network.py +++ b/nova/tests/test_network.py @@ -29,11 +29,153 @@ from nova import log as logging from nova import test from nova import utils from nova.auth import manager +from nova.network import linux_net FLAGS = flags.FLAGS LOG = logging.getLogger('nova.tests.network') +class IptablesManagerTestCase(test.TestCase): + sample_filter = ['#Generated by iptables-save on Fri Feb 18 15:17:05 2011', + '*filter', + ':INPUT ACCEPT [2223527:305688874]', + ':FORWARD ACCEPT [0:0]', + ':OUTPUT ACCEPT [2172501:140856656]', + ':nova-compute-FORWARD - [0:0]', + ':nova-compute-INPUT - [0:0]', + ':nova-compute-local - [0:0]', + ':nova-compute-OUTPUT - [0:0]', + ':nova-filter-top - [0:0]', + '-A FORWARD -j nova-filter-top ', + '-A OUTPUT -j nova-filter-top ', + '-A nova-filter-top -j nova-compute-local ', + '-A INPUT -j nova-compute-INPUT ', + '-A OUTPUT -j nova-compute-OUTPUT ', + '-A FORWARD -j nova-compute-FORWARD ', + '-A INPUT -i virbr0 -p udp -m udp --dport 53 -j ACCEPT ', + '-A INPUT -i virbr0 -p tcp -m tcp --dport 53 -j ACCEPT ', + '-A INPUT -i virbr0 -p udp -m udp --dport 67 -j ACCEPT ', + '-A INPUT -i virbr0 -p tcp -m tcp --dport 67 -j ACCEPT ', + '-A FORWARD -s 192.168.122.0/24 -i virbr0 -j ACCEPT ', + '-A FORWARD -i virbr0 -o virbr0 -j ACCEPT ', + '-A FORWARD -o virbr0 -j REJECT --reject-with ' + 'icmp-port-unreachable ', + '-A FORWARD -i virbr0 -j REJECT --reject-with ' + 'icmp-port-unreachable ', + 'COMMIT', + '# Completed on Fri Feb 18 15:17:05 2011'] + + sample_nat = ['# Generated by iptables-save on Fri Feb 18 15:17:05 2011', + '*nat', + ':PREROUTING ACCEPT [3936:762355]', + ':INPUT ACCEPT [2447:225266]', + ':OUTPUT ACCEPT [63491:4191863]', + ':POSTROUTING ACCEPT [63112:4108641]', + ':nova-compute-OUTPUT - [0:0]', + ':nova-compute-floating-ip-snat - [0:0]', + ':nova-compute-SNATTING - [0:0]', + ':nova-compute-PREROUTING - [0:0]', + ':nova-compute-POSTROUTING - [0:0]', + ':nova-postrouting-bottom - [0:0]', + '-A PREROUTING -j nova-compute-PREROUTING ', + '-A OUTPUT -j nova-compute-OUTPUT ', + '-A POSTROUTING -j nova-compute-POSTROUTING ', + '-A POSTROUTING -j nova-postrouting-bottom ', + '-A nova-postrouting-bottom -j nova-compute-SNATTING ', + '-A nova-compute-SNATTING -j nova-compute-floating-ip-snat ', + 'COMMIT', + '# Completed on Fri Feb 18 15:17:05 2011'] + + def setUp(self): + super(IptablesManagerTestCase, self).setUp() + self.manager = linux_net.IptablesManager() + + def test_filter_rules_are_wrapped(self): + current_lines = self.sample_filter + + table = self.manager.ipv4['filter'] + table.add_rule('FORWARD', '-s 1.2.3.4/5 -j DROP') + new_lines = self.manager._modify_rules(current_lines, table) + self.assertTrue('-A run_tests.py-FORWARD ' + '-s 1.2.3.4/5 -j DROP' in new_lines) + + table.remove_rule('FORWARD', '-s 1.2.3.4/5 -j DROP') + new_lines = self.manager._modify_rules(current_lines, table) + self.assertTrue('-A run_tests.py-FORWARD ' + '-s 1.2.3.4/5 -j DROP' not in new_lines) + + def test_nat_rules(self): + current_lines = self.sample_nat + new_lines = self.manager._modify_rules(current_lines, + self.manager.ipv4['nat']) + + for line in [':nova-compute-OUTPUT - [0:0]', + ':nova-compute-floating-ip-snat - [0:0]', + ':nova-compute-SNATTING - [0:0]', + ':nova-compute-PREROUTING - [0:0]', + ':nova-compute-POSTROUTING - [0:0]']: + self.assertTrue(line in new_lines, "One of nova-compute's chains " + "went missing.") + + seen_lines = set() + for line in new_lines: + line = line.strip() + self.assertTrue(line not in seen_lines, + "Duplicate line: %s" % line) + seen_lines.add(line) + + last_postrouting_line = '' + + for line in new_lines: + if line.startswith('-A POSTROUTING'): + last_postrouting_line = line + + self.assertTrue('-j nova-postrouting-bottom' in last_postrouting_line, + "Last POSTROUTING rule does not jump to " + "nova-postouting-bottom: %s" % last_postrouting_line) + + for chain in ['POSTROUTING', 'PREROUTING', 'OUTPUT']: + self.assertTrue('-A %s -j run_tests.py-%s' \ + % (chain, chain) in new_lines, + "Built-in chain %s not wrapped" % (chain,)) + + def test_filter_rules(self): + current_lines = self.sample_filter + new_lines = self.manager._modify_rules(current_lines, + self.manager.ipv4['filter']) + + for line in [':nova-compute-FORWARD - [0:0]', + ':nova-compute-INPUT - [0:0]', + ':nova-compute-local - [0:0]', + ':nova-compute-OUTPUT - [0:0]']: + self.assertTrue(line in new_lines, "One of nova-compute's chains" + " went missing.") + + seen_lines = set() + for line in new_lines: + line = line.strip() + self.assertTrue(line not in seen_lines, + "Duplicate line: %s" % line) + seen_lines.add(line) + + for chain in ['FORWARD', 'OUTPUT']: + for line in new_lines: + if line.startswith('-A %s' % chain): + self.assertTrue('-j nova-filter-top' in line, + "First %s rule does not " + "jump to nova-filter-top" % chain) + break + + self.assertTrue('-A nova-filter-top ' + '-j run_tests.py-local' in new_lines, + "nova-filter-top does not jump to wrapped local chain") + + for chain in ['INPUT', 'OUTPUT', 'FORWARD']: + self.assertTrue('-A %s -j run_tests.py-%s' \ + % (chain, chain) in new_lines, + "Built-in chain %s not wrapped" % (chain,)) + + class NetworkTestCase(test.TestCase): """Test cases for network code""" def setUp(self): @@ -42,15 +184,13 @@ class NetworkTestCase(test.TestCase): # flags in the corresponding section in nova-dhcpbridge self.flags(connection_type='fake', fake_call=True, - fake_network=True, - network_size=16, - num_networks=5) + fake_network=True) self.manager = manager.AuthManager() self.user = self.manager.create_user('netuser', 'netuser', 'netuser') self.projects = [] self.network = utils.import_object(FLAGS.network_manager) self.context = context.RequestContext(project=None, user=self.user) - for i in range(5): + for i in range(FLAGS.num_networks): name = 'project%s' % i project = self.manager.create_project(name, 'netuser', name) self.projects.append(project) @@ -117,6 +257,9 @@ class NetworkTestCase(test.TestCase): utils.to_global_ipv6( network_ref['cidr_v6'], instance_ref['mac_address'])) + self._deallocate_address(0, address) + db.instance_destroy(context.get_admin_context(), + instance_ref['id']) def test_public_network_association(self): """Makes sure that we can allocaate a public ip""" @@ -192,7 +335,7 @@ class NetworkTestCase(test.TestCase): first = self._create_address(0) lease_ip(first) instance_ids = [] - for i in range(1, 5): + for i in range(1, FLAGS.num_networks): instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address = self._create_address(i, instance_ref['id']) @@ -342,13 +485,13 @@ def lease_ip(private_ip): private_ip) instance_ref = db.fixed_ip_get_instance(context.get_admin_context(), private_ip) - cmd = "%s add %s %s fake" % (binpath('nova-dhcpbridge'), - instance_ref['mac_address'], - private_ip) + cmd = (binpath('nova-dhcpbridge'), 'add', + instance_ref['mac_address'], + private_ip, 'fake') env = {'DNSMASQ_INTERFACE': network_ref['bridge'], 'TESTING': '1', 'FLAGFILE': FLAGS.dhcpbridge_flagfile} - (out, err) = utils.execute(cmd, addl_env=env) + (out, err) = utils.execute(*cmd, addl_env=env) LOG.debug("ISSUE_IP: %s, %s ", out, err) @@ -358,11 +501,11 @@ def release_ip(private_ip): private_ip) instance_ref = db.fixed_ip_get_instance(context.get_admin_context(), private_ip) - cmd = "%s del %s %s fake" % (binpath('nova-dhcpbridge'), - instance_ref['mac_address'], - private_ip) + cmd = (binpath('nova-dhcpbridge'), 'del', + instance_ref['mac_address'], + private_ip, 'fake') env = {'DNSMASQ_INTERFACE': network_ref['bridge'], 'TESTING': '1', 'FLAGFILE': FLAGS.dhcpbridge_flagfile} - (out, err) = utils.execute(cmd, addl_env=env) + (out, err) = utils.execute(*cmd, addl_env=env) LOG.debug("RELEASE_IP: %s, %s ", out, err) diff --git a/nova/tests/test_quota.py b/nova/tests/test_quota.py index 9548a8c13..45b544753 100644 --- a/nova/tests/test_quota.py +++ b/nova/tests/test_quota.py @@ -16,14 +16,16 @@ # License for the specific language governing permissions and limitations # under the License. +from nova import compute from nova import context from nova import db from nova import flags +from nova import network from nova import quota from nova import test from nova import utils +from nova import volume from nova.auth import manager -from nova.api.ec2 import cloud from nova.compute import instance_types @@ -40,7 +42,6 @@ class QuotaTestCase(test.TestCase): quota_gigabytes=20, quota_floating_ips=1) - self.cloud = cloud.CloudController() self.manager = manager.AuthManager() self.user = self.manager.create_user('admin', 'admin', 'admin', True) self.project = self.manager.create_project('admin', 'admin', 'admin') @@ -56,7 +57,7 @@ class QuotaTestCase(test.TestCase): def _create_instance(self, cores=2): """Create a test instance""" inst = {} - inst['image_id'] = 'ami-test' + inst['image_id'] = 1 inst['reservation_id'] = 'r-fakeres' inst['user_id'] = self.user.id inst['project_id'] = self.project.id @@ -73,20 +74,43 @@ class QuotaTestCase(test.TestCase): vol['size'] = size return db.volume_create(self.context, vol)['id'] + def _get_instance_type(self, name): + instance_types = { + 'm1.tiny': dict(memory_mb=512, vcpus=1, local_gb=0, flavorid=1), + 'm1.small': dict(memory_mb=2048, vcpus=1, local_gb=20, flavorid=2), + 'm1.medium': + dict(memory_mb=4096, vcpus=2, local_gb=40, flavorid=3), + 'm1.large': dict(memory_mb=8192, vcpus=4, local_gb=80, flavorid=4), + 'm1.xlarge': + dict(memory_mb=16384, vcpus=8, local_gb=160, flavorid=5)} + return instance_types[name] + def test_quota_overrides(self): """Make sure overriding a projects quotas works""" num_instances = quota.allowed_instances(self.context, 100, - instance_types.INSTANCE_TYPES['m1.small']) + self._get_instance_type('m1.small')) self.assertEqual(num_instances, 2) db.quota_create(self.context, {'project_id': self.project.id, 'instances': 10}) num_instances = quota.allowed_instances(self.context, 100, - instance_types.INSTANCE_TYPES['m1.small']) + self._get_instance_type('m1.small')) self.assertEqual(num_instances, 4) db.quota_update(self.context, self.project.id, {'cores': 100}) num_instances = quota.allowed_instances(self.context, 100, - instance_types.INSTANCE_TYPES['m1.small']) + self._get_instance_type('m1.small')) self.assertEqual(num_instances, 10) + + # metadata_items + too_many_items = FLAGS.quota_metadata_items + 1000 + num_metadata_items = quota.allowed_metadata_items(self.context, + too_many_items) + self.assertEqual(num_metadata_items, FLAGS.quota_metadata_items) + db.quota_update(self.context, self.project.id, {'metadata_items': 5}) + num_metadata_items = quota.allowed_metadata_items(self.context, + too_many_items) + self.assertEqual(num_metadata_items, 5) + + # Cleanup db.quota_destroy(self.context, self.project.id) def test_too_many_instances(self): @@ -94,12 +118,12 @@ class QuotaTestCase(test.TestCase): for i in range(FLAGS.quota_instances): instance_id = self._create_instance() instance_ids.append(instance_id) - self.assertRaises(quota.QuotaError, self.cloud.run_instances, + self.assertRaises(quota.QuotaError, compute.API().create, self.context, min_count=1, max_count=1, instance_type='m1.small', - image_id='fake') + image_id=1) for instance_id in instance_ids: db.instance_destroy(self.context, instance_id) @@ -107,12 +131,12 @@ class QuotaTestCase(test.TestCase): instance_ids = [] instance_id = self._create_instance(cores=4) instance_ids.append(instance_id) - self.assertRaises(quota.QuotaError, self.cloud.run_instances, + self.assertRaises(quota.QuotaError, compute.API().create, self.context, min_count=1, max_count=1, instance_type='m1.small', - image_id='fake') + image_id=1) for instance_id in instance_ids: db.instance_destroy(self.context, instance_id) @@ -121,9 +145,12 @@ class QuotaTestCase(test.TestCase): for i in range(FLAGS.quota_volumes): volume_id = self._create_volume() volume_ids.append(volume_id) - self.assertRaises(quota.QuotaError, self.cloud.create_volume, - self.context, - size=10) + self.assertRaises(quota.QuotaError, + volume.API().create, + self.context, + size=10, + name='', + description='') for volume_id in volume_ids: db.volume_destroy(self.context, volume_id) @@ -132,9 +159,11 @@ class QuotaTestCase(test.TestCase): volume_id = self._create_volume(size=20) volume_ids.append(volume_id) self.assertRaises(quota.QuotaError, - self.cloud.create_volume, + volume.API().create, self.context, - size=10) + size=10, + name='', + description='') for volume_id in volume_ids: db.volume_destroy(self.context, volume_id) @@ -148,6 +177,19 @@ class QuotaTestCase(test.TestCase): # make an rpc.call, the test just finishes with OK. It # appears to be something in the magic inline callbacks # that is breaking. - self.assertRaises(quota.QuotaError, self.cloud.allocate_address, + self.assertRaises(quota.QuotaError, + network.API().allocate_floating_ip, self.context) db.floating_ip_destroy(context.get_admin_context(), address) + + def test_too_many_metadata_items(self): + metadata = {} + for i in range(FLAGS.quota_metadata_items + 1): + metadata['key%s' % i] = 'value%s' % i + self.assertRaises(quota.QuotaError, compute.API().create, + self.context, + min_count=1, + max_count=1, + instance_type='m1.small', + image_id='fake', + metadata=metadata) diff --git a/nova/tests/test_scheduler.py b/nova/tests/test_scheduler.py index 9d458244b..bb279ac4b 100644 --- a/nova/tests/test_scheduler.py +++ b/nova/tests/test_scheduler.py @@ -150,11 +150,12 @@ class SimpleDriverTestCase(test.TestCase): def tearDown(self): self.manager.delete_user(self.user) self.manager.delete_project(self.project) + super(SimpleDriverTestCase, self).tearDown() def _create_instance(self, **kwargs): """Create a test instance""" inst = {} - inst['image_id'] = 'ami-test' + inst['image_id'] = 1 inst['reservation_id'] = 'r-fakeres' inst['user_id'] = self.user.id inst['project_id'] = self.project.id @@ -168,26 +169,14 @@ class SimpleDriverTestCase(test.TestCase): def _create_volume(self): """Create a test volume""" vol = {} - vol['image_id'] = 'ami-test' - vol['reservation_id'] = 'r-fakeres' vol['size'] = 1 vol['availability_zone'] = 'test' return db.volume_create(self.context, vol)['id'] def test_doesnt_report_disabled_hosts_as_up(self): """Ensures driver doesn't find hosts before they are enabled""" - # NOTE(vish): constructing service without create method - # because we are going to use it without queue - compute1 = service.Service('host1', - 'nova-compute', - 'compute', - FLAGS.compute_manager) - compute1.start() - compute2 = service.Service('host2', - 'nova-compute', - 'compute', - FLAGS.compute_manager) - compute2.start() + compute1 = self.start_service('compute', host='host1') + compute2 = self.start_service('compute', host='host2') s1 = db.service_get_by_args(self.context, 'host1', 'nova-compute') s2 = db.service_get_by_args(self.context, 'host2', 'nova-compute') db.service_update(self.context, s1['id'], {'disabled': True}) @@ -199,18 +188,8 @@ class SimpleDriverTestCase(test.TestCase): def test_reports_enabled_hosts_as_up(self): """Ensures driver can find the hosts that are up""" - # NOTE(vish): constructing service without create method - # because we are going to use it without queue - compute1 = service.Service('host1', - 'nova-compute', - 'compute', - FLAGS.compute_manager) - compute1.start() - compute2 = service.Service('host2', - 'nova-compute', - 'compute', - FLAGS.compute_manager) - compute2.start() + compute1 = self.start_service('compute', host='host1') + compute2 = self.start_service('compute', host='host2') hosts = self.scheduler.driver.hosts_up(self.context, 'compute') self.assertEqual(2, len(hosts)) compute1.kill() @@ -218,16 +197,8 @@ class SimpleDriverTestCase(test.TestCase): def test_least_busy_host_gets_instance(self): """Ensures the host with less cores gets the next one""" - compute1 = service.Service('host1', - 'nova-compute', - 'compute', - FLAGS.compute_manager) - compute1.start() - compute2 = service.Service('host2', - 'nova-compute', - 'compute', - FLAGS.compute_manager) - compute2.start() + compute1 = self.start_service('compute', host='host1') + compute2 = self.start_service('compute', host='host2') instance_id1 = self._create_instance() compute1.run_instance(self.context, instance_id1) instance_id2 = self._create_instance() @@ -241,16 +212,8 @@ class SimpleDriverTestCase(test.TestCase): def test_specific_host_gets_instance(self): """Ensures if you set availability_zone it launches on that zone""" - compute1 = service.Service('host1', - 'nova-compute', - 'compute', - FLAGS.compute_manager) - compute1.start() - compute2 = service.Service('host2', - 'nova-compute', - 'compute', - FLAGS.compute_manager) - compute2.start() + compute1 = self.start_service('compute', host='host1') + compute2 = self.start_service('compute', host='host2') instance_id1 = self._create_instance() compute1.run_instance(self.context, instance_id1) instance_id2 = self._create_instance(availability_zone='nova:host1') @@ -263,11 +226,7 @@ class SimpleDriverTestCase(test.TestCase): compute2.kill() def test_wont_sechedule_if_specified_host_is_down(self): - compute1 = service.Service('host1', - 'nova-compute', - 'compute', - FLAGS.compute_manager) - compute1.start() + compute1 = self.start_service('compute', host='host1') s1 = db.service_get_by_args(self.context, 'host1', 'nova-compute') now = datetime.datetime.utcnow() delta = datetime.timedelta(seconds=FLAGS.service_down_time * 2) @@ -282,11 +241,7 @@ class SimpleDriverTestCase(test.TestCase): compute1.kill() def test_will_schedule_on_disabled_host_if_specified(self): - compute1 = service.Service('host1', - 'nova-compute', - 'compute', - FLAGS.compute_manager) - compute1.start() + compute1 = self.start_service('compute', host='host1') s1 = db.service_get_by_args(self.context, 'host1', 'nova-compute') db.service_update(self.context, s1['id'], {'disabled': True}) instance_id2 = self._create_instance(availability_zone='nova:host1') @@ -298,16 +253,8 @@ class SimpleDriverTestCase(test.TestCase): def test_too_many_cores(self): """Ensures we don't go over max cores""" - compute1 = service.Service('host1', - 'nova-compute', - 'compute', - FLAGS.compute_manager) - compute1.start() - compute2 = service.Service('host2', - 'nova-compute', - 'compute', - FLAGS.compute_manager) - compute2.start() + compute1 = self.start_service('compute', host='host1') + compute2 = self.start_service('compute', host='host2') instance_ids1 = [] instance_ids2 = [] for index in xrange(FLAGS.max_cores): @@ -322,6 +269,7 @@ class SimpleDriverTestCase(test.TestCase): self.scheduler.driver.schedule_run_instance, self.context, instance_id) + db.instance_destroy(self.context, instance_id) for instance_id in instance_ids1: compute1.terminate_instance(self.context, instance_id) for instance_id in instance_ids2: @@ -331,16 +279,8 @@ class SimpleDriverTestCase(test.TestCase): def test_least_busy_host_gets_volume(self): """Ensures the host with less gigabytes gets the next one""" - volume1 = service.Service('host1', - 'nova-volume', - 'volume', - FLAGS.volume_manager) - volume1.start() - volume2 = service.Service('host2', - 'nova-volume', - 'volume', - FLAGS.volume_manager) - volume2.start() + volume1 = self.start_service('volume', host='host1') + volume2 = self.start_service('volume', host='host2') volume_id1 = self._create_volume() volume1.create_volume(self.context, volume_id1) volume_id2 = self._create_volume() @@ -354,16 +294,8 @@ class SimpleDriverTestCase(test.TestCase): def test_too_many_gigabytes(self): """Ensures we don't go over max gigabytes""" - volume1 = service.Service('host1', - 'nova-volume', - 'volume', - FLAGS.volume_manager) - volume1.start() - volume2 = service.Service('host2', - 'nova-volume', - 'volume', - FLAGS.volume_manager) - volume2.start() + volume1 = self.start_service('volume', host='host1') + volume2 = self.start_service('volume', host='host2') volume_ids1 = [] volume_ids2 = [] for index in xrange(FLAGS.max_gigabytes): diff --git a/nova/tests/test_service.py b/nova/tests/test_service.py index a67c8d1e8..45d9afa6c 100644 --- a/nova/tests/test_service.py +++ b/nova/tests/test_service.py @@ -50,13 +50,6 @@ class ExtendedService(service.Service): class ServiceManagerTestCase(test.TestCase): """Test cases for Services""" - def test_attribute_error_for_no_manager(self): - serv = service.Service('test', - 'test', - 'test', - 'nova.tests.test_service.FakeManager') - self.assertRaises(AttributeError, getattr, serv, 'test_method') - def test_message_gets_to_manager(self): serv = service.Service('test', 'test', diff --git a/nova/tests/test_test.py b/nova/tests/test_test.py new file mode 100644 index 000000000..e237674e6 --- /dev/null +++ b/nova/tests/test_test.py @@ -0,0 +1,40 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Tests for the testing base code.""" + +from nova import rpc +from nova import test + + +class IsolationTestCase(test.TestCase): + """Ensure that things are cleaned up after failed tests. + + These tests don't really do much here, but if isolation fails a bunch + of other tests should fail. + + """ + def test_service_isolation(self): + self.start_service('compute') + + def test_rpc_consumer_isolation(self): + connection = rpc.Connection.instance(new=True) + consumer = rpc.TopicConsumer(connection, topic='compute') + consumer.register_callback( + lambda x, y: self.fail('I should never be called')) + consumer.attach_to_eventlet() diff --git a/nova/tests/test_utils.py b/nova/tests/test_utils.py new file mode 100644 index 000000000..34a407f1a --- /dev/null +++ b/nova/tests/test_utils.py @@ -0,0 +1,174 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 Justin Santa Barbara +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from nova import test +from nova import utils +from nova import exception + + +class GetFromPathTestCase(test.TestCase): + def test_tolerates_nones(self): + f = utils.get_from_path + + input = [] + self.assertEquals([], f(input, "a")) + self.assertEquals([], f(input, "a/b")) + self.assertEquals([], f(input, "a/b/c")) + + input = [None] + self.assertEquals([], f(input, "a")) + self.assertEquals([], f(input, "a/b")) + self.assertEquals([], f(input, "a/b/c")) + + input = [{'a': None}] + self.assertEquals([], f(input, "a")) + self.assertEquals([], f(input, "a/b")) + self.assertEquals([], f(input, "a/b/c")) + + input = [{'a': {'b': None}}] + self.assertEquals([{'b': None}], f(input, "a")) + self.assertEquals([], f(input, "a/b")) + self.assertEquals([], f(input, "a/b/c")) + + input = [{'a': {'b': {'c': None}}}] + self.assertEquals([{'b': {'c': None}}], f(input, "a")) + self.assertEquals([{'c': None}], f(input, "a/b")) + self.assertEquals([], f(input, "a/b/c")) + + input = [{'a': {'b': {'c': None}}}, {'a': None}] + self.assertEquals([{'b': {'c': None}}], f(input, "a")) + self.assertEquals([{'c': None}], f(input, "a/b")) + self.assertEquals([], f(input, "a/b/c")) + + input = [{'a': {'b': {'c': None}}}, {'a': {'b': None}}] + self.assertEquals([{'b': {'c': None}}, {'b': None}], f(input, "a")) + self.assertEquals([{'c': None}], f(input, "a/b")) + self.assertEquals([], f(input, "a/b/c")) + + def test_does_select(self): + f = utils.get_from_path + + input = [{'a': 'a_1'}] + self.assertEquals(['a_1'], f(input, "a")) + self.assertEquals([], f(input, "a/b")) + self.assertEquals([], f(input, "a/b/c")) + + input = [{'a': {'b': 'b_1'}}] + self.assertEquals([{'b': 'b_1'}], f(input, "a")) + self.assertEquals(['b_1'], f(input, "a/b")) + self.assertEquals([], f(input, "a/b/c")) + + input = [{'a': {'b': {'c': 'c_1'}}}] + self.assertEquals([{'b': {'c': 'c_1'}}], f(input, "a")) + self.assertEquals([{'c': 'c_1'}], f(input, "a/b")) + self.assertEquals(['c_1'], f(input, "a/b/c")) + + input = [{'a': {'b': {'c': 'c_1'}}}, {'a': None}] + self.assertEquals([{'b': {'c': 'c_1'}}], f(input, "a")) + self.assertEquals([{'c': 'c_1'}], f(input, "a/b")) + self.assertEquals(['c_1'], f(input, "a/b/c")) + + input = [{'a': {'b': {'c': 'c_1'}}}, + {'a': {'b': None}}] + self.assertEquals([{'b': {'c': 'c_1'}}, {'b': None}], f(input, "a")) + self.assertEquals([{'c': 'c_1'}], f(input, "a/b")) + self.assertEquals(['c_1'], f(input, "a/b/c")) + + input = [{'a': {'b': {'c': 'c_1'}}}, + {'a': {'b': {'c': 'c_2'}}}] + self.assertEquals([{'b': {'c': 'c_1'}}, {'b': {'c': 'c_2'}}], + f(input, "a")) + self.assertEquals([{'c': 'c_1'}, {'c': 'c_2'}], f(input, "a/b")) + self.assertEquals(['c_1', 'c_2'], f(input, "a/b/c")) + + self.assertEquals([], f(input, "a/b/c/d")) + self.assertEquals([], f(input, "c/a/b/d")) + self.assertEquals([], f(input, "i/r/t")) + + def test_flattens_lists(self): + f = utils.get_from_path + + input = [{'a': [1, 2, 3]}] + self.assertEquals([1, 2, 3], f(input, "a")) + self.assertEquals([], f(input, "a/b")) + self.assertEquals([], f(input, "a/b/c")) + + input = [{'a': {'b': [1, 2, 3]}}] + self.assertEquals([{'b': [1, 2, 3]}], f(input, "a")) + self.assertEquals([1, 2, 3], f(input, "a/b")) + self.assertEquals([], f(input, "a/b/c")) + + input = [{'a': {'b': [1, 2, 3]}}, {'a': {'b': [4, 5, 6]}}] + self.assertEquals([1, 2, 3, 4, 5, 6], f(input, "a/b")) + self.assertEquals([], f(input, "a/b/c")) + + input = [{'a': [{'b': [1, 2, 3]}, {'b': [4, 5, 6]}]}] + self.assertEquals([1, 2, 3, 4, 5, 6], f(input, "a/b")) + self.assertEquals([], f(input, "a/b/c")) + + input = [{'a': [1, 2, {'b': 'b_1'}]}] + self.assertEquals([1, 2, {'b': 'b_1'}], f(input, "a")) + self.assertEquals(['b_1'], f(input, "a/b")) + + def test_bad_xpath(self): + f = utils.get_from_path + + self.assertRaises(exception.Error, f, [], None) + self.assertRaises(exception.Error, f, [], "") + self.assertRaises(exception.Error, f, [], "/") + self.assertRaises(exception.Error, f, [], "/a") + self.assertRaises(exception.Error, f, [], "/a/") + self.assertRaises(exception.Error, f, [], "//") + self.assertRaises(exception.Error, f, [], "//a") + self.assertRaises(exception.Error, f, [], "a//a") + self.assertRaises(exception.Error, f, [], "a//a/") + self.assertRaises(exception.Error, f, [], "a/a/") + + def test_real_failure1(self): + # Real world failure case... + # We weren't coping when the input was a Dictionary instead of a List + # This led to test_accepts_dictionaries + f = utils.get_from_path + + inst = {'fixed_ip': {'floating_ips': [{'address': '1.2.3.4'}], + 'address': '192.168.0.3'}, + 'hostname': ''} + + private_ips = f(inst, 'fixed_ip/address') + public_ips = f(inst, 'fixed_ip/floating_ips/address') + self.assertEquals(['192.168.0.3'], private_ips) + self.assertEquals(['1.2.3.4'], public_ips) + + def test_accepts_dictionaries(self): + f = utils.get_from_path + + input = {'a': [1, 2, 3]} + self.assertEquals([1, 2, 3], f(input, "a")) + self.assertEquals([], f(input, "a/b")) + self.assertEquals([], f(input, "a/b/c")) + + input = {'a': {'b': [1, 2, 3]}} + self.assertEquals([{'b': [1, 2, 3]}], f(input, "a")) + self.assertEquals([1, 2, 3], f(input, "a/b")) + self.assertEquals([], f(input, "a/b/c")) + + input = {'a': [{'b': [1, 2, 3]}, {'b': [4, 5, 6]}]} + self.assertEquals([1, 2, 3, 4, 5, 6], f(input, "a/b")) + self.assertEquals([], f(input, "a/b/c")) + + input = {'a': [1, 2, {'b': 'b_1'}]} + self.assertEquals([1, 2, {'b': 'b_1'}], f(input, "a")) + self.assertEquals(['b_1'], f(input, "a/b")) diff --git a/nova/tests/test_virt.py b/nova/tests/test_virt.py index 6e5a0114b..648de3b77 100644 --- a/nova/tests/test_virt.py +++ b/nova/tests/test_virt.py @@ -14,6 +14,10 @@ # License for the specific language governing permissions and limitations # under the License. +import re +import os + +import eventlet from xml.etree.ElementTree import fromstring as xml_to_tree from xml.dom.minidom import parseString as xml_to_dom @@ -30,6 +34,70 @@ FLAGS = flags.FLAGS flags.DECLARE('instances_path', 'nova.compute.manager') +def _concurrency(wait, done, target): + wait.wait() + done.send() + + +class CacheConcurrencyTestCase(test.TestCase): + def setUp(self): + super(CacheConcurrencyTestCase, self).setUp() + + def fake_exists(fname): + basedir = os.path.join(FLAGS.instances_path, '_base') + if fname == basedir: + return True + return False + + def fake_execute(*args, **kwargs): + pass + + self.stubs.Set(os.path, 'exists', fake_exists) + self.stubs.Set(utils, 'execute', fake_execute) + + def test_same_fname_concurrency(self): + """Ensures that the same fname cache runs at a sequentially""" + conn = libvirt_conn.LibvirtConnection + wait1 = eventlet.event.Event() + done1 = eventlet.event.Event() + eventlet.spawn(conn._cache_image, _concurrency, + 'target', 'fname', False, wait1, done1) + wait2 = eventlet.event.Event() + done2 = eventlet.event.Event() + eventlet.spawn(conn._cache_image, _concurrency, + 'target', 'fname', False, wait2, done2) + wait2.send() + eventlet.sleep(0) + try: + self.assertFalse(done2.ready()) + self.assertTrue('fname' in conn._image_sems) + finally: + wait1.send() + done1.wait() + eventlet.sleep(0) + self.assertTrue(done2.ready()) + self.assertFalse('fname' in conn._image_sems) + + def test_different_fname_concurrency(self): + """Ensures that two different fname caches are concurrent""" + conn = libvirt_conn.LibvirtConnection + wait1 = eventlet.event.Event() + done1 = eventlet.event.Event() + eventlet.spawn(conn._cache_image, _concurrency, + 'target', 'fname2', False, wait1, done1) + wait2 = eventlet.event.Event() + done2 = eventlet.event.Event() + eventlet.spawn(conn._cache_image, _concurrency, + 'target', 'fname1', False, wait2, done2) + wait2.send() + eventlet.sleep(0) + try: + self.assertTrue(done2.ready()) + finally: + wait1.send() + eventlet.sleep(0) + + class LibvirtConnTestCase(test.TestCase): def setUp(self): super(LibvirtConnTestCase, self).setUp() @@ -204,11 +272,12 @@ class LibvirtConnTestCase(test.TestCase): conn = libvirt_conn.LibvirtConnection(True) uri = conn.get_uri() self.assertEquals(uri, testuri) + db.instance_destroy(user_context, instance_ref['id']) def tearDown(self): - super(LibvirtConnTestCase, self).tearDown() self.manager.delete_project(self.project) self.manager.delete_user(self.user) + super(LibvirtConnTestCase, self).tearDown() class IptablesFirewallTestCase(test.TestCase): @@ -233,16 +302,22 @@ class IptablesFirewallTestCase(test.TestCase): self.manager.delete_user(self.user) super(IptablesFirewallTestCase, self).tearDown() - in_rules = [ + in_nat_rules = [ + '# Generated by iptables-save v1.4.10 on Sat Feb 19 00:03:19 2011', + '*nat', + ':PREROUTING ACCEPT [1170:189210]', + ':INPUT ACCEPT [844:71028]', + ':OUTPUT ACCEPT [5149:405186]', + ':POSTROUTING ACCEPT [5063:386098]' + ] + + in_filter_rules = [ '# Generated by iptables-save v1.4.4 on Mon Dec 6 11:54:13 2010', '*filter', ':INPUT ACCEPT [969615:281627771]', ':FORWARD ACCEPT [0:0]', ':OUTPUT ACCEPT [915599:63811649]', ':nova-block-ipv4 - [0:0]', - '-A INPUT -i virbr0 -p udp -m udp --dport 53 -j ACCEPT ', - '-A INPUT -i virbr0 -p tcp -m tcp --dport 53 -j ACCEPT ', - '-A INPUT -i virbr0 -p udp -m udp --dport 67 -j ACCEPT ', '-A INPUT -i virbr0 -p tcp -m tcp --dport 67 -j ACCEPT ', '-A FORWARD -d 192.168.122.0/24 -o virbr0 -m state --state RELATED' ',ESTABLISHED -j ACCEPT ', @@ -254,7 +329,7 @@ class IptablesFirewallTestCase(test.TestCase): '# Completed on Mon Dec 6 11:54:13 2010', ] - in6_rules = [ + in6_filter_rules = [ '# Generated by ip6tables-save v1.4.4 on Tue Jan 18 23:47:56 2011', '*filter', ':INPUT ACCEPT [349155:75810423]', @@ -314,23 +389,34 @@ class IptablesFirewallTestCase(test.TestCase): instance_ref = db.instance_get(admin_ctxt, instance_ref['id']) # self.fw.add_instance(instance_ref) - def fake_iptables_execute(cmd, process_input=None): - if cmd == 'sudo ip6tables-save -t filter': - return '\n'.join(self.in6_rules), None - if cmd == 'sudo iptables-save -t filter': - return '\n'.join(self.in_rules), None - if cmd == 'sudo iptables-restore': - self.out_rules = process_input.split('\n') + def fake_iptables_execute(*cmd, **kwargs): + process_input = kwargs.get('process_input', None) + if cmd == ('sudo', 'ip6tables-save', '-t', 'filter'): + return '\n'.join(self.in6_filter_rules), None + if cmd == ('sudo', 'iptables-save', '-t', 'filter'): + return '\n'.join(self.in_filter_rules), None + if cmd == ('sudo', 'iptables-save', '-t', 'nat'): + return '\n'.join(self.in_nat_rules), None + if cmd == ('sudo', 'iptables-restore'): + lines = process_input.split('\n') + if '*filter' in lines: + self.out_rules = lines return '', '' - if cmd == 'sudo ip6tables-restore': - self.out6_rules = process_input.split('\n') + if cmd == ('sudo', 'ip6tables-restore'): + lines = process_input.split('\n') + if '*filter' in lines: + self.out6_rules = lines return '', '' - self.fw.execute = fake_iptables_execute + print cmd, kwargs + + from nova.network import linux_net + linux_net.iptables_manager.execute = fake_iptables_execute self.fw.prepare_instance_filter(instance_ref) self.fw.apply_instance_filter(instance_ref) - in_rules = filter(lambda l: not l.startswith('#'), self.in_rules) + in_rules = filter(lambda l: not l.startswith('#'), + self.in_filter_rules) for rule in in_rules: if not 'nova' in rule: self.assertTrue(rule in self.out_rules, @@ -353,18 +439,20 @@ class IptablesFirewallTestCase(test.TestCase): self.assertTrue(security_group_chain, "The security group chain wasn't added") - self.assertTrue('-A %s -p icmp -s 192.168.11.0/24 -j ACCEPT' % \ - security_group_chain in self.out_rules, + regex = re.compile('-A .* -p icmp -s 192.168.11.0/24 -j ACCEPT') + self.assertTrue(len(filter(regex.match, self.out_rules)) > 0, "ICMP acceptance rule wasn't added") - self.assertTrue('-A %s -p icmp -s 192.168.11.0/24 -m icmp --icmp-type ' - '8 -j ACCEPT' % security_group_chain in self.out_rules, + regex = re.compile('-A .* -p icmp -s 192.168.11.0/24 -m icmp ' + '--icmp-type 8 -j ACCEPT') + self.assertTrue(len(filter(regex.match, self.out_rules)) > 0, "ICMP Echo Request acceptance rule wasn't added") - self.assertTrue('-A %s -p tcp -s 192.168.10.0/24 -m multiport ' - '--dports 80:81 -j ACCEPT' % security_group_chain \ - in self.out_rules, + regex = re.compile('-A .* -p tcp -s 192.168.10.0/24 -m multiport ' + '--dports 80:81 -j ACCEPT') + self.assertTrue(len(filter(regex.match, self.out_rules)) > 0, "TCP port 80/81 acceptance rule wasn't added") + db.instance_destroy(admin_ctxt, instance_ref['id']) class NWFilterTestCase(test.TestCase): @@ -388,6 +476,7 @@ class NWFilterTestCase(test.TestCase): def tearDown(self): self.manager.delete_project(self.project) self.manager.delete_user(self.user) + super(NWFilterTestCase, self).tearDown() def test_cidr_rule_nwfilter_xml(self): cloud_controller = cloud.CloudController() @@ -514,3 +603,4 @@ class NWFilterTestCase(test.TestCase): self.fw.apply_instance_filter(instance) _ensure_all_called() self.teardown_security_group() + db.instance_destroy(admin_ctxt, instance_ref['id']) diff --git a/nova/tests/test_volume.py b/nova/tests/test_volume.py index b40ca004b..f698c85b5 100644 --- a/nova/tests/test_volume.py +++ b/nova/tests/test_volume.py @@ -99,7 +99,7 @@ class VolumeTestCase(test.TestCase): def test_run_attach_detach_volume(self): """Make sure volume can be attached and detached from instance.""" inst = {} - inst['image_id'] = 'ami-test' + inst['image_id'] = 1 inst['reservation_id'] = 'r-fakeres' inst['launch_time'] = '10' inst['user_id'] = 'fake' diff --git a/nova/tests/test_xenapi.py b/nova/tests/test_xenapi.py index 9f5b266f3..c26dc8639 100644 --- a/nova/tests/test_xenapi.py +++ b/nova/tests/test_xenapi.py @@ -31,7 +31,9 @@ from nova.compute import power_state from nova.virt import xenapi_conn from nova.virt.xenapi import fake as xenapi_fake from nova.virt.xenapi import volume_utils +from nova.virt.xenapi import vm_utils from nova.virt.xenapi.vmops import SimpleDH +from nova.virt.xenapi.vmops import VMOps from nova.tests.db import fakes as db_fakes from nova.tests.xenapi import stubs from nova.tests.glance import stubs as glance_stubs @@ -141,6 +143,10 @@ class XenAPIVolumeTestCase(test.TestCase): self.stubs.UnsetAll() +def reset_network(*args): + pass + + class XenAPIVMTestCase(test.TestCase): """ Unit tests for VM operations @@ -162,6 +168,8 @@ class XenAPIVMTestCase(test.TestCase): stubs.stubout_session(self.stubs, stubs.FakeSessionForVMTests) stubs.stubout_get_this_vm_uuid(self.stubs) stubs.stubout_stream_disk(self.stubs) + stubs.stubout_is_vdi_pv(self.stubs) + self.stubs.Set(VMOps, 'reset_network', reset_network) glance_stubs.stubout_glance_client(self.stubs, glance_stubs.FakeGlance) self.conn = xenapi_conn.get_connection(False) @@ -225,7 +233,7 @@ class XenAPIVMTestCase(test.TestCase): vm = vms[0] # Check that m1.large above turned into the right thing. - instance_type = instance_types.INSTANCE_TYPES['m1.large'] + instance_type = db.instance_type_get_by_name(conn, 'm1.large') mem_kib = long(instance_type['memory_mb']) << 10 mem_bytes = str(mem_kib << 10) vcpus = instance_type['vcpus'] @@ -243,7 +251,8 @@ class XenAPIVMTestCase(test.TestCase): # Check that the VM is running according to XenAPI. self.assertEquals(vm['power_state'], 'Running') - def _test_spawn(self, image_id, kernel_id, ramdisk_id): + def _test_spawn(self, image_id, kernel_id, ramdisk_id, + instance_type="m1.large"): stubs.stubout_session(self.stubs, stubs.FakeSessionForVMTests) values = {'name': 1, 'id': 1, @@ -252,7 +261,7 @@ class XenAPIVMTestCase(test.TestCase): 'image_id': image_id, 'kernel_id': kernel_id, 'ramdisk_id': ramdisk_id, - 'instance_type': 'm1.large', + 'instance_type': instance_type, 'mac_address': 'aa:bb:cc:dd:ee:ff', } conn = xenapi_conn.get_connection(False) @@ -260,6 +269,12 @@ class XenAPIVMTestCase(test.TestCase): conn.spawn(instance) self.check_vm_record(conn) + def test_spawn_not_enough_memory(self): + FLAGS.xenapi_image_service = 'glance' + self.assertRaises(Exception, + self._test_spawn, + 1, 2, 3, "m1.xlarge") + def test_spawn_raw_objectstore(self): FLAGS.xenapi_image_service = 'objectstore' self._test_spawn(1, None, None) @@ -270,11 +285,17 @@ class XenAPIVMTestCase(test.TestCase): def test_spawn_raw_glance(self): FLAGS.xenapi_image_service = 'glance' - self._test_spawn(1, None, None) + self._test_spawn(glance_stubs.FakeGlance.IMAGE_RAW, None, None) + + def test_spawn_vhd_glance(self): + FLAGS.xenapi_image_service = 'glance' + self._test_spawn(glance_stubs.FakeGlance.IMAGE_VHD, None, None) def test_spawn_glance(self): FLAGS.xenapi_image_service = 'glance' - self._test_spawn(1, 2, 3) + self._test_spawn(glance_stubs.FakeGlance.IMAGE_MACHINE, + glance_stubs.FakeGlance.IMAGE_KERNEL, + glance_stubs.FakeGlance.IMAGE_RAMDISK) def tearDown(self): super(XenAPIVMTestCase, self).tearDown() @@ -323,3 +344,113 @@ class XenAPIDiffieHellmanTestCase(test.TestCase): def tearDown(self): super(XenAPIDiffieHellmanTestCase, self).tearDown() + + +class XenAPIMigrateInstance(test.TestCase): + """ + Unit test for verifying migration-related actions + """ + + def setUp(self): + super(XenAPIMigrateInstance, self).setUp() + self.stubs = stubout.StubOutForTesting() + FLAGS.target_host = '127.0.0.1' + FLAGS.xenapi_connection_url = 'test_url' + FLAGS.xenapi_connection_password = 'test_pass' + db_fakes.stub_out_db_instance_api(self.stubs) + stubs.stub_out_get_target(self.stubs) + xenapi_fake.reset() + self.manager = manager.AuthManager() + self.user = self.manager.create_user('fake', 'fake', 'fake', + admin=True) + self.project = self.manager.create_project('fake', 'fake', 'fake') + self.values = {'name': 1, 'id': 1, + 'project_id': self.project.id, + 'user_id': self.user.id, + 'image_id': 1, + 'kernel_id': None, + 'ramdisk_id': None, + 'instance_type': 'm1.large', + 'mac_address': 'aa:bb:cc:dd:ee:ff', + } + stubs.stub_out_migration_methods(self.stubs) + glance_stubs.stubout_glance_client(self.stubs, + glance_stubs.FakeGlance) + + def tearDown(self): + super(XenAPIMigrateInstance, self).tearDown() + self.manager.delete_project(self.project) + self.manager.delete_user(self.user) + self.stubs.UnsetAll() + + def test_migrate_disk_and_power_off(self): + instance = db.instance_create(self.values) + stubs.stubout_session(self.stubs, stubs.FakeSessionForMigrationTests) + conn = xenapi_conn.get_connection(False) + conn.migrate_disk_and_power_off(instance, '127.0.0.1') + + def test_finish_resize(self): + instance = db.instance_create(self.values) + stubs.stubout_session(self.stubs, stubs.FakeSessionForMigrationTests) + conn = xenapi_conn.get_connection(False) + conn.finish_resize(instance, dict(base_copy='hurr', cow='durr')) + + +class XenAPIDetermineDiskImageTestCase(test.TestCase): + """ + Unit tests for code that detects the ImageType + """ + def setUp(self): + super(XenAPIDetermineDiskImageTestCase, self).setUp() + glance_stubs.stubout_glance_client(self.stubs, + glance_stubs.FakeGlance) + + class FakeInstance(object): + pass + + self.fake_instance = FakeInstance() + self.fake_instance.id = 42 + + def assert_disk_type(self, disk_type): + dt = vm_utils.VMHelper.determine_disk_image_type( + self.fake_instance) + self.assertEqual(disk_type, dt) + + def test_instance_disk(self): + """ + If a kernel is specified then the image type is DISK (aka machine) + """ + FLAGS.xenapi_image_service = 'objectstore' + self.fake_instance.image_id = glance_stubs.FakeGlance.IMAGE_MACHINE + self.fake_instance.kernel_id = glance_stubs.FakeGlance.IMAGE_KERNEL + self.assert_disk_type(vm_utils.ImageType.DISK) + + def test_instance_disk_raw(self): + """ + If the kernel isn't specified, and we're not using Glance, then + DISK_RAW is assumed. + """ + FLAGS.xenapi_image_service = 'objectstore' + self.fake_instance.image_id = glance_stubs.FakeGlance.IMAGE_RAW + self.fake_instance.kernel_id = None + self.assert_disk_type(vm_utils.ImageType.DISK_RAW) + + def test_glance_disk_raw(self): + """ + If we're using Glance, then defer to the image_type field, which in + this case will be 'raw'. + """ + FLAGS.xenapi_image_service = 'glance' + self.fake_instance.image_id = glance_stubs.FakeGlance.IMAGE_RAW + self.fake_instance.kernel_id = None + self.assert_disk_type(vm_utils.ImageType.DISK_RAW) + + def test_glance_disk_vhd(self): + """ + If we're using Glance, then defer to the image_type field, which in + this case will be 'vhd'. + """ + FLAGS.xenapi_image_service = 'glance' + self.fake_instance.image_id = glance_stubs.FakeGlance.IMAGE_VHD + self.fake_instance.kernel_id = None + self.assert_disk_type(vm_utils.ImageType.DISK_VHD) diff --git a/nova/tests/test_zones.py b/nova/tests/test_zones.py new file mode 100644 index 000000000..5a52a0506 --- /dev/null +++ b/nova/tests/test_zones.py @@ -0,0 +1,172 @@ +# Copyright 2010 United States Government as represented by the +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +Tests For ZoneManager +""" + +import datetime +import mox +import novaclient + +from nova import context +from nova import db +from nova import flags +from nova import service +from nova import test +from nova import rpc +from nova import utils +from nova.auth import manager as auth_manager +from nova.scheduler import zone_manager + +FLAGS = flags.FLAGS + + +class FakeZone: + """Represents a fake zone from the db""" + def __init__(self, *args, **kwargs): + for k, v in kwargs.iteritems(): + setattr(self, k, v) + + +def exploding_novaclient(zone): + """Used when we want to simulate a novaclient call failing.""" + raise Exception("kaboom") + + +class ZoneManagerTestCase(test.TestCase): + """Test case for zone manager""" + def test_ping(self): + zm = zone_manager.ZoneManager() + self.mox.StubOutWithMock(zm, '_refresh_from_db') + self.mox.StubOutWithMock(zm, '_poll_zones') + zm._refresh_from_db(mox.IgnoreArg()) + zm._poll_zones(mox.IgnoreArg()) + + self.mox.ReplayAll() + zm.ping(None) + self.mox.VerifyAll() + + def test_refresh_from_db_new(self): + zm = zone_manager.ZoneManager() + + self.mox.StubOutWithMock(db, 'zone_get_all') + db.zone_get_all(mox.IgnoreArg()).AndReturn([ + FakeZone(id=1, api_url='http://foo.com', username='user1', + password='pass1'), + ]) + + self.assertEquals(len(zm.zone_states), 0) + + self.mox.ReplayAll() + zm._refresh_from_db(None) + self.mox.VerifyAll() + + self.assertEquals(len(zm.zone_states), 1) + self.assertEquals(zm.zone_states[1].username, 'user1') + + def test_refresh_from_db_replace_existing(self): + zm = zone_manager.ZoneManager() + zone_state = zone_manager.ZoneState() + zone_state.update_credentials(FakeZone(id=1, api_url='http://foo.com', + username='user1', password='pass1')) + zm.zone_states[1] = zone_state + + self.mox.StubOutWithMock(db, 'zone_get_all') + db.zone_get_all(mox.IgnoreArg()).AndReturn([ + FakeZone(id=1, api_url='http://foo.com', username='user2', + password='pass2'), + ]) + + self.assertEquals(len(zm.zone_states), 1) + + self.mox.ReplayAll() + zm._refresh_from_db(None) + self.mox.VerifyAll() + + self.assertEquals(len(zm.zone_states), 1) + self.assertEquals(zm.zone_states[1].username, 'user2') + + def test_refresh_from_db_missing(self): + zm = zone_manager.ZoneManager() + zone_state = zone_manager.ZoneState() + zone_state.update_credentials(FakeZone(id=1, api_url='http://foo.com', + username='user1', password='pass1')) + zm.zone_states[1] = zone_state + + self.mox.StubOutWithMock(db, 'zone_get_all') + db.zone_get_all(mox.IgnoreArg()).AndReturn([]) + + self.assertEquals(len(zm.zone_states), 1) + + self.mox.ReplayAll() + zm._refresh_from_db(None) + self.mox.VerifyAll() + + self.assertEquals(len(zm.zone_states), 0) + + def test_refresh_from_db_add_and_delete(self): + zm = zone_manager.ZoneManager() + zone_state = zone_manager.ZoneState() + zone_state.update_credentials(FakeZone(id=1, api_url='http://foo.com', + username='user1', password='pass1')) + zm.zone_states[1] = zone_state + + self.mox.StubOutWithMock(db, 'zone_get_all') + + db.zone_get_all(mox.IgnoreArg()).AndReturn([ + FakeZone(id=2, api_url='http://foo.com', username='user2', + password='pass2'), + ]) + self.assertEquals(len(zm.zone_states), 1) + + self.mox.ReplayAll() + zm._refresh_from_db(None) + self.mox.VerifyAll() + + self.assertEquals(len(zm.zone_states), 1) + self.assertEquals(zm.zone_states[2].username, 'user2') + + def test_poll_zone(self): + self.mox.StubOutWithMock(zone_manager, '_call_novaclient') + zone_manager._call_novaclient(mox.IgnoreArg()).AndReturn( + dict(name='zohan', capabilities='hairdresser')) + + zone_state = zone_manager.ZoneState() + zone_state.update_credentials(FakeZone(id=2, + api_url='http://foo.com', username='user2', + password='pass2')) + zone_state.attempt = 1 + + self.mox.ReplayAll() + zone_manager._poll_zone(zone_state) + self.mox.VerifyAll() + self.assertEquals(zone_state.attempt, 0) + self.assertEquals(zone_state.name, 'zohan') + + def test_poll_zone_fails(self): + self.stubs.Set(zone_manager, "_call_novaclient", exploding_novaclient) + + zone_state = zone_manager.ZoneState() + zone_state.update_credentials(FakeZone(id=2, + api_url='http://foo.com', username='user2', + password='pass2')) + zone_state.attempt = FLAGS.zone_failures_to_offline - 1 + + self.mox.ReplayAll() + zone_manager._poll_zone(zone_state) + self.mox.VerifyAll() + self.assertEquals(zone_state.attempt, 3) + self.assertFalse(zone_state.is_active) + self.assertEquals(zone_state.name, None) diff --git a/nova/tests/xenapi/stubs.py b/nova/tests/xenapi/stubs.py index 624995ada..70d46a1fb 100644 --- a/nova/tests/xenapi/stubs.py +++ b/nova/tests/xenapi/stubs.py @@ -20,6 +20,7 @@ from nova.virt import xenapi_conn from nova.virt.xenapi import fake from nova.virt.xenapi import volume_utils from nova.virt.xenapi import vm_utils +from nova.virt.xenapi import vmops def stubout_instance_snapshot(stubs): @@ -27,7 +28,7 @@ def stubout_instance_snapshot(stubs): def fake_fetch_image(cls, session, instance_id, image, user, project, type): # Stubout wait_for_task - def fake_wait_for_task(self, id, task): + def fake_wait_for_task(self, task, id): class FakeEvent: def send(self, value): @@ -130,6 +131,12 @@ def stubout_stream_disk(stubs): stubs.Set(vm_utils, '_stream_disk', f) +def stubout_is_vdi_pv(stubs): + def f(_1): + return False + stubs.Set(vm_utils, '_is_vdi_pv', f) + + class FakeSessionForVMTests(fake.SessionBase): """ Stubs out a XenAPISession for VM tests """ def __init__(self, uri): @@ -171,6 +178,12 @@ class FakeSessionForVMTests(fake.SessionBase): def VM_destroy(self, session_ref, vm_ref): fake.destroy_vm(vm_ref) + def SR_scan(self, session_ref, sr_ref): + pass + + def VDI_set_name_label(self, session_ref, vdi_ref, name_label): + pass + class FakeSessionForVolumeTests(fake.SessionBase): """ Stubs out a XenAPISession for Volume tests """ @@ -205,3 +218,60 @@ class FakeSessionForVolumeFailedTests(FakeSessionForVolumeTests): def SR_forget(self, _1, ref): pass + + +class FakeSessionForMigrationTests(fake.SessionBase): + """Stubs out a XenAPISession for Migration tests""" + def __init__(self, uri): + super(FakeSessionForMigrationTests, self).__init__(uri) + + def VDI_get_by_uuid(*args): + return 'hurr' + + def VM_start(self, _1, ref, _2, _3): + vm = fake.get_record('VM', ref) + if vm['power_state'] != 'Halted': + raise fake.Failure(['VM_BAD_POWER_STATE', ref, 'Halted', + vm['power_state']]) + vm['power_state'] = 'Running' + vm['is_a_template'] = False + vm['is_control_domain'] = False + + +def stub_out_migration_methods(stubs): + def fake_get_snapshot(self, instance): + return 'foo', 'bar' + + @classmethod + def fake_get_vdi(cls, session, vm_ref): + vdi_ref = fake.create_vdi(name_label='derp', read_only=False, + sr_ref='herp', sharable=False) + vdi_rec = session.get_xenapi().VDI.get_record(vdi_ref) + return vdi_ref, {'uuid': vdi_rec['uuid'], } + + def fake_shutdown(self, inst, vm, method='clean'): + pass + + @classmethod + def fake_sr(cls, session, *args): + pass + + @classmethod + def fake_get_sr_path(cls, *args): + return "fake" + + def fake_destroy(*args, **kwargs): + pass + + def fake_reset_network(*args, **kwargs): + pass + + stubs.Set(vmops.VMOps, '_destroy', fake_destroy) + stubs.Set(vm_utils.VMHelper, 'scan_default_sr', fake_sr) + stubs.Set(vm_utils.VMHelper, 'scan_sr', fake_sr) + stubs.Set(vmops.VMOps, '_get_snapshot', fake_get_snapshot) + stubs.Set(vm_utils.VMHelper, 'get_vdi_for_vm_safely', fake_get_vdi) + stubs.Set(xenapi_conn.XenAPISession, 'wait_for_task', lambda x, y, z: None) + stubs.Set(vm_utils.VMHelper, 'get_sr_path', fake_get_sr_path) + stubs.Set(vmops.VMOps, 'reset_network', fake_reset_network) + stubs.Set(vmops.VMOps, '_shutdown', fake_shutdown) diff --git a/nova/twistd.py b/nova/twistd.py index 6390a8144..c07ed991f 100644 --- a/nova/twistd.py +++ b/nova/twistd.py @@ -43,8 +43,6 @@ else: FLAGS = flags.FLAGS -flags.DEFINE_string('logdir', None, 'directory to keep log files in ' - '(will be prepended to $logfile)') class TwistdServerOptions(ServerOptions): @@ -150,6 +148,7 @@ def WrapTwistedOptions(wrapped): options.insert(0, '') args = FLAGS(options) + logging.setup() argv = args[1:] # ignore subcommands @@ -260,7 +259,6 @@ def serve(filename): print 'usage: %s [options] [start|stop|restart]' % argv[0] sys.exit(1) - logging.basicConfig() logging.debug(_("Full set of FLAGS:")) for flag in FLAGS: logging.debug("%s : %s" % (flag, FLAGS.get(flag, None))) diff --git a/nova/utils.py b/nova/utils.py index 5f5225289..87e726394 100644 --- a/nova/utils.py +++ b/nova/utils.py @@ -2,6 +2,7 @@ # Copyright 2010 United States Government as represented by the # Administrator of the National Aeronautics and Space Administration. +# Copyright 2011 Justin Santa Barbara # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -20,30 +21,37 @@ System-level utilities and helper functions. """ +import base64 import datetime +import functools import inspect import json +import lockfile +import netaddr import os import random -import subprocess +import re import socket +import string import struct import sys import time +import types from xml.sax import saxutils -import re -import netaddr from eventlet import event from eventlet import greenthread - +from eventlet.green import subprocess +None from nova import exception from nova.exception import ProcessExecutionError +from nova import flags from nova import log as logging LOG = logging.getLogger("nova.utils") TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" +FLAGS = flags.FLAGS def import_class(import_str): @@ -53,7 +61,7 @@ def import_class(import_str): __import__(mod_str) return getattr(sys.modules[mod_str], class_str) except (ImportError, ValueError, AttributeError), exc: - logging.debug(_('Inner Exception: %s'), exc) + LOG.debug(_('Inner Exception: %s'), exc) raise exception.NotFound(_('Class %s cannot be found') % class_str) @@ -121,35 +129,90 @@ def fetchfile(url, target): # c.perform() # c.close() # fp.close() - execute("curl --fail %s -o %s" % (url, target)) + execute("curl", "--fail", url, "-o", target) + + +def execute(*cmd, **kwargs): + process_input = kwargs.get('process_input', None) + addl_env = kwargs.get('addl_env', None) + check_exit_code = kwargs.get('check_exit_code', 0) + stdin = kwargs.get('stdin', subprocess.PIPE) + stdout = kwargs.get('stdout', subprocess.PIPE) + stderr = kwargs.get('stderr', subprocess.PIPE) + attempts = kwargs.get('attempts', 1) + cmd = map(str, cmd) + + while attempts > 0: + attempts -= 1 + try: + LOG.debug(_("Running cmd (subprocess): %s"), ' '.join(cmd)) + env = os.environ.copy() + if addl_env: + env.update(addl_env) + obj = subprocess.Popen(cmd, stdin=stdin, + stdout=stdout, stderr=stderr, env=env) + result = None + if process_input != None: + result = obj.communicate(process_input) + else: + result = obj.communicate() + obj.stdin.close() + if obj.returncode: + LOG.debug(_("Result was %s") % obj.returncode) + if type(check_exit_code) == types.IntType \ + and obj.returncode != check_exit_code: + (stdout, stderr) = result + raise ProcessExecutionError(exit_code=obj.returncode, + stdout=stdout, + stderr=stderr, + cmd=' '.join(cmd)) + # NOTE(termie): this appears to be necessary to let the subprocess + # call clean something up in between calls, without + # it two execute calls in a row hangs the second one + greenthread.sleep(0) + return result + except ProcessExecutionError: + if not attempts: + raise + else: + LOG.debug(_("%r failed. Retrying."), cmd) + greenthread.sleep(random.randint(20, 200) / 100.0) -def execute(cmd, process_input=None, addl_env=None, check_exit_code=True): - LOG.debug(_("Running cmd (subprocess): %s"), cmd) - env = os.environ.copy() +def ssh_execute(ssh, cmd, process_input=None, + addl_env=None, check_exit_code=True): + LOG.debug(_("Running cmd (SSH): %s"), ' '.join(cmd)) if addl_env: - env.update(addl_env) - obj = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) - result = None - if process_input != None: - result = obj.communicate(process_input) - else: - result = obj.communicate() - obj.stdin.close() - if obj.returncode: - LOG.debug(_("Result was %s") % obj.returncode) - if check_exit_code and obj.returncode != 0: - (stdout, stderr) = result - raise ProcessExecutionError(exit_code=obj.returncode, - stdout=stdout, - stderr=stderr, - cmd=cmd) - # NOTE(termie): this appears to be necessary to let the subprocess call - # clean something up in between calls, without it two - # execute calls in a row hangs the second one - greenthread.sleep(0) - return result + raise exception.Error("Environment not supported over SSH") + + if process_input: + # This is (probably) fixable if we need it... + raise exception.Error("process_input not supported over SSH") + + stdin_stream, stdout_stream, stderr_stream = ssh.exec_command(cmd) + channel = stdout_stream.channel + + #stdin.write('process_input would go here') + #stdin.flush() + + # NOTE(justinsb): This seems suspicious... + # ...other SSH clients have buffering issues with this approach + stdout = stdout_stream.read() + stderr = stderr_stream.read() + stdin_stream.close() + + exit_status = channel.recv_exit_status() + + # exit_status == -1 if no exit code was returned + if exit_status != -1: + LOG.debug(_("Result was %s") % exit_status) + if check_exit_code and exit_status != 0: + raise exception.ProcessExecutionError(exit_code=exit_status, + stdout=stdout, + stderr=stderr, + cmd=' '.join(cmd)) + + return (stdout, stderr) def abspath(s): @@ -180,9 +243,9 @@ def debug(arg): return arg -def runthis(prompt, cmd, check_exit_code=True): - LOG.debug(_("Running %s"), (cmd)) - rv, err = execute(cmd, check_exit_code=check_exit_code) +def runthis(prompt, *cmd, **kwargs): + LOG.debug(_("Running %s"), (" ".join(cmd))) + rv, err = execute(*cmd, **kwargs) def generate_uid(topic, size=8): @@ -199,13 +262,22 @@ def generate_mac(): return ':'.join(map(lambda x: "%02x" % x, mac)) +def generate_password(length=20): + """Generate a random sequence of letters and digits + to be used as a password. Note that this is not intended + to represent the ultimate in security. + """ + chrs = string.letters + string.digits + return "".join([random.choice(chrs) for i in xrange(length)]) + + def last_octet(address): return int(address.split(".")[-1]) def get_my_linklocal(interface): try: - if_str = execute("ip -f inet6 -o addr show %s" % interface) + if_str = execute("ip", "-f", "inet6", "-o", "addr", "show", interface) condition = "\s+inet6\s+([0-9a-f:]+)/\d+\s+scope\s+link" links = [re.search(condition, x) for x in if_str[0].split('\n')] address = [w.group(1) for w in links if w is not None] @@ -440,3 +512,76 @@ def dumps(value): def loads(s): return json.loads(s) + + +def synchronized(name): + def wrap(f): + @functools.wraps(f) + def inner(*args, **kwargs): + lock = lockfile.FileLock(os.path.join(FLAGS.lock_path, + 'nova-%s.lock' % name)) + with lock: + return f(*args, **kwargs) + return inner + return wrap + + +def ensure_b64_encoding(val): + """Safety method to ensure that values expected to be base64-encoded + actually are. If they are, the value is returned unchanged. Otherwise, + the encoded value is returned. + """ + try: + dummy = base64.decode(val) + return val + except TypeError: + return base64.b64encode(val) + + +def get_from_path(items, path): + """ Returns a list of items matching the specified path. Takes an + XPath-like expression e.g. prop1/prop2/prop3, and for each item in items, + looks up items[prop1][prop2][prop3]. Like XPath, if any of the + intermediate results are lists it will treat each list item individually. + A 'None' in items or any child expressions will be ignored, this function + will not throw because of None (anywhere) in items. The returned list + will contain no None values.""" + + if path is None: + raise exception.Error("Invalid mini_xpath") + + (first_token, sep, remainder) = path.partition("/") + + if first_token == "": + raise exception.Error("Invalid mini_xpath") + + results = [] + + if items is None: + return results + + if not isinstance(items, types.ListType): + # Wrap single objects in a list + items = [items] + + for item in items: + if item is None: + continue + get_method = getattr(item, "get", None) + if get_method is None: + continue + child = get_method(first_token) + if child is None: + continue + if isinstance(child, types.ListType): + # Flatten intermediate lists + for x in child: + results.append(x) + else: + results.append(child) + + if not sep: + # No more tokens + return results + else: + return get_from_path(results, remainder) diff --git a/nova/virt/disk.py b/nova/virt/disk.py index c5565abfa..5d499c42c 100644 --- a/nova/virt/disk.py +++ b/nova/virt/disk.py @@ -38,6 +38,10 @@ flags.DEFINE_integer('minimum_root_size', 1024 * 1024 * 1024 * 10, 'minimum size in bytes of root partition') flags.DEFINE_integer('block_size', 1024 * 1024 * 256, 'block_size to use for dd') +flags.DEFINE_integer('timeout_nbd', 10, + 'time to wait for a NBD device coming up') +flags.DEFINE_integer('max_nbd_devices', 16, + 'maximum number of possible nbd devices') def extend(image, size): @@ -45,10 +49,10 @@ def extend(image, size): file_size = os.path.getsize(image) if file_size >= size: return - utils.execute('truncate -s %s %s' % (size, image)) + utils.execute('truncate', '-s', size, image) # NOTE(vish): attempts to resize filesystem - utils.execute('e2fsck -fp %s' % image, check_exit_code=False) - utils.execute('resize2fs %s' % image, check_exit_code=False) + utils.execute('e2fsck', '-fp', image, check_exit_code=False) + utils.execute('resize2fs', image, check_exit_code=False) def inject_data(image, key=None, net=None, partition=None, nbd=False): @@ -64,7 +68,7 @@ def inject_data(image, key=None, net=None, partition=None, nbd=False): try: if not partition is None: # create partition - out, err = utils.execute('sudo kpartx -a %s' % device) + out, err = utils.execute('sudo', 'kpartx', '-a', device) if err: raise exception.Error(_('Failed to load partition: %s') % err) mapped_device = '/dev/mapper/%sp%s' % (device.split('/')[-1], @@ -80,13 +84,14 @@ def inject_data(image, key=None, net=None, partition=None, nbd=False): mapped_device) # Configure ext2fs so that it doesn't auto-check every N boots - out, err = utils.execute('sudo tune2fs -c 0 -i 0 %s' % mapped_device) + out, err = utils.execute('sudo', 'tune2fs', + '-c', 0, '-i', 0, mapped_device) tmpdir = tempfile.mkdtemp() try: # mount loopback to dir out, err = utils.execute( - 'sudo mount %s %s' % (mapped_device, tmpdir)) + 'sudo', 'mount', mapped_device, tmpdir) if err: raise exception.Error(_('Failed to mount filesystem: %s') % err) @@ -99,13 +104,13 @@ def inject_data(image, key=None, net=None, partition=None, nbd=False): _inject_net_into_fs(net, tmpdir) finally: # unmount device - utils.execute('sudo umount %s' % mapped_device) + utils.execute('sudo', 'umount', mapped_device) finally: # remove temporary directory - utils.execute('rmdir %s' % tmpdir) + utils.execute('rmdir', tmpdir) if not partition is None: # remove partitions - utils.execute('sudo kpartx -d %s' % device) + utils.execute('sudo', 'kpartx', '-d', device) finally: _unlink_device(device, nbd) @@ -114,16 +119,16 @@ def _link_device(image, nbd): """Link image to device using loopback or nbd""" if nbd: device = _allocate_device() - utils.execute('sudo qemu-nbd -c %s %s' % (device, image)) + utils.execute('sudo', 'qemu-nbd', '-c', device, image) # NOTE(vish): this forks into another process, so give it a chance # to set up before continuuing - for i in xrange(10): + for i in xrange(FLAGS.timeout_nbd): if os.path.exists("/sys/block/%s/pid" % os.path.basename(device)): return device time.sleep(1) raise exception.Error(_('nbd device %s did not show up') % device) else: - out, err = utils.execute('sudo losetup --find --show %s' % image) + out, err = utils.execute('sudo', 'losetup', '--find', '--show', image) if err: raise exception.Error(_('Could not attach image to loopback: %s') % err) @@ -133,13 +138,13 @@ def _link_device(image, nbd): def _unlink_device(device, nbd): """Unlink image from device using loopback or nbd""" if nbd: - utils.execute('sudo qemu-nbd -d %s' % device) + utils.execute('sudo', 'qemu-nbd', '-d', device) _free_device(device) else: - utils.execute('sudo losetup --detach %s' % device) + utils.execute('sudo', 'losetup', '--detach', device) -_DEVICES = ['/dev/nbd%s' % i for i in xrange(16)] +_DEVICES = ['/dev/nbd%s' % i for i in xrange(FLAGS.max_nbd_devices)] def _allocate_device(): @@ -166,11 +171,12 @@ def _inject_key_into_fs(key, fs): fs is the path to the base of the filesystem into which to inject the key. """ sshdir = os.path.join(fs, 'root', '.ssh') - utils.execute('sudo mkdir -p %s' % sshdir) # existing dir doesn't matter - utils.execute('sudo chown root %s' % sshdir) - utils.execute('sudo chmod 700 %s' % sshdir) + utils.execute('sudo', 'mkdir', '-p', sshdir) # existing dir doesn't matter + utils.execute('sudo', 'chown', 'root', sshdir) + utils.execute('sudo', 'chmod', '700', sshdir) keyfile = os.path.join(sshdir, 'authorized_keys') - utils.execute('sudo tee -a %s' % keyfile, '\n' + key.strip() + '\n') + utils.execute('sudo', 'tee', '-a', keyfile, + process_input='\n' + key.strip() + '\n') def _inject_net_into_fs(net, fs): @@ -179,8 +185,8 @@ def _inject_net_into_fs(net, fs): net is the contents of /etc/network/interfaces. """ netdir = os.path.join(os.path.join(fs, 'etc'), 'network') - utils.execute('sudo mkdir -p %s' % netdir) # existing dir doesn't matter - utils.execute('sudo chown root:root %s' % netdir) - utils.execute('sudo chmod 755 %s' % netdir) + utils.execute('sudo', 'mkdir', '-p', netdir) # existing dir doesn't matter + utils.execute('sudo', 'chown', 'root:root', netdir) + utils.execute('sudo', 'chmod', 755, netdir) netfile = os.path.join(netdir, 'interfaces') - utils.execute('sudo tee %s' % netfile, net) + utils.execute('sudo', 'tee', netfile, net) diff --git a/nova/virt/fake.py b/nova/virt/fake.py index 161445b86..c744acf91 100644 --- a/nova/virt/fake.py +++ b/nova/virt/fake.py @@ -139,6 +139,24 @@ class FakeConnection(object): """ pass + def get_host_ip_addr(self): + """ + Retrieves the IP address of the dom0 + """ + pass + + def resize(self, instance, flavor): + """ + Resizes/Migrates the specified instance. + + The flavor parameter determines whether or not the instance RAM and + disk space are modified, and if so, to what size. + + The work will be done asynchronously. This function returns a task + that allows the caller to detect when it is complete. + """ + pass + def set_admin_password(self, instance, new_pass): """ Set the root password on the specified instance. @@ -152,6 +170,21 @@ class FakeConnection(object): """ pass + def inject_file(self, instance, b64_path, b64_contents): + """ + Writes a file on the specified instance. + + The first parameter is an instance of nova.compute.service.Instance, + and so the instance is being specified as instance.name. The second + parameter is the base64-encoded path to which the file is to be + written on the instance; the third is the contents of the file, also + base64-encoded. + + The work will be done asynchronously. This function returns a + task that allows the caller to detect when it is complete. + """ + pass + def rescue(self, instance): """ Rescue the specified instance. @@ -164,6 +197,19 @@ class FakeConnection(object): """ pass + def migrate_disk_and_power_off(self, instance, dest): + """ + Transfers the disk of a running instance in multiple phases, turning + off the instance before the end. + """ + pass + + def attach_disk(self, instance, disk_info): + """ + Attaches the disk to an instance given the metadata disk_info + """ + pass + def pause(self, instance, callback): """ Pause the specified instance. @@ -304,7 +350,9 @@ class FakeConnection(object): return 'FAKE CONSOLE OUTPUT' def get_ajax_console(self, instance): - return 'http://fakeajaxconsole.com/?token=FAKETOKEN' + return {'token': 'FAKETOKEN', + 'host': 'fakeajaxconsole.com', + 'port': 6969} def get_console_pool_info(self, console_type): return {'address': '127.0.0.1', diff --git a/nova/virt/images.py b/nova/virt/images.py index 7a6fef330..2e3f2ee4d 100644 --- a/nova/virt/images.py +++ b/nova/virt/images.py @@ -28,29 +28,32 @@ import time import urllib2 import urlparse +from nova import context from nova import flags from nova import log as logging from nova import utils from nova.auth import manager from nova.auth import signer -from nova.objectstore import image FLAGS = flags.FLAGS -flags.DEFINE_bool('use_s3', True, - 'whether to get images from s3 or use local copy') - LOG = logging.getLogger('nova.virt.images') -def fetch(image, path, user, project): - if FLAGS.use_s3: - f = _fetch_s3_image - else: - f = _fetch_local_image - return f(image, path, user, project) +def fetch(image_id, path, _user, _project): + # TODO(vish): Improve context handling and add owner and auth data + # when it is added to glance. Right now there is no + # auth checking in glance, so we assume that access was + # checked before we got here. + image_service = utils.import_object(FLAGS.image_service) + with open(path, "wb") as image_file: + elevated = context.get_admin_context() + metadata = image_service.get(elevated, image_id, image_file) + return metadata +# NOTE(vish): The methods below should be unnecessary, but I'm leaving +# them in case the glance client does not work on windows. def _fetch_image_no_curl(url, path, headers): request = urllib2.Request(url) for (k, v) in headers.iteritems(): @@ -94,8 +97,7 @@ def _fetch_s3_image(image, path, user, project): cmd += ['-H', '\'%s: %s\'' % (k, v)] cmd += ['-o', path] - cmd_out = ' '.join(cmd) - return utils.execute(cmd_out) + return utils.execute(*cmd) def _fetch_local_image(image, path, user, project): @@ -103,13 +105,15 @@ def _fetch_local_image(image, path, user, project): if sys.platform.startswith('win'): return shutil.copy(source, path) else: - return utils.execute('cp %s %s' % (source, path)) + return utils.execute('cp', source, path) def _image_path(path): return os.path.join(FLAGS.images_path, path) +# TODO(vish): xenapi should use the glance client code directly instead +# of retrieving the image using this method. def image_url(image): if FLAGS.image_service == "nova.image.glance.GlanceImageService": return "http://%s:%s/images/%s" % (FLAGS.glance_host, diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index 4e0fd106f..61ef256f9 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -44,9 +44,8 @@ import uuid from xml.dom import minidom -from eventlet import greenthread -from eventlet import event from eventlet import tpool +from eventlet import semaphore import IPy @@ -55,8 +54,8 @@ from nova import db from nova import exception from nova import flags from nova import log as logging +#from nova import test from nova import utils -#from nova.api import context from nova.auth import manager from nova.compute import instance_types from nova.compute import power_state @@ -362,7 +361,7 @@ class LibvirtConnection(object): raise exception.APIError("resume not supported for libvirt") @exception.wrap_exception - def rescue(self, instance): + def rescue(self, instance, callback=None): self.destroy(instance, False) xml = self.to_xml(instance, rescue=True) @@ -392,7 +391,7 @@ class LibvirtConnection(object): return timer.start(interval=0.5, now=True) @exception.wrap_exception - def unrescue(self, instance): + def unrescue(self, instance, callback=None): # NOTE(vish): Because reboot destroys and recreates an instance using # the normal xml file, we can just call reboot here self.reboot(instance) @@ -438,8 +437,10 @@ class LibvirtConnection(object): if virsh_output.startswith('/dev/'): LOG.info(_("cool, it's a device")) - out, err = utils.execute("sudo dd if=%s iflag=nonblock" % - virsh_output, check_exit_code=False) + out, err = utils.execute('sudo', 'dd', + "if=%s" % virsh_output, + 'iflag=nonblock', + check_exit_code=False) return out else: return '' @@ -461,11 +462,11 @@ class LibvirtConnection(object): console_log = os.path.join(FLAGS.instances_path, instance['name'], 'console.log') - utils.execute('sudo chown %d %s' % (os.getuid(), console_log)) + utils.execute('sudo', 'chown', os.getuid(), console_log) if FLAGS.libvirt_type == 'xen': # Xen is special - virsh_output = utils.execute("virsh ttyconsole %s" % + virsh_output = utils.execute('virsh', 'ttyconsole', instance['name']) data = self._flush_xen_console(virsh_output) fpath = self._append_to_file(data, console_log) @@ -482,9 +483,10 @@ class LibvirtConnection(object): port = random.randint(int(start_port), int(end_port)) # netcat will exit with 0 only if the port is in use, # so a nonzero return value implies it is unused - cmd = 'netcat 0.0.0.0 %s -w 1 </dev/null || echo free' % (port) - stdout, stderr = utils.execute(cmd) - if stdout.strip() == 'free': + cmd = 'netcat', '0.0.0.0', port, '-w', '1' + try: + stdout, stderr = utils.execute(*cmd, process_input='') + except ProcessExecutionError: return port raise Exception(_('Unable to find an open port')) @@ -511,7 +513,10 @@ class LibvirtConnection(object): subprocess.Popen(cmd, shell=True) return {'token': token, 'host': host, 'port': port} - def _cache_image(self, fn, target, fname, cow=False, *args, **kwargs): + _image_sems = {} + + @staticmethod + def _cache_image(fn, target, fname, cow=False, *args, **kwargs): """Wrapper for a method that creates an image that caches the image. This wrapper will save the image into a common store and create a @@ -530,14 +535,21 @@ class LibvirtConnection(object): if not os.path.exists(base_dir): os.mkdir(base_dir) base = os.path.join(base_dir, fname) - if not os.path.exists(base): - fn(target=base, *args, **kwargs) + + if fname not in LibvirtConnection._image_sems: + LibvirtConnection._image_sems[fname] = semaphore.Semaphore() + with LibvirtConnection._image_sems[fname]: + if not os.path.exists(base): + fn(target=base, *args, **kwargs) + if not LibvirtConnection._image_sems[fname].locked(): + del LibvirtConnection._image_sems[fname] + if cow: - utils.execute('qemu-img create -f qcow2 -o ' - 'cluster_size=2M,backing_file=%s %s' - % (base, target)) + utils.execute('qemu-img', 'create', '-f', 'qcow2', '-o', + 'cluster_size=2M,backing_file=%s' % base, + target) else: - utils.execute('cp %s %s' % (base, target)) + utils.execute('cp', base, target) def _fetch_image(self, target, image_id, user, project, size=None): """Grab image and optionally attempt to resize it""" @@ -547,7 +559,7 @@ class LibvirtConnection(object): def _create_local(self, target, local_gb): """Create a blank image of specified size""" - utils.execute('truncate %s -s %dG' % (target, local_gb)) + utils.execute('truncate', target, '-s', "%dG" % local_gb) # TODO(vish): should we format disk by default? def _create_image(self, inst, libvirt_xml, suffix='', disk_images=None): @@ -558,7 +570,7 @@ class LibvirtConnection(object): fname + suffix) # ensure directories exist and are writable - utils.execute('mkdir -p %s' % basepath(suffix='')) + utils.execute('mkdir', '-p', basepath(suffix='')) LOG.info(_('instance %s: Creating image'), inst['name']) f = open(basepath('libvirt.xml'), 'w') @@ -578,21 +590,23 @@ class LibvirtConnection(object): 'ramdisk_id': inst['ramdisk_id']} if disk_images['kernel_id']: + fname = '%08x' % int(disk_images['kernel_id']) self._cache_image(fn=self._fetch_image, target=basepath('kernel'), - fname=disk_images['kernel_id'], + fname=fname, image_id=disk_images['kernel_id'], user=user, project=project) if disk_images['ramdisk_id']: + fname = '%08x' % int(disk_images['ramdisk_id']) self._cache_image(fn=self._fetch_image, target=basepath('ramdisk'), - fname=disk_images['ramdisk_id'], + fname=fname, image_id=disk_images['ramdisk_id'], user=user, project=project) - root_fname = disk_images['image_id'] + root_fname = '%08x' % int(disk_images['image_id']) size = FLAGS.minimum_root_size if inst['instance_type'] == 'm1.tiny' or suffix == '.rescue': size = None @@ -606,7 +620,7 @@ class LibvirtConnection(object): user=user, project=project, size=size) - type_data = instance_types.INSTANCE_TYPES[inst['instance_type']] + type_data = instance_types.get_instance_type(inst['instance_type']) if type_data['local_gb']: self._cache_image(fn=self._create_local, @@ -658,7 +672,7 @@ class LibvirtConnection(object): ' data into image %(img_id)s (%(e)s)') % locals()) if FLAGS.libvirt_type == 'uml': - utils.execute('sudo chown root %s' % basepath('disk')) + utils.execute('sudo', 'chown', 'root', basepath('disk')) def to_xml(self, instance, rescue=False): # TODO(termie): cache? @@ -667,7 +681,8 @@ class LibvirtConnection(object): instance['id']) # FIXME(vish): stick this in db instance_type = instance['instance_type'] - instance_type = instance_types.INSTANCE_TYPES[instance_type] + # instance_type = test.INSTANCE_TYPES[instance_type] + instance_type = instance_types.get_instance_type(instance_type) ip_address = db.instance_get_fixed_address(context.get_admin_context(), instance['id']) # Assume that the gateway also acts as the dhcp server. @@ -1206,10 +1221,14 @@ class NWFilterFirewall(FirewallDriver): class IptablesFirewallDriver(FirewallDriver): def __init__(self, execute=None, **kwargs): - self.execute = execute or utils.execute + from nova.network import linux_net + self.iptables = linux_net.iptables_manager self.instances = {} self.nwfilter = NWFilterFirewall(kwargs['get_connection']) + self.iptables.ipv4['filter'].add_chain('sg-fallback') + self.iptables.ipv4['filter'].add_rule('sg-fallback', '-j DROP') + def setup_basic_filtering(self, instance): """Use NWFilter from libvirt for this.""" return self.nwfilter.setup_basic_filtering(instance) @@ -1218,126 +1237,97 @@ class IptablesFirewallDriver(FirewallDriver): """No-op. Everything is done in prepare_instance_filter""" pass - def remove_instance(self, instance): + def unfilter_instance(self, instance): if instance['id'] in self.instances: del self.instances[instance['id']] + self.remove_filters_for_instance(instance) + self.iptables.apply() else: LOG.info(_('Attempted to unfilter instance %s which is not ' 'filtered'), instance['id']) - def add_instance(self, instance): + def prepare_instance_filter(self, instance): self.instances[instance['id']] = instance + self.add_filters_for_instance(instance) + self.iptables.apply() - def unfilter_instance(self, instance): - self.remove_instance(instance) - self.apply_ruleset() + def add_filters_for_instance(self, instance): + chain_name = self._instance_chain_name(instance) - def prepare_instance_filter(self, instance): - self.add_instance(instance) - self.apply_ruleset() - - def apply_ruleset(self): - current_filter, _ = self.execute('sudo iptables-save -t filter') - current_lines = current_filter.split('\n') - new_filter = self.modify_rules(current_lines, 4) - self.execute('sudo iptables-restore', - process_input='\n'.join(new_filter)) - if(FLAGS.use_ipv6): - current_filter, _ = self.execute('sudo ip6tables-save -t filter') - current_lines = current_filter.split('\n') - new_filter = self.modify_rules(current_lines, 6) - self.execute('sudo ip6tables-restore', - process_input='\n'.join(new_filter)) + self.iptables.ipv4['filter'].add_chain(chain_name) + ipv4_address = self._ip_for_instance(instance) + self.iptables.ipv4['filter'].add_rule('local', + '-d %s -j $%s' % + (ipv4_address, chain_name)) + + if FLAGS.use_ipv6: + self.iptables.ipv6['filter'].add_chain(chain_name) + ipv6_address = self._ip_for_instance_v6(instance) + self.iptables.ipv6['filter'].add_rule('local', + '-d %s -j $%s' % + (ipv6_address, + chain_name)) + + ipv4_rules, ipv6_rules = self.instance_rules(instance) + + for rule in ipv4_rules: + self.iptables.ipv4['filter'].add_rule(chain_name, rule) + + if FLAGS.use_ipv6: + for rule in ipv6_rules: + self.iptables.ipv6['filter'].add_rule(chain_name, rule) + + def remove_filters_for_instance(self, instance): + chain_name = self._instance_chain_name(instance) + + self.iptables.ipv4['filter'].remove_chain(chain_name) + if FLAGS.use_ipv6: + self.iptables.ipv6['filter'].remove_chain(chain_name) - def modify_rules(self, current_lines, ip_version=4): + def instance_rules(self, instance): ctxt = context.get_admin_context() - # Remove any trace of nova rules. - new_filter = filter(lambda l: 'nova-' not in l, current_lines) - - seen_chains = False - for rules_index in range(len(new_filter)): - if not seen_chains: - if new_filter[rules_index].startswith(':'): - seen_chains = True - elif seen_chains == 1: - if not new_filter[rules_index].startswith(':'): - break - our_chains = [':nova-fallback - [0:0]'] - our_rules = ['-A nova-fallback -j DROP'] - - our_chains += [':nova-local - [0:0]'] - our_rules += ['-A FORWARD -j nova-local'] - our_rules += ['-A OUTPUT -j nova-local'] - - security_groups = {} - # Add our chains - # First, we add instance chains and rules - for instance_id in self.instances: - instance = self.instances[instance_id] - chain_name = self._instance_chain_name(instance) - if(ip_version == 4): - ip_address = self._ip_for_instance(instance) - elif(ip_version == 6): - ip_address = self._ip_for_instance_v6(instance) - - our_chains += [':%s - [0:0]' % chain_name] - - # Jump to the per-instance chain - our_rules += ['-A nova-local -d %s -j %s' % (ip_address, - chain_name)] - - # Always drop invalid packets - our_rules += ['-A %s -m state --state ' - 'INVALID -j DROP' % (chain_name,)] - - # Allow established connections - our_rules += ['-A %s -m state --state ' - 'ESTABLISHED,RELATED -j ACCEPT' % (chain_name,)] - - # Jump to each security group chain in turn - for security_group in \ - db.security_group_get_by_instance(ctxt, - instance['id']): - security_groups[security_group['id']] = security_group - - sg_chain_name = self._security_group_chain_name( - security_group['id']) + ipv4_rules = [] + ipv6_rules = [] - our_rules += ['-A %s -j %s' % (chain_name, sg_chain_name)] - - if(ip_version == 4): - # Allow DHCP responses - dhcp_server = self._dhcp_server_for_instance(instance) - our_rules += ['-A %s -s %s -p udp --sport 67 --dport 68 ' - '-j ACCEPT ' % (chain_name, dhcp_server)] - #Allow project network traffic - if (FLAGS.allow_project_net_traffic): - cidr = self._project_cidr_for_instance(instance) - our_rules += ['-A %s -s %s -j ACCEPT' % (chain_name, cidr)] - elif(ip_version == 6): - # Allow RA responses - ra_server = self._ra_server_for_instance(instance) - if ra_server: - our_rules += ['-A %s -s %s -p icmpv6 -j ACCEPT' % - (chain_name, ra_server + "/128")] - #Allow project network traffic - if (FLAGS.allow_project_net_traffic): - cidrv6 = self._project_cidrv6_for_instance(instance) - our_rules += ['-A %s -s %s -j ACCEPT' % - (chain_name, cidrv6)] - - # If nothing matches, jump to the fallback chain - our_rules += ['-A %s -j nova-fallback' % (chain_name,)] + # Always drop invalid packets + ipv4_rules += ['-m state --state ' 'INVALID -j DROP'] + ipv6_rules += ['-m state --state ' 'INVALID -j DROP'] - # then, security group chains and rules - for security_group_id in security_groups: - chain_name = self._security_group_chain_name(security_group_id) - our_chains += [':%s - [0:0]' % chain_name] + # Allow established connections + ipv4_rules += ['-m state --state ESTABLISHED,RELATED -j ACCEPT'] + ipv6_rules += ['-m state --state ESTABLISHED,RELATED -j ACCEPT'] - rules = \ - db.security_group_rule_get_by_security_group(ctxt, - security_group_id) + dhcp_server = self._dhcp_server_for_instance(instance) + ipv4_rules += ['-s %s -p udp --sport 67 --dport 68 ' + '-j ACCEPT' % (dhcp_server,)] + + #Allow project network traffic + if FLAGS.allow_project_net_traffic: + cidr = self._project_cidr_for_instance(instance) + ipv4_rules += ['-s %s -j ACCEPT' % (cidr,)] + + # We wrap these in FLAGS.use_ipv6 because they might cause + # a DB lookup. The other ones are just list operations, so + # they're not worth the clutter. + if FLAGS.use_ipv6: + # Allow RA responses + ra_server = self._ra_server_for_instance(instance) + if ra_server: + ipv6_rules += ['-s %s/128 -p icmpv6 -j ACCEPT' % (ra_server,)] + + #Allow project network traffic + if FLAGS.allow_project_net_traffic: + cidrv6 = self._project_cidrv6_for_instance(instance) + ipv6_rules += ['-s %s -j ACCEPT' % (cidrv6,)] + + security_groups = db.security_group_get_by_instance(ctxt, + instance['id']) + + # then, security group chains and rules + for security_group in security_groups: + rules = db.security_group_rule_get_by_security_group(ctxt, + security_group['id']) for rule in rules: logging.info('%r', rule) @@ -1348,14 +1338,16 @@ class IptablesFirewallDriver(FirewallDriver): continue version = _get_ip_version(rule.cidr) - if version != ip_version: - continue + if version == 4: + rules = ipv4_rules + else: + rules = ipv6_rules protocol = rule.protocol if version == 6 and rule.protocol == 'icmp': protocol = 'icmpv6' - args = ['-A', chain_name, '-p', protocol, '-s', rule.cidr] + args = ['-p', protocol, '-s', rule.cidr] if rule.protocol in ['udp', 'tcp']: if rule.from_port == rule.to_port: @@ -1376,32 +1368,39 @@ class IptablesFirewallDriver(FirewallDriver): icmp_type_arg += '/%s' % icmp_code if icmp_type_arg: - if(ip_version == 4): + if version == 4: args += ['-m', 'icmp', '--icmp-type', icmp_type_arg] - elif(ip_version == 6): + elif version == 6: args += ['-m', 'icmp6', '--icmpv6-type', icmp_type_arg] args += ['-j ACCEPT'] - our_rules += [' '.join(args)] + rules += [' '.join(args)] + + ipv4_rules += ['-j $sg-fallback'] + ipv6_rules += ['-j $sg-fallback'] - new_filter[rules_index:rules_index] = our_rules - new_filter[rules_index:rules_index] = our_chains - logging.info('new_filter: %s', '\n'.join(new_filter)) - return new_filter + return ipv4_rules, ipv6_rules def refresh_security_group_members(self, security_group): pass def refresh_security_group_rules(self, security_group): - self.apply_ruleset() + for instance in self.instances.values(): + # We use the semaphore to make sure noone applies the rule set + # after we've yanked the existing rules but before we've put in + # the new ones. + with self.iptables.semaphore: + self.remove_filters_for_instance(instance) + self.add_filters_for_instance(instance) + self.iptables.apply() def _security_group_chain_name(self, security_group_id): return 'nova-sg-%s' % (security_group_id,) def _instance_chain_name(self, instance): - return 'nova-inst-%s' % (instance['id'],) + return 'inst-%s' % (instance['id'],) def _ip_for_instance(self, instance): return db.instance_get_fixed_address(context.get_admin_context(), diff --git a/nova/virt/xenapi/fake.py b/nova/virt/xenapi/fake.py index e8352771c..ba12d4d3a 100644 --- a/nova/virt/xenapi/fake.py +++ b/nova/virt/xenapi/fake.py @@ -286,6 +286,13 @@ class SessionBase(object): rec['currently_attached'] = False rec['device'] = '' + def host_compute_free_memory(self, _1, ref): + #Always return 12GB available + return 12 * 1024 * 1024 * 1024 + + def host_call_plugin(*args): + return 'herp' + def xenapi_request(self, methodname, params): if methodname.startswith('login'): self._login(methodname, params) @@ -397,7 +404,7 @@ class SessionBase(object): field in _db_content[cls][ref]): return _db_content[cls][ref][field] - LOG.debuug(_('Raising NotImplemented')) + LOG.debug(_('Raising NotImplemented')) raise NotImplementedError( _('xenapi.fake does not have an implementation for %s or it has ' 'been called with the wrong number of arguments') % name) diff --git a/nova/virt/xenapi/vm_utils.py b/nova/virt/xenapi/vm_utils.py index 4bbd522c1..4e6c71446 100644 --- a/nova/virt/xenapi/vm_utils.py +++ b/nova/virt/xenapi/vm_utils.py @@ -24,6 +24,7 @@ import pickle import re import time import urllib +import uuid from xml.dom import minidom from eventlet import event @@ -63,11 +64,14 @@ class ImageType: 0 - kernel/ramdisk image (goes on dom0's filesystem) 1 - disk image (local SR, partitioned by objectstore plugin) 2 - raw disk image (local SR, NOT partitioned by plugin) + 3 - vhd disk image (local SR, NOT inspected by XS, PV assumed for + linux, HVM assumed for Windows) """ KERNEL_RAMDISK = 0 DISK = 1 DISK_RAW = 2 + DISK_VHD = 3 class VMHelper(HelperBase): @@ -82,7 +86,8 @@ class VMHelper(HelperBase): the pv_kernel flag indicates whether the guest is HVM or PV """ - instance_type = instance_types.INSTANCE_TYPES[instance.instance_type] + instance_type = instance_types.\ + get_instance_type(instance.instance_type) mem = str(long(instance_type['memory_mb']) * 1024 * 1024) vcpus = str(instance_type['vcpus']) rec = { @@ -139,6 +144,17 @@ class VMHelper(HelperBase): return vm_ref @classmethod + def ensure_free_mem(cls, session, instance): + instance_type = instance_types.get_instance_type( + instance.instance_type) + mem = long(instance_type['memory_mb']) * 1024 * 1024 + #get free memory from host + host = session.get_xenapi_host() + host_free_mem = long(session.get_xenapi().host. + compute_free_memory(host)) + return host_free_mem >= mem + + @classmethod def create_vbd(cls, session, vm_ref, vdi_ref, userdevice, bootable): """Create a VBD record. Returns a Deferred that gives the new VBD reference.""" @@ -191,19 +207,17 @@ class VMHelper(HelperBase): """Destroy VBD from host database""" try: task = session.call_xenapi('Async.VBD.destroy', vbd_ref) - #FIXME(armando): find a solution to missing instance_id - #with Josh Kearney - session.wait_for_task(0, task) + session.wait_for_task(task) except cls.XenAPI.Failure, exc: LOG.exception(exc) raise StorageError(_('Unable to destroy VBD %s') % vbd_ref) @classmethod - def create_vif(cls, session, vm_ref, network_ref, mac_address): + def create_vif(cls, session, vm_ref, network_ref, mac_address, dev="0"): """Create a VIF record. Returns a Deferred that gives the new VIF reference.""" vif_rec = {} - vif_rec['device'] = '0' + vif_rec['device'] = dev vif_rec['network'] = network_ref vif_rec['VM'] = vm_ref vif_rec['MAC'] = mac_address @@ -239,24 +253,40 @@ class VMHelper(HelperBase): return vdi_ref @classmethod + def get_vdi_for_vm_safely(cls, session, vm_ref): + vdi_refs = VMHelper.lookup_vm_vdis(session, vm_ref) + if vdi_refs is None: + raise Exception(_("No VDIs found for VM %s") % vm_ref) + else: + num_vdis = len(vdi_refs) + if num_vdis != 1: + raise Exception( + _("Unexpected number of VDIs (%(num_vdis)s) found" + " for VM %(vm_ref)s") % locals()) + + vdi_ref = vdi_refs[0] + vdi_rec = session.get_xenapi().VDI.get_record(vdi_ref) + return vdi_ref, vdi_rec + + @classmethod def create_snapshot(cls, session, instance_id, vm_ref, label): - """ Creates Snapshot (Template) VM, Snapshot VBD, Snapshot VDI, - Snapshot VHD - """ + """Creates Snapshot (Template) VM, Snapshot VBD, Snapshot VDI, + Snapshot VHD""" #TODO(sirp): Add quiesce and VSS locking support when Windows support # is added LOG.debug(_("Snapshotting VM %(vm_ref)s with label '%(label)s'...") % locals()) - vm_vdi_ref, vm_vdi_rec = get_vdi_for_vm_safely(session, vm_ref) + vm_vdi_ref, vm_vdi_rec = cls.get_vdi_for_vm_safely(session, vm_ref) vm_vdi_uuid = vm_vdi_rec["uuid"] sr_ref = vm_vdi_rec["SR"] original_parent_uuid = get_vhd_parent_uuid(session, vm_vdi_ref) task = session.call_xenapi('Async.VM.snapshot', vm_ref, label) - template_vm_ref = session.wait_for_task(instance_id, task) - template_vdi_rec = get_vdi_for_vm_safely(session, template_vm_ref)[1] + template_vm_ref = session.wait_for_task(task, instance_id) + template_vdi_rec = cls.get_vdi_for_vm_safely(session, + template_vm_ref)[1] template_vdi_uuid = template_vdi_rec["uuid"] LOG.debug(_('Created snapshot %(template_vm_ref)s from' @@ -266,29 +296,53 @@ class VMHelper(HelperBase): session, instance_id, sr_ref, vm_vdi_ref, original_parent_uuid) #TODO(sirp): we need to assert only one parent, not parents two deep - return template_vm_ref, [template_vdi_uuid, parent_uuid] + template_vdi_uuids = {'image': parent_uuid, + 'snap': template_vdi_uuid} + return template_vm_ref, template_vdi_uuids + + @classmethod + def get_sr(cls, session, sr_label='slices'): + """Finds the SR named by the given name label and returns + the UUID""" + return session.call_xenapi('SR.get_by_name_label', sr_label)[0] + + @classmethod + def get_sr_path(cls, session): + """Return the path to our storage repository + + This is used when we're dealing with VHDs directly, either by taking + snapshots or by restoring an image in the DISK_VHD format. + """ + sr_ref = safe_find_sr(session) + sr_rec = session.get_xenapi().SR.get_record(sr_ref) + sr_uuid = sr_rec["uuid"] + return os.path.join(FLAGS.xenapi_sr_base_path, sr_uuid) @classmethod def upload_image(cls, session, instance_id, vdi_uuids, image_id): """ Requests that the Glance plugin bundle the specified VDIs and push them into Glance using the specified human-friendly name. """ + # NOTE(sirp): Currently we only support uploading images as VHD, there + # is no RAW equivalent (yet) logging.debug(_("Asking xapi to upload %(vdi_uuids)s as" " ID %(image_id)s") % locals()) params = {'vdi_uuids': vdi_uuids, 'image_id': image_id, 'glance_host': FLAGS.glance_host, - 'glance_port': FLAGS.glance_port} + 'glance_port': FLAGS.glance_port, + 'sr_path': cls.get_sr_path(session)} kwargs = {'params': pickle.dumps(params)} - task = session.async_call_plugin('glance', 'put_vdis', kwargs) - session.wait_for_task(instance_id, task) + task = session.async_call_plugin('glance', 'upload_vhd', kwargs) + session.wait_for_task(task, instance_id) @classmethod - def fetch_image(cls, session, instance_id, image, user, project, type): + def fetch_image(cls, session, instance_id, image, user, project, + image_type): """ - type is interpreted as an ImageType instance + image_type is interpreted as an ImageType instance Related flags: xenapi_image_service = ['glance', 'objectstore'] glance_address = 'address for glance services' @@ -298,35 +352,80 @@ class VMHelper(HelperBase): if FLAGS.xenapi_image_service == 'glance': return cls._fetch_image_glance(session, instance_id, image, - access, type) + access, image_type) else: return cls._fetch_image_objectstore(session, instance_id, image, - access, user.secret, type) + access, user.secret, + image_type) + + @classmethod + def _fetch_image_glance_vhd(cls, session, instance_id, image, access, + image_type): + LOG.debug(_("Asking xapi to fetch vhd image %(image)s") + % locals()) + + sr_ref = safe_find_sr(session) + + # NOTE(sirp): The Glance plugin runs under Python 2.4 which does not + # have the `uuid` module. To work around this, we generate the uuids + # here (under Python 2.6+) and pass them as arguments + uuid_stack = [str(uuid.uuid4()) for i in xrange(2)] + + params = {'image_id': image, + 'glance_host': FLAGS.glance_host, + 'glance_port': FLAGS.glance_port, + 'uuid_stack': uuid_stack, + 'sr_path': cls.get_sr_path(session)} + + kwargs = {'params': pickle.dumps(params)} + task = session.async_call_plugin('glance', 'download_vhd', kwargs) + vdi_uuid = session.wait_for_task(task, instance_id) + + cls.scan_sr(session, instance_id, sr_ref) + + # Set the name-label to ease debugging + vdi_ref = session.get_xenapi().VDI.get_by_uuid(vdi_uuid) + name_label = get_name_label_for_image(image) + session.get_xenapi().VDI.set_name_label(vdi_ref, name_label) + + LOG.debug(_("xapi 'download_vhd' returned VDI UUID %(vdi_uuid)s") + % locals()) + return vdi_uuid @classmethod - def _fetch_image_glance(cls, session, instance_id, image, access, type): - sr = find_sr(session) - if sr is None: - raise exception.NotFound('Cannot find SR to write VDI to') + def _fetch_image_glance_disk(cls, session, instance_id, image, access, + image_type): + """Fetch the image from Glance + + NOTE: + Unlike _fetch_image_glance_vhd, this method does not use the Glance + plugin; instead, it streams the disks through domU to the VDI + directly. - c = glance.client.Client(FLAGS.glance_host, FLAGS.glance_port) + """ + # FIXME(sirp): Since the Glance plugin seems to be required for the + # VHD disk, it may be worth using the plugin for both VHD and RAW and + # DISK restores + sr_ref = safe_find_sr(session) - meta, image_file = c.get_image(image) + client = glance.client.Client(FLAGS.glance_host, FLAGS.glance_port) + meta, image_file = client.get_image(image) virtual_size = int(meta['size']) vdi_size = virtual_size LOG.debug(_("Size for image %(image)s:%(virtual_size)d") % locals()) - if type == ImageType.DISK: + + if image_type == ImageType.DISK: # Make room for MBR. vdi_size += MBR_SIZE_BYTES - vdi = cls.create_vdi(session, sr, _('Glance image %s') % image, - vdi_size, False) + name_label = get_name_label_for_image(image) + vdi = cls.create_vdi(session, sr_ref, name_label, vdi_size, False) with_vdi_attached_here(session, vdi, False, lambda dev: - _stream_disk(dev, type, + _stream_disk(dev, image_type, virtual_size, image_file)) - if (type == ImageType.KERNEL_RAMDISK): + if image_type == ImageType.KERNEL_RAMDISK: #we need to invoke a plugin for copying VDI's #content into proper path LOG.debug(_("Copying VDI %s to /boot/guest on dom0"), vdi) @@ -336,7 +435,7 @@ class VMHelper(HelperBase): #let the plugin copy the correct number of bytes args['image-size'] = str(vdi_size) task = session.async_call_plugin('glance', fn, args) - filename = session.wait_for_task(instance_id, task) + filename = session.wait_for_task(task, instance_id) #remove the VDI as it is not needed anymore session.get_xenapi().VDI.destroy(vdi) LOG.debug(_("Kernel/Ramdisk VDI %s destroyed"), vdi) @@ -345,27 +444,99 @@ class VMHelper(HelperBase): return session.get_xenapi().VDI.get_uuid(vdi) @classmethod + def determine_disk_image_type(cls, instance): + """Disk Image Types are used to determine where the kernel will reside + within an image. To figure out which type we're dealing with, we use + the following rules: + + 1. If we're using Glance, we can use the image_type field to + determine the image_type + + 2. If we're not using Glance, then we need to deduce this based on + whether a kernel_id is specified. + """ + def log_disk_format(image_type): + pretty_format = {ImageType.KERNEL_RAMDISK: 'KERNEL_RAMDISK', + ImageType.DISK: 'DISK', + ImageType.DISK_RAW: 'DISK_RAW', + ImageType.DISK_VHD: 'DISK_VHD'} + disk_format = pretty_format[image_type] + image_id = instance.image_id + instance_id = instance.id + LOG.debug(_("Detected %(disk_format)s format for image " + "%(image_id)s, instance %(instance_id)s") % locals()) + + def determine_from_glance(): + glance_disk_format2nova_type = { + 'ami': ImageType.DISK, + 'aki': ImageType.KERNEL_RAMDISK, + 'ari': ImageType.KERNEL_RAMDISK, + 'raw': ImageType.DISK_RAW, + 'vhd': ImageType.DISK_VHD} + client = glance.client.Client(FLAGS.glance_host, FLAGS.glance_port) + meta = client.get_image_meta(instance.image_id) + disk_format = meta['disk_format'] + try: + return glance_disk_format2nova_type[disk_format] + except KeyError: + raise exception.NotFound( + _("Unrecognized disk_format '%(disk_format)s'") + % locals()) + + def determine_from_instance(): + if instance.kernel_id: + return ImageType.DISK + else: + return ImageType.DISK_RAW + + # FIXME(sirp): can we unify the ImageService and xenapi_image_service + # abstractions? + if FLAGS.xenapi_image_service == 'glance': + image_type = determine_from_glance() + else: + image_type = determine_from_instance() + + log_disk_format(image_type) + return image_type + + @classmethod + def _fetch_image_glance(cls, session, instance_id, image, access, + image_type): + if image_type == ImageType.DISK_VHD: + return cls._fetch_image_glance_vhd( + session, instance_id, image, access, image_type) + else: + return cls._fetch_image_glance_disk( + session, instance_id, image, access, image_type) + + @classmethod def _fetch_image_objectstore(cls, session, instance_id, image, access, - secret, type): + secret, image_type): url = images.image_url(image) LOG.debug(_("Asking xapi to fetch %(url)s as %(access)s") % locals()) - fn = (type != ImageType.KERNEL_RAMDISK) and 'get_vdi' or 'get_kernel' + if image_type == ImageType.KERNEL_RAMDISK: + fn = 'get_kernel' + else: + fn = 'get_vdi' args = {} args['src_url'] = url args['username'] = access args['password'] = secret args['add_partition'] = 'false' args['raw'] = 'false' - if type != ImageType.KERNEL_RAMDISK: + if image_type != ImageType.KERNEL_RAMDISK: args['add_partition'] = 'true' - if type == ImageType.DISK_RAW: + if image_type == ImageType.DISK_RAW: args['raw'] = 'true' task = session.async_call_plugin('objectstore', fn, args) - uuid = session.wait_for_task(instance_id, task) + uuid = session.wait_for_task(task, instance_id) return uuid @classmethod def lookup_image(cls, session, instance_id, vdi_ref): + """ + Determine if VDI is using a PV kernel + """ if FLAGS.xenapi_image_service == 'glance': return cls._lookup_image_glance(session, vdi_ref) else: @@ -378,31 +549,19 @@ class VMHelper(HelperBase): args = {} args['vdi-ref'] = vdi_ref task = session.async_call_plugin('objectstore', fn, args) - pv_str = session.wait_for_task(instance_id, task) + pv_str = session.wait_for_task(task, instance_id) pv = None if pv_str.lower() == 'true': pv = True elif pv_str.lower() == 'false': pv = False - LOG.debug(_("PV Kernel in VDI:%d"), pv) + LOG.debug(_("PV Kernel in VDI:%s"), pv) return pv @classmethod def _lookup_image_glance(cls, session, vdi_ref): LOG.debug(_("Looking up vdi %s for PV kernel"), vdi_ref) - - def is_vdi_pv(dev): - LOG.debug(_("Running pygrub against %s"), dev) - output = os.popen('pygrub -qn /dev/%s' % dev) - for line in output.readlines(): - #try to find kernel string - m = re.search('(?<=kernel:)/.*(?:>)', line) - if m and m.group(0).find('xen') != -1: - LOG.debug(_("Found Xen kernel %s") % m.group(0)) - return True - LOG.debug(_("No Xen kernel found. Booting HVM.")) - return False - return with_vdi_attached_here(session, vdi_ref, True, is_vdi_pv) + return with_vdi_attached_here(session, vdi_ref, True, _is_vdi_pv) @classmethod def lookup(cls, session, i): @@ -440,6 +599,14 @@ class VMHelper(HelperBase): return None @classmethod + def lookup_kernel_ramdisk(cls, session, vm): + vm_rec = session.get_xenapi().VM.get_record(vm) + if 'PV_kernel' in vm_rec and 'PV_ramdisk' in vm_rec: + return (vm_rec['PV_kernel'], vm_rec['PV_ramdisk']) + else: + return (None, None) + + @classmethod def compile_info(cls, record): """Fill record with VM status information""" LOG.info(_("(VM_UTILS) xenserver vm state -> |%s|"), @@ -478,6 +645,21 @@ class VMHelper(HelperBase): except cls.XenAPI.Failure as e: return {"Unable to retrieve diagnostics": e} + @classmethod + def scan_sr(cls, session, instance_id=None, sr_ref=None): + """Scans the SR specified by sr_ref""" + if sr_ref: + LOG.debug(_("Re-scanning SR %s"), sr_ref) + task = session.call_xenapi('Async.SR.scan', sr_ref) + session.wait_for_task(task, instance_id) + + @classmethod + def scan_default_sr(cls, session): + """Looks for the system default SR and triggers a re-scan""" + #FIXME(sirp/mdietz): refactor scan_default_sr in there + sr_ref = cls.get_sr(session) + session.call_xenapi('SR.scan', sr_ref) + def get_rrd(host, uuid): """Return the VM RRD XML as a string""" @@ -520,12 +702,6 @@ def get_vhd_parent_uuid(session, vdi_ref): return None -def scan_sr(session, instance_id, sr_ref): - LOG.debug(_("Re-scanning SR %s"), sr_ref) - task = session.call_xenapi('Async.SR.scan', sr_ref) - session.wait_for_task(instance_id, task) - - def wait_for_vhd_coalesce(session, instance_id, sr_ref, vdi_ref, original_parent_uuid): """ Spin until the parent VHD is coalesced into its parent VHD @@ -550,7 +726,7 @@ def wait_for_vhd_coalesce(session, instance_id, sr_ref, vdi_ref, " %(max_attempts)d), giving up...") % locals()) raise exception.Error(msg) - scan_sr(session, instance_id, sr_ref) + VMHelper.scan_sr(session, instance_id, sr_ref) parent_uuid = get_vhd_parent_uuid(session, vdi_ref) if original_parent_uuid and (parent_uuid != original_parent_uuid): LOG.debug(_("Parent %(parent_uuid)s doesn't match original parent" @@ -581,7 +757,18 @@ def get_vdi_for_vm_safely(session, vm_ref): return vdi_ref, vdi_rec +def safe_find_sr(session): + """Same as find_sr except raises a NotFound exception if SR cannot be + determined + """ + sr_ref = find_sr(session) + if sr_ref is None: + raise exception.NotFound(_('Cannot find SR to read/write VDI')) + return sr_ref + + def find_sr(session): + """Return the storage repository to hold VM images""" host = session.get_xenapi_host() srs = session.get_xenapi().SR.get_all() for sr in srs: @@ -696,9 +883,22 @@ def get_this_vm_ref(session): return session.get_xenapi().VM.get_by_uuid(get_this_vm_uuid()) -def _stream_disk(dev, type, virtual_size, image_file): +def _is_vdi_pv(dev): + LOG.debug(_("Running pygrub against %s"), dev) + output = os.popen('pygrub -qn /dev/%s' % dev) + for line in output.readlines(): + #try to find kernel string + m = re.search('(?<=kernel:)/.*(?:>)', line) + if m and m.group(0).find('xen') != -1: + LOG.debug(_("Found Xen kernel %s") % m.group(0)) + return True + LOG.debug(_("No Xen kernel found. Booting HVM.")) + return False + + +def _stream_disk(dev, image_type, virtual_size, image_file): offset = 0 - if type == ImageType.DISK: + if image_type == ImageType.DISK: offset = MBR_SIZE_BYTES _write_partition(virtual_size, dev) @@ -717,13 +917,17 @@ def _write_partition(virtual_size, dev): LOG.debug(_('Writing partition table %(primary_first)d %(primary_last)d' ' to %(dest)s...') % locals()) - def execute(cmd, process_input=None, check_exit_code=True): - return utils.execute(cmd=cmd, - process_input=process_input, - check_exit_code=check_exit_code) + def execute(*cmd, **kwargs): + return utils.execute(*cmd, **kwargs) - execute('parted --script %s mklabel msdos' % dest) - execute('parted --script %s mkpart primary %ds %ds' % - (dest, primary_first, primary_last)) + execute('parted', '--script', dest, 'mklabel', 'msdos') + execute('parted', '--script', dest, 'mkpart', 'primary', + '%ds' % primary_first, + '%ds' % primary_last) LOG.debug(_('Writing partition table %s done.'), dest) + + +def get_name_label_for_image(image): + # TODO(sirp): This should eventually be the URI for the Glance image + return _('Glance image %s') % image diff --git a/nova/virt/xenapi/vmops.py b/nova/virt/xenapi/vmops.py index e84ce20c4..562ecd4d5 100644 --- a/nova/virt/xenapi/vmops.py +++ b/nova/virt/xenapi/vmops.py @@ -22,6 +22,7 @@ Management class for VM-related functions (spawn, reboot, etc). import json import M2Crypto import os +import pickle import subprocess import tempfile import uuid @@ -49,6 +50,7 @@ class VMOps(object): def __init__(self, session): self.XenAPI = session.get_imported_xenapi() self._session = session + VMHelper.XenAPI = self.XenAPI def list_instances(self): @@ -60,112 +62,185 @@ class VMOps(object): vms.append(rec["name_label"]) return vms + def _start(self, instance, vm_ref=None): + """Power on a VM instance""" + if not vm_ref: + vm_ref = VMHelper.lookup(self._session, instance.name) + if vm_ref is None: + raise exception(_('Attempted to power on non-existent instance' + ' bad instance id %s') % instance.id) + LOG.debug(_("Starting instance %s"), instance.name) + self._session.call_xenapi('VM.start', vm_ref, False, False) + + def create_disk(self, instance): + user = AuthManager().get_user(instance.user_id) + project = AuthManager().get_project(instance.project_id) + disk_image_type = VMHelper.determine_disk_image_type(instance) + vdi_uuid = VMHelper.fetch_image(self._session, instance.id, + instance.image_id, user, project, disk_image_type) + return vdi_uuid + def spawn(self, instance): + vdi_uuid = self.create_disk(instance) + self._spawn_with_disk(instance, vdi_uuid=vdi_uuid) + + def _spawn_with_disk(self, instance, vdi_uuid): """Create VM instance""" - vm = VMHelper.lookup(self._session, instance.name) + instance_name = instance.name + vm = VMHelper.lookup(self._session, instance_name) if vm is not None: raise exception.Duplicate(_('Attempted to create' - ' non-unique name %s') % instance.name) - - bridge = db.network_get_by_instance(context.get_admin_context(), - instance['id'])['bridge'] - network_ref = \ - NetworkHelper.find_network_with_bridge(self._session, bridge) + ' non-unique name %s') % instance_name) + + #ensure enough free memory is available + if not VMHelper.ensure_free_mem(self._session, instance): + LOG.exception(_('instance %(instance_name)s: not enough free ' + 'memory') % locals()) + db.instance_set_state(context.get_admin_context(), + instance['id'], + power_state.SHUTDOWN) + return user = AuthManager().get_user(instance.user_id) project = AuthManager().get_project(instance.project_id) - #if kernel is not present we must download a raw disk - if instance.kernel_id: - disk_image_type = ImageType.DISK - else: - disk_image_type = ImageType.DISK_RAW - vdi_uuid = VMHelper.fetch_image(self._session, instance.id, - instance.image_id, user, project, disk_image_type) + + kernel = ramdisk = pv_kernel = None + + # Are we building from a pre-existing disk? vdi_ref = self._session.call_xenapi('VDI.get_by_uuid', vdi_uuid) - #Have a look at the VDI and see if it has a PV kernel - pv_kernel = False - if not instance.kernel_id: + + disk_image_type = VMHelper.determine_disk_image_type(instance) + if disk_image_type == ImageType.DISK_RAW: + # Have a look at the VDI and see if it has a PV kernel pv_kernel = VMHelper.lookup_image(self._session, instance.id, vdi_ref) - kernel = None + elif disk_image_type == ImageType.DISK_VHD: + # TODO(sirp): Assuming PV for now; this will need to be + # configurable as Windows will use HVM. + pv_kernel = True + if instance.kernel_id: kernel = VMHelper.fetch_image(self._session, instance.id, instance.kernel_id, user, project, ImageType.KERNEL_RAMDISK) - ramdisk = None + if instance.ramdisk_id: ramdisk = VMHelper.fetch_image(self._session, instance.id, instance.ramdisk_id, user, project, ImageType.KERNEL_RAMDISK) + vm_ref = VMHelper.create_vm(self._session, instance, kernel, ramdisk, pv_kernel) - VMHelper.create_vbd(self._session, vm_ref, vdi_ref, 0, True) + VMHelper.create_vbd(session=self._session, vm_ref=vm_ref, + vdi_ref=vdi_ref, userdevice=0, bootable=True) + + # inject_network_info and create vifs + networks = self.inject_network_info(instance) + self.create_vifs(instance, networks) - if network_ref: - VMHelper.create_vif(self._session, vm_ref, - network_ref, instance.mac_address) LOG.debug(_('Starting VM %s...'), vm_ref) - self._session.call_xenapi('VM.start', vm_ref, False, False) - instance_name = instance.name + self._start(instance, vm_ref) LOG.info(_('Spawning VM %(instance_name)s created %(vm_ref)s.') - % locals()) - + % locals()) + + def _inject_onset_files(): + onset_files = instance.onset_files + if onset_files: + # Check if this is a JSON-encoded string and convert if needed. + if isinstance(onset_files, basestring): + try: + onset_files = json.loads(onset_files) + except ValueError: + LOG.exception(_("Invalid value for onset_files: '%s'") + % onset_files) + onset_files = [] + # Inject any files, if specified + for path, contents in instance.onset_files: + LOG.debug(_("Injecting file path: '%s'") % path) + self.inject_file(instance, path, contents) # NOTE(armando): Do we really need to do this in virt? + # NOTE(tr3buchet): not sure but wherever we do it, we need to call + # reset_network afterwards timer = utils.LoopingCall(f=None) def _wait_for_boot(): try: - state = self.get_info(instance['name'])['state'] + state = self.get_info(instance_name)['state'] db.instance_set_state(context.get_admin_context(), instance['id'], state) if state == power_state.RUNNING: - LOG.debug(_('Instance %s: booted'), instance['name']) + LOG.debug(_('Instance %s: booted'), instance_name) timer.stop() + _inject_onset_files() + return True except Exception, exc: LOG.warn(exc) LOG.exception(_('instance %s: failed to boot'), - instance['name']) + instance_name) db.instance_set_state(context.get_admin_context(), instance['id'], power_state.SHUTDOWN) timer.stop() + return False timer.f = _wait_for_boot + + # call to reset network to configure network from xenstore + self.reset_network(instance) + return timer.start(interval=0.5, now=True) def _get_vm_opaque_ref(self, instance_or_vm): """Refactored out the common code of many methods that receive either a vm name or a vm instance, and want a vm instance in return. """ - vm = None - try: - if instance_or_vm.startswith("OpaqueRef:"): - # Got passed an opaque ref; return it + # if instance_or_vm is a string it must be opaque ref or instance name + if isinstance(instance_or_vm, basestring): + obj = None + try: + # check for opaque ref + obj = self._session.get_xenapi().VM.get_record(instance_or_vm) return instance_or_vm - else: - # Must be the instance name + except self.XenAPI.Failure: + # wasn't an opaque ref, must be an instance name + instance_name = instance_or_vm + + # if instance_or_vm is an int/long it must be instance id + elif isinstance(instance_or_vm, (int, long)): + ctx = context.get_admin_context() + try: + instance_obj = db.instance_get(ctx, instance_or_vm) + instance_name = instance_obj.name + except exception.NotFound: + # The unit tests screw this up, as they use an integer for + # the vm name. I'd fix that up, but that's a matter for + # another bug report. So for now, just try with the passed + # value instance_name = instance_or_vm - except (AttributeError, KeyError): - # Note the the KeyError will only happen with fakes.py - # Not a string; must be an ID or a vm instance - if isinstance(instance_or_vm, (int, long)): - ctx = context.get_admin_context() - try: - instance_obj = db.instance_get(ctx, instance_or_vm) - instance_name = instance_obj.name - except exception.NotFound: - # The unit tests screw this up, as they use an integer for - # the vm name. I'd fix that up, but that's a matter for - # another bug report. So for now, just try with the passed - # value - instance_name = instance_or_vm - else: - instance_name = instance_or_vm.name - vm = VMHelper.lookup(self._session, instance_name) - if vm is None: - raise Exception(_('Instance not present %s') % instance_name) - return vm + + # otherwise instance_or_vm is an instance object + else: + instance_name = instance_or_vm.name + vm_ref = VMHelper.lookup(self._session, instance_name) + if vm_ref is None: + raise exception.NotFound( + _('Instance not present %s') % instance_name) + return vm_ref + + def _acquire_bootlock(self, vm): + """Prevent an instance from booting""" + self._session.call_xenapi( + "VM.set_blocked_operations", + vm, + {"start": ""}) + + def _release_bootlock(self, vm): + """Allow an instance to boot""" + self._session.call_xenapi( + "VM.remove_from_blocked_operations", + vm, + "start") def snapshot(self, instance, image_id): - """ Create snapshot from a running VM instance + """Create snapshot from a running VM instance :param instance: instance to be snapshotted :param image_id: id of image to upload to @@ -186,7 +261,20 @@ class VMOps(object): that will bundle the VHDs together and then push the bundle into Glance. """ + template_vm_ref = None + try: + template_vm_ref, template_vdi_uuids = self._get_snapshot(instance) + # call plugin to ship snapshot off to glance + VMHelper.upload_image( + self._session, instance.id, template_vdi_uuids, image_id) + finally: + if template_vm_ref: + self._destroy(instance, template_vm_ref, + shutdown=False, destroy_kernel_ramdisk=False) + logging.debug(_("Finished snapshot and upload for VM %s"), instance) + + def _get_snapshot(self, instance): #TODO(sirp): Add quiesce and VSS locking support when Windows support # is added @@ -197,25 +285,95 @@ class VMOps(object): try: template_vm_ref, template_vdi_uuids = VMHelper.create_snapshot( self._session, instance.id, vm_ref, label) + return template_vm_ref, template_vdi_uuids except self.XenAPI.Failure, exc: logging.error(_("Unable to Snapshot %(vm_ref)s: %(exc)s") % locals()) return + def migrate_disk_and_power_off(self, instance, dest): + """Copies a VHD from one host machine to another + + :param instance: the instance that owns the VHD in question + :param dest: the destination host machine + :param disk_type: values are 'primary' or 'cow' + """ + vm_ref = VMHelper.lookup(self._session, instance.name) + + # The primary VDI becomes the COW after the snapshot, and we can + # identify it via the VBD. The base copy is the parent_uuid returned + # from the snapshot creation + + base_copy_uuid = cow_uuid = None + template_vdi_uuids = template_vm_ref = None try: - # call plugin to ship snapshot off to glance - VMHelper.upload_image( - self._session, instance.id, template_vdi_uuids, image_id) + # transfer the base copy + template_vm_ref, template_vdi_uuids = self._get_snapshot(instance) + base_copy_uuid = template_vdi_uuids[1] + vdi_ref, vm_vdi_rec = \ + VMHelper.get_vdi_for_vm_safely(self._session, vm_ref) + cow_uuid = vm_vdi_rec['uuid'] + + params = {'host': dest, + 'vdi_uuid': base_copy_uuid, + 'instance_id': instance.id, + 'sr_path': VMHelper.get_sr_path(self._session)} + + task = self._session.async_call_plugin('migration', 'transfer_vhd', + {'params': pickle.dumps(params)}) + self._session.wait_for_task(task, instance.id) + + # Now power down the instance and transfer the COW VHD + self._shutdown(instance, vm_ref, method='clean') + + params = {'host': dest, + 'vdi_uuid': cow_uuid, + 'instance_id': instance.id, + 'sr_path': VMHelper.get_sr_path(self._session), } + + task = self._session.async_call_plugin('migration', 'transfer_vhd', + {'params': pickle.dumps(params)}) + self._session.wait_for_task(task, instance.id) + finally: - self._destroy(instance, template_vm_ref, shutdown=False) + if template_vm_ref: + self._destroy(instance, template_vm_ref, + shutdown=False, destroy_kernel_ramdisk=False) - logging.debug(_("Finished snapshot and upload for VM %s"), instance) + # TODO(mdietz): we could also consider renaming these to something + # sensible so we don't need to blindly pass around dictionaries + return {'base_copy': base_copy_uuid, 'cow': cow_uuid} + + def attach_disk(self, instance, base_copy_uuid, cow_uuid): + """Links the base copy VHD to the COW via the XAPI plugin""" + vm_ref = VMHelper.lookup(self._session, instance.name) + new_base_copy_uuid = str(uuid.uuid4()) + new_cow_uuid = str(uuid.uuid4()) + params = {'instance_id': instance.id, + 'old_base_copy_uuid': base_copy_uuid, + 'old_cow_uuid': cow_uuid, + 'new_base_copy_uuid': new_base_copy_uuid, + 'new_cow_uuid': new_cow_uuid, + 'sr_path': VMHelper.get_sr_path(self._session), } + + task = self._session.async_call_plugin('migration', + 'move_vhds_into_sr', {'params': pickle.dumps(params)}) + self._session.wait_for_task(task, instance.id) + + # Now we rescan the SR so we find the VHDs + VMHelper.scan_default_sr(self._session) + + return new_cow_uuid + + def resize(self, instance, flavor): + """Resize a running instance by changing it's RAM and disk size """ + raise NotImplementedError() def reboot(self, instance): """Reboot VM instance""" vm = self._get_vm_opaque_ref(instance) task = self._session.call_xenapi('Async.VM.clean_reboot', vm) - self._session.wait_for_task(instance.id, task) + self._session.wait_for_task(task, instance.id) def set_admin_password(self, instance, new_pass): """Set the root/admin password on the VM instance. This is done via @@ -255,22 +413,58 @@ class VMOps(object): raise RuntimeError(resp_dict['message']) return resp_dict['message'] - def _shutdown(self, instance, vm): - """Shutdown an instance """ + def inject_file(self, instance, b64_path, b64_contents): + """Write a file to the VM instance. The path to which it is to be + written and the contents of the file need to be supplied; both should + be base64-encoded to prevent errors with non-ASCII characters being + transmitted. If the agent does not support file injection, or the user + has disabled it, a NotImplementedError will be raised. + """ + # Files/paths *should* be base64-encoded at this point, but + # double-check to make sure. + b64_path = utils.ensure_b64_encoding(b64_path) + b64_contents = utils.ensure_b64_encoding(b64_contents) + + # Need to uniquely identify this request. + transaction_id = str(uuid.uuid4()) + args = {'id': transaction_id, 'b64_path': b64_path, + 'b64_contents': b64_contents} + # If the agent doesn't support file injection, a NotImplementedError + # will be raised with the appropriate message. + resp = self._make_agent_call('inject_file', instance, '', args) + resp_dict = json.loads(resp) + if resp_dict['returncode'] != '0': + # There was some other sort of error; the message will contain + # a description of the error. + raise RuntimeError(resp_dict['message']) + return resp_dict['message'] + + def _shutdown(self, instance, vm, hard=True): + """Shutdown an instance""" state = self.get_info(instance['name'])['state'] if state == power_state.SHUTDOWN: LOG.warn(_("VM %(vm)s already halted, skipping shutdown...") % locals()) return + instance_id = instance.id + LOG.debug(_("Shutting down VM for Instance %(instance_id)s") + % locals()) try: - task = self._session.call_xenapi('Async.VM.hard_shutdown', vm) - self._session.wait_for_task(instance.id, task) + task = None + if hard: + task = self._session.call_xenapi("Async.VM.hard_shutdown", vm) + else: + task = self._session.call_xenapi('Async.VM.clean_shutdown', vm) + self._session.wait_for_task(task, instance.id) except self.XenAPI.Failure, exc: LOG.exception(exc) def _destroy_vdis(self, instance, vm): """Destroys all VDIs associated with a VM """ + instance_id = instance.id + LOG.debug(_("Destroying VDIs for Instance %(instance_id)s") + % locals()) vdis = VMHelper.lookup_vm_vdis(self._session, vm) if not vdis: @@ -279,18 +473,60 @@ class VMOps(object): for vdi in vdis: try: task = self._session.call_xenapi('Async.VDI.destroy', vdi) - self._session.wait_for_task(instance.id, task) + self._session.wait_for_task(task, instance.id) except self.XenAPI.Failure, exc: LOG.exception(exc) + def _destroy_kernel_ramdisk(self, instance, vm): + """ + Three situations can occur: + + 1. We have neither a ramdisk nor a kernel, in which case we are a + RAW image and can omit this step + + 2. We have one or the other, in which case, we should flag as an + error + + 3. We have both, in which case we safely remove both the kernel + and the ramdisk. + """ + instance_id = instance.id + if not instance.kernel_id and not instance.ramdisk_id: + # 1. No kernel or ramdisk + LOG.debug(_("Instance %(instance_id)s using RAW or VHD, " + "skipping kernel and ramdisk deletion") % locals()) + return + + if not (instance.kernel_id and instance.ramdisk_id): + # 2. We only have kernel xor ramdisk + raise exception.NotFound( + _("Instance %(instance_id)s has a kernel or ramdisk but not " + "both" % locals())) + + # 3. We have both kernel and ramdisk + (kernel, ramdisk) = VMHelper.lookup_kernel_ramdisk( + self._session, vm) + + LOG.debug(_("Removing kernel/ramdisk files")) + + args = {'kernel-file': kernel, 'ramdisk-file': ramdisk} + task = self._session.async_call_plugin( + 'glance', 'remove_kernel_ramdisk', args) + self._session.wait_for_task(task, instance.id) + + LOG.debug(_("kernel/ramdisk files removed")) + def _destroy_vm(self, instance, vm): """Destroys a VM record """ + instance_id = instance.id try: task = self._session.call_xenapi('Async.VM.destroy', vm) - self._session.wait_for_task(instance.id, task) + self._session.wait_for_task(task, instance_id) except self.XenAPI.Failure, exc: LOG.exception(exc) + LOG.debug(_("Instance %(instance_id)s VM destroyed") % locals()) + def destroy(self, instance): """ Destroy VM instance @@ -298,32 +534,37 @@ class VMOps(object): This is the method exposed by xenapi_conn.destroy(). The rest of the destroy_* methods are internal. """ + instance_id = instance.id + LOG.info(_("Destroying VM for Instance %(instance_id)s") % locals()) vm = VMHelper.lookup(self._session, instance.name) return self._destroy(instance, vm, shutdown=True) - def _destroy(self, instance, vm, shutdown=True): + def _destroy(self, instance, vm, shutdown=True, + destroy_kernel_ramdisk=True): """ Destroys VM instance by performing: - 1. A shutdown if requested - 2. Destroying associated VDIs - 3. Destroying that actual VM record + 1. A shutdown if requested + 2. Destroying associated VDIs + 3. Destroying kernel and ramdisk files (if necessary) + 4. Destroying that actual VM record """ if vm is None: - # Don't complain, just return. This lets us clean up instances - # that have already disappeared from the underlying platform. + LOG.warning(_("VM is not present, skipping destroy...")) return if shutdown: self._shutdown(instance, vm) self._destroy_vdis(instance, vm) + if destroy_kernel_ramdisk: + self._destroy_kernel_ramdisk(instance, vm) self._destroy_vm(instance, vm) def _wait_with_callback(self, instance_id, task, callback): ret = None try: - ret = self._session.wait_for_task(instance_id, task) + ret = self._session.wait_for_task(task, instance_id) except self.XenAPI.Failure, exc: LOG.exception(exc) callback(ret) @@ -352,6 +593,78 @@ class VMOps(object): task = self._session.call_xenapi('Async.VM.resume', vm, False, True) self._wait_with_callback(instance.id, task, callback) + def rescue(self, instance, callback): + """Rescue the specified instance + - shutdown the instance VM + - set 'bootlock' to prevent the instance from starting in rescue + - spawn a rescue VM (the vm name-label will be instance-N-rescue) + + """ + rescue_vm = VMHelper.lookup(self._session, instance.name + "-rescue") + if rescue_vm: + raise RuntimeError(_( + "Instance is already in Rescue Mode: %s" % instance.name)) + + vm = self._get_vm_opaque_ref(instance) + self._shutdown(instance, vm) + self._acquire_bootlock(vm) + + instance._rescue = True + self.spawn(instance) + rescue_vm = self._get_vm_opaque_ref(instance) + + vbd = self._session.get_xenapi().VM.get_VBDs(vm)[0] + vdi_ref = self._session.get_xenapi().VBD.get_record(vbd)["VDI"] + vbd_ref = VMHelper.create_vbd( + self._session, + rescue_vm, + vdi_ref, + 1, + False) + + self._session.call_xenapi("Async.VBD.plug", vbd_ref) + + def unrescue(self, instance, callback): + """Unrescue the specified instance + - unplug the instance VM's disk from the rescue VM + - teardown the rescue VM + - release the bootlock to allow the instance VM to start + + """ + rescue_vm = VMHelper.lookup(self._session, instance.name + "-rescue") + + if not rescue_vm: + raise exception.NotFound(_( + "Instance is not in Rescue Mode: %s" % instance.name)) + + original_vm = self._get_vm_opaque_ref(instance) + vbds = self._session.get_xenapi().VM.get_VBDs(rescue_vm) + + instance._rescue = False + + for vbd_ref in vbds: + vbd = self._session.get_xenapi().VBD.get_record(vbd_ref) + if vbd["userdevice"] == "1": + VMHelper.unplug_vbd(self._session, vbd_ref) + VMHelper.destroy_vbd(self._session, vbd_ref) + + task1 = self._session.call_xenapi("Async.VM.hard_shutdown", rescue_vm) + self._session.wait_for_task(task1, instance.id) + + vdis = VMHelper.lookup_vm_vdis(self._session, rescue_vm) + for vdi in vdis: + try: + task = self._session.call_xenapi('Async.VDI.destroy', vdi) + self._session.wait_for_task(task, instance.id) + except self.XenAPI.Failure: + continue + + task2 = self._session.call_xenapi('Async.VM.destroy', rescue_vm) + self._session.wait_for_task(task2, instance.id) + + self._release_bootlock(original_vm) + self._start(instance, original_vm) + def get_info(self, instance): """Return data about VM instance""" vm = self._get_vm_opaque_ref(instance) @@ -374,6 +687,102 @@ class VMOps(object): # TODO: implement this! return 'http://fakeajaxconsole/fake_url' + def inject_network_info(self, instance): + """ + Generate the network info and make calls to place it into the + xenstore and the xenstore param list + + """ + # TODO(tr3buchet) - remove comment in multi-nic + # I've decided to go ahead and consider multiple IPs and networks + # at this stage even though they aren't implemented because these will + # be needed for multi-nic and there was no sense writing it for single + # network/single IP and then having to turn around and re-write it + vm_opaque_ref = self._get_vm_opaque_ref(instance.id) + logging.debug(_("injecting network info to xenstore for vm: |%s|"), + vm_opaque_ref) + admin_context = context.get_admin_context() + IPs = db.fixed_ip_get_all_by_instance(admin_context, instance['id']) + networks = db.network_get_all_by_instance(admin_context, + instance['id']) + for network in networks: + network_IPs = [ip for ip in IPs if ip.network_id == network.id] + + def ip_dict(ip): + return { + "ip": ip.address, + "netmask": network["netmask"], + "enabled": "1"} + + def ip6_dict(ip6): + return { + "ip": ip6.addressV6, + "netmask": ip6.netmaskV6, + "gateway": ip6.gatewayV6, + "enabled": "1"} + + mac_id = instance.mac_address.replace(':', '') + location = 'vm-data/networking/%s' % mac_id + mapping = { + 'label': network['label'], + 'gateway': network['gateway'], + 'mac': instance.mac_address, + 'dns': [network['dns']], + 'ips': [ip_dict(ip) for ip in network_IPs], + 'ip6s': [ip6_dict(ip) for ip in network_IPs]} + + self.write_to_param_xenstore(vm_opaque_ref, {location: mapping}) + + try: + self.write_to_xenstore(vm_opaque_ref, location, + mapping['location']) + except KeyError: + # catch KeyError for domid if instance isn't running + pass + + return networks + + def create_vifs(self, instance, networks=None): + """ + Creates vifs for an instance + + """ + vm_opaque_ref = self._get_vm_opaque_ref(instance.id) + logging.debug(_("creating vif(s) for vm: |%s|"), vm_opaque_ref) + if networks is None: + networks = db.network_get_all_by_instance(admin_context, + instance['id']) + # TODO(tr3buchet) - remove comment in multi-nic + # this bit here about creating the vifs will be updated + # in multi-nic to handle multiple IPs on the same network + # and multiple networks + # for now it works as there is only one of each + for network in networks: + bridge = network['bridge'] + network_ref = \ + NetworkHelper.find_network_with_bridge(self._session, bridge) + + if network_ref: + try: + device = "1" if instance._rescue else "0" + except AttributeError: + device = "0" + + VMHelper.create_vif( + self._session, + vm_opaque_ref, + network_ref, + instance.mac_address, + device) + + def reset_network(self, instance): + """ + Creates uuid arg to pass to make_agent_call and calls it. + + """ + args = {'id': str(uuid.uuid4())} + resp = self._make_agent_call('resetnetwork', instance, '', args) + def list_from_xenstore(self, vm, path): """Runs the xenstore-ls command to get a listing of all records from 'path' downward. Returns a dict with the sub-paths as keys, @@ -434,7 +843,7 @@ class VMOps(object): args.update(addl_args) try: task = self._session.async_call_plugin(plugin, method, args) - ret = self._session.wait_for_task(instance_id, task) + ret = self._session.wait_for_task(task, instance_id) except self.XenAPI.Failure, e: ret = None err_trace = e.details[-1] @@ -443,6 +852,11 @@ class VMOps(object): if 'TIMEOUT:' in err_msg: LOG.error(_('TIMEOUT: The call to %(method)s timed out. ' 'VM id=%(instance_id)s; args=%(strargs)s') % locals()) + elif 'NOT IMPLEMENTED:' in err_msg: + LOG.error(_('NOT IMPLEMENTED: The call to %(method)s is not' + ' supported by the agent. VM id=%(instance_id)s;' + ' args=%(strargs)s') % locals()) + raise NotImplementedError(err_msg) else: LOG.error(_('The call to %(method)s returned an error: %(e)s. ' 'VM id=%(instance_id)s; args=%(strargs)s') % locals()) diff --git a/nova/virt/xenapi/volumeops.py b/nova/virt/xenapi/volumeops.py index d89a6f995..757ecf5ad 100644 --- a/nova/virt/xenapi/volumeops.py +++ b/nova/virt/xenapi/volumeops.py @@ -83,7 +83,7 @@ class VolumeOps(object): try: task = self._session.call_xenapi('Async.VBD.plug', vbd_ref) - self._session.wait_for_task(vol_rec['deviceNumber'], task) + self._session.wait_for_task(task, vol_rec['deviceNumber']) except self.XenAPI.Failure, exc: LOG.exception(exc) VolumeHelper.destroy_iscsi_storage(self._session, diff --git a/nova/virt/xenapi_conn.py b/nova/virt/xenapi_conn.py index a0b0499b8..b63a5f8c3 100644 --- a/nova/virt/xenapi_conn.py +++ b/nova/virt/xenapi_conn.py @@ -100,6 +100,8 @@ flags.DEFINE_integer('xenapi_vhd_coalesce_max_attempts', 5, 'Max number of times to poll for VHD to coalesce.' ' Used only if connection_type=xenapi.') +flags.DEFINE_string('xenapi_sr_base_path', '/var/run/sr-mount', + 'Base path to the storage repository') flags.DEFINE_string('target_host', None, 'iSCSI Target Host') @@ -156,10 +158,20 @@ class XenAPIConnection(object): """Create VM instance""" self._vmops.spawn(instance) + def finish_resize(self, instance, disk_info): + """Completes a resize, turning on the migrated instance""" + vdi_uuid = self._vmops.attach_disk(instance, disk_info['base_copy'], + disk_info['cow']) + self._vmops._spawn_with_disk(instance, vdi_uuid) + def snapshot(self, instance, image_id): """ Create snapshot from a running VM instance """ self._vmops.snapshot(instance, image_id) + def resize(self, instance, flavor): + """Resize a VM instance""" + raise NotImplementedError() + def reboot(self, instance): """Reboot VM instance""" self._vmops.reboot(instance) @@ -168,6 +180,12 @@ class XenAPIConnection(object): """Set the root/admin password on the VM instance""" self._vmops.set_admin_password(instance, new_pass) + def inject_file(self, instance, b64_path, b64_contents): + """Create a file on the VM instance. The file path and contents + should be base64-encoded. + """ + self._vmops.inject_file(instance, b64_path, b64_contents) + def destroy(self, instance): """Destroy VM instance""" self._vmops.destroy(instance) @@ -180,6 +198,11 @@ class XenAPIConnection(object): """Unpause paused VM instance""" self._vmops.unpause(instance, callback) + def migrate_disk_and_power_off(self, instance, dest): + """Transfers the VHD of a running instance to another host, then shuts + off the instance copies over the COW disk""" + return self._vmops.migrate_disk_and_power_off(instance, dest) + def suspend(self, instance, callback): """suspend the specified instance""" self._vmops.suspend(instance, callback) @@ -188,6 +211,22 @@ class XenAPIConnection(object): """resume the specified instance""" self._vmops.resume(instance, callback) + def rescue(self, instance, callback): + """Rescue the specified instance""" + self._vmops.rescue(instance, callback) + + def unrescue(self, instance, callback): + """Unrescue the specified instance""" + self._vmops.unrescue(instance, callback) + + def reset_network(self, instance): + """reset networking for specified instance""" + self._vmops.reset_network(instance) + + def inject_network_info(self, instance): + """inject network info for specified instance""" + self._vmops.inject_network_info(instance) + def get_info(self, instance_id): """Return data about VM instance""" return self._vmops.get_info(instance_id) @@ -204,6 +243,10 @@ class XenAPIConnection(object): """Return link to instance's ajax console""" return self._vmops.get_ajax_console(instance) + def get_host_ip_addr(self): + xs_url = urlparse.urlparse(FLAGS.xenapi_connection_url) + return xs_url.netloc + def attach_volume(self, instance_name, device_path, mountpoint): """Attach volume storage to VM instance""" return self._volumeops.attach_volume(instance_name, @@ -263,7 +306,7 @@ class XenAPISession(object): self._session.xenapi.Async.host.call_plugin, self.get_xenapi_host(), plugin, fn, args) - def wait_for_task(self, id, task): + def wait_for_task(self, task, id=None): """Return the result of the given task. The task is polled until it completes. Not re-entrant.""" done = event.Event() @@ -290,10 +333,11 @@ class XenAPISession(object): try: name = self._session.xenapi.task.get_name_label(task) status = self._session.xenapi.task.get_status(task) - action = dict( - instance_id=int(id), - action=name[0:255], # Ensure action is never > 255 - error=None) + if id: + action = dict( + instance_id=int(id), + action=name[0:255], # Ensure action is never > 255 + error=None) if status == "pending": return elif status == "success": @@ -307,7 +351,9 @@ class XenAPISession(object): LOG.warn(_("Task [%(name)s] %(task)s status:" " %(status)s %(error_info)s") % locals()) done.send_exception(self.XenAPI.Failure(error_info)) - db.instance_action_create(context.get_admin_context(), action) + + if id: + db.instance_action_create(context.get_admin_context(), action) except self.XenAPI.Failure, exc: LOG.warn(exc) done.send_exception(*sys.exc_info()) diff --git a/nova/volume/api.py b/nova/volume/api.py index 478c83486..2f4494845 100644 --- a/nova/volume/api.py +++ b/nova/volume/api.py @@ -49,7 +49,7 @@ class API(base.Base): options = { 'size': size, - 'user_id': context.user.id, + 'user_id': context.user_id, 'project_id': context.project_id, 'availability_zone': FLAGS.storage_availability_zone, 'status': "creating", @@ -85,7 +85,7 @@ class API(base.Base): return self.db.volume_get(context, volume_id) def get_all(self, context): - if context.user.is_admin(): + if context.is_admin: return self.db.volume_get_all(context) return self.db.volume_get_all_by_project(context, context.project_id) diff --git a/nova/volume/driver.py b/nova/volume/driver.py index da7307733..45cc800e7 100644 --- a/nova/volume/driver.py +++ b/nova/volume/driver.py @@ -21,6 +21,7 @@ Drivers for volumes. """ import time +import os from nova import exception from nova import flags @@ -36,6 +37,8 @@ flags.DEFINE_string('aoe_eth_dev', 'eth0', 'Which device to export the volumes on') flags.DEFINE_string('num_shell_tries', 3, 'number of times to attempt to run flakey shell commands') +flags.DEFINE_string('num_iscsi_scan_tries', 3, + 'number of times to rescan iSCSI target to find volume') flags.DEFINE_integer('num_shelves', 100, 'Number of vblade shelves') @@ -62,14 +65,14 @@ class VolumeDriver(object): self._execute = execute self._sync_exec = sync_exec - def _try_execute(self, command): + def _try_execute(self, *command): # NOTE(vish): Volume commands can partially fail due to timing, but # running them a second time on failure will usually # recover nicely. tries = 0 while True: try: - self._execute(command) + self._execute(*command) return True except exception.ProcessExecutionError: tries = tries + 1 @@ -81,34 +84,35 @@ class VolumeDriver(object): def check_for_setup_error(self): """Returns an error if prerequisites aren't met""" - out, err = self._execute("sudo vgs --noheadings -o name") + out, err = self._execute('sudo', 'vgs', '--noheadings', '-o', 'name') volume_groups = out.split() if not FLAGS.volume_group in volume_groups: raise exception.Error(_("volume group %s doesn't exist") % FLAGS.volume_group) def create_volume(self, volume): - """Creates a logical volume.""" + """Creates a logical volume. Can optionally return a Dictionary of + changes to the volume object to be persisted.""" if int(volume['size']) == 0: sizestr = '100M' else: sizestr = '%sG' % volume['size'] - self._try_execute("sudo lvcreate -L %s -n %s %s" % - (sizestr, + self._try_execute('sudo', 'lvcreate', '-L', sizestr, '-n', volume['name'], - FLAGS.volume_group)) + FLAGS.volume_group) def delete_volume(self, volume): """Deletes a logical volume.""" try: - self._try_execute("sudo lvdisplay %s/%s" % + self._try_execute('sudo', 'lvdisplay', + '%s/%s' % (FLAGS.volume_group, volume['name'])) except Exception as e: # If the volume isn't present, then don't attempt to delete return True - self._try_execute("sudo lvremove -f %s/%s" % + self._try_execute('sudo', 'lvremove', '-f', "%s/%s" % (FLAGS.volume_group, volume['name'])) @@ -123,7 +127,8 @@ class VolumeDriver(object): raise NotImplementedError() def create_export(self, context, volume): - """Exports the volume.""" + """Exports the volume. Can optionally return a Dictionary of changes + to the volume object to be persisted.""" raise NotImplementedError() def remove_export(self, context, volume): @@ -163,12 +168,13 @@ class AOEDriver(VolumeDriver): blade_id) = self.db.volume_allocate_shelf_and_blade(context, volume['id']) self._try_execute( - "sudo vblade-persist setup %s %s %s /dev/%s/%s" % - (shelf_id, + 'sudo', 'vblade-persist', 'setup', + shelf_id, blade_id, FLAGS.aoe_eth_dev, - FLAGS.volume_group, - volume['name'])) + "/dev/%s/%s" % + (FLAGS.volume_group, + volume['name'])) # NOTE(vish): The standard _try_execute does not work here # because these methods throw errors if other # volumes on this host are in the process of @@ -177,9 +183,9 @@ class AOEDriver(VolumeDriver): # just wait a bit for the current volume to # be ready and ignore any errors. time.sleep(2) - self._execute("sudo vblade-persist auto all", + self._execute('sudo', 'vblade-persist', 'auto', 'all', check_exit_code=False) - self._execute("sudo vblade-persist start all", + self._execute('sudo', 'vblade-persist', 'start', 'all', check_exit_code=False) def remove_export(self, context, volume): @@ -187,15 +193,15 @@ class AOEDriver(VolumeDriver): (shelf_id, blade_id) = self.db.volume_get_shelf_and_blade(context, volume['id']) - self._try_execute("sudo vblade-persist stop %s %s" % - (shelf_id, blade_id)) - self._try_execute("sudo vblade-persist destroy %s %s" % - (shelf_id, blade_id)) + self._try_execute('sudo', 'vblade-persist', 'stop', + shelf_id, blade_id) + self._try_execute('sudo', 'vblade-persist', 'destroy', + shelf_id, blade_id) def discover_volume(self, _volume): """Discover volume on a remote host.""" - self._execute("sudo aoe-discover") - self._execute("sudo aoe-stat", check_exit_code=False) + self._execute('sudo', 'aoe-discover') + self._execute('sudo', 'aoe-stat', check_exit_code=False) def undiscover_volume(self, _volume): """Undiscover volume on a remote host.""" @@ -222,7 +228,18 @@ class FakeAOEDriver(AOEDriver): class ISCSIDriver(VolumeDriver): - """Executes commands relating to ISCSI volumes.""" + """Executes commands relating to ISCSI volumes. + + We make use of model provider properties as follows: + + :provider_location: if present, contains the iSCSI target information + in the same format as an ietadm discovery + i.e. '<ip>:<port>,<portal> <target IQN>' + + :provider_auth: if present, contains a space-separated triple: + '<auth method> <auth username> <auth password>'. + `CHAP` is the only auth_method in use at the moment. + """ def ensure_export(self, context, volume): """Synchronously recreates an export for a logical volume.""" @@ -236,13 +253,16 @@ class ISCSIDriver(VolumeDriver): iscsi_name = "%s%s" % (FLAGS.iscsi_target_prefix, volume['name']) volume_path = "/dev/%s/%s" % (FLAGS.volume_group, volume['name']) - self._sync_exec("sudo ietadm --op new " - "--tid=%s --params Name=%s" % - (iscsi_target, iscsi_name), + self._sync_exec('sudo', 'ietadm', '--op', 'new', + "--tid=%s" % iscsi_target, + '--params', + "Name=%s" % iscsi_name, check_exit_code=False) - self._sync_exec("sudo ietadm --op new --tid=%s " - "--lun=0 --params Path=%s,Type=fileio" % - (iscsi_target, volume_path), + self._sync_exec('sudo', 'ietadm', '--op', 'new', + "--tid=%s" % iscsi_target, + '--lun=0', + '--params', + "Path=%s,Type=fileio" % volume_path, check_exit_code=False) def _ensure_iscsi_targets(self, context, host): @@ -263,12 +283,13 @@ class ISCSIDriver(VolumeDriver): volume['host']) iscsi_name = "%s%s" % (FLAGS.iscsi_target_prefix, volume['name']) volume_path = "/dev/%s/%s" % (FLAGS.volume_group, volume['name']) - self._execute("sudo ietadm --op new " - "--tid=%s --params Name=%s" % + self._execute('sudo', 'ietadm', '--op', 'new', + '--tid=%s --params Name=%s' % (iscsi_target, iscsi_name)) - self._execute("sudo ietadm --op new --tid=%s " - "--lun=0 --params Path=%s,Type=fileio" % - (iscsi_target, volume_path)) + self._execute('sudo', 'ietadm', '--op', 'new', + '--tid=%s' % iscsi_target, + '--lun=0', '--params', + 'Path=%s,Type=fileio' % volume_path) def remove_export(self, context, volume): """Removes an export for a logical volume.""" @@ -283,51 +304,162 @@ class ISCSIDriver(VolumeDriver): try: # ietadm show will exit with an error # this export has already been removed - self._execute("sudo ietadm --op show --tid=%s " % iscsi_target) + self._execute('sudo', 'ietadm', '--op', 'show', + '--tid=%s' % iscsi_target) except Exception as e: LOG.info(_("Skipping remove_export. No iscsi_target " + "is presently exported for volume: %d"), volume['id']) return - self._execute("sudo ietadm --op delete --tid=%s " - "--lun=0" % iscsi_target) - self._execute("sudo ietadm --op delete --tid=%s" % - iscsi_target) + self._execute('sudo', 'ietadm', '--op', 'delete', + '--tid=%s' % iscsi_target, + '--lun=0') + self._execute('sudo', 'ietadm', '--op', 'delete', + '--tid=%s' % iscsi_target) + + def _do_iscsi_discovery(self, volume): + #TODO(justinsb): Deprecate discovery and use stored info + #NOTE(justinsb): Discovery won't work with CHAP-secured targets (?) + LOG.warn(_("ISCSI provider_location not stored, using discovery")) + + volume_name = volume['name'] - def _get_name_and_portal(self, volume_name, host): - """Gets iscsi name and portal from volume name and host.""" - (out, _err) = self._execute("sudo iscsiadm -m discovery -t " - "sendtargets -p %s" % host) + (out, _err) = self._execute('sudo', 'iscsiadm', '-m', 'discovery', + '-t', 'sendtargets', '-p', volume['host']) for target in out.splitlines(): if FLAGS.iscsi_ip_prefix in target and volume_name in target: - (location, _sep, iscsi_name) = target.partition(" ") - break - iscsi_portal = location.split(",")[0] - return (iscsi_name, iscsi_portal) + return target + return None + + def _get_iscsi_properties(self, volume): + """Gets iscsi configuration + + We ideally get saved information in the volume entity, but fall back + to discovery if need be. Discovery may be completely removed in future + The properties are: + + :target_discovered: boolean indicating whether discovery was used + + :target_iqn: the IQN of the iSCSI target + + :target_portal: the portal of the iSCSI target + + :auth_method:, :auth_username:, :auth_password: + + the authentication details. Right now, either auth_method is not + present meaning no authentication, or auth_method == `CHAP` + meaning use CHAP with the specified credentials. + """ + + properties = {} + + location = volume['provider_location'] + + if location: + # provider_location is the same format as iSCSI discovery output + properties['target_discovered'] = False + else: + location = self._do_iscsi_discovery(volume) + + if not location: + raise exception.Error(_("Could not find iSCSI export " + " for volume %s") % + (volume['name'])) + + LOG.debug(_("ISCSI Discovery: Found %s") % (location)) + properties['target_discovered'] = True + + (iscsi_target, _sep, iscsi_name) = location.partition(" ") + + iscsi_portal = iscsi_target.split(",")[0] + + properties['target_iqn'] = iscsi_name + properties['target_portal'] = iscsi_portal + + auth = volume['provider_auth'] + + if auth: + (auth_method, auth_username, auth_secret) = auth.split() + + properties['auth_method'] = auth_method + properties['auth_username'] = auth_username + properties['auth_password'] = auth_secret + + return properties + + def _run_iscsiadm(self, iscsi_properties, iscsi_command): + command = ("sudo iscsiadm -m node -T %s -p %s %s" % + (iscsi_properties['target_iqn'], + iscsi_properties['target_portal'], + iscsi_command)) + (out, err) = self._execute(command) + LOG.debug("iscsiadm %s: stdout=%s stderr=%s" % + (iscsi_command, out, err)) + return (out, err) + + def _iscsiadm_update(self, iscsi_properties, property_key, property_value): + iscsi_command = ("--op update -n %s -v %s" % + (property_key, property_value)) + return self._run_iscsiadm(iscsi_properties, iscsi_command) def discover_volume(self, volume): """Discover volume on a remote host.""" - iscsi_name, iscsi_portal = self._get_name_and_portal(volume['name'], - volume['host']) - self._execute("sudo iscsiadm -m node -T %s -p %s --login" % - (iscsi_name, iscsi_portal)) - self._execute("sudo iscsiadm -m node -T %s -p %s --op update " - "-n node.startup -v automatic" % - (iscsi_name, iscsi_portal)) - return "/dev/disk/by-path/ip-%s-iscsi-%s-lun-0" % (iscsi_portal, - iscsi_name) + iscsi_properties = self._get_iscsi_properties(volume) + + if not iscsi_properties['target_discovered']: + self._run_iscsiadm(iscsi_properties, "--op new") + + if iscsi_properties.get('auth_method'): + self._iscsiadm_update(iscsi_properties, + "node.session.auth.authmethod", + iscsi_properties['auth_method']) + self._iscsiadm_update(iscsi_properties, + "node.session.auth.username", + iscsi_properties['auth_username']) + self._iscsiadm_update(iscsi_properties, + "node.session.auth.password", + iscsi_properties['auth_password']) + + self._run_iscsiadm(iscsi_properties, "--login") + + self._iscsiadm_update(iscsi_properties, "node.startup", "automatic") + + mount_device = ("/dev/disk/by-path/ip-%s-iscsi-%s-lun-0" % + (iscsi_properties['target_portal'], + iscsi_properties['target_iqn'])) + + # The /dev/disk/by-path/... node is not always present immediately + # TODO(justinsb): This retry-with-delay is a pattern, move to utils? + tries = 0 + while not os.path.exists(mount_device): + if tries >= FLAGS.num_iscsi_scan_tries: + raise exception.Error(_("iSCSI device not found at %s") % + (mount_device)) + + LOG.warn(_("ISCSI volume not yet found at: %(mount_device)s. " + "Will rescan & retry. Try number: %(tries)s") % + locals()) + + # The rescan isn't documented as being necessary(?), but it helps + self._run_iscsiadm(iscsi_properties, "--rescan") + + tries = tries + 1 + if not os.path.exists(mount_device): + time.sleep(tries ** 2) + + if tries != 0: + LOG.debug(_("Found iSCSI node %(mount_device)s " + "(after %(tries)s rescans)") % + locals()) + + return mount_device def undiscover_volume(self, volume): """Undiscover volume on a remote host.""" - iscsi_name, iscsi_portal = self._get_name_and_portal(volume['name'], - volume['host']) - self._execute("sudo iscsiadm -m node -T %s -p %s --op update " - "-n node.startup -v manual" % - (iscsi_name, iscsi_portal)) - self._execute("sudo iscsiadm -m node -T %s -p %s --logout " % - (iscsi_name, iscsi_portal)) - self._execute("sudo iscsiadm -m node --op delete " - "--targetname %s" % iscsi_name) + iscsi_properties = self._get_iscsi_properties(volume) + self._iscsiadm_update(iscsi_properties, "node.startup", "manual") + self._run_iscsiadm(iscsi_properties, "--logout") + self._run_iscsiadm(iscsi_properties, "--op delete") class FakeISCSIDriver(ISCSIDriver): @@ -353,7 +485,7 @@ class RBDDriver(VolumeDriver): def check_for_setup_error(self): """Returns an error if prerequisites aren't met""" - (stdout, stderr) = self._execute("rados lspools") + (stdout, stderr) = self._execute('rados', 'lspools') pools = stdout.split("\n") if not FLAGS.rbd_pool in pools: raise exception.Error(_("rbd has no pool %s") % @@ -365,16 +497,13 @@ class RBDDriver(VolumeDriver): size = 100 else: size = int(volume['size']) * 1024 - self._try_execute("rbd --pool %s --size %d create %s" % - (FLAGS.rbd_pool, - size, - volume['name'])) + self._try_execute('rbd', '--pool', FLAGS.rbd_pool, + '--size', size, 'create', volume['name']) def delete_volume(self, volume): """Deletes a logical volume.""" - self._try_execute("rbd --pool %s rm %s" % - (FLAGS.rbd_pool, - volume['name'])) + self._try_execute('rbd', '--pool', FLAGS.rbd_pool, + 'rm', voluname['name']) def local_path(self, volume): """Returns the path of the rbd volume.""" @@ -409,7 +538,7 @@ class SheepdogDriver(VolumeDriver): def check_for_setup_error(self): """Returns an error if prerequisites aren't met""" try: - (out, err) = self._execute("collie cluster info") + (out, err) = self._execute('collie', 'cluster', 'info') if not out.startswith('running'): raise exception.Error(_("Sheepdog is not working: %s") % out) except exception.ProcessExecutionError: @@ -421,12 +550,13 @@ class SheepdogDriver(VolumeDriver): sizestr = '100M' else: sizestr = '%sG' % volume['size'] - self._try_execute("qemu-img create sheepdog:%s %s" % - (volume['name'], sizestr)) + self._try_execute('qemu-img', 'create', + "sheepdog:%s" % volume['name'], + sizestr) def delete_volume(self, volume): """Deletes a logical volume""" - self._try_execute("collie vdi delete %s" % volume['name']) + self._try_execute('collie', 'vdi', 'delete', volume['name']) def local_path(self, volume): return "sheepdog:%s" % volume['name'] diff --git a/nova/volume/manager.py b/nova/volume/manager.py index 6f8e25e19..3e8bc16b3 100644 --- a/nova/volume/manager.py +++ b/nova/volume/manager.py @@ -87,7 +87,7 @@ class VolumeManager(manager.Manager): if volume['status'] in ['available', 'in-use']: self.driver.ensure_export(ctxt, volume) else: - LOG.info(_("volume %s: skipping export"), volume_ref['name']) + LOG.info(_("volume %s: skipping export"), volume['name']) def create_volume(self, context, volume_id): """Creates and exports the volume.""" @@ -107,14 +107,18 @@ class VolumeManager(manager.Manager): vol_size = volume_ref['size'] LOG.debug(_("volume %(vol_name)s: creating lv of" " size %(vol_size)sG") % locals()) - self.driver.create_volume(volume_ref) + model_update = self.driver.create_volume(volume_ref) + if model_update: + self.db.volume_update(context, volume_ref['id'], model_update) LOG.debug(_("volume %s: creating export"), volume_ref['name']) - self.driver.create_export(context, volume_ref) - except Exception as e: + model_update = self.driver.create_export(context, volume_ref) + if model_update: + self.db.volume_update(context, volume_ref['id'], model_update) + except Exception: self.db.volume_update(context, volume_ref['id'], {'status': 'error'}) - raise e + raise now = datetime.datetime.utcnow() self.db.volume_update(context, @@ -137,11 +141,11 @@ class VolumeManager(manager.Manager): self.driver.remove_export(context, volume_ref) LOG.debug(_("volume %s: deleting"), volume_ref['name']) self.driver.delete_volume(volume_ref) - except Exception as e: + except Exception: self.db.volume_update(context, volume_ref['id'], {'status': 'error_deleting'}) - raise e + raise self.db.volume_destroy(context, volume_id) LOG.debug(_("volume %s: deleted successfully"), volume_ref['name']) diff --git a/nova/volume/san.py b/nova/volume/san.py new file mode 100644 index 000000000..9532c8116 --- /dev/null +++ b/nova/volume/san.py @@ -0,0 +1,585 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 Justin Santa Barbara +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +Drivers for san-stored volumes. + +The unique thing about a SAN is that we don't expect that we can run the volume +controller on the SAN hardware. We expect to access it over SSH or some API. +""" + +import os +import paramiko + +from xml.etree import ElementTree + +from nova import exception +from nova import flags +from nova import log as logging +from nova.utils import ssh_execute +from nova.volume.driver import ISCSIDriver + +LOG = logging.getLogger("nova.volume.driver") +FLAGS = flags.FLAGS +flags.DEFINE_boolean('san_thin_provision', 'true', + 'Use thin provisioning for SAN volumes?') +flags.DEFINE_string('san_ip', '', + 'IP address of SAN controller') +flags.DEFINE_string('san_login', 'admin', + 'Username for SAN controller') +flags.DEFINE_string('san_password', '', + 'Password for SAN controller') +flags.DEFINE_string('san_privatekey', '', + 'Filename of private key to use for SSH authentication') +flags.DEFINE_string('san_clustername', '', + 'Cluster name to use for creating volumes') +flags.DEFINE_integer('san_ssh_port', 22, + 'SSH port to use with SAN') + + +class SanISCSIDriver(ISCSIDriver): + """ Base class for SAN-style storage volumes + + A SAN-style storage value is 'different' because the volume controller + probably won't run on it, so we need to access is over SSH or another + remote protocol. + """ + + def _build_iscsi_target_name(self, volume): + return "%s%s" % (FLAGS.iscsi_target_prefix, volume['name']) + + # discover_volume is still OK + # undiscover_volume is still OK + + def _connect_to_ssh(self): + ssh = paramiko.SSHClient() + #TODO(justinsb): We need a better SSH key policy + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + if FLAGS.san_password: + ssh.connect(FLAGS.san_ip, + port=FLAGS.san_ssh_port, + username=FLAGS.san_login, + password=FLAGS.san_password) + elif FLAGS.san_privatekey: + privatekeyfile = os.path.expanduser(FLAGS.san_privatekey) + # It sucks that paramiko doesn't support DSA keys + privatekey = paramiko.RSAKey.from_private_key_file(privatekeyfile) + ssh.connect(FLAGS.san_ip, + port=FLAGS.san_ssh_port, + username=FLAGS.san_login, + pkey=privatekey) + else: + raise exception.Error(_("Specify san_password or san_privatekey")) + return ssh + + def _run_ssh(self, command, check_exit_code=True): + #TODO(justinsb): SSH connection caching (?) + ssh = self._connect_to_ssh() + + #TODO(justinsb): Reintroduce the retry hack + ret = ssh_execute(ssh, command, check_exit_code=check_exit_code) + + ssh.close() + + return ret + + def ensure_export(self, context, volume): + """Synchronously recreates an export for a logical volume.""" + pass + + def create_export(self, context, volume): + """Exports the volume.""" + pass + + def remove_export(self, context, volume): + """Removes an export for a logical volume.""" + pass + + def check_for_setup_error(self): + """Returns an error if prerequisites aren't met""" + if not (FLAGS.san_password or FLAGS.san_privatekey): + raise exception.Error(_("Specify san_password or san_privatekey")) + + if not (FLAGS.san_ip): + raise exception.Error(_("san_ip must be set")) + + +def _collect_lines(data): + """ Split lines from data into an array, trimming them """ + matches = [] + for line in data.splitlines(): + match = line.strip() + matches.append(match) + + return matches + + +def _get_prefixed_values(data, prefix): + """Collect lines which start with prefix; with trimming""" + matches = [] + for line in data.splitlines(): + line = line.strip() + if line.startswith(prefix): + match = line[len(prefix):] + match = match.strip() + matches.append(match) + + return matches + + +class SolarisISCSIDriver(SanISCSIDriver): + """Executes commands relating to Solaris-hosted ISCSI volumes. + + Basic setup for a Solaris iSCSI server: + + pkg install storage-server SUNWiscsit + + svcadm enable stmf + + svcadm enable -r svc:/network/iscsi/target:default + + pfexec itadm create-tpg e1000g0 ${MYIP} + + pfexec itadm create-target -t e1000g0 + + + Then grant the user that will be logging on lots of permissions. + I'm not sure exactly which though: + + zfs allow justinsb create,mount,destroy rpool + + usermod -P'File System Management' justinsb + + usermod -P'Primary Administrator' justinsb + + Also make sure you can login using san_login & san_password/san_privatekey + """ + + def _view_exists(self, luid): + cmd = "pfexec /usr/sbin/stmfadm list-view -l %s" % (luid) + (out, _err) = self._run_ssh(cmd, + check_exit_code=False) + if "no views found" in out: + return False + + if "View Entry:" in out: + return True + + raise exception.Error("Cannot parse list-view output: %s" % (out)) + + def _get_target_groups(self): + """Gets list of target groups from host.""" + (out, _err) = self._run_ssh("pfexec /usr/sbin/stmfadm list-tg") + matches = _get_prefixed_values(out, 'Target group: ') + LOG.debug("target_groups=%s" % matches) + return matches + + def _target_group_exists(self, target_group_name): + return target_group_name not in self._get_target_groups() + + def _get_target_group_members(self, target_group_name): + (out, _err) = self._run_ssh("pfexec /usr/sbin/stmfadm list-tg -v %s" % + (target_group_name)) + matches = _get_prefixed_values(out, 'Member: ') + LOG.debug("members of %s=%s" % (target_group_name, matches)) + return matches + + def _is_target_group_member(self, target_group_name, iscsi_target_name): + return iscsi_target_name in ( + self._get_target_group_members(target_group_name)) + + def _get_iscsi_targets(self): + cmd = ("pfexec /usr/sbin/itadm list-target | " + "awk '{print $1}' | grep -v ^TARGET") + (out, _err) = self._run_ssh(cmd) + matches = _collect_lines(out) + LOG.debug("_get_iscsi_targets=%s" % (matches)) + return matches + + def _iscsi_target_exists(self, iscsi_target_name): + return iscsi_target_name in self._get_iscsi_targets() + + def _build_zfs_poolname(self, volume): + #TODO(justinsb): rpool should be configurable + zfs_poolname = 'rpool/%s' % (volume['name']) + return zfs_poolname + + def create_volume(self, volume): + """Creates a volume.""" + if int(volume['size']) == 0: + sizestr = '100M' + else: + sizestr = '%sG' % volume['size'] + + zfs_poolname = self._build_zfs_poolname(volume) + + thin_provision_arg = '-s' if FLAGS.san_thin_provision else '' + # Create a zfs volume + self._run_ssh("pfexec /usr/sbin/zfs create %s -V %s %s" % + (thin_provision_arg, + sizestr, + zfs_poolname)) + + def _get_luid(self, volume): + zfs_poolname = self._build_zfs_poolname(volume) + + cmd = ("pfexec /usr/sbin/sbdadm list-lu | " + "grep -w %s | awk '{print $1}'" % + (zfs_poolname)) + + (stdout, _stderr) = self._run_ssh(cmd) + + luid = stdout.strip() + return luid + + def _is_lu_created(self, volume): + luid = self._get_luid(volume) + return luid + + def delete_volume(self, volume): + """Deletes a volume.""" + zfs_poolname = self._build_zfs_poolname(volume) + self._run_ssh("pfexec /usr/sbin/zfs destroy %s" % + (zfs_poolname)) + + def local_path(self, volume): + # TODO(justinsb): Is this needed here? + escaped_group = FLAGS.volume_group.replace('-', '--') + escaped_name = volume['name'].replace('-', '--') + return "/dev/mapper/%s-%s" % (escaped_group, escaped_name) + + def ensure_export(self, context, volume): + """Synchronously recreates an export for a logical volume.""" + #TODO(justinsb): On bootup, this is called for every volume. + # It then runs ~5 SSH commands for each volume, + # most of which fetch the same info each time + # This makes initial start stupid-slow + self._do_export(volume, force_create=False) + + def create_export(self, context, volume): + self._do_export(volume, force_create=True) + + def _do_export(self, volume, force_create): + # Create a Logical Unit (LU) backed by the zfs volume + zfs_poolname = self._build_zfs_poolname(volume) + + if force_create or not self._is_lu_created(volume): + cmd = ("pfexec /usr/sbin/sbdadm create-lu /dev/zvol/rdsk/%s" % + (zfs_poolname)) + self._run_ssh(cmd) + + luid = self._get_luid(volume) + iscsi_name = self._build_iscsi_target_name(volume) + target_group_name = 'tg-%s' % volume['name'] + + # Create a iSCSI target, mapped to just this volume + if force_create or not self._target_group_exists(target_group_name): + self._run_ssh("pfexec /usr/sbin/stmfadm create-tg %s" % + (target_group_name)) + + # Yes, we add the initiatior before we create it! + # Otherwise, it complains that the target is already active + if force_create or not self._is_target_group_member(target_group_name, + iscsi_name): + self._run_ssh("pfexec /usr/sbin/stmfadm add-tg-member -g %s %s" % + (target_group_name, iscsi_name)) + if force_create or not self._iscsi_target_exists(iscsi_name): + self._run_ssh("pfexec /usr/sbin/itadm create-target -n %s" % + (iscsi_name)) + if force_create or not self._view_exists(luid): + self._run_ssh("pfexec /usr/sbin/stmfadm add-view -t %s %s" % + (target_group_name, luid)) + + #TODO(justinsb): Is this always 1? Does it matter? + iscsi_portal_interface = '1' + iscsi_portal = FLAGS.san_ip + ":3260," + iscsi_portal_interface + + db_update = {} + db_update['provider_location'] = ("%s %s" % + (iscsi_portal, + iscsi_name)) + + return db_update + + def remove_export(self, context, volume): + """Removes an export for a logical volume.""" + + # This is the reverse of _do_export + luid = self._get_luid(volume) + iscsi_name = self._build_iscsi_target_name(volume) + target_group_name = 'tg-%s' % volume['name'] + + if self._view_exists(luid): + self._run_ssh("pfexec /usr/sbin/stmfadm remove-view -l %s -a" % + (luid)) + + if self._iscsi_target_exists(iscsi_name): + self._run_ssh("pfexec /usr/sbin/stmfadm offline-target %s" % + (iscsi_name)) + self._run_ssh("pfexec /usr/sbin/itadm delete-target %s" % + (iscsi_name)) + + # We don't delete the tg-member; we delete the whole tg! + + if self._target_group_exists(target_group_name): + self._run_ssh("pfexec /usr/sbin/stmfadm delete-tg %s" % + (target_group_name)) + + if self._is_lu_created(volume): + self._run_ssh("pfexec /usr/sbin/sbdadm delete-lu %s" % + (luid)) + + +class HpSanISCSIDriver(SanISCSIDriver): + """Executes commands relating to HP/Lefthand SAN ISCSI volumes. + + We use the CLIQ interface, over SSH. + + Rough overview of CLIQ commands used: + + :createVolume: (creates the volume) + + :getVolumeInfo: (to discover the IQN etc) + + :getClusterInfo: (to discover the iSCSI target IP address) + + :assignVolumeChap: (exports it with CHAP security) + + The 'trick' here is that the HP SAN enforces security by default, so + normally a volume mount would need both to configure the SAN in the volume + layer and do the mount on the compute layer. Multi-layer operations are + not catered for at the moment in the nova architecture, so instead we + share the volume using CHAP at volume creation time. Then the mount need + only use those CHAP credentials, so can take place exclusively in the + compute layer. + """ + + def _cliq_run(self, verb, cliq_args): + """Runs a CLIQ command over SSH, without doing any result parsing""" + cliq_arg_strings = [] + for k, v in cliq_args.items(): + cliq_arg_strings.append(" %s=%s" % (k, v)) + cmd = verb + ''.join(cliq_arg_strings) + + return self._run_ssh(cmd) + + def _cliq_run_xml(self, verb, cliq_args, check_cliq_result=True): + """Runs a CLIQ command over SSH, parsing and checking the output""" + cliq_args['output'] = 'XML' + (out, _err) = self._cliq_run(verb, cliq_args) + + LOG.debug(_("CLIQ command returned %s"), out) + + result_xml = ElementTree.fromstring(out) + if check_cliq_result: + response_node = result_xml.find("response") + if response_node is None: + msg = (_("Malformed response to CLIQ command " + "%(verb)s %(cliq_args)s. Result=%(out)s") % + locals()) + raise exception.Error(msg) + + result_code = response_node.attrib.get("result") + + if result_code != "0": + msg = (_("Error running CLIQ command %(verb)s %(cliq_args)s. " + " Result=%(out)s") % + locals()) + raise exception.Error(msg) + + return result_xml + + def _cliq_get_cluster_info(self, cluster_name): + """Queries for info about the cluster (including IP)""" + cliq_args = {} + cliq_args['clusterName'] = cluster_name + cliq_args['searchDepth'] = '1' + cliq_args['verbose'] = '0' + + result_xml = self._cliq_run_xml("getClusterInfo", cliq_args) + + return result_xml + + def _cliq_get_cluster_vip(self, cluster_name): + """Gets the IP on which a cluster shares iSCSI volumes""" + cluster_xml = self._cliq_get_cluster_info(cluster_name) + + vips = [] + for vip in cluster_xml.findall("response/cluster/vip"): + vips.append(vip.attrib.get('ipAddress')) + + if len(vips) == 1: + return vips[0] + + _xml = ElementTree.tostring(cluster_xml) + msg = (_("Unexpected number of virtual ips for cluster " + " %(cluster_name)s. Result=%(_xml)s") % + locals()) + raise exception.Error(msg) + + def _cliq_get_volume_info(self, volume_name): + """Gets the volume info, including IQN""" + cliq_args = {} + cliq_args['volumeName'] = volume_name + result_xml = self._cliq_run_xml("getVolumeInfo", cliq_args) + + # Result looks like this: + #<gauche version="1.0"> + # <response description="Operation succeeded." name="CliqSuccess" + # processingTime="87" result="0"> + # <volume autogrowPages="4" availability="online" blockSize="1024" + # bytesWritten="0" checkSum="false" clusterName="Cluster01" + # created="2011-02-08T19:56:53Z" deleting="false" description="" + # groupName="Group01" initialQuota="536870912" isPrimary="true" + # iscsiIqn="iqn.2003-10.com.lefthandnetworks:group01:25366:vol-b" + # maxSize="6865387257856" md5="9fa5c8b2cca54b2948a63d833097e1ca" + # minReplication="1" name="vol-b" parity="0" replication="2" + # reserveQuota="536870912" scratchQuota="4194304" + # serialNumber="9fa5c8b2cca54b2948a63d833097e1ca0000000000006316" + # size="1073741824" stridePages="32" thinProvision="true"> + # <status description="OK" value="2"/> + # <permission access="rw" + # authGroup="api-34281B815713B78-(trimmed)51ADD4B7030853AA7" + # chapName="chapusername" chapRequired="true" id="25369" + # initiatorSecret="" iqn="" iscsiEnabled="true" + # loadBalance="true" targetSecret="supersecret"/> + # </volume> + # </response> + #</gauche> + + # Flatten the nodes into a dictionary; use prefixes to avoid collisions + volume_attributes = {} + + volume_node = result_xml.find("response/volume") + for k, v in volume_node.attrib.items(): + volume_attributes["volume." + k] = v + + status_node = volume_node.find("status") + if not status_node is None: + for k, v in status_node.attrib.items(): + volume_attributes["status." + k] = v + + # We only consider the first permission node + permission_node = volume_node.find("permission") + if not permission_node is None: + for k, v in status_node.attrib.items(): + volume_attributes["permission." + k] = v + + LOG.debug(_("Volume info: %(volume_name)s => %(volume_attributes)s") % + locals()) + return volume_attributes + + def create_volume(self, volume): + """Creates a volume.""" + cliq_args = {} + cliq_args['clusterName'] = FLAGS.san_clustername + #TODO(justinsb): Should we default to inheriting thinProvision? + cliq_args['thinProvision'] = '1' if FLAGS.san_thin_provision else '0' + cliq_args['volumeName'] = volume['name'] + if int(volume['size']) == 0: + cliq_args['size'] = '100MB' + else: + cliq_args['size'] = '%sGB' % volume['size'] + + self._cliq_run_xml("createVolume", cliq_args) + + volume_info = self._cliq_get_volume_info(volume['name']) + cluster_name = volume_info['volume.clusterName'] + iscsi_iqn = volume_info['volume.iscsiIqn'] + + #TODO(justinsb): Is this always 1? Does it matter? + cluster_interface = '1' + + cluster_vip = self._cliq_get_cluster_vip(cluster_name) + iscsi_portal = cluster_vip + ":3260," + cluster_interface + + model_update = {} + model_update['provider_location'] = ("%s %s" % + (iscsi_portal, + iscsi_iqn)) + + return model_update + + def delete_volume(self, volume): + """Deletes a volume.""" + cliq_args = {} + cliq_args['volumeName'] = volume['name'] + cliq_args['prompt'] = 'false' # Don't confirm + + self._cliq_run_xml("deleteVolume", cliq_args) + + def local_path(self, volume): + # TODO(justinsb): Is this needed here? + raise exception.Error(_("local_path not supported")) + + def ensure_export(self, context, volume): + """Synchronously recreates an export for a logical volume.""" + return self._do_export(context, volume, force_create=False) + + def create_export(self, context, volume): + return self._do_export(context, volume, force_create=True) + + def _do_export(self, context, volume, force_create): + """Supports ensure_export and create_export""" + volume_info = self._cliq_get_volume_info(volume['name']) + + is_shared = 'permission.authGroup' in volume_info + + model_update = {} + + should_export = False + + if force_create or not is_shared: + should_export = True + # Check that we have a project_id + project_id = volume['project_id'] + if not project_id: + project_id = context.project_id + + if project_id: + #TODO(justinsb): Use a real per-project password here + chap_username = 'proj_' + project_id + # HP/Lefthand requires that the password be >= 12 characters + chap_password = 'project_secret_' + project_id + else: + msg = (_("Could not determine project for volume %s, " + "can't export") % + (volume['name'])) + if force_create: + raise exception.Error(msg) + else: + LOG.warn(msg) + should_export = False + + if should_export: + cliq_args = {} + cliq_args['volumeName'] = volume['name'] + cliq_args['chapName'] = chap_username + cliq_args['targetSecret'] = chap_password + + self._cliq_run_xml("assignVolumeChap", cliq_args) + + model_update['provider_auth'] = ("CHAP %s %s" % + (chap_username, chap_password)) + + return model_update + + def remove_export(self, context, volume): + """Removes an export for a logical volume.""" + cliq_args = {} + cliq_args['volumeName'] = volume['name'] + + self._cliq_run_xml("unassignVolume", cliq_args) diff --git a/nova/wsgi.py b/nova/wsgi.py index e01cc1e1e..2d18da8fb 100644 --- a/nova/wsgi.py +++ b/nova/wsgi.py @@ -36,6 +36,7 @@ import webob.exc from paste import deploy +from nova import exception from nova import flags from nova import log as logging from nova import utils @@ -59,7 +60,6 @@ class Server(object): """Server class to manage multiple WSGI sockets and applications.""" def __init__(self, threads=1000): - logging.basicConfig() self.pool = eventlet.GreenPool(threads) def start(self, application, port, host='0.0.0.0', backlog=128): @@ -83,6 +83,35 @@ class Server(object): log=WritableLogger(logger)) +class Request(webob.Request): + + def best_match_content_type(self): + """ + Determine the most acceptable content-type based on the + query extension then the Accept header + """ + + parts = self.path.rsplit(".", 1) + + if len(parts) > 1: + format = parts[1] + if format in ["json", "xml"]: + return "application/{0}".format(parts[1]) + + ctypes = ["application/json", "application/xml"] + bm = self.accept.best_match(ctypes) + + return bm or "application/json" + + def get_content_type(self): + try: + ct = self.headers["Content-Type"] + assert ct in ("application/xml", "application/json") + return ct + except Exception: + raise webob.exc.HTTPBadRequest("Invalid content type") + + class Application(object): """Base WSGI application wrapper. Subclasses need to implement __call__.""" @@ -114,7 +143,7 @@ class Application(object): def __call__(self, environ, start_response): r"""Subclasses will probably want to implement __call__ like this: - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=Request) def __call__(self, req): # Any of the following objects work as responses: @@ -200,7 +229,7 @@ class Middleware(Application): """Do whatever you'd like to the response.""" return response - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=Request) def __call__(self, req): response = self.process_request(req) if response: @@ -213,7 +242,7 @@ class Debug(Middleware): """Helper class that can be inserted into any WSGI application chain to get information about the request and response.""" - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=Request) def __call__(self, req): print ("*" * 40) + " REQUEST ENVIRON" for key, value in req.environ.items(): @@ -277,7 +306,7 @@ class Router(object): self._router = routes.middleware.RoutesMiddleware(self._dispatch, self.map) - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=Request) def __call__(self, req): """ Route the incoming request to a controller based on self.map. @@ -286,7 +315,7 @@ class Router(object): return self._router @staticmethod - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=Request) def _dispatch(req): """ Called by self._router after matching the incoming request to a route @@ -305,11 +334,11 @@ class Controller(object): WSGI app that reads routing information supplied by RoutesMiddleware and calls the requested action method upon itself. All action methods must, in addition to their normal parameters, accept a 'req' argument - which is the incoming webob.Request. They raise a webob.exc exception, + which is the incoming wsgi.Request. They raise a webob.exc exception, or return a dict which will be serialized by requested content type. """ - @webob.dec.wsgify + @webob.dec.wsgify(RequestClass=Request) def __call__(self, req): """ Call the method specified in req.environ by RoutesMiddleware. @@ -319,32 +348,45 @@ class Controller(object): method = getattr(self, action) del arg_dict['controller'] del arg_dict['action'] + if 'format' in arg_dict: + del arg_dict['format'] arg_dict['req'] = req result = method(**arg_dict) + if type(result) is dict: - return self._serialize(result, req) + content_type = req.best_match_content_type() + body = self._serialize(result, content_type) + + response = webob.Response() + response.headers["Content-Type"] = content_type + response.body = body + return response + else: return result - def _serialize(self, data, request): + def _serialize(self, data, content_type): """ - Serialize the given dict to the response type requested in request. + Serialize the given dict to the provided content_type. Uses self._serialization_metadata if it exists, which is a dict mapping MIME types to information needed to serialize to that type. """ _metadata = getattr(type(self), "_serialization_metadata", {}) - serializer = Serializer(request.environ, _metadata) - return serializer.to_content_type(data) + serializer = Serializer(_metadata) + try: + return serializer.serialize(data, content_type) + except exception.InvalidContentType: + raise webob.exc.HTTPNotAcceptable() - def _deserialize(self, data, request): + def _deserialize(self, data, content_type): """ - Deserialize the request body to the response type requested in request. + Deserialize the request body to the specefied content type. Uses self._serialization_metadata if it exists, which is a dict mapping MIME types to information needed to serialize to that type. """ _metadata = getattr(type(self), "_serialization_metadata", {}) - serializer = Serializer(request.environ, _metadata) - return serializer.deserialize(data) + serializer = Serializer(_metadata) + return serializer.deserialize(data, content_type) class Serializer(object): @@ -352,50 +394,52 @@ class Serializer(object): Serializes and deserializes dictionaries to certain MIME types. """ - def __init__(self, environ, metadata=None): + def __init__(self, metadata=None): """ Create a serializer based on the given WSGI environment. 'metadata' is an optional dict mapping MIME types to information needed to serialize a dictionary to that type. """ self.metadata = metadata or {} - req = webob.Request.blank('', environ) - suffix = req.path_info.split('.')[-1].lower() - if suffix == 'json': - self.handler = self._to_json - elif suffix == 'xml': - self.handler = self._to_xml - elif 'application/json' in req.accept: - self.handler = self._to_json - elif 'application/xml' in req.accept: - self.handler = self._to_xml - else: - # This is the default - self.handler = self._to_json - def to_content_type(self, data): - """ - Serialize a dictionary into a string. + def _get_serialize_handler(self, content_type): + handlers = { + "application/json": self._to_json, + "application/xml": self._to_xml, + } + + try: + return handlers[content_type] + except Exception: + raise exception.InvalidContentType() - The format of the string will be decided based on the Content Type - requested in self.environ: by Accept: header, or by URL suffix. + def serialize(self, data, content_type): """ - return self.handler(data) + Serialize a dictionary into a string of the specified content type. + """ + return self._get_serialize_handler(content_type)(data) - def deserialize(self, datastring): + def deserialize(self, datastring, content_type): """ Deserialize a string to a dictionary. The string must be in the format of a supported MIME type. """ - datastring = datastring.strip() + return self.get_deserialize_handler(content_type)(datastring) + + def get_deserialize_handler(self, content_type): + handlers = { + "application/json": self._from_json, + "application/xml": self._from_xml, + } + try: - is_xml = (datastring[0] == '<') - if not is_xml: - return utils.loads(datastring) - return self._from_xml(datastring) - except: - return None + return handlers[content_type] + except Exception: + raise exception.InvalidContentType() + + def _from_json(self, datastring): + return utils.loads(datastring) def _from_xml(self, datastring): xmldata = self.metadata.get('application/xml', {}) @@ -515,10 +559,3 @@ def load_paste_app(filename, appname): except LookupError: pass return app - - -def paste_config_to_flags(config, mixins): - for k, v in mixins.iteritems(): - value = config.get(k, v) - converted_value = FLAGS[k].parser.Parse(value) - setattr(FLAGS, k, converted_value) |
