From 6daf6d30dfeab459a0b672d909b115b1a5ce86c3 Mon Sep 17 00:00:00 2001 From: "Kevin L. Mitchell" Date: Wed, 13 Jul 2011 14:41:13 -0500 Subject: Augment rate limiting to allow greater flexibility through the api-paste.ini configuration --- nova/api/openstack/limits.py | 96 +++++++++++++++++++++++++++++++-- nova/tests/api/openstack/test_limits.py | 30 +++++++++-- 2 files changed, 116 insertions(+), 10 deletions(-) diff --git a/nova/api/openstack/limits.py b/nova/api/openstack/limits.py index d08287f6b..8642a28f9 100644 --- a/nova/api/openstack/limits.py +++ b/nova/api/openstack/limits.py @@ -31,8 +31,8 @@ from collections import defaultdict from webob.dec import wsgify from nova import quota +from nova import utils from nova import wsgi as base_wsgi -from nova import wsgi from nova.api.openstack import common from nova.api.openstack import faults from nova.api.openstack.views import limits as limits_views @@ -119,6 +119,8 @@ class Limit(object): 60 * 60 * 24: "DAY", } + UNIT_MAP = dict([(v, k) for k, v in UNITS.items()]) + def __init__(self, verb, uri, regex, value, unit): """ Initialize a new `Limit`. @@ -224,16 +226,30 @@ class RateLimitingMiddleware(base_wsgi.Middleware): is stored in memory for this implementation. """ - def __init__(self, application, limits=None): + def __init__(self, application, limits=None, limiter=None, **kwargs): """ Initialize new `RateLimitingMiddleware`, which wraps the given WSGI application and sets up the given limits. @param application: WSGI application to wrap - @param limits: List of dictionaries describing limits + @param limits: String describing limits + @param limiter: String identifying class for representing limits + + Other parameters are passed to the constructor for the limiter. """ base_wsgi.Middleware.__init__(self, application) - self._limiter = Limiter(limits or DEFAULT_LIMITS) + + # Select the limiter class + if limiter is None: + limiter = Limiter + else: + limiter = utils.import_class(limiter) + + # Parse the limits, if any are provided + if limits is not None: + limits = limiter.parse_limits(limits) + + self._limiter = limiter(limits or DEFAULT_LIMITS, **kwargs) @wsgify(RequestClass=wsgi.Request) def __call__(self, req): @@ -271,7 +287,7 @@ class Limiter(object): Rate-limit checking class which handles limits in memory. """ - def __init__(self, limits): + def __init__(self, limits, **kwargs): """ Initialize the new `Limiter`. @@ -280,6 +296,12 @@ class Limiter(object): self.limits = copy.deepcopy(limits) self.levels = defaultdict(lambda: copy.deepcopy(limits)) + # Pick up any per-user limit information + for key, value in kwargs.items(): + if key.startswith('user:'): + username = key[5:] + self.levels[username] = self.parse_limits(value) + def get_limits(self, username=None): """ Return the limits for a given user. @@ -305,6 +327,59 @@ class Limiter(object): return None, None + @staticmethod + def parse_limits(limits): + """ + Convert a string into a list of Limit instances. This + implementation expects a semicolon-separated sequence of + parenthesized groups, where each group contains a + comma-separated sequence consisting of HTTP method, + user-readable URI, a URI reg-exp, an integer number of + requests which can be made, and a unit of measure. Valid + values for the latter are "SECOND", "MINUTE", "HOUR", and + "DAY". + + @return: List of Limit instances. + """ + + # Handle empty limit strings + limits = limits.strip() + if not limits: + return [] + + # Split up the limits by semicolon + result = [] + for group in limits.split(';'): + group = group.strip() + if group[0] != '(' or group[-1] != ')': + raise ValueError("Groups must be surrounded by parentheses") + group = group[1:-1] + + # Extract the Limit arguments + args = [a.strip() for a in group.split(',')] + if len(args) != 5: + raise ValueError("Groups must contain exactly 5 elements") + + # Pull out the arguments + verb, uri, regex, value, unit = args + + # Upper-case the verb + verb = verb.upper() + + # Convert value--raises ValueError if it's not integer + value = int(value) + + # Convert unit + unit = unit.upper() + if unit not in Limit.UNIT_MAP: + raise ValueError("Invalid units specified") + unit = Limit.UNIT_MAP[unit] + + # Build a limit + result.append(Limit(verb, uri, regex, value, unit)) + + return result + class WsgiLimiter(object): """ @@ -388,3 +463,14 @@ class WsgiLimiterProxy(object): return None, None return resp.getheader("X-Wait-Seconds"), resp.read() or None + + @staticmethod + def parse_limits(limits): + """ + Ignore a limits string--simply doesn't apply for the limit + proxy. + + @return: Empty list. + """ + + return [] diff --git a/nova/tests/api/openstack/test_limits.py b/nova/tests/api/openstack/test_limits.py index 38c959fae..eeacd2fca 100644 --- a/nova/tests/api/openstack/test_limits.py +++ b/nova/tests/api/openstack/test_limits.py @@ -400,6 +400,10 @@ class LimitsControllerV11Test(BaseLimitTestSuite): self._test_index_absolute_limits_json(expected) +class TestLimiter(limits.Limiter): + pass + + class LimitMiddlewareTest(BaseLimitTestSuite): """ Tests for the `limits.RateLimitingMiddleware` class. @@ -413,10 +417,14 @@ class LimitMiddlewareTest(BaseLimitTestSuite): def setUp(self): """Prepare middleware for use through fake WSGI app.""" BaseLimitTestSuite.setUp(self) - _limits = [ - limits.Limit("GET", "*", ".*", 1, 60), - ] - self.app = limits.RateLimitingMiddleware(self._empty_app, _limits) + _limits = '(GET, *, .*, 1, MINUTE)' + self.app = limits.RateLimitingMiddleware(self._empty_app, _limits, + "%s.TestLimiter" % + self.__class__.__module__) + + def test_limit_class(self): + """Test that middleware selected correct limiter class.""" + assert isinstance(self.app._limiter, TestLimiter) def test_good_request(self): """Test successful GET request through middleware.""" @@ -500,7 +508,8 @@ class LimiterTest(BaseLimitTestSuite): def setUp(self): """Run before each test.""" BaseLimitTestSuite.setUp(self) - self.limiter = limits.Limiter(TEST_LIMITS) + userlimits = {'user:user3': ''} + self.limiter = limits.Limiter(TEST_LIMITS, **userlimits) def _check(self, num, verb, url, username=None): """Check and yield results from checks.""" @@ -605,6 +614,12 @@ class LimiterTest(BaseLimitTestSuite): results = list(self._check(10, "PUT", "/anything")) self.assertEqual(expected, results) + def test_user_limit(self): + """ + Test user-specific limits. + """ + self.assertEqual(self.limiter.levels['user3'], []) + def test_multiple_users(self): """ Tests involving multiple users. @@ -619,6 +634,11 @@ class LimiterTest(BaseLimitTestSuite): results = list(self._check(15, "PUT", "/anything", "user2")) self.assertEqual(expected, results) + # User3 + expected = [None] * 20 + results = list(self._check(20, "PUT", "/anything", "user3")) + self.assertEqual(expected, results) + self.time += 1.0 # User1 again -- cgit