diff options
| -rw-r--r-- | openstack/common/db/sqlalchemy/session.py | 4 | ||||
| -rw-r--r-- | tests/unit/db/sqlalchemy/test_sqlalchemy.py | 63 |
2 files changed, 67 insertions, 0 deletions
diff --git a/openstack/common/db/sqlalchemy/session.py b/openstack/common/db/sqlalchemy/session.py index 700273a..06b198e 100644 --- a/openstack/common/db/sqlalchemy/session.py +++ b/openstack/common/db/sqlalchemy/session.py @@ -593,6 +593,10 @@ class Session(sqlalchemy.orm.session.Session): def flush(self, *args, **kwargs): return super(Session, self).flush(*args, **kwargs) + @wrap_db_error + def execute(self, *args, **kwargs): + return super(Session, self).execute(*args, **kwargs) + def get_maker(engine, autocommit=True, expire_on_commit=False): """Return a SQLAlchemy sessionmaker using the given engine.""" diff --git a/tests/unit/db/sqlalchemy/test_sqlalchemy.py b/tests/unit/db/sqlalchemy/test_sqlalchemy.py index d4894ba..10f7e41 100644 --- a/tests/unit/db/sqlalchemy/test_sqlalchemy.py +++ b/tests/unit/db/sqlalchemy/test_sqlalchemy.py @@ -23,8 +23,14 @@ try: except ImportError: HAS_MYSQLDB = False +from sqlalchemy import Column, MetaData, Table, UniqueConstraint +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import DateTime, Integer + from openstack.common import context from openstack.common import exception +from openstack.common.db import common as common_db +from openstack.common.db.sqlalchemy import models from openstack.common.db.sqlalchemy import session from tests import utils as test_utils @@ -70,3 +76,60 @@ class DbPoolTestCase(test_utils.BaseTestCase): self.assertEqual(info['kwargs']['max_idle'], 11) self.assertEqual(info['kwargs']['min_size'], 21) self.assertEqual(info['kwargs']['max_size'], 42) + + +BASE = declarative_base() +_TABLE_NAME = '__tmp__test__tmp__' + + +class TmpTable(BASE, models.ModelBase): + __tablename__ = _TABLE_NAME + id = Column(Integer, primary_key=True) + foo = Column(Integer) + + +class SessionErrorWrapperTestCase(test_utils.BaseTestCase): + def setUp(self): + super(SessionErrorWrapperTestCase, self).setUp() + meta = MetaData() + meta.bind = session.get_engine() + test_table = Table(_TABLE_NAME, meta, + Column('id', Integer, primary_key=True, + nullable=False), + Column('deleted', Integer, default=0), + Column('deleted_at', DateTime), + Column('updated_at', DateTime), + Column('created_at', DateTime), + Column('foo', Integer), + UniqueConstraint('foo', name='uniq_foo')) + test_table.create() + + def tearDown(self): + super(SessionErrorWrapperTestCase, self).tearDown() + meta = MetaData() + meta.bind = session.get_engine() + test_table = Table(_TABLE_NAME, meta, autoload=True) + test_table.drop() + + def test_flush_wrapper(self): + tbl = TmpTable() + tbl.update({'foo': 10}) + tbl.save() + + tbl2 = TmpTable() + tbl2.update({'foo': 10}) + self.assertRaises(common_db.DBDuplicateEntry, tbl2.save) + + def test_execute_wrapper(self): + _session = session.get_session() + with _session.begin(): + for i in [10, 20]: + tbl = TmpTable() + tbl.update({'foo': i}) + tbl.save(session=_session) + + method = _session.query(TmpTable).\ + filter_by(foo=10).\ + update + self.assertRaises(common_db.DBDuplicateEntry, + method, {'foo': 20}) |
