summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--nova/db/sqlalchemy/session.py4
-rw-r--r--nova/tests/test_sqlalchemy.py63
2 files changed, 67 insertions, 0 deletions
diff --git a/nova/db/sqlalchemy/session.py b/nova/db/sqlalchemy/session.py
index 727f79eec..28ec613c5 100644
--- a/nova/db/sqlalchemy/session.py
+++ b/nova/db/sqlalchemy/session.py
@@ -564,6 +564,10 @@ class Session(sqlalchemy.orm.session.Session):
def flush(self, *args, **kwargs):
return super(Session, self).flush(*args, **kwargs)
+ @wrap_db_error
+ def execute(self, *args, **kwargs):
+ return super(Session, self).execute(*args, **kwargs)
+
def get_maker(engine, autocommit=True, expire_on_commit=False):
"""Return a SQLAlchemy sessionmaker using the given engine."""
diff --git a/nova/tests/test_sqlalchemy.py b/nova/tests/test_sqlalchemy.py
index f79d607f8..5c7f4450b 100644
--- a/nova/tests/test_sqlalchemy.py
+++ b/nova/tests/test_sqlalchemy.py
@@ -22,8 +22,14 @@ try:
except ImportError:
MySQLdb = None
+from sqlalchemy import Column, MetaData, Table, UniqueConstraint
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy import DateTime, Integer
+
from nova import context
+from nova.db.sqlalchemy import models
from nova.db.sqlalchemy import session
+from nova import exception
from nova import test
@@ -64,3 +70,60 @@ class DbPoolTestCase(test.TestCase):
self.assertEqual(info['kwargs']['max_idle'], 11)
self.assertEqual(info['kwargs']['min_size'], 21)
self.assertEqual(info['kwargs']['max_size'], 42)
+
+
+BASE = declarative_base()
+_TABLE_NAME = '__tmp__test__tmp__'
+
+
+class TmpTable(BASE, models.NovaBase):
+ __tablename__ = _TABLE_NAME
+ id = Column(Integer, primary_key=True)
+ foo = Column(Integer)
+
+
+class SessionErrorWrapperTestCase(test.TestCase):
+ def setUp(self):
+ super(SessionErrorWrapperTestCase, self).setUp()
+ meta = MetaData()
+ meta.bind = session.get_engine()
+ test_table = Table(_TABLE_NAME, meta,
+ Column('id', Integer, primary_key=True,
+ nullable=False),
+ Column('deleted', Integer, default=0),
+ Column('deleted_at', DateTime),
+ Column('updated_at', DateTime),
+ Column('created_at', DateTime),
+ Column('foo', Integer),
+ UniqueConstraint('foo', name='uniq_foo'))
+ test_table.create()
+
+ def tearDown(self):
+ super(SessionErrorWrapperTestCase, self).tearDown()
+ meta = MetaData()
+ meta.bind = session.get_engine()
+ test_table = Table(_TABLE_NAME, meta, autoload=True)
+ test_table.drop()
+
+ def test_flush_wrapper(self):
+ tbl = TmpTable()
+ tbl.update({'foo': 10})
+ tbl.save()
+
+ tbl2 = TmpTable()
+ tbl2.update({'foo': 10})
+ self.assertRaises(exception.DBDuplicateEntry, tbl2.save)
+
+ def test_execute_wrapper(self):
+ _session = session.get_session()
+ with _session.begin():
+ for i in [10, 20]:
+ tbl = TmpTable()
+ tbl.update({'foo': i})
+ tbl.save(session=_session)
+
+ method = _session.query(TmpTable).\
+ filter_by(foo=10).\
+ update
+ self.assertRaises(exception.DBDuplicateEntry,
+ method, {'foo': 20})