diff options
| author | Kevin L. Mitchell <kevin.mitchell@rackspace.com> | 2012-01-13 14:13:59 -0600 |
|---|---|---|
| committer | Kevin L. Mitchell <kevin.mitchell@rackspace.com> | 2012-01-13 14:13:59 -0600 |
| commit | 1d4e35be694884a0ea8e586ffb2d06ecd6c48685 (patch) | |
| tree | 66e310d4ade4e53a66195d1e4c11cf1ab3168935 | |
| parent | 6c898e6abf44caa176790e9cd4505aeed145397c (diff) | |
| download | nova-1d4e35be694884a0ea8e586ffb2d06ecd6c48685.tar.gz nova-1d4e35be694884a0ea8e586ffb2d06ecd6c48685.tar.xz nova-1d4e35be694884a0ea8e586ffb2d06ecd6c48685.zip | |
Refactor request and action extensions.
The goal of this refactoring is to eventually eliminate
ExtensionMiddleware and LazySerializationMiddleware completely, by
executing extensions directly within the processing done by
Resource.__call__(). This patch implements the infrastructure
required to perform this extension processing.
Implements blueprint extension-refactor.
Change-Id: I23398fc906a9a105de354a8133337ecfc69a3ad3
| -rw-r--r-- | nova/api/openstack/compute/__init__.py | 56 | ||||
| -rw-r--r-- | nova/api/openstack/compute/servers.py | 101 | ||||
| -rw-r--r-- | nova/api/openstack/extensions.py | 35 | ||||
| -rw-r--r-- | nova/api/openstack/volume/__init__.py | 43 | ||||
| -rw-r--r-- | nova/api/openstack/wsgi.py | 311 | ||||
| -rw-r--r-- | nova/tests/api/openstack/compute/test_extensions.py | 135 | ||||
| -rw-r--r-- | nova/tests/api/openstack/compute/test_server_actions.py | 115 | ||||
| -rw-r--r-- | nova/tests/api/openstack/test_wsgi.py | 347 |
8 files changed, 998 insertions, 145 deletions
diff --git a/nova/api/openstack/compute/__init__.py b/nova/api/openstack/compute/__init__.py index 2f6e92a42..69b8daab7 100644 --- a/nova/api/openstack/compute/__init__.py +++ b/nova/api/openstack/compute/__init__.py @@ -67,8 +67,10 @@ class APIRouter(base_wsgi.Router): ext_mgr = extensions.ExtensionManager() mapper = nova.api.openstack.ProjectMapper() + self.resources = {} self._setup_routes(mapper) self._setup_ext_routes(mapper, ext_mgr) + self._setup_extensions(ext_mgr) super(APIRouter, self).__init__(mapper) def _setup_ext_routes(self, mapper, ext_mgr): @@ -76,10 +78,12 @@ class APIRouter(base_wsgi.Router): LOG.debug(_('Extended resource: %s'), resource.collection) + wsgi_resource = wsgi.Resource( + resource.controller, resource.deserializer, + resource.serializer) + self.resources[resource.collection] = wsgi_resource kargs = dict( - controller=wsgi.Resource( - resource.controller, resource.deserializer, - resource.serializer), + controller=wsgi_resource, collection=resource.collection_actions, member=resource.member_actions) @@ -88,39 +92,66 @@ class APIRouter(base_wsgi.Router): mapper.resource(resource.collection, resource.collection, **kargs) + def _setup_extensions(self, ext_mgr): + for extension in ext_mgr.get_controller_extensions(): + ext_name = extension.extension.name + collection = extension.collection + controller = extension.controller + + if collection not in self.resources: + LOG.warning(_('Extension %(ext_name)s: Cannot extend ' + 'resource %(collection)s: No such resource') % + locals()) + continue + + LOG.debug(_('Extension %(ext_name)s extending resource: ' + '%(collection)s') % locals()) + + resource = self.resources[collection] + resource.register_actions(controller) + resource.register_extensions(controller) + def _setup_routes(self, mapper): + self.resources['versions'] = versions.create_resource() mapper.connect("versions", "/", - controller=versions.create_resource(), + controller=self.resources['versions'], action='show') mapper.redirect("", "/") + self.resources['consoles'] = consoles.create_resource() mapper.resource("console", "consoles", - controller=consoles.create_resource(), + controller=self.resources['consoles'], parent_resource=dict(member_name='server', collection_name='servers')) + self.resources['servers'] = servers.create_resource() mapper.resource("server", "servers", - controller=servers.create_resource(), + controller=self.resources['servers'], collection={'detail': 'GET'}, member={'action': 'POST'}) - mapper.resource("ip", "ips", controller=ips.create_resource(), + self.resources['ips'] = ips.create_resource() + mapper.resource("ip", "ips", controller=self.resources['ips'], parent_resource=dict(member_name='server', collection_name='servers')) + self.resources['images'] = images.create_resource() mapper.resource("image", "images", - controller=images.create_resource(), + controller=self.resources['images'], collection={'detail': 'GET'}) + self.resources['limits'] = limits.create_resource() mapper.resource("limit", "limits", - controller=limits.create_resource()) + controller=self.resources['limits']) + self.resources['flavors'] = flavors.create_resource() mapper.resource("flavor", "flavors", - controller=flavors.create_resource(), + controller=self.resources['flavors'], collection={'detail': 'GET'}) - image_metadata_controller = image_metadata.create_resource() + self.resources['image_metadata'] = image_metadata.create_resource() + image_metadata_controller = self.resources['image_metadata'] mapper.resource("image_meta", "metadata", controller=image_metadata_controller, @@ -132,7 +163,8 @@ class APIRouter(base_wsgi.Router): action='update_all', conditions={"method": ['PUT']}) - server_metadata_controller = server_metadata.create_resource() + self.resources['server_metadata'] = server_metadata.create_resource() + server_metadata_controller = self.resources['server_metadata'] mapper.resource("server_meta", "metadata", controller=server_metadata_controller, diff --git a/nova/api/openstack/compute/servers.py b/nova/api/openstack/compute/servers.py index 60b85f591..71dc1e2ae 100644 --- a/nova/api/openstack/compute/servers.py +++ b/nova/api/openstack/compute/servers.py @@ -792,31 +792,10 @@ class Controller(wsgi.Controller): @wsgi.response(202) @wsgi.serializers(xml=FullServerTemplate) @wsgi.deserializers(xml=ActionDeserializer) + @wsgi.action('confirmResize') @exception.novaclient_converter @scheduler_api.redirect_handler - def action(self, req, id, body): - """Multi-purpose method used to take actions on a server""" - _actions = { - 'changePassword': self._action_change_password, - 'reboot': self._action_reboot, - 'resize': self._action_resize, - 'confirmResize': self._action_confirm_resize, - 'revertResize': self._action_revert_resize, - 'rebuild': self._action_rebuild, - 'createImage': self._action_create_image, - } - - for key in body: - if key in _actions: - return _actions[key](body, req, id) - else: - msg = _("There is no such server action: %s") % (key,) - raise exc.HTTPBadRequest(explanation=msg) - - msg = _("Invalid request body") - raise exc.HTTPBadRequest(explanation=msg) - - def _action_confirm_resize(self, input_dict, req, id): + def _action_confirm_resize(self, req, id, body): context = req.environ['nova.context'] instance = self._get_server(context, id) try: @@ -832,7 +811,13 @@ class Controller(wsgi.Controller): raise exc.HTTPBadRequest() return exc.HTTPNoContent() - def _action_revert_resize(self, input_dict, req, id): + @wsgi.response(202) + @wsgi.serializers(xml=FullServerTemplate) + @wsgi.deserializers(xml=ActionDeserializer) + @wsgi.action('revertResize') + @exception.novaclient_converter + @scheduler_api.redirect_handler + def _action_revert_resize(self, req, id, body): context = req.environ['nova.context'] instance = self._get_server(context, id) try: @@ -848,10 +833,16 @@ class Controller(wsgi.Controller): raise exc.HTTPBadRequest() return webob.Response(status_int=202) - def _action_reboot(self, input_dict, req, id): - if 'reboot' in input_dict and 'type' in input_dict['reboot']: + @wsgi.response(202) + @wsgi.serializers(xml=FullServerTemplate) + @wsgi.deserializers(xml=ActionDeserializer) + @wsgi.action('reboot') + @exception.novaclient_converter + @scheduler_api.redirect_handler + def _action_reboot(self, req, id, body): + if 'reboot' in body and 'type' in body['reboot']: valid_reboot_types = ['HARD', 'SOFT'] - reboot_type = input_dict['reboot']['type'].upper() + reboot_type = body['reboot']['type'].upper() if not valid_reboot_types.count(reboot_type): msg = _("Argument 'type' for reboot is not HARD or SOFT") LOG.exception(msg) @@ -930,13 +921,19 @@ class Controller(wsgi.Controller): return common.get_id_from_href(flavor_ref) - def _action_change_password(self, input_dict, req, id): + @wsgi.response(202) + @wsgi.serializers(xml=FullServerTemplate) + @wsgi.deserializers(xml=ActionDeserializer) + @wsgi.action('changePassword') + @exception.novaclient_converter + @scheduler_api.redirect_handler + def _action_change_password(self, req, id, body): context = req.environ['nova.context'] - if (not 'changePassword' in input_dict - or not 'adminPass' in input_dict['changePassword']): + if (not 'changePassword' in body + or not 'adminPass' in body['changePassword']): msg = _("No adminPass was specified") raise exc.HTTPBadRequest(explanation=msg) - password = input_dict['changePassword']['adminPass'] + password = body['changePassword']['adminPass'] if not isinstance(password, basestring) or password == '': msg = _("Invalid adminPass") raise exc.HTTPBadRequest(explanation=msg) @@ -956,10 +953,16 @@ class Controller(wsgi.Controller): LOG.debug(msg) raise exc.HTTPBadRequest(explanation=msg) - def _action_resize(self, input_dict, req, id): + @wsgi.response(202) + @wsgi.serializers(xml=FullServerTemplate) + @wsgi.deserializers(xml=ActionDeserializer) + @wsgi.action('resize') + @exception.novaclient_converter + @scheduler_api.redirect_handler + def _action_resize(self, req, id, body): """ Resizes a given instance to the flavor size requested """ try: - flavor_ref = input_dict["resize"]["flavorRef"] + flavor_ref = body["resize"]["flavorRef"] if not flavor_ref: msg = _("Resize request has invalid 'flavorRef' attribute.") raise exc.HTTPBadRequest(explanation=msg) @@ -969,10 +972,16 @@ class Controller(wsgi.Controller): return self._resize(req, id, flavor_ref) - def _action_rebuild(self, info, request, instance_id): + @wsgi.response(202) + @wsgi.serializers(xml=FullServerTemplate) + @wsgi.deserializers(xml=ActionDeserializer) + @wsgi.action('rebuild') + @exception.novaclient_converter + @scheduler_api.redirect_handler + def _action_rebuild(self, req, id, body): """Rebuild an instance with the given attributes""" try: - body = info['rebuild'] + body = body['rebuild'] except (KeyError, TypeError): raise exc.HTTPBadRequest(_("Invalid request body")) @@ -987,8 +996,8 @@ class Controller(wsgi.Controller): except (KeyError, TypeError): password = utils.generate_password(FLAGS.password_length) - context = request.environ['nova.context'] - instance = self._get_server(context, instance_id) + context = req.environ['nova.context'] + instance = self._get_server(context, id) attr_map = { 'personality': 'files_to_inject', @@ -1025,10 +1034,10 @@ class Controller(wsgi.Controller): msg = _("Instance could not be found") raise exc.HTTPNotFound(explanation=msg) - instance = self._get_server(context, instance_id) + instance = self._get_server(context, id) self._add_instance_faults(context, [instance]) - view = self._view_builder.show(request, instance) + view = self._view_builder.show(req, instance) # Add on the adminPass attribute since the view doesn't do it view['server']['adminPass'] = password @@ -1036,11 +1045,17 @@ class Controller(wsgi.Controller): robj = wsgi.ResponseObject(view) return self._add_location(robj) + @wsgi.response(202) + @wsgi.serializers(xml=FullServerTemplate) + @wsgi.deserializers(xml=ActionDeserializer) + @wsgi.action('createImage') + @exception.novaclient_converter + @scheduler_api.redirect_handler @common.check_snapshots_enabled - def _action_create_image(self, input_dict, req, instance_id): + def _action_create_image(self, req, id, body): """Snapshot a server instance.""" context = req.environ['nova.context'] - entity = input_dict.get("createImage", {}) + entity = body.get("createImage", {}) try: image_name = entity["name"] @@ -1054,7 +1069,7 @@ class Controller(wsgi.Controller): raise exc.HTTPBadRequest(explanation=msg) # preserve link to server in image properties - server_ref = os.path.join(req.application_url, 'servers', instance_id) + server_ref = os.path.join(req.application_url, 'servers', id) props = {'instance_ref': server_ref} metadata = entity.get('metadata', {}) @@ -1065,7 +1080,7 @@ class Controller(wsgi.Controller): msg = _("Invalid metadata") raise exc.HTTPBadRequest(explanation=msg) - instance = self._get_server(context, instance_id) + instance = self._get_server(context, id) try: image = self.compute_api.snapshot(context, diff --git a/nova/api/openstack/extensions.py b/nova/api/openstack/extensions.py index 6c49e8ace..669bc699a 100644 --- a/nova/api/openstack/extensions.py +++ b/nova/api/openstack/extensions.py @@ -97,6 +97,14 @@ class ExtensionDescriptor(object): request_exts = [] return request_exts + def get_controller_extensions(self): + """List of extensions.ControllerExtension extension objects. + + Controller extensions are used to extend existing controllers. + """ + controller_exts = [] + return controller_exts + @classmethod def nsmap(cls): """Synthesize a namespace map from extension.""" @@ -441,6 +449,18 @@ class ExtensionManager(object): pass return request_exts + def get_controller_extensions(self): + """Returns a list of ControllerExtension objects.""" + controller_exts = [] + for ext in self.extensions.values(): + try: + controller_exts.extend(ext.get_controller_extensions()) + except AttributeError: + # NOTE(Vek): Extensions aren't required to have + # controller extensions + pass + return controller_exts + def _check_extension(self, extension): """Checks for required methods in extension objects.""" try: @@ -492,6 +512,20 @@ class ExtensionManager(object): '%(exc)s') % locals()) +class ControllerExtension(object): + """Extend core controllers of nova OpenStack API. + + Provide a way to extend existing nova OpenStack API core + controllers. + """ + + def __init__(self, extension, collection, controller): + self.extension = extension + self.collection = collection + self.controller = controller + + +@utils.deprecated("Superseded by ControllerExtension") class RequestExtension(object): """Extend requests and responses of core nova OpenStack API resources. @@ -507,6 +541,7 @@ class RequestExtension(object): self.pre_handler = pre_handler +@utils.deprecated("Superseded by ControllerExtension") class ActionExtension(object): """Add custom actions to core nova OpenStack API resources.""" diff --git a/nova/api/openstack/volume/__init__.py b/nova/api/openstack/volume/__init__.py index 075b53c29..d83725a1a 100644 --- a/nova/api/openstack/volume/__init__.py +++ b/nova/api/openstack/volume/__init__.py @@ -56,8 +56,10 @@ class APIRouter(base_wsgi.Router): ext_mgr = extensions.ExtensionManager() mapper = nova.api.openstack.ProjectMapper() + self.resources = {} self._setup_routes(mapper) self._setup_ext_routes(mapper, ext_mgr) + self._setup_extensions(ext_mgr) super(APIRouter, self).__init__(mapper) def _setup_ext_routes(self, mapper, ext_mgr): @@ -66,13 +68,13 @@ class APIRouter(base_wsgi.Router): for resource in ext_mgr.get_resources(): LOG.debug(_('Extended resource: %s'), resource.collection) - if resource.serializer is None: - resource.serializer = serializer + wsgi_resource = wsgi.Resource( + resource.controller, resource.deserializer, + resource.serializer) + self.resources[resource.collection] = wsgi_resource kargs = dict( - controller=wsgi.Resource( - resource.controller, resource.deserializer, - resource.serializer), + controller=wsgi_resource, collection=resource.collection_actions, member=resource.member_actions) @@ -81,19 +83,42 @@ class APIRouter(base_wsgi.Router): mapper.resource(resource.collection, resource.collection, **kargs) + def _setup_extensions(self, ext_mgr): + for extension in ext_mgr.get_controller_extensions(): + ext_name = extension.extension.name + collection = extension.collection + controller = extension.controller + + if collection not in self.resources: + LOG.warning(_('Extension %(ext_name)s: Cannot extend ' + 'resource %(collection)s: No such resource') % + locals()) + continue + + LOG.debug(_('Extension %(ext_name)s extending resource: ' + '%(collection)s') % locals()) + + resource = self.resources[collection] + resource.register_actions(controller) + resource.register_extensions(controller) + def _setup_routes(self, mapper): + self.resources['versions'] = versions.create_resource() mapper.connect("versions", "/", - controller=versions.create_resource(), + controller=self.resources['versions'], action='show') mapper.redirect("", "/") + self.resources['volumes'] = volumes.create_resource() mapper.resource("volume", "volumes", - controller=volumes.create_resource(), + controller=self.resources['volumes'], collection={'detail': 'GET'}) + self.resources['types'] = types.create_resource() mapper.resource("type", "types", - controller=types.create_resource()) + controller=self.resources['types']) + self.resources['snapshots'] = snapshots.create_resource() mapper.resource("snapshot", "snapshots", - controller=snapshots.create_resource()) + controller=self.resources['snapshots']) diff --git a/nova/api/openstack/wsgi.py b/nova/api/openstack/wsgi.py index f8790d9f0..defb26e6e 100644 --- a/nova/api/openstack/wsgi.py +++ b/nova/api/openstack/wsgi.py @@ -15,11 +15,13 @@ # License for the specific language governing permissions and limitations # under the License. +import inspect from xml.dom import minidom from xml.parsers import expat from lxml import etree import webob +from webob import exc from nova import exception from nova import log as logging @@ -700,6 +702,65 @@ class ResponseObject(object): return self._headers.copy() +def action_peek_json(body): + """Determine action to invoke.""" + + try: + decoded = utils.loads(body) + except ValueError: + msg = _("cannot understand JSON") + raise exception.MalformedRequestBody(reason=msg) + + # Make sure there's exactly one key... + if len(decoded) != 1: + msg = _("too many body keys") + raise exception.MalformedRequestBody(reason=msg) + + # Return the action and the decoded body... + return decoded.keys()[0] + + +def action_peek_xml(body): + """Determine action to invoke.""" + + dom = minidom.parseString(body) + action_node = dom.childNodes[0] + + return action_node.tagName + + +class ResourceExceptionHandler(object): + """Context manager to handle Resource exceptions. + + Used when processing exceptions generated by API implementation + methods (or their extensions). Converts most exceptions to Fault + exceptions, with the appropriate logging. + """ + + def __enter__(self): + return None + + def __exit__(self, ex_type, ex_value, ex_traceback): + if not ex_value: + return True + + if isinstance(ex_value, exception.NotAuthorized): + msg = unicode(ex_value) + raise Fault(webob.exc.HTTPUnauthorized(explanation=msg)) + elif isinstance(ex_value, TypeError): + LOG.exception(ex_value) + raise Fault(webob.exc.HTTPBadRequest()) + elif isinstance(ex_value, Fault): + LOG.info(_("Fault thrown: %s"), unicode(ex_value)) + raise ex_value + elif isinstance(ex_value, webob.exc.HTTPException): + LOG.info(_("HTTP exception thrown: %s"), unicode(ex_value)) + raise Fault(ex_value) + + # We didn't handle the exception + return False + + class Resource(wsgi.Application): """WSGI app that handles (de)serialization and controller dispatch. @@ -717,15 +778,17 @@ class Resource(wsgi.Application): """ def __init__(self, controller, deserializer=None, serializer=None, - **deserializers): + action_peek=None, **deserializers): """ :param controller: object that implement methods created by routes lib :param deserializer: object that can serialize the output of a controller into a webob response :param serializer: object that can deserialize a webob request into necessary pieces - + :param action_peek: dictionary of routines for peeking into an action + request body to determine the desired action """ + self.controller = controller self.deserializer = deserializer self.serializer = serializer @@ -738,6 +801,45 @@ class Resource(wsgi.Application): self.default_serializers = dict(xml=XMLDictSerializer, json=JSONDictSerializer) + self.action_peek = dict(xml=action_peek_xml, + json=action_peek_json) + self.action_peek.update(action_peek or {}) + + # Copy over the actions dictionary + self.wsgi_actions = {} + if controller: + self.register_actions(controller) + + # Save a mapping of extensions + self.wsgi_extensions = {} + self.wsgi_action_extensions = {} + + def register_actions(self, controller): + """Registers controller actions with this resource.""" + + actions = getattr(controller, 'wsgi_actions', {}) + for key, method_name in actions.items(): + self.wsgi_actions[key] = getattr(controller, method_name) + + def register_extensions(self, controller): + """Registers controller extensions with this resource.""" + + extensions = getattr(controller, 'wsgi_extensions', []) + for method_name, action_name in extensions: + # Look up the extending method + extension = getattr(controller, method_name) + + if action_name: + # Extending an action... + if action_name not in self.wsgi_action_extensions: + self.wsgi_action_extensions[action_name] = [] + self.wsgi_action_extensions[action_name].append(extension) + else: + # Extending a regular method + if method_name not in self.wsgi_extensions: + self.wsgi_extensions[method_name] = [] + self.wsgi_extensions[method_name].append(extension) + def get_action_args(self, request_environment): """Parse dictionary created by routes library.""" @@ -793,6 +895,66 @@ class Resource(wsgi.Application): return deserializer().deserialize(body) + def pre_process_extensions(self, extensions, request, action_args): + # List of callables for post-processing extensions + post = [] + + for ext in extensions: + if inspect.isgeneratorfunction(ext): + response = None + + # If it's a generator function, the part before the + # yield is the preprocessing stage + try: + with ResourceExceptionHandler(): + gen = ext(req=request, **action_args) + response = gen.next() + except Fault as ex: + response = ex + + # We had a response... + if response: + return response, [] + + # No response, queue up generator for post-processing + post.append(gen) + else: + # Regular functions only perform post-processing + post.append(ext) + + # Run post-processing in the reverse order + return None, reversed(post) + + def post_process_extensions(self, extensions, resp_obj, request, + action_args): + for ext in extensions: + response = None + if inspect.isgenerator(ext): + # If it's a generator, run the second half of + # processing + try: + with ResourceExceptionHandler(): + response = ext.send(resp_obj) + except StopIteration: + # Normal exit of generator + continue + except Fault as ex: + response = ex + else: + # Regular functions get post-processing... + try: + with ResourceExceptionHandler(): + response = ext(req=request, resp_obj=resp_obj, + **action_args) + except Fault as ex: + response = ex + + # We had a response... + if response: + return response + + return None + @webob.dec.wsgify(RequestClass=Request) def __call__(self, request): """WSGI method that controls (de)serialization and method dispatch.""" @@ -809,9 +971,16 @@ class Resource(wsgi.Application): # Get the implementing method try: - meth = self.get_method(request, action) + meth, extensions = self.get_method(request, action, + content_type, body) except (AttributeError, TypeError): return Fault(webob.exc.HTTPNotFound()) + except KeyError as ex: + msg = _("There is no such action: %s") % ex.args[0] + return Fault(webob.exc.HTTPBadRequest(explanation=msg)) + except exception.MalformedRequestBody: + msg = _("Malformed request body") + return Fault(webob.exc.HTTPBadRequest(explanation=msg)) # Now, deserialize the request body... try: @@ -837,21 +1006,16 @@ class Resource(wsgi.Application): msg = _("Malformed request url") return Fault(webob.exc.HTTPBadRequest(explanation=msg)) - response = None - try: - action_result = self.dispatch(meth, request, action_args) - except exception.NotAuthorized as ex: - msg = unicode(ex) - response = Fault(webob.exc.HTTPUnauthorized(explanation=msg)) - except TypeError as ex: - LOG.exception(ex) - response = Fault(webob.exc.HTTPBadRequest()) - except Fault as ex: - LOG.info(_("Fault thrown: %s"), unicode(ex)) - response = ex - except webob.exc.HTTPException as ex: - LOG.info(_("HTTP exception thrown: %s"), unicode(ex)) - response = Fault(ex) + # Run pre-processing extensions + response, post = self.pre_process_extensions(extensions, + request, action_args) + + if not response: + try: + with ResourceExceptionHandler(): + action_result = self.dispatch(meth, request, action_args) + except Fault as ex: + response = ex if not response: # No exceptions; convert action_result into a @@ -864,7 +1028,12 @@ class Resource(wsgi.Application): else: response = action_result + # Run post-processing extensions if resp_obj: + response = self.post_process_extensions(post, resp_obj, + request, action_args) + + if resp_obj and not response: if self.serializer: response = self.serializer.serialize(request, resp_obj.obj, @@ -890,12 +1059,29 @@ class Resource(wsgi.Application): return response - def get_method(self, request, action): - """Look up the action-specific method.""" + def get_method(self, request, action, content_type, body): + """Look up the action-specific method and its extensions.""" + + # Look up the method + try: + if not self.controller: + meth = getattr(self, action) + else: + meth = getattr(self.controller, action) + except AttributeError as ex: + if action != 'action' or not self.wsgi_actions: + # Propagate the error + raise + else: + return meth, self.wsgi_extensions.get(action, []) + + # OK, it's an action; figure out which action... + mtype = _MEDIA_TYPE_MAP.get(content_type) + action_name = self.action_peek[mtype](body) - if self.controller is None: - return getattr(self, action) - return getattr(self.controller, action) + # Look up the action method + return (self.wsgi_actions[action_name], + self.wsgi_action_extensions.get(action_name, [])) def dispatch(self, method, request, action_args): """Dispatch a call to the action-specific method.""" @@ -903,14 +1089,91 @@ class Resource(wsgi.Application): return method(req=request, **action_args) +def action(name): + """Mark a function as an action. + + The given name will be taken as the action key in the body. + """ + + def decorator(func): + func.wsgi_action = name + return func + return decorator + + +def extends(*args, **kwargs): + """Indicate a function extends an operation. + + Can be used as either:: + + @extends + def index(...): + pass + + or as:: + + @extends(action='resize') + def _action_resize(...): + pass + """ + + def decorator(func): + # Store enough information to find what we're extending + func.wsgi_extends = (func.__name__, kwargs.get('action')) + return func + + # If we have positional arguments, call the decorator + if args: + return decorator(*args) + + # OK, return the decorator instead + return decorator + + +class ControllerMetaclass(type): + """Controller metaclass. + + This metaclass automates the task of assembling a dictionary + mapping action keys to method names. + """ + + def __new__(mcs, name, bases, cls_dict): + """Adds the wsgi_actions dictionary to the class.""" + + # Find all actions + actions = {} + extensions = [] + for key, value in cls_dict.items(): + if not callable(value): + continue + if getattr(value, 'wsgi_action', None): + actions[value.wsgi_action] = key + elif getattr(value, 'wsgi_extends', None): + extensions.append(value.wsgi_extends) + + # Add the actions and extensions to the class dict + cls_dict['wsgi_actions'] = actions + cls_dict['wsgi_extensions'] = extensions + + return super(ControllerMetaclass, mcs).__new__(mcs, name, bases, + cls_dict) + + class Controller(object): """Default controller.""" + __metaclass__ = ControllerMetaclass + _view_builder_class = None def __init__(self, view_builder=None): """Initialize controller with a view builder instance.""" - self._view_builder = view_builder or self._view_builder_class() + if view_builder: + self._view_builder = view_builder + elif self._view_builder_class: + self._view_builder = self._view_builder_class() + else: + self._view_builder = None class Fault(webob.exc.HTTPException): diff --git a/nova/tests/api/openstack/compute/test_extensions.py b/nova/tests/api/openstack/compute/test_extensions.py index a4585781b..4e41caa69 100644 --- a/nova/tests/api/openstack/compute/test_extensions.py +++ b/nova/tests/api/openstack/compute/test_extensions.py @@ -21,7 +21,7 @@ import json import webob from lxml import etree -from nova.api.openstack import compute +from nova.api.openstack import compute from nova.api.openstack import extensions as base_extensions from nova.api.openstack.compute import extensions as compute_extensions from nova.api.openstack import wsgi @@ -36,6 +36,7 @@ FLAGS = flags.FLAGS NS = "{http://docs.openstack.org/compute/api/v1.1}" ATOMNS = "{http://www.w3.org/2005/Atom}" response_body = "Try to say this Mr. Knox, sir..." +extension_body = "I am not a fox!" class StubController(object): @@ -54,16 +55,60 @@ class StubController(object): raise webob.exc.HTTPNotFound() +class StubActionController(wsgi.Controller): + def __init__(self, body): + self.body = body + + @wsgi.action('fooAction') + def _action_foo(self, req, id, body): + return self.body + + +class StubControllerExtension(base_extensions.ExtensionDescriptor): + name = 'twaadle' + + def __init__(self): + pass + + +class StubEarlyExtensionController(wsgi.Controller): + def __init__(self, body): + self.body = body + + @wsgi.extends + def index(self, req): + yield self.body + + @wsgi.extends(action='fooAction') + def _action_foo(self, req, id, body): + yield self.body + + +class StubLateExtensionController(wsgi.Controller): + def __init__(self, body): + self.body = body + + @wsgi.extends + def index(self, req, resp_obj): + return self.body + + @wsgi.extends(action='fooAction') + def _action_foo(self, req, resp_obj, id, body): + return self.body + + class StubExtensionManager(object): """Provides access to Tweedle Beetles""" name = "Tweedle Beetle Extension" alias = "TWDLBETL" - def __init__(self, resource_ext=None, action_ext=None, request_ext=None): + def __init__(self, resource_ext=None, action_ext=None, request_ext=None, + controller_ext=None): self.resource_ext = resource_ext self.action_ext = action_ext self.request_ext = request_ext + self.controller_ext = controller_ext def get_resources(self): resource_exts = [] @@ -83,6 +128,12 @@ class StubExtensionManager(object): request_extensions.append(self.request_ext) return request_extensions + def get_controller_extensions(self): + controller_extensions = [] + if self.controller_ext: + controller_extensions.append(self.controller_ext) + return controller_extensions + class ExtensionTestCase(test.TestCase): def setUp(self): @@ -411,7 +462,7 @@ class ActionExtensionTest(ExtensionTestCase): body = json.loads(response.body) expected = { "badRequest": { - "message": "There is no such server action: blah", + "message": "There is no such action: blah", "code": 400 } } @@ -478,6 +529,84 @@ class RequestExtensionTest(ExtensionTestCase): self.assertEqual("Pig Bands!", response_data['big_bands']) +class ControllerExtensionTest(ExtensionTestCase): + def test_controller_extension_early(self): + controller = StubController(response_body) + res_ext = base_extensions.ResourceExtension('tweedles', controller) + ext_controller = StubEarlyExtensionController(extension_body) + extension = StubControllerExtension() + cont_ext = base_extensions.ControllerExtension(extension, 'tweedles', + ext_controller) + manager = StubExtensionManager(resource_ext=res_ext, + controller_ext=cont_ext) + app = compute.APIRouter(manager) + request = webob.Request.blank("/fake/tweedles") + response = request.get_response(app) + self.assertEqual(200, response.status_int) + self.assertEqual(extension_body, response.body) + + def test_controller_extension_late(self): + # Need a dict for the body to convert to a ResponseObject + controller = StubController(dict(foo=response_body)) + res_ext = base_extensions.ResourceExtension('tweedles', controller) + + ext_controller = StubLateExtensionController(extension_body) + extension = StubControllerExtension() + cont_ext = base_extensions.ControllerExtension(extension, 'tweedles', + ext_controller) + + manager = StubExtensionManager(resource_ext=res_ext, + controller_ext=cont_ext) + app = compute.APIRouter(manager) + request = webob.Request.blank("/fake/tweedles") + response = request.get_response(app) + self.assertEqual(200, response.status_int) + self.assertEqual(extension_body, response.body) + + def test_controller_action_extension_early(self): + controller = StubActionController(response_body) + actions = dict(action='POST') + res_ext = base_extensions.ResourceExtension('tweedles', controller, + member_actions=actions) + ext_controller = StubEarlyExtensionController(extension_body) + extension = StubControllerExtension() + cont_ext = base_extensions.ControllerExtension(extension, 'tweedles', + ext_controller) + manager = StubExtensionManager(resource_ext=res_ext, + controller_ext=cont_ext) + app = compute.APIRouter(manager) + request = webob.Request.blank("/fake/tweedles/foo/action") + request.method = 'POST' + request.headers['Content-Type'] = 'application/json' + request.body = json.dumps(dict(fooAction=True)) + response = request.get_response(app) + self.assertEqual(200, response.status_int) + self.assertEqual(extension_body, response.body) + + def test_controller_action_extension_late(self): + # Need a dict for the body to convert to a ResponseObject + controller = StubActionController(dict(foo=response_body)) + actions = dict(action='POST') + res_ext = base_extensions.ResourceExtension('tweedles', controller, + member_actions=actions) + + ext_controller = StubLateExtensionController(extension_body) + extension = StubControllerExtension() + cont_ext = base_extensions.ControllerExtension(extension, 'tweedles', + ext_controller) + + manager = StubExtensionManager(resource_ext=res_ext, + controller_ext=cont_ext) + app = compute.APIRouter(manager) + request = webob.Request.blank("/fake/tweedles/foo/action") + request.method = 'POST' + request.headers['Content-Type'] = 'application/json' + request.body = json.dumps(dict(fooAction=True)) + response = request.get_response(app) + self.assertEqual(200, response.status_int) + self.assertEqual(extension_body, response.body) + + class ExtensionsXMLSerializerTest(test.TestCase): def test_serialize_extension(self): diff --git a/nova/tests/api/openstack/compute/test_server_actions.py b/nova/tests/api/openstack/compute/test_server_actions.py index 2f3976375..e92816172 100644 --- a/nova/tests/api/openstack/compute/test_server_actions.py +++ b/nova/tests/api/openstack/compute/test_server_actions.py @@ -174,27 +174,13 @@ class ServerActionsControllerTest(test.TestCase): self.stubs.UnsetAll() super(ServerActionsControllerTest, self).tearDown() - def test_server_bad_body(self): - body = {} - - req = fakes.HTTPRequest.blank(self.url) - self.assertRaises(webob.exc.HTTPBadRequest, - self.controller.action, req, FAKE_UUID, body) - - def test_server_unknown_action(self): - body = {'sockTheFox': {'fakekey': '1234'}} - - req = fakes.HTTPRequest.blank(self.url) - self.assertRaises(webob.exc.HTTPBadRequest, - self.controller.action, req, FAKE_UUID, body) - def test_server_change_password(self): mock_method = MockSetAdminPassword() self.stubs.Set(nova.compute.api.API, 'set_admin_password', mock_method) body = {'changePassword': {'adminPass': '1234pass'}} req = fakes.HTTPRequest.blank(self.url) - self.controller.action(req, FAKE_UUID, body) + self.controller._action_change_password(req, FAKE_UUID, body) self.assertEqual(mock_method.instance_id, self.uuid) self.assertEqual(mock_method.password, '1234pass') @@ -203,47 +189,53 @@ class ServerActionsControllerTest(test.TestCase): body = {'changePassword': {'adminPass': 1234}} req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPBadRequest, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_change_password, + req, FAKE_UUID, body) def test_server_change_password_bad_request(self): body = {'changePassword': {'pass': '12345'}} req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPBadRequest, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_change_password, + req, FAKE_UUID, body) def test_server_change_password_empty_string(self): body = {'changePassword': {'adminPass': ''}} req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPBadRequest, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_change_password, + req, FAKE_UUID, body) def test_server_change_password_none(self): body = {'changePassword': {'adminPass': None}} req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPBadRequest, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_change_password, + req, FAKE_UUID, body) def test_reboot_hard(self): body = dict(reboot=dict(type="HARD")) req = fakes.HTTPRequest.blank(self.url) - self.controller.action(req, FAKE_UUID, body) + self.controller._action_reboot(req, FAKE_UUID, body) def test_reboot_soft(self): body = dict(reboot=dict(type="SOFT")) req = fakes.HTTPRequest.blank(self.url) - self.controller.action(req, FAKE_UUID, body) + self.controller._action_reboot(req, FAKE_UUID, body) def test_reboot_incorrect_type(self): body = dict(reboot=dict(type="NOT_A_TYPE")) req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPBadRequest, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_reboot, + req, FAKE_UUID, body) def test_reboot_missing_type(self): body = dict(reboot=dict()) req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPBadRequest, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_reboot, + req, FAKE_UUID, body) def test_reboot_not_found(self): self.stubs.Set(nova.db, 'instance_get_by_uuid', @@ -251,7 +243,8 @@ class ServerActionsControllerTest(test.TestCase): body = dict(reboot=dict(type="HARD")) req = fakes.HTTPRequest.blank(self.url) - self.assertRaises(webob.exc.HTTPNotFound, self.controller.action, + self.assertRaises(webob.exc.HTTPNotFound, + self.controller._action_reboot, req, str(utils.gen_uuid()), body) def test_reboot_raises_conflict_on_invalid_state(self): @@ -263,7 +256,8 @@ class ServerActionsControllerTest(test.TestCase): self.stubs.Set(nova.compute.api.API, 'reboot', fake_reboot) req = fakes.HTTPRequest.blank(self.url) - self.assertRaises(webob.exc.HTTPConflict, self.controller.action, + self.assertRaises(webob.exc.HTTPConflict, + self.controller._action_reboot, req, FAKE_UUID, body) def test_rebuild_accepted_minimum(self): @@ -278,7 +272,7 @@ class ServerActionsControllerTest(test.TestCase): } req = fakes.HTTPRequest.blank(self.url) - robj = self.controller.action(req, FAKE_UUID, body) + robj = self.controller._action_rebuild(req, FAKE_UUID, body) body = robj.obj self.assertEqual(body['server']['image']['id'], '2') @@ -300,7 +294,8 @@ class ServerActionsControllerTest(test.TestCase): req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPConflict, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_rebuild, + req, FAKE_UUID, body) def test_rebuild_accepted_with_metadata(self): metadata = {'new': 'metadata'} @@ -316,7 +311,7 @@ class ServerActionsControllerTest(test.TestCase): } req = fakes.HTTPRequest.blank(self.url) - body = self.controller.action(req, FAKE_UUID, body).obj + body = self.controller._action_rebuild(req, FAKE_UUID, body).obj self.assertEqual(body['server']['metadata'], metadata) @@ -330,7 +325,8 @@ class ServerActionsControllerTest(test.TestCase): req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPBadRequest, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_rebuild, + req, FAKE_UUID, body) def test_rebuild_bad_entity(self): body = { @@ -341,7 +337,8 @@ class ServerActionsControllerTest(test.TestCase): req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPBadRequest, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_rebuild, + req, FAKE_UUID, body) def test_rebuild_bad_personality(self): body = { @@ -356,7 +353,8 @@ class ServerActionsControllerTest(test.TestCase): req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPBadRequest, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_rebuild, + req, FAKE_UUID, body) def test_rebuild_personality(self): body = { @@ -370,7 +368,7 @@ class ServerActionsControllerTest(test.TestCase): } req = fakes.HTTPRequest.blank(self.url) - body = self.controller.action(req, FAKE_UUID, body).obj + body = self.controller._action_rebuild(req, FAKE_UUID, body).obj self.assertTrue('personality' not in body['server']) @@ -386,7 +384,7 @@ class ServerActionsControllerTest(test.TestCase): } req = fakes.HTTPRequest.blank(self.url) - body = self.controller.action(req, FAKE_UUID, body).obj + body = self.controller._action_rebuild(req, FAKE_UUID, body).obj self.assertEqual(body['server']['image']['id'], '2') self.assertEqual(body['server']['adminPass'], 'asdf') @@ -404,7 +402,8 @@ class ServerActionsControllerTest(test.TestCase): req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPNotFound, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_rebuild, + req, FAKE_UUID, body) def test_rebuild_accessIP(self): attributes = { @@ -430,7 +429,7 @@ class ServerActionsControllerTest(test.TestCase): task_state=None, progress=0, **attributes).AndReturn(None) self.mox.ReplayAll() - self.controller.action(req, FAKE_UUID, body) + self.controller._action_rebuild(req, FAKE_UUID, body) self.mox.VerifyAll() def test_resize_server(self): @@ -445,7 +444,7 @@ class ServerActionsControllerTest(test.TestCase): self.stubs.Set(nova.compute.api.API, 'resize', resize_mock) req = fakes.HTTPRequest.blank(self.url) - body = self.controller.action(req, FAKE_UUID, body) + body = self.controller._action_resize(req, FAKE_UUID, body) self.assertEqual(self.resize_called, True) @@ -454,14 +453,16 @@ class ServerActionsControllerTest(test.TestCase): req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPBadRequest, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_resize, + req, FAKE_UUID, body) def test_resize_server_no_flavor_ref(self): body = dict(resize=dict(flavorRef=None)) req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPBadRequest, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_resize, + req, FAKE_UUID, body) def test_resize_raises_conflict_on_invalid_state(self): body = dict(resize=dict(flavorRef="http://localhost/3")) @@ -472,7 +473,8 @@ class ServerActionsControllerTest(test.TestCase): self.stubs.Set(nova.compute.api.API, 'resize', fake_resize) req = fakes.HTTPRequest.blank(self.url) - self.assertRaises(webob.exc.HTTPConflict, self.controller.action, + self.assertRaises(webob.exc.HTTPConflict, + self.controller._action_resize, req, FAKE_UUID, body) def test_confirm_resize_server(self): @@ -486,7 +488,7 @@ class ServerActionsControllerTest(test.TestCase): self.stubs.Set(nova.compute.api.API, 'confirm_resize', cr_mock) req = fakes.HTTPRequest.blank(self.url) - body = self.controller.action(req, FAKE_UUID, body) + body = self.controller._action_confirm_resize(req, FAKE_UUID, body) self.assertEqual(self.confirm_resize_called, True) @@ -503,7 +505,8 @@ class ServerActionsControllerTest(test.TestCase): req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPBadRequest, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_confirm_resize, + req, FAKE_UUID, body) def test_confirm_resize_raises_conflict_on_invalid_state(self): body = dict(confirmResize=None) @@ -515,7 +518,8 @@ class ServerActionsControllerTest(test.TestCase): fake_confirm_resize) req = fakes.HTTPRequest.blank(self.url) - self.assertRaises(webob.exc.HTTPConflict, self.controller.action, + self.assertRaises(webob.exc.HTTPConflict, + self.controller._action_confirm_resize, req, FAKE_UUID, body) def test_revert_resize_migration_not_found(self): @@ -531,7 +535,8 @@ class ServerActionsControllerTest(test.TestCase): req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPBadRequest, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_revert_resize, + req, FAKE_UUID, body) def test_revert_resize_server(self): body = dict(revertResize=None) @@ -544,7 +549,7 @@ class ServerActionsControllerTest(test.TestCase): self.stubs.Set(nova.compute.api.API, 'revert_resize', revert_mock) req = fakes.HTTPRequest.blank(self.url) - body = self.controller.action(req, FAKE_UUID, body) + body = self.controller._action_revert_resize(req, FAKE_UUID, body) self.assertEqual(self.revert_resize_called, True) @@ -558,7 +563,8 @@ class ServerActionsControllerTest(test.TestCase): fake_revert_resize) req = fakes.HTTPRequest.blank(self.url) - self.assertRaises(webob.exc.HTTPConflict, self.controller.action, + self.assertRaises(webob.exc.HTTPConflict, + self.controller._action_revert_resize, req, FAKE_UUID, body) def test_create_image(self): @@ -569,7 +575,7 @@ class ServerActionsControllerTest(test.TestCase): } req = fakes.HTTPRequest.blank(self.url) - response = self.controller.action(req, FAKE_UUID, body) + response = self.controller._action_create_image(req, FAKE_UUID, body) location = response.headers['Location'] self.assertEqual('http://localhost/v2/fake/images/123', location) @@ -589,7 +595,8 @@ class ServerActionsControllerTest(test.TestCase): } req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPBadRequest, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_create_image, + req, FAKE_UUID, body) def test_create_image_with_metadata(self): body = { @@ -600,7 +607,7 @@ class ServerActionsControllerTest(test.TestCase): } req = fakes.HTTPRequest.blank(self.url) - response = self.controller.action(req, FAKE_UUID, body) + response = self.controller._action_create_image(req, FAKE_UUID, body) location = response.headers['Location'] self.assertEqual('http://localhost/v2/fake/images/123', location) @@ -617,7 +624,8 @@ class ServerActionsControllerTest(test.TestCase): req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPRequestEntityTooLarge, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_create_image, + req, FAKE_UUID, body) def test_create_image_no_name(self): body = { @@ -625,7 +633,8 @@ class ServerActionsControllerTest(test.TestCase): } req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPBadRequest, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_create_image, + req, FAKE_UUID, body) def test_create_image_bad_metadata(self): body = { @@ -636,7 +645,8 @@ class ServerActionsControllerTest(test.TestCase): } req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPBadRequest, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_create_image, + req, FAKE_UUID, body) def test_create_image_raises_conflict_on_invalid_state(self): def snapshot(*args, **kwargs): @@ -651,7 +661,8 @@ class ServerActionsControllerTest(test.TestCase): req = fakes.HTTPRequest.blank(self.url) self.assertRaises(webob.exc.HTTPConflict, - self.controller.action, req, FAKE_UUID, body) + self.controller._action_create_image, + req, FAKE_UUID, body) class TestServerActionXMLDeserializer(test.TestCase): diff --git a/nova/tests/api/openstack/test_wsgi.py b/nova/tests/api/openstack/test_wsgi.py index 2d5f33eee..534ae7e2c 100644 --- a/nova/tests/api/openstack/test_wsgi.py +++ b/nova/tests/api/openstack/test_wsgi.py @@ -1,5 +1,6 @@ # vim: tabstop=4 shiftwidth=4 softtabstop=4 +import inspect import json import webob @@ -445,7 +446,7 @@ class ResourceTest(test.TestCase): controller = Controller() resource = wsgi.Resource(controller) - method = resource.get_method(None, 'index') + method, extensions = resource.get_method(None, 'index', None, '') actual = resource.dispatch(method, None, {'pants': 'off'}) expected = 'off' self.assertEqual(actual, expected) @@ -458,7 +459,68 @@ class ResourceTest(test.TestCase): controller = Controller() resource = wsgi.Resource(controller) self.assertRaises(AttributeError, resource.get_method, - None, 'create') + None, 'create', None, '') + + def test_get_method_action_json(self): + class Controller(wsgi.Controller): + @wsgi.action('fooAction') + def _action_foo(self, req, id, body): + return body + + controller = Controller() + resource = wsgi.Resource(controller) + method, extensions = resource.get_method(None, 'action', + 'application/json', + '{"fooAction": true}') + self.assertEqual(controller._action_foo, method) + + def test_get_method_action_xml(self): + class Controller(wsgi.Controller): + @wsgi.action('fooAction') + def _action_foo(self, req, id, body): + return body + + controller = Controller() + resource = wsgi.Resource(controller) + method, extensions = resource.get_method(None, 'action', + 'application/xml', + '<fooAction>true</fooAction>') + self.assertEqual(controller._action_foo, method) + + def test_get_method_action_bad_body(self): + class Controller(wsgi.Controller): + @wsgi.action('fooAction') + def _action_foo(self, req, id, body): + return body + + controller = Controller() + resource = wsgi.Resource(controller) + self.assertRaises(exception.MalformedRequestBody, resource.get_method, + None, 'action', 'application/json', '{}') + + def test_get_method_unknown_controller_action(self): + class Controller(wsgi.Controller): + @wsgi.action('fooAction') + def _action_foo(self, req, id, body): + return body + + controller = Controller() + resource = wsgi.Resource(controller) + self.assertRaises(KeyError, resource.get_method, + None, 'action', 'application/json', + '{"barAction": true}') + + def test_get_method_action_method(self): + class Controller(): + def action(self, req, pants=None): + return pants + + controller = Controller() + resource = wsgi.Resource(controller) + method, extensions = resource.get_method(None, 'action', + 'application/xml', + '<fooAction>true</fooAction') + self.assertEqual(controller.action, method) def test_get_action_args(self): class Controller(object): @@ -595,6 +657,287 @@ class ResourceTest(test.TestCase): obj = resource.deserialize(controller.index, 'application/xml', 'foo') self.assertEqual(obj, 'xml') + def test_register_actions(self): + class Controller(object): + def index(self, req, pants=None): + return pants + + class ControllerExtended(wsgi.Controller): + @wsgi.action('fooAction') + def _action_foo(self, req, id, body): + return body + + @wsgi.action('barAction') + def _action_bar(self, req, id, body): + return body + + controller = Controller() + resource = wsgi.Resource(controller) + self.assertEqual({}, resource.wsgi_actions) + + extended = ControllerExtended() + resource.register_actions(extended) + self.assertEqual({ + 'fooAction': extended._action_foo, + 'barAction': extended._action_bar, + }, resource.wsgi_actions) + + def test_register_extensions(self): + class Controller(object): + def index(self, req, pants=None): + return pants + + class ControllerExtended(wsgi.Controller): + @wsgi.extends + def index(self, req, resp_obj, pants=None): + return None + + @wsgi.extends(action='fooAction') + def _action_foo(self, req, resp, id, body): + return None + + controller = Controller() + resource = wsgi.Resource(controller) + self.assertEqual({}, resource.wsgi_extensions) + self.assertEqual({}, resource.wsgi_action_extensions) + + extended = ControllerExtended() + resource.register_extensions(extended) + self.assertEqual({'index': [extended.index]}, resource.wsgi_extensions) + self.assertEqual({'fooAction': [extended._action_foo]}, + resource.wsgi_action_extensions) + + def test_get_method_extensions(self): + class Controller(object): + def index(self, req, pants=None): + return pants + + class ControllerExtended(wsgi.Controller): + @wsgi.extends + def index(self, req, resp_obj, pants=None): + return None + + controller = Controller() + extended = ControllerExtended() + resource = wsgi.Resource(controller) + resource.register_extensions(extended) + method, extensions = resource.get_method(None, 'index', None, '') + self.assertEqual(method, controller.index) + self.assertEqual(extensions, [extended.index]) + + def test_get_method_action_extensions(self): + class Controller(wsgi.Controller): + def index(self, req, pants=None): + return pants + + @wsgi.action('fooAction') + def _action_foo(self, req, id, body): + return body + + class ControllerExtended(wsgi.Controller): + @wsgi.extends(action='fooAction') + def _action_foo(self, req, resp_obj, id, body): + return None + + controller = Controller() + extended = ControllerExtended() + resource = wsgi.Resource(controller) + resource.register_extensions(extended) + method, extensions = resource.get_method(None, 'action', + 'application/json', + '{"fooAction": true}') + self.assertEqual(method, controller._action_foo) + self.assertEqual(extensions, [extended._action_foo]) + + def test_pre_process_extensions_regular(self): + class Controller(object): + def index(self, req, pants=None): + return pants + + controller = Controller() + resource = wsgi.Resource(controller) + + called = [] + + def extension1(req, resp_obj): + called.append(1) + return None + + def extension2(req, resp_obj): + called.append(2) + return None + + extensions = [extension1, extension2] + response, post = resource.pre_process_extensions(extensions, None, {}) + self.assertEqual(called, []) + self.assertEqual(response, None) + self.assertEqual(list(post), [extension2, extension1]) + + def test_pre_process_extensions_generator(self): + class Controller(object): + def index(self, req, pants=None): + return pants + + controller = Controller() + resource = wsgi.Resource(controller) + + called = [] + + def extension1(req): + called.append('pre1') + resp_obj = yield + called.append('post1') + + def extension2(req): + called.append('pre2') + resp_obj = yield + called.append('post2') + + extensions = [extension1, extension2] + response, post = resource.pre_process_extensions(extensions, None, {}) + post = list(post) + self.assertEqual(called, ['pre1', 'pre2']) + self.assertEqual(response, None) + self.assertEqual(len(post), 2) + self.assertTrue(inspect.isgenerator(post[0])) + self.assertTrue(inspect.isgenerator(post[1])) + + for gen in post: + try: + gen.send(None) + except StopIteration: + continue + + self.assertEqual(called, ['pre1', 'pre2', 'post2', 'post1']) + + def test_pre_process_extensions_generator_response(self): + class Controller(object): + def index(self, req, pants=None): + return pants + + controller = Controller() + resource = wsgi.Resource(controller) + + called = [] + + def extension1(req): + called.append('pre1') + yield 'foo' + + def extension2(req): + called.append('pre2') + + extensions = [extension1, extension2] + response, post = resource.pre_process_extensions(extensions, None, {}) + self.assertEqual(called, ['pre1']) + self.assertEqual(response, 'foo') + self.assertEqual(post, []) + + def test_post_process_extensions_regular(self): + class Controller(object): + def index(self, req, pants=None): + return pants + + controller = Controller() + resource = wsgi.Resource(controller) + + called = [] + + def extension1(req, resp_obj): + called.append(1) + return None + + def extension2(req, resp_obj): + called.append(2) + return None + + response = resource.post_process_extensions([extension2, extension1], + None, None, {}) + self.assertEqual(called, [2, 1]) + self.assertEqual(response, None) + + def test_post_process_extensions_regular_response(self): + class Controller(object): + def index(self, req, pants=None): + return pants + + controller = Controller() + resource = wsgi.Resource(controller) + + called = [] + + def extension1(req, resp_obj): + called.append(1) + return None + + def extension2(req, resp_obj): + called.append(2) + return 'foo' + + response = resource.post_process_extensions([extension2, extension1], + None, None, {}) + self.assertEqual(called, [2]) + self.assertEqual(response, 'foo') + + def test_post_process_extensions_generator(self): + class Controller(object): + def index(self, req, pants=None): + return pants + + controller = Controller() + resource = wsgi.Resource(controller) + + called = [] + + def extension1(req): + resp_obj = yield + called.append(1) + + def extension2(req): + resp_obj = yield + called.append(2) + + ext1 = extension1(None) + ext1.next() + ext2 = extension2(None) + ext2.next() + + response = resource.post_process_extensions([ext2, ext1], + None, None, {}) + + self.assertEqual(called, [2, 1]) + self.assertEqual(response, None) + + def test_post_process_extensions_generator_response(self): + class Controller(object): + def index(self, req, pants=None): + return pants + + controller = Controller() + resource = wsgi.Resource(controller) + + called = [] + + def extension1(req): + resp_obj = yield + called.append(1) + + def extension2(req): + resp_obj = yield + called.append(2) + yield 'foo' + + ext1 = extension1(None) + ext1.next() + ext2 = extension2(None) + ext2.next() + + response = resource.post_process_extensions([ext2, ext1], + None, None, {}) + + self.assertEqual(called, [2]) + self.assertEqual(response, 'foo') + class ResponseObjectTest(test.TestCase): def test_default_code(self): |
