summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTony NIU <niuwl586@gmail.com>2013-01-09 20:09:40 +0800
committerTony NIU <niuwl586@gmail.com>2013-01-15 08:43:28 +0800
commit9c2c4ece645119fd450783217c359b38584553c8 (patch)
treeb6f13baf986da692864ae621629730cc697e6063
parent9460ff5c35809f4911cb5a1ee5f68d6351e797f4 (diff)
add database string field length check
Added database string field length check, so when insert to a table, if the length of string field exceed the limit of column when, it will return a 400 error instead of truncating the string. Change-Id: I7216fe736ea6e5a23b5647b107fcb2699f1fa99d Fixes: bug #1090247
-rw-r--r--keystone/common/sql/core.py40
-rw-r--r--keystone/exception.py5
-rw-r--r--tests/test_backend.py20
-rw-r--r--tests/test_backend_sql.py20
-rw-r--r--tests/test_v3.py2
-rw-r--r--tests/test_v3_catalog.py9
6 files changed, 94 insertions, 2 deletions
diff --git a/keystone/common/sql/core.py b/keystone/common/sql/core.py
index 10629fc3..3634c75a 100644
--- a/keystone/common/sql/core.py
+++ b/keystone/common/sql/core.py
@@ -23,10 +23,12 @@ from sqlalchemy.ext import declarative
import sqlalchemy.orm
import sqlalchemy.pool
from sqlalchemy import types as sql_types
+from sqlalchemy.orm.attributes import InstrumentedAttribute
from keystone.common import logging
from keystone import config
from keystone.openstack.common import jsonutils
+from keystone import exception
CONF = config.CONF
@@ -49,6 +51,44 @@ Boolean = sql.Boolean
Text = sql.Text
+def initialize_decorator(init):
+ """Ensure that the length of string field do not exceed the limit.
+
+ This decorator check the initialize arguments, to make sure the
+ length of string field do not exceed the length limit, or raise a
+ 'StringLengthExceeded' exception.
+
+ Use decorator instead of inheritance, because the metaclass will
+ check the __tablename__, primary key columns, etc. at the class
+ definition.
+
+ """
+ def initialize(self, *args, **kwargs):
+ cls = type(self)
+ for k, v in kwargs.items():
+ if hasattr(cls, k):
+ attr = getattr(cls, k)
+ if isinstance(attr, InstrumentedAttribute):
+ column = attr.property.columns[0]
+ if isinstance(column.type, String):
+ if column.type.length and \
+ column.type.length < len(str(v)):
+ #if signing.token_format == 'PKI', the id will
+ #store it's public key which is very long.
+ if config.CONF.signing.token_format == 'PKI' and \
+ self.__tablename__ == 'token' and \
+ k == 'id':
+ continue
+
+ raise exception.StringLengthExceeded(
+ string=v, type=k, length=column.type.length)
+
+ init(self, *args, **kwargs)
+ return initialize
+
+ModelBase.__init__ = initialize_decorator(ModelBase.__init__)
+
+
def set_global_engine(engine):
global GLOBAL_ENGINE
GLOBAL_ENGINE = engine
diff --git a/keystone/exception.py b/keystone/exception.py
index 96caf322..9923f578 100644
--- a/keystone/exception.py
+++ b/keystone/exception.py
@@ -73,6 +73,11 @@ class ValidationError(Error):
title = 'Bad Request'
+class StringLengthExceeded(ValidationError):
+ """The length of string "%(string)s" exceeded the limit of column
+ %(type)s(CHAR(%(length)d))."""
+
+
class SecurityError(Error):
"""Avoids exposing details of security failures, unless in debug mode."""
diff --git a/tests/test_backend.py b/tests/test_backend.py
index 5bcdfbe3..af4d7472 100644
--- a/tests/test_backend.py
+++ b/tests/test_backend.py
@@ -886,7 +886,7 @@ class CatalogTests(object):
endpoint = {
'id': uuid.uuid4().hex,
'region': uuid.uuid4().hex,
- 'interface': uuid.uuid4().hex,
+ 'interface': uuid.uuid4().hex[:8],
'url': uuid.uuid4().hex,
'service_id': service['id'],
}
@@ -934,6 +934,24 @@ class CatalogTests(object):
{},
uuid.uuid4().hex)
+ def test_create_endpoint(self):
+ service = {
+ 'id': uuid.uuid4().hex,
+ 'type': uuid.uuid4().hex,
+ 'name': uuid.uuid4().hex,
+ 'description': uuid.uuid4().hex,
+ }
+ self.catalog_api.create_service(service['id'], service.copy())
+
+ endpoint = {
+ 'id': uuid.uuid4().hex,
+ 'region': "0" * 255,
+ 'service_id': service['id'],
+ 'interface': 'public',
+ 'url': uuid.uuid4().hex,
+ }
+ self.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
+
class PolicyTests(object):
def _new_policy_ref(self):
diff --git a/tests/test_backend_sql.py b/tests/test_backend_sql.py
index 8306c65e..fb8cafa1 100644
--- a/tests/test_backend_sql.py
+++ b/tests/test_backend_sql.py
@@ -264,6 +264,26 @@ class SqlCatalog(SqlTests, test_backend.CatalogTests):
self.assertIsNone(catalog_endpoint.get('adminURL'))
self.assertIsNone(catalog_endpoint.get('internalURL'))
+ def test_create_endpoint_400(self):
+ service = {
+ 'id': uuid.uuid4().hex,
+ 'type': uuid.uuid4().hex,
+ 'name': uuid.uuid4().hex,
+ 'description': uuid.uuid4().hex,
+ }
+ self.catalog_api.create_service(service['id'], service.copy())
+
+ endpoint = {
+ 'id': uuid.uuid4().hex,
+ 'region': "0" * 256,
+ 'service_id': service['id'],
+ 'interface': 'public',
+ 'url': uuid.uuid4().hex,
+ }
+
+ with self.assertRaises(exception.StringLengthExceeded):
+ self.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
+
class SqlPolicy(SqlTests, test_backend.PolicyTests):
pass
diff --git a/tests/test_v3.py b/tests/test_v3.py
index 9a999585..ed7b5e66 100644
--- a/tests/test_v3.py
+++ b/tests/test_v3.py
@@ -42,7 +42,7 @@ class RestfulTestCase(test_content_types.RestfulTestCase):
def new_endpoint_ref(self, service_id):
ref = self.new_ref()
- ref['interface'] = uuid.uuid4().hex
+ ref['interface'] = uuid.uuid4().hex[:8]
ref['service_id'] = service_id
ref['url'] = uuid.uuid4().hex
return ref
diff --git a/tests/test_v3_catalog.py b/tests/test_v3_catalog.py
index 9f5bf913..3a901709 100644
--- a/tests/test_v3_catalog.py
+++ b/tests/test_v3_catalog.py
@@ -119,6 +119,15 @@ class CatalogTestCase(test_v3.RestfulTestCase):
body={'endpoint': ref})
self.assertValidEndpointResponse(r, ref)
+ def assertValidErrorResponse(self, response):
+ self.assertTrue(response.status in [400])
+
+ def test_create_endpoint_400(self):
+ """POST /endpoints"""
+ ref = self.new_endpoint_ref(service_id=self.service_id)
+ ref["region"] = "0" * 256
+ self.post('/endpoints', body={'endpoint': ref}, expected_status=400)
+
def test_get_endpoint(self):
"""GET /endpoints/{endpoint_id}"""
r = self.get(