From 1d4e35be694884a0ea8e586ffb2d06ecd6c48685 Mon Sep 17 00:00:00 2001 From: "Kevin L. Mitchell" Date: Fri, 13 Jan 2012 14:13:59 -0600 Subject: Refactor request and action extensions. The goal of this refactoring is to eventually eliminate ExtensionMiddleware and LazySerializationMiddleware completely, by executing extensions directly within the processing done by Resource.__call__(). This patch implements the infrastructure required to perform this extension processing. Implements blueprint extension-refactor. Change-Id: I23398fc906a9a105de354a8133337ecfc69a3ad3 --- nova/api/openstack/compute/__init__.py | 56 ++++-- nova/api/openstack/compute/servers.py | 101 ++++++----- nova/api/openstack/extensions.py | 35 ++++ nova/api/openstack/volume/__init__.py | 43 ++++- nova/api/openstack/wsgi.py | 311 ++++++++++++++++++++++++++++++--- 5 files changed, 458 insertions(+), 88 deletions(-) (limited to 'nova/api') diff --git a/nova/api/openstack/compute/__init__.py b/nova/api/openstack/compute/__init__.py index 2f6e92a42..69b8daab7 100644 --- a/nova/api/openstack/compute/__init__.py +++ b/nova/api/openstack/compute/__init__.py @@ -67,8 +67,10 @@ class APIRouter(base_wsgi.Router): ext_mgr = extensions.ExtensionManager() mapper = nova.api.openstack.ProjectMapper() + self.resources = {} self._setup_routes(mapper) self._setup_ext_routes(mapper, ext_mgr) + self._setup_extensions(ext_mgr) super(APIRouter, self).__init__(mapper) def _setup_ext_routes(self, mapper, ext_mgr): @@ -76,10 +78,12 @@ class APIRouter(base_wsgi.Router): LOG.debug(_('Extended resource: %s'), resource.collection) + wsgi_resource = wsgi.Resource( + resource.controller, resource.deserializer, + resource.serializer) + self.resources[resource.collection] = wsgi_resource kargs = dict( - controller=wsgi.Resource( - resource.controller, resource.deserializer, - resource.serializer), + controller=wsgi_resource, collection=resource.collection_actions, member=resource.member_actions) @@ -88,39 +92,66 @@ class APIRouter(base_wsgi.Router): mapper.resource(resource.collection, resource.collection, **kargs) + def _setup_extensions(self, ext_mgr): + for extension in ext_mgr.get_controller_extensions(): + ext_name = extension.extension.name + collection = extension.collection + controller = extension.controller + + if collection not in self.resources: + LOG.warning(_('Extension %(ext_name)s: Cannot extend ' + 'resource %(collection)s: No such resource') % + locals()) + continue + + LOG.debug(_('Extension %(ext_name)s extending resource: ' + '%(collection)s') % locals()) + + resource = self.resources[collection] + resource.register_actions(controller) + resource.register_extensions(controller) + def _setup_routes(self, mapper): + self.resources['versions'] = versions.create_resource() mapper.connect("versions", "/", - controller=versions.create_resource(), + controller=self.resources['versions'], action='show') mapper.redirect("", "/") + self.resources['consoles'] = consoles.create_resource() mapper.resource("console", "consoles", - controller=consoles.create_resource(), + controller=self.resources['consoles'], parent_resource=dict(member_name='server', collection_name='servers')) + self.resources['servers'] = servers.create_resource() mapper.resource("server", "servers", - controller=servers.create_resource(), + controller=self.resources['servers'], collection={'detail': 'GET'}, member={'action': 'POST'}) - mapper.resource("ip", "ips", controller=ips.create_resource(), + self.resources['ips'] = ips.create_resource() + mapper.resource("ip", "ips", controller=self.resources['ips'], parent_resource=dict(member_name='server', collection_name='servers')) + self.resources['images'] = images.create_resource() mapper.resource("image", "images", - controller=images.create_resource(), + controller=self.resources['images'], collection={'detail': 'GET'}) + self.resources['limits'] = limits.create_resource() mapper.resource("limit", "limits", - controller=limits.create_resource()) + controller=self.resources['limits']) + self.resources['flavors'] = flavors.create_resource() mapper.resource("flavor", "flavors", - controller=flavors.create_resource(), + controller=self.resources['flavors'], collection={'detail': 'GET'}) - image_metadata_controller = image_metadata.create_resource() + self.resources['image_metadata'] = image_metadata.create_resource() + image_metadata_controller = self.resources['image_metadata'] mapper.resource("image_meta", "metadata", controller=image_metadata_controller, @@ -132,7 +163,8 @@ class APIRouter(base_wsgi.Router): action='update_all', conditions={"method": ['PUT']}) - server_metadata_controller = server_metadata.create_resource() + self.resources['server_metadata'] = server_metadata.create_resource() + server_metadata_controller = self.resources['server_metadata'] mapper.resource("server_meta", "metadata", controller=server_metadata_controller, diff --git a/nova/api/openstack/compute/servers.py b/nova/api/openstack/compute/servers.py index 60b85f591..71dc1e2ae 100644 --- a/nova/api/openstack/compute/servers.py +++ b/nova/api/openstack/compute/servers.py @@ -792,31 +792,10 @@ class Controller(wsgi.Controller): @wsgi.response(202) @wsgi.serializers(xml=FullServerTemplate) @wsgi.deserializers(xml=ActionDeserializer) + @wsgi.action('confirmResize') @exception.novaclient_converter @scheduler_api.redirect_handler - def action(self, req, id, body): - """Multi-purpose method used to take actions on a server""" - _actions = { - 'changePassword': self._action_change_password, - 'reboot': self._action_reboot, - 'resize': self._action_resize, - 'confirmResize': self._action_confirm_resize, - 'revertResize': self._action_revert_resize, - 'rebuild': self._action_rebuild, - 'createImage': self._action_create_image, - } - - for key in body: - if key in _actions: - return _actions[key](body, req, id) - else: - msg = _("There is no such server action: %s") % (key,) - raise exc.HTTPBadRequest(explanation=msg) - - msg = _("Invalid request body") - raise exc.HTTPBadRequest(explanation=msg) - - def _action_confirm_resize(self, input_dict, req, id): + def _action_confirm_resize(self, req, id, body): context = req.environ['nova.context'] instance = self._get_server(context, id) try: @@ -832,7 +811,13 @@ class Controller(wsgi.Controller): raise exc.HTTPBadRequest() return exc.HTTPNoContent() - def _action_revert_resize(self, input_dict, req, id): + @wsgi.response(202) + @wsgi.serializers(xml=FullServerTemplate) + @wsgi.deserializers(xml=ActionDeserializer) + @wsgi.action('revertResize') + @exception.novaclient_converter + @scheduler_api.redirect_handler + def _action_revert_resize(self, req, id, body): context = req.environ['nova.context'] instance = self._get_server(context, id) try: @@ -848,10 +833,16 @@ class Controller(wsgi.Controller): raise exc.HTTPBadRequest() return webob.Response(status_int=202) - def _action_reboot(self, input_dict, req, id): - if 'reboot' in input_dict and 'type' in input_dict['reboot']: + @wsgi.response(202) + @wsgi.serializers(xml=FullServerTemplate) + @wsgi.deserializers(xml=ActionDeserializer) + @wsgi.action('reboot') + @exception.novaclient_converter + @scheduler_api.redirect_handler + def _action_reboot(self, req, id, body): + if 'reboot' in body and 'type' in body['reboot']: valid_reboot_types = ['HARD', 'SOFT'] - reboot_type = input_dict['reboot']['type'].upper() + reboot_type = body['reboot']['type'].upper() if not valid_reboot_types.count(reboot_type): msg = _("Argument 'type' for reboot is not HARD or SOFT") LOG.exception(msg) @@ -930,13 +921,19 @@ class Controller(wsgi.Controller): return common.get_id_from_href(flavor_ref) - def _action_change_password(self, input_dict, req, id): + @wsgi.response(202) + @wsgi.serializers(xml=FullServerTemplate) + @wsgi.deserializers(xml=ActionDeserializer) + @wsgi.action('changePassword') + @exception.novaclient_converter + @scheduler_api.redirect_handler + def _action_change_password(self, req, id, body): context = req.environ['nova.context'] - if (not 'changePassword' in input_dict - or not 'adminPass' in input_dict['changePassword']): + if (not 'changePassword' in body + or not 'adminPass' in body['changePassword']): msg = _("No adminPass was specified") raise exc.HTTPBadRequest(explanation=msg) - password = input_dict['changePassword']['adminPass'] + password = body['changePassword']['adminPass'] if not isinstance(password, basestring) or password == '': msg = _("Invalid adminPass") raise exc.HTTPBadRequest(explanation=msg) @@ -956,10 +953,16 @@ class Controller(wsgi.Controller): LOG.debug(msg) raise exc.HTTPBadRequest(explanation=msg) - def _action_resize(self, input_dict, req, id): + @wsgi.response(202) + @wsgi.serializers(xml=FullServerTemplate) + @wsgi.deserializers(xml=ActionDeserializer) + @wsgi.action('resize') + @exception.novaclient_converter + @scheduler_api.redirect_handler + def _action_resize(self, req, id, body): """ Resizes a given instance to the flavor size requested """ try: - flavor_ref = input_dict["resize"]["flavorRef"] + flavor_ref = body["resize"]["flavorRef"] if not flavor_ref: msg = _("Resize request has invalid 'flavorRef' attribute.") raise exc.HTTPBadRequest(explanation=msg) @@ -969,10 +972,16 @@ class Controller(wsgi.Controller): return self._resize(req, id, flavor_ref) - def _action_rebuild(self, info, request, instance_id): + @wsgi.response(202) + @wsgi.serializers(xml=FullServerTemplate) + @wsgi.deserializers(xml=ActionDeserializer) + @wsgi.action('rebuild') + @exception.novaclient_converter + @scheduler_api.redirect_handler + def _action_rebuild(self, req, id, body): """Rebuild an instance with the given attributes""" try: - body = info['rebuild'] + body = body['rebuild'] except (KeyError, TypeError): raise exc.HTTPBadRequest(_("Invalid request body")) @@ -987,8 +996,8 @@ class Controller(wsgi.Controller): except (KeyError, TypeError): password = utils.generate_password(FLAGS.password_length) - context = request.environ['nova.context'] - instance = self._get_server(context, instance_id) + context = req.environ['nova.context'] + instance = self._get_server(context, id) attr_map = { 'personality': 'files_to_inject', @@ -1025,10 +1034,10 @@ class Controller(wsgi.Controller): msg = _("Instance could not be found") raise exc.HTTPNotFound(explanation=msg) - instance = self._get_server(context, instance_id) + instance = self._get_server(context, id) self._add_instance_faults(context, [instance]) - view = self._view_builder.show(request, instance) + view = self._view_builder.show(req, instance) # Add on the adminPass attribute since the view doesn't do it view['server']['adminPass'] = password @@ -1036,11 +1045,17 @@ class Controller(wsgi.Controller): robj = wsgi.ResponseObject(view) return self._add_location(robj) + @wsgi.response(202) + @wsgi.serializers(xml=FullServerTemplate) + @wsgi.deserializers(xml=ActionDeserializer) + @wsgi.action('createImage') + @exception.novaclient_converter + @scheduler_api.redirect_handler @common.check_snapshots_enabled - def _action_create_image(self, input_dict, req, instance_id): + def _action_create_image(self, req, id, body): """Snapshot a server instance.""" context = req.environ['nova.context'] - entity = input_dict.get("createImage", {}) + entity = body.get("createImage", {}) try: image_name = entity["name"] @@ -1054,7 +1069,7 @@ class Controller(wsgi.Controller): raise exc.HTTPBadRequest(explanation=msg) # preserve link to server in image properties - server_ref = os.path.join(req.application_url, 'servers', instance_id) + server_ref = os.path.join(req.application_url, 'servers', id) props = {'instance_ref': server_ref} metadata = entity.get('metadata', {}) @@ -1065,7 +1080,7 @@ class Controller(wsgi.Controller): msg = _("Invalid metadata") raise exc.HTTPBadRequest(explanation=msg) - instance = self._get_server(context, instance_id) + instance = self._get_server(context, id) try: image = self.compute_api.snapshot(context, diff --git a/nova/api/openstack/extensions.py b/nova/api/openstack/extensions.py index 6c49e8ace..669bc699a 100644 --- a/nova/api/openstack/extensions.py +++ b/nova/api/openstack/extensions.py @@ -97,6 +97,14 @@ class ExtensionDescriptor(object): request_exts = [] return request_exts + def get_controller_extensions(self): + """List of extensions.ControllerExtension extension objects. + + Controller extensions are used to extend existing controllers. + """ + controller_exts = [] + return controller_exts + @classmethod def nsmap(cls): """Synthesize a namespace map from extension.""" @@ -441,6 +449,18 @@ class ExtensionManager(object): pass return request_exts + def get_controller_extensions(self): + """Returns a list of ControllerExtension objects.""" + controller_exts = [] + for ext in self.extensions.values(): + try: + controller_exts.extend(ext.get_controller_extensions()) + except AttributeError: + # NOTE(Vek): Extensions aren't required to have + # controller extensions + pass + return controller_exts + def _check_extension(self, extension): """Checks for required methods in extension objects.""" try: @@ -492,6 +512,20 @@ class ExtensionManager(object): '%(exc)s') % locals()) +class ControllerExtension(object): + """Extend core controllers of nova OpenStack API. + + Provide a way to extend existing nova OpenStack API core + controllers. + """ + + def __init__(self, extension, collection, controller): + self.extension = extension + self.collection = collection + self.controller = controller + + +@utils.deprecated("Superseded by ControllerExtension") class RequestExtension(object): """Extend requests and responses of core nova OpenStack API resources. @@ -507,6 +541,7 @@ class RequestExtension(object): self.pre_handler = pre_handler +@utils.deprecated("Superseded by ControllerExtension") class ActionExtension(object): """Add custom actions to core nova OpenStack API resources.""" diff --git a/nova/api/openstack/volume/__init__.py b/nova/api/openstack/volume/__init__.py index 075b53c29..d83725a1a 100644 --- a/nova/api/openstack/volume/__init__.py +++ b/nova/api/openstack/volume/__init__.py @@ -56,8 +56,10 @@ class APIRouter(base_wsgi.Router): ext_mgr = extensions.ExtensionManager() mapper = nova.api.openstack.ProjectMapper() + self.resources = {} self._setup_routes(mapper) self._setup_ext_routes(mapper, ext_mgr) + self._setup_extensions(ext_mgr) super(APIRouter, self).__init__(mapper) def _setup_ext_routes(self, mapper, ext_mgr): @@ -66,13 +68,13 @@ class APIRouter(base_wsgi.Router): for resource in ext_mgr.get_resources(): LOG.debug(_('Extended resource: %s'), resource.collection) - if resource.serializer is None: - resource.serializer = serializer + wsgi_resource = wsgi.Resource( + resource.controller, resource.deserializer, + resource.serializer) + self.resources[resource.collection] = wsgi_resource kargs = dict( - controller=wsgi.Resource( - resource.controller, resource.deserializer, - resource.serializer), + controller=wsgi_resource, collection=resource.collection_actions, member=resource.member_actions) @@ -81,19 +83,42 @@ class APIRouter(base_wsgi.Router): mapper.resource(resource.collection, resource.collection, **kargs) + def _setup_extensions(self, ext_mgr): + for extension in ext_mgr.get_controller_extensions(): + ext_name = extension.extension.name + collection = extension.collection + controller = extension.controller + + if collection not in self.resources: + LOG.warning(_('Extension %(ext_name)s: Cannot extend ' + 'resource %(collection)s: No such resource') % + locals()) + continue + + LOG.debug(_('Extension %(ext_name)s extending resource: ' + '%(collection)s') % locals()) + + resource = self.resources[collection] + resource.register_actions(controller) + resource.register_extensions(controller) + def _setup_routes(self, mapper): + self.resources['versions'] = versions.create_resource() mapper.connect("versions", "/", - controller=versions.create_resource(), + controller=self.resources['versions'], action='show') mapper.redirect("", "/") + self.resources['volumes'] = volumes.create_resource() mapper.resource("volume", "volumes", - controller=volumes.create_resource(), + controller=self.resources['volumes'], collection={'detail': 'GET'}) + self.resources['types'] = types.create_resource() mapper.resource("type", "types", - controller=types.create_resource()) + controller=self.resources['types']) + self.resources['snapshots'] = snapshots.create_resource() mapper.resource("snapshot", "snapshots", - controller=snapshots.create_resource()) + controller=self.resources['snapshots']) diff --git a/nova/api/openstack/wsgi.py b/nova/api/openstack/wsgi.py index f8790d9f0..defb26e6e 100644 --- a/nova/api/openstack/wsgi.py +++ b/nova/api/openstack/wsgi.py @@ -15,11 +15,13 @@ # License for the specific language governing permissions and limitations # under the License. +import inspect from xml.dom import minidom from xml.parsers import expat from lxml import etree import webob +from webob import exc from nova import exception from nova import log as logging @@ -700,6 +702,65 @@ class ResponseObject(object): return self._headers.copy() +def action_peek_json(body): + """Determine action to invoke.""" + + try: + decoded = utils.loads(body) + except ValueError: + msg = _("cannot understand JSON") + raise exception.MalformedRequestBody(reason=msg) + + # Make sure there's exactly one key... + if len(decoded) != 1: + msg = _("too many body keys") + raise exception.MalformedRequestBody(reason=msg) + + # Return the action and the decoded body... + return decoded.keys()[0] + + +def action_peek_xml(body): + """Determine action to invoke.""" + + dom = minidom.parseString(body) + action_node = dom.childNodes[0] + + return action_node.tagName + + +class ResourceExceptionHandler(object): + """Context manager to handle Resource exceptions. + + Used when processing exceptions generated by API implementation + methods (or their extensions). Converts most exceptions to Fault + exceptions, with the appropriate logging. + """ + + def __enter__(self): + return None + + def __exit__(self, ex_type, ex_value, ex_traceback): + if not ex_value: + return True + + if isinstance(ex_value, exception.NotAuthorized): + msg = unicode(ex_value) + raise Fault(webob.exc.HTTPUnauthorized(explanation=msg)) + elif isinstance(ex_value, TypeError): + LOG.exception(ex_value) + raise Fault(webob.exc.HTTPBadRequest()) + elif isinstance(ex_value, Fault): + LOG.info(_("Fault thrown: %s"), unicode(ex_value)) + raise ex_value + elif isinstance(ex_value, webob.exc.HTTPException): + LOG.info(_("HTTP exception thrown: %s"), unicode(ex_value)) + raise Fault(ex_value) + + # We didn't handle the exception + return False + + class Resource(wsgi.Application): """WSGI app that handles (de)serialization and controller dispatch. @@ -717,15 +778,17 @@ class Resource(wsgi.Application): """ def __init__(self, controller, deserializer=None, serializer=None, - **deserializers): + action_peek=None, **deserializers): """ :param controller: object that implement methods created by routes lib :param deserializer: object that can serialize the output of a controller into a webob response :param serializer: object that can deserialize a webob request into necessary pieces - + :param action_peek: dictionary of routines for peeking into an action + request body to determine the desired action """ + self.controller = controller self.deserializer = deserializer self.serializer = serializer @@ -738,6 +801,45 @@ class Resource(wsgi.Application): self.default_serializers = dict(xml=XMLDictSerializer, json=JSONDictSerializer) + self.action_peek = dict(xml=action_peek_xml, + json=action_peek_json) + self.action_peek.update(action_peek or {}) + + # Copy over the actions dictionary + self.wsgi_actions = {} + if controller: + self.register_actions(controller) + + # Save a mapping of extensions + self.wsgi_extensions = {} + self.wsgi_action_extensions = {} + + def register_actions(self, controller): + """Registers controller actions with this resource.""" + + actions = getattr(controller, 'wsgi_actions', {}) + for key, method_name in actions.items(): + self.wsgi_actions[key] = getattr(controller, method_name) + + def register_extensions(self, controller): + """Registers controller extensions with this resource.""" + + extensions = getattr(controller, 'wsgi_extensions', []) + for method_name, action_name in extensions: + # Look up the extending method + extension = getattr(controller, method_name) + + if action_name: + # Extending an action... + if action_name not in self.wsgi_action_extensions: + self.wsgi_action_extensions[action_name] = [] + self.wsgi_action_extensions[action_name].append(extension) + else: + # Extending a regular method + if method_name not in self.wsgi_extensions: + self.wsgi_extensions[method_name] = [] + self.wsgi_extensions[method_name].append(extension) + def get_action_args(self, request_environment): """Parse dictionary created by routes library.""" @@ -793,6 +895,66 @@ class Resource(wsgi.Application): return deserializer().deserialize(body) + def pre_process_extensions(self, extensions, request, action_args): + # List of callables for post-processing extensions + post = [] + + for ext in extensions: + if inspect.isgeneratorfunction(ext): + response = None + + # If it's a generator function, the part before the + # yield is the preprocessing stage + try: + with ResourceExceptionHandler(): + gen = ext(req=request, **action_args) + response = gen.next() + except Fault as ex: + response = ex + + # We had a response... + if response: + return response, [] + + # No response, queue up generator for post-processing + post.append(gen) + else: + # Regular functions only perform post-processing + post.append(ext) + + # Run post-processing in the reverse order + return None, reversed(post) + + def post_process_extensions(self, extensions, resp_obj, request, + action_args): + for ext in extensions: + response = None + if inspect.isgenerator(ext): + # If it's a generator, run the second half of + # processing + try: + with ResourceExceptionHandler(): + response = ext.send(resp_obj) + except StopIteration: + # Normal exit of generator + continue + except Fault as ex: + response = ex + else: + # Regular functions get post-processing... + try: + with ResourceExceptionHandler(): + response = ext(req=request, resp_obj=resp_obj, + **action_args) + except Fault as ex: + response = ex + + # We had a response... + if response: + return response + + return None + @webob.dec.wsgify(RequestClass=Request) def __call__(self, request): """WSGI method that controls (de)serialization and method dispatch.""" @@ -809,9 +971,16 @@ class Resource(wsgi.Application): # Get the implementing method try: - meth = self.get_method(request, action) + meth, extensions = self.get_method(request, action, + content_type, body) except (AttributeError, TypeError): return Fault(webob.exc.HTTPNotFound()) + except KeyError as ex: + msg = _("There is no such action: %s") % ex.args[0] + return Fault(webob.exc.HTTPBadRequest(explanation=msg)) + except exception.MalformedRequestBody: + msg = _("Malformed request body") + return Fault(webob.exc.HTTPBadRequest(explanation=msg)) # Now, deserialize the request body... try: @@ -837,21 +1006,16 @@ class Resource(wsgi.Application): msg = _("Malformed request url") return Fault(webob.exc.HTTPBadRequest(explanation=msg)) - response = None - try: - action_result = self.dispatch(meth, request, action_args) - except exception.NotAuthorized as ex: - msg = unicode(ex) - response = Fault(webob.exc.HTTPUnauthorized(explanation=msg)) - except TypeError as ex: - LOG.exception(ex) - response = Fault(webob.exc.HTTPBadRequest()) - except Fault as ex: - LOG.info(_("Fault thrown: %s"), unicode(ex)) - response = ex - except webob.exc.HTTPException as ex: - LOG.info(_("HTTP exception thrown: %s"), unicode(ex)) - response = Fault(ex) + # Run pre-processing extensions + response, post = self.pre_process_extensions(extensions, + request, action_args) + + if not response: + try: + with ResourceExceptionHandler(): + action_result = self.dispatch(meth, request, action_args) + except Fault as ex: + response = ex if not response: # No exceptions; convert action_result into a @@ -864,7 +1028,12 @@ class Resource(wsgi.Application): else: response = action_result + # Run post-processing extensions if resp_obj: + response = self.post_process_extensions(post, resp_obj, + request, action_args) + + if resp_obj and not response: if self.serializer: response = self.serializer.serialize(request, resp_obj.obj, @@ -890,12 +1059,29 @@ class Resource(wsgi.Application): return response - def get_method(self, request, action): - """Look up the action-specific method.""" + def get_method(self, request, action, content_type, body): + """Look up the action-specific method and its extensions.""" + + # Look up the method + try: + if not self.controller: + meth = getattr(self, action) + else: + meth = getattr(self.controller, action) + except AttributeError as ex: + if action != 'action' or not self.wsgi_actions: + # Propagate the error + raise + else: + return meth, self.wsgi_extensions.get(action, []) + + # OK, it's an action; figure out which action... + mtype = _MEDIA_TYPE_MAP.get(content_type) + action_name = self.action_peek[mtype](body) - if self.controller is None: - return getattr(self, action) - return getattr(self.controller, action) + # Look up the action method + return (self.wsgi_actions[action_name], + self.wsgi_action_extensions.get(action_name, [])) def dispatch(self, method, request, action_args): """Dispatch a call to the action-specific method.""" @@ -903,14 +1089,91 @@ class Resource(wsgi.Application): return method(req=request, **action_args) +def action(name): + """Mark a function as an action. + + The given name will be taken as the action key in the body. + """ + + def decorator(func): + func.wsgi_action = name + return func + return decorator + + +def extends(*args, **kwargs): + """Indicate a function extends an operation. + + Can be used as either:: + + @extends + def index(...): + pass + + or as:: + + @extends(action='resize') + def _action_resize(...): + pass + """ + + def decorator(func): + # Store enough information to find what we're extending + func.wsgi_extends = (func.__name__, kwargs.get('action')) + return func + + # If we have positional arguments, call the decorator + if args: + return decorator(*args) + + # OK, return the decorator instead + return decorator + + +class ControllerMetaclass(type): + """Controller metaclass. + + This metaclass automates the task of assembling a dictionary + mapping action keys to method names. + """ + + def __new__(mcs, name, bases, cls_dict): + """Adds the wsgi_actions dictionary to the class.""" + + # Find all actions + actions = {} + extensions = [] + for key, value in cls_dict.items(): + if not callable(value): + continue + if getattr(value, 'wsgi_action', None): + actions[value.wsgi_action] = key + elif getattr(value, 'wsgi_extends', None): + extensions.append(value.wsgi_extends) + + # Add the actions and extensions to the class dict + cls_dict['wsgi_actions'] = actions + cls_dict['wsgi_extensions'] = extensions + + return super(ControllerMetaclass, mcs).__new__(mcs, name, bases, + cls_dict) + + class Controller(object): """Default controller.""" + __metaclass__ = ControllerMetaclass + _view_builder_class = None def __init__(self, view_builder=None): """Initialize controller with a view builder instance.""" - self._view_builder = view_builder or self._view_builder_class() + if view_builder: + self._view_builder = view_builder + elif self._view_builder_class: + self._view_builder = self._view_builder_class() + else: + self._view_builder = None class Fault(webob.exc.HTTPException): -- cgit