diff options
| author | Soren Hansen <soren.hansen@rackspace.com> | 2010-10-02 11:12:46 +0200 |
|---|---|---|
| committer | Soren Hansen <soren.hansen@rackspace.com> | 2010-10-02 11:12:46 +0200 |
| commit | 3e27f5dfae379e70af023134cbab02e18b450ce1 (patch) | |
| tree | e54336386b9334fcf57561c2d6717cd7fc753ee2 /nova | |
| parent | 6a0bf3e048da0f7a0c0daf8e25167452cb86bf73 (diff) | |
| parent | 4d13a8554459638387d772a23fffe6aaaab3348d (diff) | |
Merge trunk.
Diffstat (limited to 'nova')
32 files changed, 1214 insertions, 482 deletions
diff --git a/nova/api/cloud.py b/nova/api/cloud.py new file mode 100644 index 000000000..345677d4f --- /dev/null +++ b/nova/api/cloud.py @@ -0,0 +1,42 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Methods for API calls to control instances via AMQP. +""" + + +from nova import db +from nova import flags +from nova import rpc + +FLAGS = flags.FLAGS + + +def reboot(instance_id, context=None): + """Reboot the given instance. + + #TODO(gundlach) not actually sure what context is used for by ec2 here + -- I think we can just remove it and use None all the time. + """ + instance_ref = db.instance_get_by_ec2_id(None, instance_id) + host = instance_ref['host'] + rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host), + {"method": "reboot_instance", + "args": {"context": None, + "instance_id": instance_ref['id']}}) diff --git a/nova/api/ec2/context.py b/nova/api/context.py index c53ba98d9..b66cfe468 100644 --- a/nova/api/ec2/context.py +++ b/nova/api/context.py @@ -31,3 +31,16 @@ class APIRequestContext(object): [random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-') for x in xrange(20)] ) + if user: + self.is_admin = user.is_admin() + else: + self.is_admin = False + self.read_deleted = False + + +def get_admin_context(user=None, read_deleted=False): + context_ref = APIRequestContext(user=user, project=None) + context_ref.is_admin = True + context_ref.read_deleted = read_deleted + return context_ref + diff --git a/nova/api/ec2/__init__.py b/nova/api/ec2/__init__.py index 7a958f841..6b538a7f1 100644 --- a/nova/api/ec2/__init__.py +++ b/nova/api/ec2/__init__.py @@ -27,8 +27,8 @@ import webob.exc from nova import exception from nova import flags from nova import wsgi +from nova.api import context from nova.api.ec2 import apirequest -from nova.api.ec2 import context from nova.api.ec2 import admin from nova.api.ec2 import cloud from nova.auth import manager @@ -193,15 +193,15 @@ class Authorizer(wsgi.Middleware): return True if 'none' in roles: return False - return any(context.project.has_role(context.user.id, role) + return any(context.project.has_role(context.user.id, role) for role in roles) - + class Executor(wsgi.Application): """Execute an EC2 API request. - Executes 'ec2.action' upon 'ec2.controller', passing 'ec2.context' and + Executes 'ec2.action' upon 'ec2.controller', passing 'ec2.context' and 'ec2.action_args' (all variables in WSGI environ.) Returns an XML response, or a 400 upon failure. """ diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index d3f54367b..79c95788b 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -36,6 +36,7 @@ from nova import quota from nova import rpc from nova import utils from nova.compute.instance_types import INSTANCE_TYPES +from nova.api import cloud from nova.api.ec2 import images @@ -684,12 +685,7 @@ class CloudController(object): def reboot_instances(self, context, instance_id, **kwargs): """instance_id is a list of instance ids""" for id_str in instance_id: - instance_ref = db.instance_get_by_ec2_id(context, id_str) - host = instance_ref['host'] - rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host), - {"method": "reboot_instance", - "args": {"context": None, - "instance_id": instance_ref['id']}}) + cloud.reboot(id_str, context=context) return True def update_instance(self, context, instance_id, **kwargs): diff --git a/nova/api/rackspace/__init__.py b/nova/api/rackspace/__init__.py index 98802663f..89a4693ad 100644 --- a/nova/api/rackspace/__init__.py +++ b/nova/api/rackspace/__init__.py @@ -31,6 +31,7 @@ import webob from nova import flags from nova import utils from nova import wsgi +from nova.api.rackspace import faults from nova.api.rackspace import backup_schedules from nova.api.rackspace import flavors from nova.api.rackspace import images @@ -67,7 +68,7 @@ class AuthMiddleware(wsgi.Middleware): user = self.auth_driver.authorize_token(req.headers["X-Auth-Token"]) if not user: - return webob.exc.HTTPUnauthorized() + return faults.Fault(webob.exc.HTTPUnauthorized()) if not req.environ.has_key('nova.context'): req.environ['nova.context'] = {} @@ -112,8 +113,10 @@ class RateLimitingMiddleware(wsgi.Middleware): delay = self.get_delay(action_name, username) if delay: # TODO(gundlach): Get the retry-after format correct. - raise webob.exc.HTTPRequestEntityTooLarge(headers={ - 'Retry-After': time.time() + delay}) + exc = webob.exc.HTTPRequestEntityTooLarge( + explanation='Too many requests.', + headers={'Retry-After': time.time() + delay}) + raise faults.Fault(exc) return self.application def get_delay(self, action_name, username): @@ -165,3 +168,23 @@ class APIRouter(wsgi.Router): controller=sharedipgroups.Controller()) super(APIRouter, self).__init__(mapper) + + +def limited(items, req): + """Return a slice of items according to requested offset and limit. + + items - a sliceable + req - wobob.Request possibly containing offset and limit GET variables. + offset is where to start in the list, and limit is the maximum number + of items to return. + + If limit is not specified, 0, or > 1000, defaults to 1000. + """ + offset = int(req.GET.get('offset', 0)) + limit = int(req.GET.get('limit', 0)) + if not limit: + limit = 1000 + limit = min(1000, limit) + range_end = offset + limit + return items[offset:range_end] + diff --git a/nova/api/rackspace/auth.py b/nova/api/rackspace/auth.py index 8bfb0753e..c45156ebd 100644 --- a/nova/api/rackspace/auth.py +++ b/nova/api/rackspace/auth.py @@ -11,6 +11,7 @@ from nova import db from nova import flags from nova import manager from nova import utils +from nova.api.rackspace import faults FLAGS = flags.FLAGS @@ -36,13 +37,13 @@ class BasicApiAuthManager(object): # honor it path_info = req.path_info if len(path_info) > 1: - return webob.exc.HTTPUnauthorized() + return faults.Fault(webob.exc.HTTPUnauthorized()) try: username, key = req.headers['X-Auth-User'], \ req.headers['X-Auth-Key'] except KeyError: - return webob.exc.HTTPUnauthorized() + return faults.Fault(webob.exc.HTTPUnauthorized()) username, key = req.headers['X-Auth-User'], req.headers['X-Auth-Key'] token, user = self._authorize_user(username, key) @@ -57,7 +58,7 @@ class BasicApiAuthManager(object): res.status = '204' return res else: - return webob.exc.HTTPUnauthorized() + return faults.Fault(webob.exc.HTTPUnauthorized()) def authorize_token(self, token_hash): """ retrieves user information from the datastore given a token diff --git a/nova/api/rackspace/backup_schedules.py b/nova/api/rackspace/backup_schedules.py index 46da778ee..cb83023bc 100644 --- a/nova/api/rackspace/backup_schedules.py +++ b/nova/api/rackspace/backup_schedules.py @@ -20,6 +20,7 @@ from webob import exc from nova import wsgi from nova.api.rackspace import _id_translator +from nova.api.rackspace import faults import nova.image.service class Controller(wsgi.Controller): @@ -27,12 +28,12 @@ class Controller(wsgi.Controller): pass def index(self, req, server_id): - return exc.HTTPNotFound() + return faults.Fault(exc.HTTPNotFound()) def create(self, req, server_id): """ No actual update method required, since the existing API allows both create and update through a POST """ - return exc.HTTPNotFound() + return faults.Fault(exc.HTTPNotFound()) def delete(self, req, server_id): - return exc.HTTPNotFound() + return faults.Fault(exc.HTTPNotFound()) diff --git a/nova/api/rackspace/context.py b/nova/api/rackspace/context.py new file mode 100644 index 000000000..77394615b --- /dev/null +++ b/nova/api/rackspace/context.py @@ -0,0 +1,33 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +APIRequestContext +""" + +import random + +class Project(object): + def __init__(self, user_id): + self.id = user_id + +class APIRequestContext(object): + """ This is an adapter class to get around all of the assumptions made in + the FlatNetworking """ + def __init__(self, user_id): + self.user_id = user_id + self.project = Project(user_id) diff --git a/nova/api/rackspace/faults.py b/nova/api/rackspace/faults.py new file mode 100644 index 000000000..32e5c866f --- /dev/null +++ b/nova/api/rackspace/faults.py @@ -0,0 +1,62 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import webob.dec +import webob.exc + +from nova import wsgi + + +class Fault(webob.exc.HTTPException): + + """An RS API fault response.""" + + _fault_names = { + 400: "badRequest", + 401: "unauthorized", + 403: "resizeNotAllowed", + 404: "itemNotFound", + 405: "badMethod", + 409: "inProgress", + 413: "overLimit", + 415: "badMediaType", + 501: "notImplemented", + 503: "serviceUnavailable"} + + def __init__(self, exception): + """Create a Fault for the given webob.exc.exception.""" + self.wrapped_exc = exception + + @webob.dec.wsgify + def __call__(self, req): + """Generate a WSGI response based on the exception passed to ctor.""" + # Replace the body with fault details. + code = self.wrapped_exc.status_int + fault_name = self._fault_names.get(code, "cloudServersFault") + fault_data = { + fault_name: { + 'code': code, + 'message': self.wrapped_exc.explanation}} + if code == 413: + retry = self.wrapped_exc.headers['Retry-After'] + fault_data[fault_name]['retryAfter'] = retry + # 'code' is an attribute on the fault tag itself + metadata = {'application/xml': {'attributes': {fault_name: 'code'}}} + serializer = wsgi.Serializer(req.environ, metadata) + self.wrapped_exc.body = serializer.to_content_type(fault_data) + return self.wrapped_exc diff --git a/nova/api/rackspace/flavors.py b/nova/api/rackspace/flavors.py index 3bcf170e5..916449854 100644 --- a/nova/api/rackspace/flavors.py +++ b/nova/api/rackspace/flavors.py @@ -15,9 +15,12 @@ # License for the specific language governing permissions and limitations # under the License. +from webob import exc + +from nova.api.rackspace import faults from nova.compute import instance_types from nova import wsgi -from webob import exc +import nova.api.rackspace class Controller(wsgi.Controller): """Flavor controller for the Rackspace API.""" @@ -38,6 +41,7 @@ class Controller(wsgi.Controller): def detail(self, req): """Return all flavors in detail.""" items = [self.show(req, id)['flavor'] for id in self._all_ids()] + items = nova.api.rackspace.limited(items, req) return dict(flavors=items) def show(self, req, id): @@ -47,7 +51,7 @@ class Controller(wsgi.Controller): item = dict(ram=val['memory_mb'], disk=val['local_gb'], id=val['flavorid'], name=name) return dict(flavor=item) - raise exc.HTTPNotFound() + raise faults.Fault(exc.HTTPNotFound()) def _all_ids(self): """Return the list of all flavorids.""" diff --git a/nova/api/rackspace/images.py b/nova/api/rackspace/images.py index 11b058dec..4a7dd489c 100644 --- a/nova/api/rackspace/images.py +++ b/nova/api/rackspace/images.py @@ -19,7 +19,9 @@ from webob import exc from nova import wsgi from nova.api.rackspace import _id_translator +import nova.api.rackspace import nova.image.service +from nova.api.rackspace import faults class Controller(wsgi.Controller): @@ -45,6 +47,7 @@ class Controller(wsgi.Controller): def detail(self, req): """Return all public images in detail.""" data = self._service.index() + data = nova.api.rackspace.limited(data, req) for img in data: img['id'] = self._id_translator.to_rs_id(img['id']) return dict(images=data) @@ -58,14 +61,14 @@ class Controller(wsgi.Controller): def delete(self, req, id): # Only public images are supported for now. - raise exc.HTTPNotFound() + raise faults.Fault(exc.HTTPNotFound()) def create(self, req): # Only public images are supported for now, so a request to # make a backup of a server cannot be supproted. - raise exc.HTTPNotFound() + raise faults.Fault(exc.HTTPNotFound()) def update(self, req, id): # Users may not modify public images, and that's all that # we support for now. - raise exc.HTTPNotFound() + raise faults.Fault(exc.HTTPNotFound()) diff --git a/nova/api/rackspace/servers.py b/nova/api/rackspace/servers.py index 4ab04bde7..11efd8aef 100644 --- a/nova/api/rackspace/servers.py +++ b/nova/api/rackspace/servers.py @@ -17,33 +17,45 @@ import time +import webob from webob import exc from nova import flags from nova import rpc from nova import utils from nova import wsgi +from nova.api import cloud from nova.api.rackspace import _id_translator +from nova.api.rackspace import context +from nova.api.rackspace import faults +from nova.compute import instance_types from nova.compute import power_state +import nova.api.rackspace import nova.image.service FLAGS = flags.FLAGS +flags.DEFINE_string('rs_network_manager', 'nova.network.manager.FlatManager', + 'Networking for rackspace') +def _instance_id_translator(): + """ Helper method for initializing an id translator for Rackspace instance + ids """ + return _id_translator.RackspaceAPIIdTranslator( "instance", 'nova') -def translator_instance(): +def _image_service(): """ Helper method for initializing the image id translator """ service = nova.image.service.ImageService.load() - return _id_translator.RackspaceAPIIdTranslator( - "image", service.__class__.__name__) + return (service, _id_translator.RackspaceAPIIdTranslator( + "image", service.__class__.__name__)) def _filter_params(inst_dict): """ Extracts all updatable parameters for a server update request """ - keys = ['name', 'adminPass'] + keys = dict(name='name', admin_pass='adminPass') new_attrs = {} - for k in keys: - if inst_dict.has_key(k): - new_attrs[k] = inst_dict[k] + for k, v in keys.items(): + if inst_dict.has_key(v): + new_attrs[k] = inst_dict[v] return new_attrs def _entity_list(entities): @@ -82,7 +94,6 @@ def _entity_inst(inst): class Controller(wsgi.Controller): """ The Server API controller for the Openstack API """ - _serialization_metadata = { 'application/xml': { @@ -101,42 +112,58 @@ class Controller(wsgi.Controller): def index(self, req): """ Returns a list of server names and ids for a given user """ - user_id = req.environ['nova.context']['user']['id'] - instance_list = self.db_driver.instance_get_all_by_user(None, user_id) - res = [_entity_inst(inst)['server'] for inst in instance_list] - return _entity_list(res) + return self._items(req, entity_maker=_entity_inst) def detail(self, req): """ Returns a list of server details for a given user """ + return self._items(req, entity_maker=_entity_detail) + + def _items(self, req, entity_maker): + """Returns a list of servers for a given user. + + entity_maker - either _entity_detail or _entity_inst + """ user_id = req.environ['nova.context']['user']['id'] - res = [_entity_detail(inst)['server'] for inst in - self.db_driver.instance_get_all_by_user(None, user_id)] + instance_list = self.db_driver.instance_get_all_by_user(None, user_id) + limited_list = nova.api.rackspace.limited(instance_list, req) + res = [entity_maker(inst)['server'] for inst in limited_list] return _entity_list(res) def show(self, req, id): """ Returns server details by server id """ + inst_id_trans = _instance_id_translator() + inst_id = inst_id_trans.from_rs_id(id) + user_id = req.environ['nova.context']['user']['id'] - inst = self.db_driver.instance_get(None, id) + inst = self.db_driver.instance_get_by_ec2_id(None, inst_id) if inst: if inst.user_id == user_id: return _entity_detail(inst) - raise exc.HTTPNotFound() + raise faults.Fault(exc.HTTPNotFound()) def delete(self, req, id): """ Destroys a server """ + inst_id_trans = _instance_id_translator() + inst_id = inst_id_trans.from_rs_id(id) + user_id = req.environ['nova.context']['user']['id'] - instance = self.db_driver.instance_get(None, id) + instance = self.db_driver.instance_get_by_ec2_id(None, inst_id) if instance and instance['user_id'] == user_id: self.db_driver.instance_destroy(None, id) - return exc.HTTPAccepted() - return exc.HTTPNotFound() + return faults.Fault(exc.HTTPAccepted()) + return faults.Fault(exc.HTTPNotFound()) def create(self, req): """ Creates a new server for a given user """ - if not req.environ.has_key('inst_dict'): - return exc.HTTPUnprocessableEntity() - inst = self._build_server_instance(req) + env = self._deserialize(req.body, req) + if not env: + return faults.Fault(exc.HTTPUnprocessableEntity()) + + try: + inst = self._build_server_instance(req, env) + except Exception, e: + return faults.Fault(exc.HTTPUnprocessableEntity()) rpc.cast( FLAGS.compute_topic, { @@ -146,62 +173,127 @@ class Controller(wsgi.Controller): def update(self, req, id): """ Updates the server name or password """ - if not req.environ.has_key('inst_dict'): - return exc.HTTPUnprocessableEntity() + inst_id_trans = _instance_id_translator() + inst_id = inst_id_trans.from_rs_id(id) + user_id = req.environ['nova.context']['user']['id'] - instance = self.db_driver.instance_get(None, id) - if not instance: - return exc.HTTPNotFound() + inst_dict = self._deserialize(req.body, req) + + if not inst_dict: + return faults.Fault(exc.HTTPUnprocessableEntity()) - attrs = req.environ['nova.context'].get('model_attributes', None) - if attrs: - self.db_driver.instance_update(None, id, _filter_params(attrs)) - return exc.HTTPNoContent() + instance = self.db_driver.instance_get_by_ec2_id(None, inst_id) + if not instance or instance.user_id != user_id: + return faults.Fault(exc.HTTPNotFound()) + + self.db_driver.instance_update(None, id, + _filter_params(inst_dict['server'])) + return faults.Fault(exc.HTTPNoContent()) def action(self, req, id): """ multi-purpose method used to reboot, rebuild, and resize a server """ - if not req.environ.has_key('inst_dict'): - return exc.HTTPUnprocessableEntity() - - def _build_server_instance(self, req): + input_dict = self._deserialize(req.body, req) + try: + reboot_type = input_dict['reboot']['type'] + except Exception: + raise faults.Fault(webob.exc.HTTPNotImplemented()) + opaque_id = _instance_id_translator().from_rs_id(id) + cloud.reboot(opaque_id) + + def _build_server_instance(self, req, env): """Build instance data structure and save it to the data store.""" ltime = time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime()) inst = {} - env = req.environ['inst_dict'] + inst_id_trans = _instance_id_translator() + + user_id = req.environ['nova.context']['user']['id'] + + flavor_id = env['server']['flavorId'] + + instance_type, flavor = [(k, v) for k, v in + instance_types.INSTANCE_TYPES.iteritems() + if v['flavorid'] == flavor_id][0] image_id = env['server']['imageId'] - opaque_id = translator_instance().from_rs_id(image_id) + + img_service, image_id_trans = _image_service() - inst['name'] = env['server']['server_name'] - inst['image_id'] = opaque_id - inst['instance_type'] = env['server']['flavorId'] + opaque_image_id = image_id_trans.to_rs_id(image_id) + image = img_service.show(opaque_image_id) - user_id = req.environ['nova.context']['user']['id'] - inst['user_id'] = user_id + if not image: + raise Exception, "Image not found" + inst['server_name'] = env['server']['name'] + inst['image_id'] = opaque_image_id + inst['user_id'] = user_id inst['launch_time'] = ltime inst['mac_address'] = utils.generate_mac() + inst['project_id'] = user_id - inst['project_id'] = env['project']['id'] - inst['reservation_id'] = reservation - reservation = utils.generate_uid('r') + inst['state_description'] = 'scheduling' + inst['kernel_id'] = image.get('kernelId', FLAGS.default_kernel) + inst['ramdisk_id'] = image.get('ramdiskId', FLAGS.default_ramdisk) + inst['reservation_id'] = utils.generate_uid('r') - address = self.network.allocate_ip( - inst['user_id'], - inst['project_id'], - mac=inst['mac_address']) + inst['display_name'] = env['server']['name'] + inst['display_description'] = env['server']['name'] - inst['private_dns_name'] = str(address) - inst['bridge_name'] = network.BridgedNetwork.get_network_for_project( - inst['user_id'], - inst['project_id'], - 'default')['bridge_name'] + #TODO(dietz) this may be ill advised + key_pair_ref = self.db_driver.key_pair_get_all_by_user( + None, user_id)[0] + + inst['key_data'] = key_pair_ref['public_key'] + inst['key_name'] = key_pair_ref['name'] + + #TODO(dietz) stolen from ec2 api, see TODO there + inst['security_group'] = 'default' + + # Flavor related attributes + inst['instance_type'] = instance_type + inst['memory_mb'] = flavor['memory_mb'] + inst['vcpus'] = flavor['vcpus'] + inst['local_gb'] = flavor['local_gb'] ref = self.db_driver.instance_create(None, inst) - inst['id'] = ref.id + inst['id'] = inst_id_trans.to_rs_id(ref.ec2_id) + # TODO(dietz): this isn't explicitly necessary, but the networking + # calls depend on an object with a project_id property, and therefore + # should be cleaned up later + api_context = context.APIRequestContext(user_id) + + inst['mac_address'] = utils.generate_mac() + + #TODO(dietz) is this necessary? + inst['launch_index'] = 0 + + inst['hostname'] = ref.ec2_id + self.db_driver.instance_update(None, inst['id'], inst) + + network_manager = utils.import_object(FLAGS.rs_network_manager) + address = network_manager.allocate_fixed_ip(api_context, + inst['id']) + + # TODO(vish): This probably should be done in the scheduler + # network is setup when host is assigned + network_topic = self._get_network_topic(user_id) + rpc.call(network_topic, + {"method": "setup_fixed_ip", + "args": {"context": None, + "address": address}}) return inst - + def _get_network_topic(self, user_id): + """Retrieves the network host for a project""" + network_ref = self.db_driver.project_get_network(None, + user_id) + host = network_ref['host'] + if not host: + host = rpc.call(FLAGS.network_topic, + {"method": "set_network_host", + "args": {"context": None, + "project_id": user_id}}) + return self.db_driver.queue_get_for(None, FLAGS.network_topic, host) diff --git a/nova/db/api.py b/nova/db/api.py index c94e4239d..fb449ca57 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -175,11 +175,6 @@ def floating_ip_get_by_address(context, address): return IMPL.floating_ip_get_by_address(context, address) -def floating_ip_get_instance(context, address): - """Get an instance for a floating ip by address.""" - return IMPL.floating_ip_get_instance(context, address) - - #################### @@ -209,6 +204,11 @@ def fixed_ip_disassociate(context, address): return IMPL.fixed_ip_disassociate(context, address) +def fixed_ip_disassociate_all_by_timeout(context, host, time): + """Disassociate old fixed ips from host""" + return IMPL.fixed_ip_disassociate_all_by_timeout(context, host, time) + + def fixed_ip_get_by_address(context, address): """Get a fixed ip by address or raise if it does not exist.""" return IMPL.fixed_ip_get_by_address(context, address) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 6633bfa6f..9d304f127 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -19,7 +19,7 @@ Implementation of SQLAlchemy backend """ -import sys +import warnings from nova import db from nova import exception @@ -29,38 +29,108 @@ from nova.db.sqlalchemy import models from nova.db.sqlalchemy.session import get_session from sqlalchemy import or_ from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import joinedload_all +from sqlalchemy.orm import joinedload, joinedload_all from sqlalchemy.sql import exists, func FLAGS = flags.FLAGS -# NOTE(vish): disabling docstring pylint because the docstrings are -# in the interface definition -# pylint: disable-msg=C0111 -def _deleted(context): - """Calculates whether to include deleted objects based on context. +def is_admin_context(context): + """Indicates if the request context is an administrator.""" + if not context: + warnings.warn('Use of empty request context is deprecated', + DeprecationWarning) + return True + return context.is_admin - Currently just looks for a flag called deleted in the context dict. + +def is_user_context(context): + """Indicates if the request context is a normal user.""" + if not context: + return False + if not context.user or not context.project: + return False + return True + + +def authorize_project_context(context, project_id): + """Ensures that the request context has permission to access the + given project. """ - if not hasattr(context, 'get'): + if is_user_context(context): + if not context.project: + raise exception.NotAuthorized() + elif context.project.id != project_id: + raise exception.NotAuthorized() + + +def authorize_user_context(context, user_id): + """Ensures that the request context has permission to access the + given user. + """ + if is_user_context(context): + if not context.user: + raise exception.NotAuthorized() + elif context.user.id != user_id: + raise exception.NotAuthorized() + + +def can_read_deleted(context): + """Indicates if the context has access to deleted objects.""" + if not context: return False - return context.get('deleted', False) + return context.read_deleted -################### +def require_admin_context(f): + """Decorator used to indicate that the method requires an + administrator context. + """ + def wrapper(*args, **kwargs): + if not is_admin_context(args[0]): + raise exception.NotAuthorized() + return f(*args, **kwargs) + return wrapper + +def require_context(f): + """Decorator used to indicate that the method requires either + an administrator or normal user context. + """ + def wrapper(*args, **kwargs): + if not is_admin_context(args[0]) and not is_user_context(args[0]): + raise exception.NotAuthorized() + return f(*args, **kwargs) + return wrapper + + +################### +@require_admin_context def service_destroy(context, service_id): session = get_session() with session.begin(): - service_ref = models.Service.find(service_id, session=session) + service_ref = service_get(context, service_id, session=session) service_ref.delete(session=session) -def service_get(_context, service_id): - return models.Service.find(service_id) +@require_admin_context +def service_get(context, service_id, session=None): + if not session: + session = get_session() + + result = session.query(models.Service + ).filter_by(id=service_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + if not result: + raise exception.NotFound('No service for id %s' % service_id) + + return result + + +@require_admin_context def service_get_all_by_topic(context, topic): session = get_session() return session.query(models.Service @@ -70,7 +140,8 @@ def service_get_all_by_topic(context, topic): ).all() -def _service_get_all_topic_subquery(_context, session, topic, subq, label): +@require_admin_context +def _service_get_all_topic_subquery(context, session, topic, subq, label): sort_value = getattr(subq.c, label) return session.query(models.Service, func.coalesce(sort_value, 0) ).filter_by(topic=topic @@ -81,6 +152,7 @@ def _service_get_all_topic_subquery(_context, session, topic, subq, label): ).all() +@require_admin_context def service_get_all_compute_sorted(context): session = get_session() with session.begin(): @@ -105,6 +177,7 @@ def service_get_all_compute_sorted(context): label) +@require_admin_context def service_get_all_network_sorted(context): session = get_session() with session.begin(): @@ -122,6 +195,7 @@ def service_get_all_network_sorted(context): label) +@require_admin_context def service_get_all_volume_sorted(context): session = get_session() with session.begin(): @@ -139,11 +213,22 @@ def service_get_all_volume_sorted(context): label) -def service_get_by_args(_context, host, binary): - return models.Service.find_by_args(host, binary) +@require_admin_context +def service_get_by_args(context, host, binary): + session = get_session() + result = session.query(models.Service + ).filter_by(host=host + ).filter_by(binary=binary + ).filter_by(deleted=can_read_deleted(context) + ).first() + if not result: + raise exception.NotFound('No service for %s, %s' % (host, binary)) + + return result -def service_create(_context, values): +@require_admin_context +def service_create(context, values): service_ref = models.Service() for (key, value) in values.iteritems(): service_ref[key] = value @@ -151,10 +236,11 @@ def service_create(_context, values): return service_ref -def service_update(_context, service_id, values): +@require_admin_context +def service_update(context, service_id, values): session = get_session() with session.begin(): - service_ref = models.Service.find(service_id, session=session) + service_ref = session_get(context, service_id, session=session) for (key, value) in values.iteritems(): service_ref[key] = value service_ref.save(session=session) @@ -163,7 +249,9 @@ def service_update(_context, service_id, values): ################### -def floating_ip_allocate_address(_context, host, project_id): +@require_context +def floating_ip_allocate_address(context, host, project_id): + authorize_project_context(context, project_id) session = get_session() with session.begin(): floating_ip_ref = session.query(models.FloatingIp @@ -182,7 +270,8 @@ def floating_ip_allocate_address(_context, host, project_id): return floating_ip_ref['address'] -def floating_ip_create(_context, values): +@require_context +def floating_ip_create(context, values): floating_ip_ref = models.FloatingIp() for (key, value) in values.iteritems(): floating_ip_ref[key] = value @@ -190,7 +279,9 @@ def floating_ip_create(_context, values): return floating_ip_ref['address'] -def floating_ip_count_by_project(_context, project_id): +@require_context +def floating_ip_count_by_project(context, project_id): + authorize_project_context(context, project_id) session = get_session() return session.query(models.FloatingIp ).filter_by(project_id=project_id @@ -198,39 +289,53 @@ def floating_ip_count_by_project(_context, project_id): ).count() -def floating_ip_fixed_ip_associate(_context, floating_address, fixed_address): +@require_context +def floating_ip_fixed_ip_associate(context, floating_address, fixed_address): session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(floating_address, - session=session) - fixed_ip_ref = models.FixedIp.find_by_str(fixed_address, - session=session) + # TODO(devcamcar): How to ensure floating_id belongs to user? + floating_ip_ref = floating_ip_get_by_address(context, + floating_address, + session=session) + fixed_ip_ref = fixed_ip_get_by_address(context, + fixed_address, + session=session) floating_ip_ref.fixed_ip = fixed_ip_ref floating_ip_ref.save(session=session) -def floating_ip_deallocate(_context, address): +@require_context +def floating_ip_deallocate(context, address): session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) + # TODO(devcamcar): How to ensure floating id belongs to user? + floating_ip_ref = floating_ip_get_by_address(context, + address, + session=session) floating_ip_ref['project_id'] = None floating_ip_ref.save(session=session) -def floating_ip_destroy(_context, address): +@require_context +def floating_ip_destroy(context, address): session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) + # TODO(devcamcar): Ensure address belongs to user. + floating_ip_ref = get_floating_ip_by_address(context, + address, + session=session) floating_ip_ref.delete(session=session) -def floating_ip_disassociate(_context, address): +@require_context +def floating_ip_disassociate(context, address): session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) + # TODO(devcamcar): Ensure address belongs to user. + # Does get_floating_ip_by_address handle this? + floating_ip_ref = floating_ip_get_by_address(context, + address, + session=session) fixed_ip_ref = floating_ip_ref.fixed_ip if fixed_ip_ref: fixed_ip_address = fixed_ip_ref['address'] @@ -241,7 +346,8 @@ def floating_ip_disassociate(_context, address): return fixed_ip_address -def floating_ip_get_all(_context): +@require_admin_context +def floating_ip_get_all(context): session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -249,7 +355,8 @@ def floating_ip_get_all(_context): ).all() -def floating_ip_get_all_by_host(_context, host): +@require_admin_context +def floating_ip_get_all_by_host(context, host): session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -257,7 +364,10 @@ def floating_ip_get_all_by_host(_context, host): ).filter_by(deleted=False ).all() -def floating_ip_get_all_by_project(_context, project_id): + +@require_context +def floating_ip_get_all_by_project(context, project_id): + authorize_project_context(context, project_id) session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -265,24 +375,31 @@ def floating_ip_get_all_by_project(_context, project_id): ).filter_by(deleted=False ).all() -def floating_ip_get_by_address(_context, address): - return models.FloatingIp.find_by_str(address) +@require_context +def floating_ip_get_by_address(context, address, session=None): + # TODO(devcamcar): Ensure the address belongs to user. + if not session: + session = get_session() + + result = session.query(models.FloatingIp + ).filter_by(address=address + ).filter_by(deleted=can_read_deleted(context) + ).first() + if not result: + raise exception.NotFound('No fixed ip for address %s' % address) -def floating_ip_get_instance(_context, address): - session = get_session() - with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) - return floating_ip_ref.fixed_ip.instance + return result ################### -def fixed_ip_associate(_context, address, instance_id): +@require_context +def fixed_ip_associate(context, address, instance_id): session = get_session() with session.begin(): + instance = instance_get(context, instance_id, session=session) fixed_ip_ref = session.query(models.FixedIp ).filter_by(address=address ).filter_by(deleted=False @@ -293,12 +410,12 @@ def fixed_ip_associate(_context, address, instance_id): # then this has concurrency issues if not fixed_ip_ref: raise db.NoMoreAddresses() - fixed_ip_ref.instance = models.Instance.find(instance_id, - session=session) + fixed_ip_ref.instance = instance session.add(fixed_ip_ref) -def fixed_ip_associate_pool(_context, network_id, instance_id): +@require_admin_context +def fixed_ip_associate_pool(context, network_id, instance_id): session = get_session() with session.begin(): network_or_none = or_(models.FixedIp.network_id == network_id, @@ -315,14 +432,17 @@ def fixed_ip_associate_pool(_context, network_id, instance_id): if not fixed_ip_ref: raise db.NoMoreAddresses() if not fixed_ip_ref.network: - fixed_ip_ref.network = models.Network.find(network_id, - session=session) - fixed_ip_ref.instance = models.Instance.find(instance_id, - session=session) + fixed_ip_ref.network = network_get(context, + network_id, + session=session) + fixed_ip_ref.instance = instance_get(context, + instance_id, + session=session) session.add(fixed_ip_ref) return fixed_ip_ref['address'] +@require_context def fixed_ip_create(_context, values): fixed_ip_ref = models.FixedIp() for (key, value) in values.iteritems(): @@ -331,44 +451,72 @@ def fixed_ip_create(_context, values): return fixed_ip_ref['address'] -def fixed_ip_disassociate(_context, address): +@require_context +def fixed_ip_disassociate(context, address): session = get_session() with session.begin(): - fixed_ip_ref = models.FixedIp.find_by_str(address, session=session) + fixed_ip_ref = fixed_ip_get_by_address(context, + address, + session=session) fixed_ip_ref.instance = None fixed_ip_ref.save(session=session) -def fixed_ip_get_by_address(_context, address): +@require_admin_context +def fixed_ip_disassociate_all_by_timeout(_context, host, time): session = get_session() - with session.begin(): - try: - return session.query(models.FixedIp - ).options(joinedload_all('instance') - ).filter_by(address=address - ).filter_by(deleted=False - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for address %s" % address) - raise new_exc.__class__, new_exc, sys.exc_info()[2] + # NOTE(vish): The nested select is because sqlite doesn't support + # JOINs in UPDATEs. + result = session.execute('UPDATE fixed_ips SET instance_id = NULL, ' + 'leased = 0 ' + 'WHERE network_id IN (SELECT id FROM networks ' + 'WHERE host = :host) ' + 'AND updated_at < :time ' + 'AND instance_id IS NOT NULL ' + 'AND allocated = 0', + {'host': host, + 'time': time.isoformat()}) + return result.rowcount + + +@require_context +def fixed_ip_get_by_address(context, address, session=None): + if not session: + session = get_session() + result = session.query(models.FixedIp + ).filter_by(address=address + ).filter_by(deleted=can_read_deleted(context) + ).options(joinedload('network') + ).options(joinedload('instance') + ).first() + if not result: + raise exception.NotFound('No floating ip for address %s' % address) + if is_user_context(context): + authorize_project_context(context, result.instance.project_id) -def fixed_ip_get_instance(_context, address): - session = get_session() - with session.begin(): - return models.FixedIp.find_by_str(address, session=session).instance + return result -def fixed_ip_get_network(_context, address): - session = get_session() - with session.begin(): - return models.FixedIp.find_by_str(address, session=session).network +@require_context +def fixed_ip_get_instance(context, address): + fixed_ip_ref = fixed_ip_get_by_address(context, address) + return fixed_ip_ref.instance + + +@require_admin_context +def fixed_ip_get_network(context, address): + fixed_ip_ref = fixed_ip_get_by_address(context, address) + return fixed_ip_ref.network -def fixed_ip_update(_context, address, values): +@require_context +def fixed_ip_update(context, address, values): session = get_session() with session.begin(): - fixed_ip_ref = models.FixedIp.find_by_str(address, session=session) + fixed_ip_ref = fixed_ip_get_by_address(context, + address, + session=session) for (key, value) in values.iteritems(): fixed_ip_ref[key] = value fixed_ip_ref.save(session=session) @@ -377,7 +525,8 @@ def fixed_ip_update(_context, address, values): ################### -def instance_create(_context, values): +@require_context +def instance_create(context, values): instance_ref = models.Instance() for (key, value) in values.iteritems(): instance_ref[key] = value @@ -386,12 +535,14 @@ def instance_create(_context, values): with session.begin(): while instance_ref.ec2_id == None: ec2_id = utils.generate_uid(instance_ref.__prefix__) - if not instance_ec2_id_exists(_context, ec2_id, session=session): + if not instance_ec2_id_exists(context, ec2_id, session=session): instance_ref.ec2_id = ec2_id instance_ref.save(session=session) return instance_ref -def instance_data_get_for_project(_context, project_id): + +@require_admin_context +def instance_data_get_for_project(context, project_id): session = get_session() result = session.query(func.count(models.Instance.id), func.sum(models.Instance.vcpus) @@ -402,81 +553,130 @@ def instance_data_get_for_project(_context, project_id): return (result[0] or 0, result[1] or 0) -def instance_destroy(_context, instance_id): +@require_context +def instance_destroy(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) instance_ref.delete(session=session) -def instance_get(context, instance_id): - return models.Instance.find(instance_id, deleted=_deleted(context)) +@require_context +def instance_get(context, instance_id, session=None): + if not session: + session = get_session() + result = None + + if is_admin_context(context): + result = session.query(models.Instance + ).filter_by(id=instance_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Instance + ).filter_by(project_id=context.project.id + ).filter_by(id=instance_id + ).filter_by(deleted=False + ).first() + if not result: + raise exception.NotFound('No instance for id %s' % instance_id) + + return result +@require_admin_context def instance_get_all(context): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() + +@require_admin_context def instance_get_all_by_user(context, user_id): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).filter_by(user_id=user_id ).all() + +@require_context def instance_get_all_by_project(context, project_id): + authorize_project_context(context, project_id) + session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).filter_by(project_id=project_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() -def instance_get_all_by_reservation(_context, reservation_id): +@require_context +def instance_get_all_by_reservation(context, reservation_id): session = get_session() - return session.query(models.Instance - ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(reservation_id=reservation_id - ).filter_by(deleted=False - ).all() + + if is_admin_context(context): + return session.query(models.Instance + ).options(joinedload_all('fixed_ip.floating_ips') + ).filter_by(reservation_id=reservation_id + ).filter_by(deleted=can_read_deleted(context) + ).all() + elif is_user_context(context): + return session.query(models.Instance + ).options(joinedload_all('fixed_ip.floating_ips') + ).filter_by(project_id=context.project.id + ).filter_by(reservation_id=reservation_id + ).filter_by(deleted=False + ).all() +@require_context def instance_get_by_ec2_id(context, ec2_id): session = get_session() - instance_ref = session.query(models.Instance + + if is_admin_context(context): + result = session.query(models.Instance + ).filter_by(ec2_id=ec2_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Instance + ).filter_by(project_id=context.project.id ).filter_by(ec2_id=ec2_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=False ).first() - if not instance_ref: + if not result: raise exception.NotFound('Instance %s not found' % (ec2_id)) - return instance_ref + return result +@require_context def instance_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() return session.query(exists().where(models.Instance.id==ec2_id)).one()[0] -def instance_get_fixed_address(_context, instance_id): +@require_context +def instance_get_fixed_address(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) if not instance_ref.fixed_ip: return None return instance_ref.fixed_ip['address'] -def instance_get_floating_address(_context, instance_id): +@require_context +def instance_get_floating_address(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) if not instance_ref.fixed_ip: return None if not instance_ref.fixed_ip.floating_ips: @@ -485,12 +685,14 @@ def instance_get_floating_address(_context, instance_id): return instance_ref.fixed_ip.floating_ips[0]['address'] +@require_admin_context def instance_is_vpn(context, instance_id): # TODO(vish): Move this into image code somewhere instance_ref = instance_get(context, instance_id) return instance_ref['image_id'] == FLAGS.vpn_image_id +@require_admin_context def instance_set_state(context, instance_id, state, description=None): # TODO(devcamcar): Move this out of models and into driver from nova.compute import power_state @@ -502,10 +704,11 @@ def instance_set_state(context, instance_id, state, description=None): 'state_description': description}) -def instance_update(_context, instance_id, values): +@require_context +def instance_update(context, instance_id, values): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) for (key, value) in values.iteritems(): instance_ref[key] = value instance_ref.save(session=session) @@ -514,7 +717,8 @@ def instance_update(_context, instance_id, values): ################### -def key_pair_create(_context, values): +@require_context +def key_pair_create(context, values): key_pair_ref = models.KeyPair() for (key, value) in values.iteritems(): key_pair_ref[key] = value @@ -522,16 +726,18 @@ def key_pair_create(_context, values): return key_pair_ref -def key_pair_destroy(_context, user_id, name): +@require_context +def key_pair_destroy(context, user_id, name): + authorize_user_context(context, user_id) session = get_session() with session.begin(): - key_pair_ref = models.KeyPair.find_by_args(user_id, - name, - session=session) + key_pair_ref = key_pair_get(context, user_id, name, session=session) key_pair_ref.delete(session=session) -def key_pair_destroy_all_by_user(_context, user_id): +@require_context +def key_pair_destroy_all_by_user(context, user_id): + authorize_user_context(context, user_id) session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -539,11 +745,27 @@ def key_pair_destroy_all_by_user(_context, user_id): {'id': user_id}) -def key_pair_get(_context, user_id, name): - return models.KeyPair.find_by_args(user_id, name) +@require_context +def key_pair_get(context, user_id, name, session=None): + authorize_user_context(context, user_id) + if not session: + session = get_session() -def key_pair_get_all_by_user(_context, user_id): + result = session.query(models.KeyPair + ).filter_by(user_id=user_id + ).filter_by(name=name + ).filter_by(deleted=can_read_deleted(context) + ).first() + if not result: + raise exception.NotFound('no keypair for user %s, name %s' % + (user_id, name)) + return result + + +@require_context +def key_pair_get_all_by_user(context, user_id): + authorize_user_context(context, user_id) session = get_session() return session.query(models.KeyPair ).filter_by(user_id=user_id @@ -554,11 +776,16 @@ def key_pair_get_all_by_user(_context, user_id): ################### -def network_count(_context): - return models.Network.count() +@require_admin_context +def network_count(context): + session = get_session() + return session.query(models.Network + ).filter_by(deleted=can_read_deleted(context) + ).count() -def network_count_allocated_ips(_context, network_id): +@require_admin_context +def network_count_allocated_ips(context, network_id): session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -567,7 +794,8 @@ def network_count_allocated_ips(_context, network_id): ).count() -def network_count_available_ips(_context, network_id): +@require_admin_context +def network_count_available_ips(context, network_id): session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -577,7 +805,8 @@ def network_count_available_ips(_context, network_id): ).count() -def network_count_reserved_ips(_context, network_id): +@require_admin_context +def network_count_reserved_ips(context, network_id): session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -586,7 +815,8 @@ def network_count_reserved_ips(_context, network_id): ).count() -def network_create(_context, values): +@require_admin_context +def network_create(context, values): network_ref = models.Network() for (key, value) in values.iteritems(): network_ref[key] = value @@ -594,7 +824,8 @@ def network_create(_context, values): return network_ref -def network_destroy(_context, network_id): +@require_admin_context +def network_destroy(context, network_id): session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -612,14 +843,34 @@ def network_destroy(_context, network_id): {'id': network_id}) -def network_get(_context, network_id): - return models.Network.find(network_id) +@require_context +def network_get(context, network_id, session=None): + if not session: + session = get_session() + result = None + + if is_admin_context(context): + result = session.query(models.Network + ).filter_by(id=network_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Network + ).filter_by(project_id=context.project.id + ).filter_by(id=network_id + ).filter_by(deleted=False + ).first() + if not result: + raise exception.NotFound('No network for id %s' % network_id) + + return result # NOTE(vish): pylint complains because of the long method name, but # it fits with the names of the rest of the methods # pylint: disable-msg=C0103 -def network_get_associated_fixed_ips(_context, network_id): +@require_admin_context +def network_get_associated_fixed_ips(context, network_id): session = get_session() return session.query(models.FixedIp ).options(joinedload_all('instance') @@ -629,18 +880,22 @@ def network_get_associated_fixed_ips(_context, network_id): ).all() -def network_get_by_bridge(_context, bridge): +@require_admin_context +def network_get_by_bridge(context, bridge): session = get_session() - rv = session.query(models.Network + result = session.query(models.Network ).filter_by(bridge=bridge ).filter_by(deleted=False ).first() - if not rv: + + if not result: raise exception.NotFound('No network for bridge %s' % bridge) - return rv + + return result -def network_get_index(_context, network_id): +@require_admin_context +def network_get_index(context, network_id): session = get_session() with session.begin(): network_index = session.query(models.NetworkIndex @@ -648,19 +903,28 @@ def network_get_index(_context, network_id): ).filter_by(deleted=False ).with_lockmode('update' ).first() + if not network_index: raise db.NoMoreNetworks() - network_index['network'] = models.Network.find(network_id, - session=session) + + network_index['network'] = network_get(context, + network_id, + session=session) session.add(network_index) + return network_index['index'] -def network_index_count(_context): - return models.NetworkIndex.count() +@require_admin_context +def network_index_count(context): + session = get_session() + return session.query(models.NetworkIndex + ).filter_by(deleted=can_read_deleted(context) + ).count() -def network_index_create_safe(_context, values): +@require_admin_context +def network_index_create_safe(context, values): network_index_ref = models.NetworkIndex() for (key, value) in values.iteritems(): network_index_ref[key] = value @@ -670,29 +934,32 @@ def network_index_create_safe(_context, values): pass -def network_set_host(_context, network_id, host_id): +@require_admin_context +def network_set_host(context, network_id, host_id): session = get_session() with session.begin(): - network = session.query(models.Network - ).filter_by(id=network_id - ).filter_by(deleted=False - ).with_lockmode('update' - ).first() - if not network: - raise exception.NotFound("Couldn't find network with %s" % - network_id) + network_ref = session.query(models.Network + ).filter_by(id=network_id + ).filter_by(deleted=False + ).with_lockmode('update' + ).first() + if not network_ref: + raise exception.NotFound('No network for id %s' % network_id) + # NOTE(vish): if with_lockmode isn't supported, as in sqlite, # then this has concurrency issues - if not network['host']: - network['host'] = host_id - session.add(network) - return network['host'] + if not network_ref['host']: + network_ref['host'] = host_id + session.add(network_ref) + return network_ref['host'] -def network_update(_context, network_id, values): + +@require_context +def network_update(context, network_id, values): session = get_session() with session.begin(): - network_ref = models.Network.find(network_id, session=session) + network_ref = network_get(context, network_id, session=session) for (key, value) in values.iteritems(): network_ref[key] = value network_ref.save(session=session) @@ -701,15 +968,18 @@ def network_update(_context, network_id, values): ################### -def project_get_network(_context, project_id): +@require_context +def project_get_network(context, project_id): session = get_session() - rv = session.query(models.Network + result= session.query(models.Network ).filter_by(project_id=project_id ).filter_by(deleted=False ).first() - if not rv: + + if not result: raise exception.NotFound('No network for project: %s' % project_id) - return rv + + return result ################### @@ -719,14 +989,20 @@ def queue_get_for(_context, topic, physical_node_id): # FIXME(ja): this should be servername? return "%s.%s" % (topic, physical_node_id) + ################### -def export_device_count(_context): - return models.ExportDevice.count() +@require_admin_context +def export_device_count(context): + session = get_session() + return session.query(models.ExportDevice + ).filter_by(deleted=can_read_deleted(context) + ).count() -def export_device_create(_context, values): +@require_admin_context +def export_device_create(context, values): export_device_ref = models.ExportDevice() for (key, value) in values.iteritems(): export_device_ref[key] = value @@ -760,7 +1036,23 @@ def auth_create_token(_context, token): ################### -def quota_create(_context, values): +@require_admin_context +def quota_get(context, project_id, session=None): + if not session: + session = get_session() + + result = session.query(models.Quota + ).filter_by(project_id=project_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + if not result: + raise exception.NotFound('No quota for project_id %s' % project_id) + + return result + + +@require_admin_context +def quota_create(context, values): quota_ref = models.Quota() for (key, value) in values.iteritems(): quota_ref[key] = value @@ -768,30 +1060,29 @@ def quota_create(_context, values): return quota_ref -def quota_get(_context, project_id): - return models.Quota.find_by_str(project_id) - - -def quota_update(_context, project_id, values): +@require_admin_context +def quota_update(context, project_id, values): session = get_session() with session.begin(): - quota_ref = models.Quota.find_by_str(project_id, session=session) + quota_ref = quota_get(context, project_id, session=session) for (key, value) in values.iteritems(): quota_ref[key] = value quota_ref.save(session=session) -def quota_destroy(_context, project_id): +@require_admin_context +def quota_destroy(context, project_id): session = get_session() with session.begin(): - quota_ref = models.Quota.find_by_str(project_id, session=session) + quota_ref = quota_get(context, project_id, session=session) quota_ref.delete(session=session) ################### -def volume_allocate_shelf_and_blade(_context, volume_id): +@require_admin_context +def volume_allocate_shelf_and_blade(context, volume_id): session = get_session() with session.begin(): export_device = session.query(models.ExportDevice @@ -808,19 +1099,20 @@ def volume_allocate_shelf_and_blade(_context, volume_id): return (export_device.shelf_id, export_device.blade_id) -def volume_attached(_context, volume_id, instance_id, mountpoint): +@require_admin_context +def volume_attached(context, volume_id, instance_id, mountpoint): session = get_session() with session.begin(): - volume_ref = models.Volume.find(volume_id, session=session) + volume_ref = volume_get(context, volume_id, session=session) volume_ref['status'] = 'in-use' volume_ref['mountpoint'] = mountpoint volume_ref['attach_status'] = 'attached' - volume_ref.instance = models.Instance.find(instance_id, - session=session) + volume_ref.instance = instance_get(context, instance_id, session=session) volume_ref.save(session=session) -def volume_create(_context, values): +@require_context +def volume_create(context, values): volume_ref = models.Volume() for (key, value) in values.iteritems(): volume_ref[key] = value @@ -829,13 +1121,14 @@ def volume_create(_context, values): with session.begin(): while volume_ref.ec2_id == None: ec2_id = utils.generate_uid(volume_ref.__prefix__) - if not volume_ec2_id_exists(_context, ec2_id, session=session): + if not volume_ec2_id_exists(context, ec2_id, session=session): volume_ref.ec2_id = ec2_id volume_ref.save(session=session) return volume_ref -def volume_data_get_for_project(_context, project_id): +@require_admin_context +def volume_data_get_for_project(context, project_id): session = get_session() result = session.query(func.count(models.Volume.id), func.sum(models.Volume.size) @@ -846,7 +1139,8 @@ def volume_data_get_for_project(_context, project_id): return (result[0] or 0, result[1] or 0) -def volume_destroy(_context, volume_id): +@require_admin_context +def volume_destroy(context, volume_id): session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -857,10 +1151,11 @@ def volume_destroy(_context, volume_id): {'id': volume_id}) -def volume_detached(_context, volume_id): +@require_admin_context +def volume_detached(context, volume_id): session = get_session() with session.begin(): - volume_ref = models.Volume.find(volume_id, session=session) + volume_ref = volume_get(context, volume_id, session=session) volume_ref['status'] = 'available' volume_ref['mountpoint'] = None volume_ref['attach_status'] = 'detached' @@ -868,60 +1163,113 @@ def volume_detached(_context, volume_id): volume_ref.save(session=session) -def volume_get(context, volume_id): - return models.Volume.find(volume_id, deleted=_deleted(context)) +@require_context +def volume_get(context, volume_id, session=None): + if not session: + session = get_session() + result = None + + if is_admin_context(context): + result = session.query(models.Volume + ).filter_by(id=volume_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Volume + ).filter_by(project_id=context.project.id + ).filter_by(id=volume_id + ).filter_by(deleted=False + ).first() + if not result: + raise exception.NotFound('No volume for id %s' % volume_id) + return result -def volume_get_all(context): - return models.Volume.all(deleted=_deleted(context)) +@require_admin_context +def volume_get_all(context): + return session.query(models.Volume + ).filter_by(deleted=can_read_deleted(context) + ).all() +@require_context def volume_get_all_by_project(context, project_id): + authorize_project_context(context, project_id) + session = get_session() return session.query(models.Volume ).filter_by(project_id=project_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() +@require_context def volume_get_by_ec2_id(context, ec2_id): session = get_session() - volume_ref = session.query(models.Volume + result = None + + if is_admin_context(context): + result = session.query(models.Volume + ).filter_by(ec2_id=ec2_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Volume + ).filter_by(project_id=context.project.id ).filter_by(ec2_id=ec2_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=False ).first() - if not volume_ref: - raise exception.NotFound('Volume %s not found' % (ec2_id)) + else: + raise exception.NotAuthorized() - return volume_ref + if not result: + raise exception.NotFound('Volume %s not found' % ec2_id) + return result + +@require_context def volume_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() - return session.query(exists().where(models.Volume.id==ec2_id)).one()[0] + + return session.query(exists( + ).where(models.Volume.id==ec2_id) + ).one()[0] -def volume_get_instance(_context, volume_id): +@require_admin_context +def volume_get_instance(context, volume_id): session = get_session() - with session.begin(): - return models.Volume.find(volume_id, session=session).instance + result = session.query(models.Volume + ).filter_by(id=volume_id + ).filter_by(deleted=can_read_deleted(context) + ).options(joinedload('instance') + ).first() + if not result: + raise exception.NotFound('Volume %s not found' % ec2_id) + return result.instance -def volume_get_shelf_and_blade(_context, volume_id): + +@require_admin_context +def volume_get_shelf_and_blade(context, volume_id): session = get_session() - export_device = session.query(models.ExportDevice - ).filter_by(volume_id=volume_id - ).first() - if not export_device: - raise exception.NotFound() - return (export_device.shelf_id, export_device.blade_id) + result = session.query(models.ExportDevice + ).filter_by(volume_id=volume_id + ).first() + if not result: + raise exception.NotFound('No export device found for volume %s' % + volume_id) + + return (result.shelf_id, result.blade_id) -def volume_update(_context, volume_id, values): +@require_context +def volume_update(context, volume_id, values): session = get_session() with session.begin(): - volume_ref = models.Volume.find(volume_id, session=session) + volume_ref = volume_get(context, volume_id, session=session) for (key, value) in values.iteritems(): volume_ref[key] = value volume_ref.save(session=session) diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 01e58b05e..1837a7584 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -50,44 +50,6 @@ class NovaBase(object): deleted_at = Column(DateTime) deleted = Column(Boolean, default=False) - @classmethod - def all(cls, session=None, deleted=False): - """Get all objects of this type""" - if not session: - session = get_session() - return session.query(cls - ).filter_by(deleted=deleted - ).all() - - @classmethod - def count(cls, session=None, deleted=False): - """Count objects of this type""" - if not session: - session = get_session() - return session.query(cls - ).filter_by(deleted=deleted - ).count() - - @classmethod - def find(cls, obj_id, session=None, deleted=False): - """Find object by id""" - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(id=obj_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for id %s" % obj_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - """Find object by str_id""" - int_id = int(str_id.rpartition('-')[2]) - return cls.find(int_id, session=session, deleted=deleted) - @property def str_id(self): """Get string id of object (generally prefix + '-' + id)""" @@ -176,21 +138,6 @@ class Service(BASE, NovaBase): report_count = Column(Integer, nullable=False, default=0) disabled = Column(Boolean, default=False) - @classmethod - def find_by_args(cls, host, binary, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(host=host - ).filter_by(binary=binary - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for %s, %s" % (host, - binary)) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - class Instance(BASE, NovaBase): """Represents a guest vm""" @@ -199,6 +146,8 @@ class Instance(BASE, NovaBase): id = Column(Integer, primary_key=True) ec2_id = Column(String(10), unique=True) + admin_pass = Column(String(255)) + user_id = Column(String(255)) project_id = Column(String(255)) @@ -282,7 +231,11 @@ class Volume(BASE, NovaBase): size = Column(Integer) availability_zone = Column(String(255)) # TODO(vish): foreign key? instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True) - instance = relationship(Instance, backref=backref('volumes')) + instance = relationship(Instance, + backref=backref('volumes'), + foreign_keys=instance_id, + primaryjoin='and_(Volume.instance_id==Instance.id,' + 'Volume.deleted==False)') mountpoint = Column(String(255)) attach_time = Column(String(255)) # TODO(vish): datetime status = Column(String(255)) # TODO(vish): enum? @@ -313,18 +266,6 @@ class Quota(BASE, NovaBase): def str_id(self): return self.project_id - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(project_id=str_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for project_id %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] class ExportDevice(BASE, NovaBase): """Represates a shelf and blade that a volume can be exported on""" @@ -333,8 +274,11 @@ class ExportDevice(BASE, NovaBase): shelf_id = Column(Integer) blade_id = Column(Integer) volume_id = Column(Integer, ForeignKey('volumes.id'), nullable=True) - volume = relationship(Volume, backref=backref('export_device', - uselist=False)) + volume = relationship(Volume, + backref=backref('export_device', uselist=False), + foreign_keys=volume_id, + primaryjoin='and_(ExportDevice.volume_id==Volume.id,' + 'ExportDevice.deleted==False)') class KeyPair(BASE, NovaBase): @@ -352,26 +296,6 @@ class KeyPair(BASE, NovaBase): def str_id(self): return '%s.%s' % (self.user_id, self.name) - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - user_id, _sep, name = str_id.partition('.') - return cls.find_by_str(user_id, name, session, deleted) - - @classmethod - def find_by_args(cls, user_id, name, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(user_id=user_id - ).filter_by(name=name - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for user %s, name %s" % - (user_id, name)) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - class Network(BASE, NovaBase): """Represents a network""" @@ -407,8 +331,12 @@ class NetworkIndex(BASE, NovaBase): id = Column(Integer, primary_key=True) index = Column(Integer, unique=True) network_id = Column(Integer, ForeignKey('networks.id'), nullable=True) - network = relationship(Network, backref=backref('network_index', - uselist=False)) + network = relationship(Network, + backref=backref('network_index', uselist=False), + foreign_keys=network_id, + primaryjoin='and_(NetworkIndex.network_id==Network.id,' + 'NetworkIndex.deleted==False)') + class AuthToken(BASE, NovaBase): """Represents an authorization token for all API transactions. Fields @@ -432,8 +360,11 @@ class FixedIp(BASE, NovaBase): network_id = Column(Integer, ForeignKey('networks.id'), nullable=True) network = relationship(Network, backref=backref('fixed_ips')) instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True) - instance = relationship(Instance, backref=backref('fixed_ip', - uselist=False)) + instance = relationship(Instance, + backref=backref('fixed_ip', uselist=False), + foreign_keys=instance_id, + primaryjoin='and_(FixedIp.instance_id==Instance.id,' + 'FixedIp.deleted==False)') allocated = Column(Boolean, default=False) leased = Column(Boolean, default=False) reserved = Column(Boolean, default=False) @@ -442,19 +373,6 @@ class FixedIp(BASE, NovaBase): def str_id(self): return self.address - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(address=str_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for address %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - class FloatingIp(BASE, NovaBase): """Represents a floating ip that dynamically forwards to a fixed ip""" @@ -462,24 +380,14 @@ class FloatingIp(BASE, NovaBase): id = Column(Integer, primary_key=True) address = Column(String(255)) fixed_ip_id = Column(Integer, ForeignKey('fixed_ips.id'), nullable=True) - fixed_ip = relationship(FixedIp, backref=backref('floating_ips')) - + fixed_ip = relationship(FixedIp, + backref=backref('floating_ips'), + foreign_keys=fixed_ip_id, + primaryjoin='and_(FloatingIp.fixed_ip_id==FixedIp.id,' + 'FloatingIp.deleted==False)') project_id = Column(String(255)) host = Column(String(255)) # , ForeignKey('hosts.id')) - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(address=str_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for address %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - def register_models(): """Register Models and create metadata""" diff --git a/nova/manager.py b/nova/manager.py index 94e4ae959..56ba7d3f6 100644 --- a/nova/manager.py +++ b/nova/manager.py @@ -22,6 +22,7 @@ Base class for managers of different parts of the system from nova import utils from nova import flags +from twisted.internet import defer FLAGS = flags.FLAGS flags.DEFINE_string('db_driver', 'nova.db.api', @@ -38,6 +39,11 @@ class Manager(object): db_driver = FLAGS.db_driver self.db = utils.import_object(db_driver) # pylint: disable-msg=C0103 + @defer.inlineCallbacks + def periodic_tasks(self, context=None): + """Tasks to be run at a periodic interval""" + yield + def init_host(self): """Do any initialization that needs to be run if this is a standalone service. diff --git a/nova/network/manager.py b/nova/network/manager.py index 62133ae92..c77062389 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -20,10 +20,12 @@ Network Hosts are responsible for allocating ips and setting up network """ +import datetime import logging import math import IPy +from twisted.internet import defer from nova import db from nova import exception @@ -62,7 +64,9 @@ flags.DEFINE_integer('cnt_vpn_clients', 5, flags.DEFINE_string('network_driver', 'nova.network.linux_net', 'Driver to use for network creation') flags.DEFINE_bool('update_dhcp_on_disassociate', False, - 'Whether to update dhcp when fixed_ip is disassocated') + 'Whether to update dhcp when fixed_ip is disassociated') +flags.DEFINE_integer('fixed_ip_disassociate_timeout', 600, + 'Seconds after which a deallocated ip is disassociated') class AddressAlreadyAllocated(exception.Error): @@ -94,7 +98,7 @@ class NetworkManager(manager.Manager): # TODO(vish): can we minimize db access by just getting the # id here instead of the ref? network_id = network_ref['id'] - host = self.db.network_set_host(context, + host = self.db.network_set_host(None, network_id, self.host) self._on_set_network_host(context, network_id) @@ -225,6 +229,19 @@ class FlatManager(NetworkManager): class VlanManager(NetworkManager): """Vlan network with dhcp""" + @defer.inlineCallbacks + def periodic_tasks(self, context=None): + """Tasks to be run at a periodic interval""" + yield super(VlanManager, self).periodic_tasks(context) + now = datetime.datetime.utcnow() + timeout = FLAGS.fixed_ip_disassociate_timeout + time = now - datetime.timedelta(seconds=timeout) + num = self.db.fixed_ip_disassociate_all_by_timeout(self, + self.host, + time) + if num: + logging.debug("Dissassociated %s stale fixed ip(s)", num) + def init_host(self): """Do any initialization that needs to be run if this is a standalone service. @@ -239,7 +256,7 @@ class VlanManager(NetworkManager): address = network_ref['vpn_private_address'] self.db.fixed_ip_associate(context, address, instance_id) else: - address = self.db.fixed_ip_associate_pool(context, + address = self.db.fixed_ip_associate_pool(None, network_ref['id'], instance_id) self.db.fixed_ip_update(context, address, {'allocated': True}) @@ -249,14 +266,6 @@ class VlanManager(NetworkManager): """Returns a fixed ip to the pool""" self.db.fixed_ip_update(context, address, {'allocated': False}) fixed_ip_ref = self.db.fixed_ip_get_by_address(context, address) - if not fixed_ip_ref['leased']: - self.db.fixed_ip_disassociate(context, address) - # NOTE(vish): dhcp server isn't updated until next setup, this - # means there will stale entries in the conf file - # the code below will update the file if necessary - if FLAGS.update_dhcp_on_disassociate: - network_ref = self.db.fixed_ip_get_network(context, address) - self.driver.update_dhcp(context, network_ref['id']) def setup_fixed_ip(self, context, address): @@ -273,9 +282,6 @@ class VlanManager(NetworkManager): """Called by dhcp-bridge when ip is leased""" logging.debug("Leasing IP %s", address) fixed_ip_ref = self.db.fixed_ip_get_by_address(context, address) - if not fixed_ip_ref['allocated']: - logging.warn("IP %s leased that was already deallocated", address) - return instance_ref = fixed_ip_ref['instance'] if not instance_ref: raise exception.Error("IP %s leased that isn't associated" % @@ -286,14 +292,13 @@ class VlanManager(NetworkManager): self.db.fixed_ip_update(context, fixed_ip_ref['address'], {'leased': True}) + if not fixed_ip_ref['allocated']: + logging.warn("IP %s leased that was already deallocated", address) def release_fixed_ip(self, context, mac, address): """Called by dhcp-bridge when ip is released""" logging.debug("Releasing IP %s", address) fixed_ip_ref = self.db.fixed_ip_get_by_address(context, address) - if not fixed_ip_ref['leased']: - logging.warn("IP %s released that was not leased", address) - return instance_ref = fixed_ip_ref['instance'] if not instance_ref: raise exception.Error("IP %s released that isn't associated" % @@ -301,7 +306,11 @@ class VlanManager(NetworkManager): if instance_ref['mac_address'] != mac: raise exception.Error("IP %s released from bad mac %s vs %s" % (address, instance_ref['mac_address'], mac)) - self.db.fixed_ip_update(context, address, {'leased': False}) + if not fixed_ip_ref['leased']: + logging.warn("IP %s released that was not leased", address) + self.db.fixed_ip_update(context, + fixed_ip_ref['str_id'], + {'leased': False}) if not fixed_ip_ref['allocated']: self.db.fixed_ip_disassociate(context, address) # NOTE(vish): dhcp server isn't updated until next setup, this diff --git a/nova/objectstore/__init__.py b/nova/objectstore/__init__.py index b8890ac03..ecad9be7c 100644 --- a/nova/objectstore/__init__.py +++ b/nova/objectstore/__init__.py @@ -22,7 +22,7 @@ .. automodule:: nova.objectstore :platform: Unix - :synopsis: Currently a trivial file-based system, getting extended w/ mongo. + :synopsis: Currently a trivial file-based system, getting extended w/ swift. .. moduleauthor:: Jesse Andrews <jesse@ansolabs.com> .. moduleauthor:: Devin Carlen <devin.carlen@gmail.com> .. moduleauthor:: Vishvananda Ishaya <vishvananda@yahoo.com> diff --git a/nova/service.py b/nova/service.py index dcd2a09ef..a6c186896 100644 --- a/nova/service.py +++ b/nova/service.py @@ -37,7 +37,11 @@ from nova import utils FLAGS = flags.FLAGS flags.DEFINE_integer('report_interval', 10, - 'seconds between nodes reporting state to cloud', + 'seconds between nodes reporting state to datastore', + lower_bound=1) + +flags.DEFINE_integer('periodic_interval', 60, + 'seconds between running periodic tasks', lower_bound=1) @@ -81,7 +85,8 @@ class Service(object, service.Service): binary=None, topic=None, manager=None, - report_interval=None): + report_interval=None, + periodic_interval=None): """Instantiates class and passes back application object. Args: @@ -90,6 +95,7 @@ class Service(object, service.Service): topic, defaults to bin_name - "nova-" part manager, defaults to FLAGS.<topic>_manager report_interval, defaults to FLAGS.report_interval + periodic_interval, defaults to FLAGS.periodic_interval """ if not host: host = FLAGS.host @@ -101,6 +107,8 @@ class Service(object, service.Service): manager = FLAGS.get('%s_manager' % topic, None) if not report_interval: report_interval = FLAGS.report_interval + if not periodic_interval: + periodic_interval = FLAGS.periodic_interval logging.warn("Starting %s node", topic) service_obj = cls(host, binary, topic, manager) conn = rpc.Connection.instance() @@ -113,11 +121,14 @@ class Service(object, service.Service): topic='%s.%s' % (topic, host), proxy=service_obj) + consumer_all.attach_to_twisted() + consumer_node.attach_to_twisted() + pulse = task.LoopingCall(service_obj.report_state) pulse.start(interval=report_interval, now=False) - consumer_all.attach_to_twisted() - consumer_node.attach_to_twisted() + pulse = task.LoopingCall(service_obj.periodic_tasks) + pulse.start(interval=periodic_interval, now=False) # This is the parent service that twistd will be looking for when it # parses this file, return it so that we can get it into globals. @@ -133,6 +144,11 @@ class Service(object, service.Service): logging.warn("Service killed that has no database entry") @defer.inlineCallbacks + def periodic_tasks(self, context=None): + """Tasks to be run at a periodic interval""" + yield self.manager.periodic_tasks(context) + + @defer.inlineCallbacks def report_state(self, context=None): """Update the state of this service in the datastore.""" try: diff --git a/nova/tests/api/rackspace/__init__.py b/nova/tests/api/rackspace/__init__.py index 622cb4335..bfd0f87a7 100644 --- a/nova/tests/api/rackspace/__init__.py +++ b/nova/tests/api/rackspace/__init__.py @@ -17,6 +17,7 @@ import unittest +from nova.api.rackspace import limited from nova.api.rackspace import RateLimitingMiddleware from nova.tests.api.test_helper import * from webob import Request @@ -77,3 +78,31 @@ class RateLimitingMiddlewareTest(unittest.TestCase): self.assertEqual(middleware.limiter.__class__.__name__, "Limiter") middleware = RateLimitingMiddleware(APIStub(), service_host='foobar') self.assertEqual(middleware.limiter.__class__.__name__, "WSGIAppProxy") + + +class LimiterTest(unittest.TestCase): + + def testLimiter(self): + items = range(2000) + req = Request.blank('/') + self.assertEqual(limited(items, req), items[ :1000]) + req = Request.blank('/?offset=0') + self.assertEqual(limited(items, req), items[ :1000]) + req = Request.blank('/?offset=3') + self.assertEqual(limited(items, req), items[3:1003]) + req = Request.blank('/?offset=2005') + self.assertEqual(limited(items, req), []) + req = Request.blank('/?limit=10') + self.assertEqual(limited(items, req), items[ :10]) + req = Request.blank('/?limit=0') + self.assertEqual(limited(items, req), items[ :1000]) + req = Request.blank('/?limit=3000') + self.assertEqual(limited(items, req), items[ :1000]) + req = Request.blank('/?offset=1&limit=3') + self.assertEqual(limited(items, req), items[1:4]) + req = Request.blank('/?offset=3&limit=0') + self.assertEqual(limited(items, req), items[3:1003]) + req = Request.blank('/?offset=3&limit=1500') + self.assertEqual(limited(items, req), items[3:1003]) + req = Request.blank('/?offset=3000&limit=10') + self.assertEqual(limited(items, req), []) diff --git a/nova/tests/api/rackspace/auth.py b/nova/tests/api/rackspace/auth.py index a6e10970f..56677c2f4 100644 --- a/nova/tests/api/rackspace/auth.py +++ b/nova/tests/api/rackspace/auth.py @@ -1,12 +1,14 @@ -import webob -import webob.dec +import datetime import unittest + import stubout +import webob +import webob.dec + import nova.api import nova.api.rackspace.auth from nova import auth from nova.tests.api.rackspace import test_helper -import datetime class Test(unittest.TestCase): def setUp(self): diff --git a/nova/tests/api/rackspace/flavors.py b/nova/tests/api/rackspace/flavors.py index 7bd1ea1c4..d25a2e2be 100644 --- a/nova/tests/api/rackspace/flavors.py +++ b/nova/tests/api/rackspace/flavors.py @@ -38,7 +38,6 @@ class FlavorsTest(unittest.TestCase): def test_get_flavor_list(self): req = webob.Request.blank('/v1.0/flavors') res = req.get_response(nova.api.API()) - print res def test_get_flavor_by_id(self): pass diff --git a/nova/tests/api/rackspace/images.py b/nova/tests/api/rackspace/images.py index 560d8c898..4c9987e8b 100644 --- a/nova/tests/api/rackspace/images.py +++ b/nova/tests/api/rackspace/images.py @@ -15,6 +15,7 @@ # License for the specific language governing permissions and limitations # under the License. +import stubout import unittest from nova.api.rackspace import images diff --git a/nova/tests/api/rackspace/servers.py b/nova/tests/api/rackspace/servers.py index 9fd8e5e88..69ad2c1d3 100644 --- a/nova/tests/api/rackspace/servers.py +++ b/nova/tests/api/rackspace/servers.py @@ -26,6 +26,7 @@ import nova.api.rackspace from nova.api.rackspace import servers import nova.db.api from nova.db.sqlalchemy.models import Instance +import nova.rpc from nova.tests.api.test_helper import * from nova.tests.api.rackspace import test_helper @@ -52,8 +53,11 @@ class ServersTest(unittest.TestCase): test_helper.stub_for_testing(self.stubs) test_helper.stub_out_rate_limiting(self.stubs) test_helper.stub_out_auth(self.stubs) + test_helper.stub_out_id_translator(self.stubs) + test_helper.stub_out_key_pair_funcs(self.stubs) + test_helper.stub_out_image_service(self.stubs) self.stubs.Set(nova.db.api, 'instance_get_all', return_servers) - self.stubs.Set(nova.db.api, 'instance_get', return_server) + self.stubs.Set(nova.db.api, 'instance_get_by_ec2_id', return_server) self.stubs.Set(nova.db.api, 'instance_get_all_by_user', return_servers) @@ -67,9 +71,6 @@ class ServersTest(unittest.TestCase): self.assertEqual(res_dict['server']['id'], '1') self.assertEqual(res_dict['server']['name'], 'server1') - def test_get_backup_schedule(self): - pass - def test_get_server_list(self): req = webob.Request.blank('/v1.0/servers') res = req.get_response(nova.api.API()) @@ -82,24 +83,86 @@ class ServersTest(unittest.TestCase): self.assertEqual(s.get('imageId', None), None) i += 1 - #def test_create_instance(self): - # test_helper.stub_out_image_translator(self.stubs) - # body = dict(server=dict( - # name='server_test', imageId=2, flavorId=2, metadata={}, - # personality = {} - # )) - # req = webob.Request.blank('/v1.0/servers') - # req.method = 'POST' - # req.body = json.dumps(body) + def test_create_instance(self): + def server_update(context, id, params): + pass + + def instance_create(context, inst): + class Foo(object): + ec2_id = 1 + return Foo() + + def fake_method(*args, **kwargs): + pass + + def project_get_network(context, user_id): + return dict(id='1', host='localhost') + + def queue_get_for(context, *args): + return 'network_topic' + + self.stubs.Set(nova.db.api, 'project_get_network', project_get_network) + self.stubs.Set(nova.db.api, 'instance_create', instance_create) + self.stubs.Set(nova.rpc, 'cast', fake_method) + self.stubs.Set(nova.rpc, 'call', fake_method) + self.stubs.Set(nova.db.api, 'instance_update', + server_update) + self.stubs.Set(nova.db.api, 'queue_get_for', queue_get_for) + self.stubs.Set(nova.network.manager.FlatManager, 'allocate_fixed_ip', + fake_method) + + test_helper.stub_out_id_translator(self.stubs) + body = dict(server=dict( + name='server_test', imageId=2, flavorId=2, metadata={}, + personality = {} + )) + req = webob.Request.blank('/v1.0/servers') + req.method = 'POST' + req.body = json.dumps(body) + + res = req.get_response(nova.api.API()) + + self.assertEqual(res.status_int, 200) + + def test_update_no_body(self): + req = webob.Request.blank('/v1.0/servers/1') + req.method = 'PUT' + res = req.get_response(nova.api.API()) + self.assertEqual(res.status_int, 422) + + def test_update_bad_params(self): + """ Confirm that update is filtering params """ + inst_dict = dict(cat='leopard', name='server_test', adminPass='bacon') + self.body = json.dumps(dict(server=inst_dict)) - # res = req.get_response(nova.api.API()) + def server_update(context, id, params): + self.update_called = True + filtered_dict = dict(name='server_test', admin_pass='bacon') + self.assertEqual(params, filtered_dict) - # print res - def test_update_server_password(self): - pass + self.stubs.Set(nova.db.api, 'instance_update', + server_update) - def test_update_server_name(self): - pass + req = webob.Request.blank('/v1.0/servers/1') + req.method = 'PUT' + req.body = self.body + req.get_response(nova.api.API()) + + def test_update_server(self): + inst_dict = dict(name='server_test', adminPass='bacon') + self.body = json.dumps(dict(server=inst_dict)) + + def server_update(context, id, params): + filtered_dict = dict(name='server_test', admin_pass='bacon') + self.assertEqual(params, filtered_dict) + + self.stubs.Set(nova.db.api, 'instance_update', + server_update) + + req = webob.Request.blank('/v1.0/servers/1') + req.method = 'PUT' + req.body = self.body + req.get_response(nova.api.API()) def test_create_backup_schedules(self): req = webob.Request.blank('/v1.0/servers/1/backup_schedules') diff --git a/nova/tests/api/rackspace/sharedipgroups.py b/nova/tests/api/rackspace/sharedipgroups.py index b4b281db7..1906b54f5 100644 --- a/nova/tests/api/rackspace/sharedipgroups.py +++ b/nova/tests/api/rackspace/sharedipgroups.py @@ -15,6 +15,7 @@ # License for the specific language governing permissions and limitations # under the License. +import stubout import unittest from nova.api.rackspace import sharedipgroups diff --git a/nova/tests/api/rackspace/test_helper.py b/nova/tests/api/rackspace/test_helper.py index aa7fb382c..2cf154f63 100644 --- a/nova/tests/api/rackspace/test_helper.py +++ b/nova/tests/api/rackspace/test_helper.py @@ -9,6 +9,7 @@ from nova import utils from nova import flags import nova.api.rackspace.auth import nova.api.rackspace._id_translator +from nova.image import service from nova.wsgi import Router FLAGS = flags.FLAGS @@ -40,7 +41,19 @@ def fake_wsgi(self, req): req.environ['inst_dict'] = json.loads(req.body) return self.application -def stub_out_image_translator(stubs): +def stub_out_key_pair_funcs(stubs): + def key_pair(context, user_id): + return [dict(name='key', public_key='public_key')] + stubs.Set(nova.db.api, 'key_pair_get_all_by_user', + key_pair) + +def stub_out_image_service(stubs): + def fake_image_show(meh, id): + return dict(kernelId=1, ramdiskId=1) + + stubs.Set(nova.image.service.LocalImageService, 'show', fake_image_show) + +def stub_out_id_translator(stubs): class FakeTranslator(object): def __init__(self, id_type, service_name): pass diff --git a/nova/tests/api/rackspace/testfaults.py b/nova/tests/api/rackspace/testfaults.py new file mode 100644 index 000000000..b2931bc98 --- /dev/null +++ b/nova/tests/api/rackspace/testfaults.py @@ -0,0 +1,40 @@ +import unittest +import webob +import webob.dec +import webob.exc + +from nova.api.rackspace import faults + +class TestFaults(unittest.TestCase): + + def test_fault_parts(self): + req = webob.Request.blank('/.xml') + f = faults.Fault(webob.exc.HTTPBadRequest(explanation='scram')) + resp = req.get_response(f) + + first_two_words = resp.body.strip().split()[:2] + self.assertEqual(first_two_words, ['<badRequest', 'code="400">']) + body_without_spaces = ''.join(resp.body.split()) + self.assertTrue('<message>scram</message>' in body_without_spaces) + + def test_retry_header(self): + req = webob.Request.blank('/.xml') + exc = webob.exc.HTTPRequestEntityTooLarge(explanation='sorry', + headers={'Retry-After': 4}) + f = faults.Fault(exc) + resp = req.get_response(f) + first_two_words = resp.body.strip().split()[:2] + self.assertEqual(first_two_words, ['<overLimit', 'code="413">']) + body_sans_spaces = ''.join(resp.body.split()) + self.assertTrue('<message>sorry</message>' in body_sans_spaces) + self.assertTrue('<retryAfter>4</retryAfter>' in body_sans_spaces) + self.assertEqual(resp.headers['Retry-After'], 4) + + def test_raise(self): + @webob.dec.wsgify + def raiser(req): + raise faults.Fault(webob.exc.HTTPNotFound(explanation='whut?')) + req = webob.Request.blank('/.xml') + resp = req.get_response(raiser) + self.assertEqual(resp.status_int, 404) + self.assertTrue('whut?' in resp.body) diff --git a/nova/tests/compute_unittest.py b/nova/tests/compute_unittest.py index f5c0f1c09..1e2bb113b 100644 --- a/nova/tests/compute_unittest.py +++ b/nova/tests/compute_unittest.py @@ -30,7 +30,7 @@ from nova import flags from nova import test from nova import utils from nova.auth import manager - +from nova.api import context FLAGS = flags.FLAGS @@ -96,7 +96,9 @@ class ComputeTestCase(test.TrialTestCase): self.assertEqual(instance_ref['deleted_at'], None) terminate = datetime.datetime.utcnow() yield self.compute.terminate_instance(self.context, instance_id) - instance_ref = db.instance_get({'deleted': True}, instance_id) + self.context = context.get_admin_context(user=self.user, + read_deleted=True) + instance_ref = db.instance_get(self.context, instance_id) self.assert_(instance_ref['launched_at'] < terminate) self.assert_(instance_ref['deleted_at'] > terminate) diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index da65b50a2..5370966d2 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -56,12 +56,12 @@ class NetworkTestCase(test.TrialTestCase): 'netuser', name)) # create the necessary network data for the project - self.network.set_network_host(self.context, self.projects[i].id) - instance_ref = db.instance_create(None, - {'mac_address': utils.generate_mac()}) + user_context = context.APIRequestContext(project=self.projects[i], + user=self.user) + self.network.set_network_host(user_context, self.projects[i].id) + instance_ref = self._create_instance(0) self.instance_id = instance_ref['id'] - instance_ref = db.instance_create(None, - {'mac_address': utils.generate_mac()}) + instance_ref = self._create_instance(1) self.instance2_id = instance_ref['id'] def tearDown(self): # pylint: disable-msg=C0103 @@ -74,6 +74,15 @@ class NetworkTestCase(test.TrialTestCase): self.manager.delete_project(project) self.manager.delete_user(self.user) + def _create_instance(self, project_num, mac=None): + if not mac: + mac = utils.generate_mac() + project = self.projects[project_num] + self.context.project = project + return db.instance_create(self.context, + {'project_id': project.id, + 'mac_address': mac}) + def _create_address(self, project_num, instance_id=None): """Create an address in given project num""" if instance_id is None: @@ -81,9 +90,15 @@ class NetworkTestCase(test.TrialTestCase): self.context.project = self.projects[project_num] return self.network.allocate_fixed_ip(self.context, instance_id) + def _deallocate_address(self, project_num, address): + self.context.project = self.projects[project_num] + self.network.deallocate_fixed_ip(self.context, address) + + def test_public_network_association(self): """Makes sure that we can allocaate a public ip""" # TODO(vish): better way of adding floating ips + self.context.project = self.projects[0] pubnet = IPy.IP(flags.FLAGS.public_range) address = str(pubnet[0]) try: @@ -109,7 +124,7 @@ class NetworkTestCase(test.TrialTestCase): address = self._create_address(0) self.assertTrue(is_allocated_in_project(address, self.projects[0].id)) lease_ip(address) - self.network.deallocate_fixed_ip(self.context, address) + self._deallocate_address(0, address) # Doesn't go away until it's dhcp released self.assertTrue(is_allocated_in_project(address, self.projects[0].id)) @@ -130,14 +145,14 @@ class NetworkTestCase(test.TrialTestCase): lease_ip(address) lease_ip(address2) - self.network.deallocate_fixed_ip(self.context, address) + self._deallocate_address(0, address) release_ip(address) self.assertFalse(is_allocated_in_project(address, self.projects[0].id)) # First address release shouldn't affect the second self.assertTrue(is_allocated_in_project(address2, self.projects[1].id)) - self.network.deallocate_fixed_ip(self.context, address2) + self._deallocate_address(1, address2) release_ip(address2) self.assertFalse(is_allocated_in_project(address2, self.projects[1].id)) @@ -148,24 +163,19 @@ class NetworkTestCase(test.TrialTestCase): lease_ip(first) instance_ids = [] for i in range(1, 5): - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address = self._create_address(i, instance_ref['id']) - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address2 = self._create_address(i, instance_ref['id']) - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address3 = self._create_address(i, instance_ref['id']) lease_ip(address) lease_ip(address2) lease_ip(address3) + self.context.project = self.projects[i] self.assertFalse(is_allocated_in_project(address, self.projects[0].id)) self.assertFalse(is_allocated_in_project(address2, @@ -181,7 +191,7 @@ class NetworkTestCase(test.TrialTestCase): for instance_id in instance_ids: db.instance_destroy(None, instance_id) release_ip(first) - self.network.deallocate_fixed_ip(self.context, first) + self._deallocate_address(0, first) def test_vpn_ip_and_port_looks_valid(self): """Ensure the vpn ip and port are reasonable""" @@ -242,9 +252,7 @@ class NetworkTestCase(test.TrialTestCase): addresses = [] instance_ids = [] for i in range(num_available_ips): - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(0) instance_ids.append(instance_ref['id']) address = self._create_address(0, instance_ref['id']) addresses.append(address) diff --git a/nova/tests/service_unittest.py b/nova/tests/service_unittest.py index 01da0eb8a..06f80e82c 100644 --- a/nova/tests/service_unittest.py +++ b/nova/tests/service_unittest.py @@ -65,15 +65,20 @@ class ServiceTestCase(test.BaseTestCase): proxy=mox.IsA(service.Service)).AndReturn( rpc.AdapterConsumer) + rpc.AdapterConsumer.attach_to_twisted() + rpc.AdapterConsumer.attach_to_twisted() + # Stub out looping call a bit needlessly since we don't have an easy # way to cancel it (yet) when the tests finishes service.task.LoopingCall(mox.IgnoreArg()).AndReturn( service.task.LoopingCall) service.task.LoopingCall.start(interval=mox.IgnoreArg(), now=mox.IgnoreArg()) + service.task.LoopingCall(mox.IgnoreArg()).AndReturn( + service.task.LoopingCall) + service.task.LoopingCall.start(interval=mox.IgnoreArg(), + now=mox.IgnoreArg()) - rpc.AdapterConsumer.attach_to_twisted() - rpc.AdapterConsumer.attach_to_twisted() service_create = {'host': host, 'binary': binary, 'topic': topic, diff --git a/nova/virt/xenapi.py b/nova/virt/xenapi.py index 1c6de4403..0d06b1fce 100644 --- a/nova/virt/xenapi.py +++ b/nova/virt/xenapi.py @@ -103,8 +103,8 @@ class XenAPIConnection(object): self._conn.login_with_password(user, pw) def list_instances(self): - result = [self._conn.xenapi.VM.get_name_label(vm) \ - for vm in self._conn.xenapi.VM.get_all()] + return [self._conn.xenapi.VM.get_name_label(vm) \ + for vm in self._conn.xenapi.VM.get_all()] @defer.inlineCallbacks def spawn(self, instance): diff --git a/nova/wsgi.py b/nova/wsgi.py index da9374542..b91d91121 100644 --- a/nova/wsgi.py +++ b/nova/wsgi.py @@ -230,6 +230,15 @@ class Controller(object): serializer = Serializer(request.environ, _metadata) return serializer.to_content_type(data) + def _deserialize(self, data, request): + """ + Deserialize the request body to the response type requested in request. + Uses self._serialization_metadata if it exists, which is a dict mapping + MIME types to information needed to serialize to that type. + """ + _metadata = getattr(type(self), "_serialization_metadata", {}) + serializer = Serializer(request.environ, _metadata) + return serializer.deserialize(data) class Serializer(object): """ @@ -272,10 +281,13 @@ class Serializer(object): The string must be in the format of a supported MIME type. """ datastring = datastring.strip() - is_xml = (datastring[0] == '<') - if not is_xml: - return json.loads(datastring) - return self._from_xml(datastring) + try: + is_xml = (datastring[0] == '<') + if not is_xml: + return json.loads(datastring) + return self._from_xml(datastring) + except: + return None def _from_xml(self, datastring): xmldata = self.metadata.get('application/xml', {}) |
