summaryrefslogtreecommitdiffstats
path: root/tests/unit/db/sqlalchemy/test_sqlalchemy.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unit/db/sqlalchemy/test_sqlalchemy.py')
-rw-r--r--tests/unit/db/sqlalchemy/test_sqlalchemy.py63
1 files changed, 63 insertions, 0 deletions
diff --git a/tests/unit/db/sqlalchemy/test_sqlalchemy.py b/tests/unit/db/sqlalchemy/test_sqlalchemy.py
index d4894ba..10f7e41 100644
--- a/tests/unit/db/sqlalchemy/test_sqlalchemy.py
+++ b/tests/unit/db/sqlalchemy/test_sqlalchemy.py
@@ -23,8 +23,14 @@ try:
except ImportError:
HAS_MYSQLDB = False
+from sqlalchemy import Column, MetaData, Table, UniqueConstraint
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy import DateTime, Integer
+
from openstack.common import context
from openstack.common import exception
+from openstack.common.db import common as common_db
+from openstack.common.db.sqlalchemy import models
from openstack.common.db.sqlalchemy import session
from tests import utils as test_utils
@@ -70,3 +76,60 @@ class DbPoolTestCase(test_utils.BaseTestCase):
self.assertEqual(info['kwargs']['max_idle'], 11)
self.assertEqual(info['kwargs']['min_size'], 21)
self.assertEqual(info['kwargs']['max_size'], 42)
+
+
+BASE = declarative_base()
+_TABLE_NAME = '__tmp__test__tmp__'
+
+
+class TmpTable(BASE, models.ModelBase):
+ __tablename__ = _TABLE_NAME
+ id = Column(Integer, primary_key=True)
+ foo = Column(Integer)
+
+
+class SessionErrorWrapperTestCase(test_utils.BaseTestCase):
+ def setUp(self):
+ super(SessionErrorWrapperTestCase, self).setUp()
+ meta = MetaData()
+ meta.bind = session.get_engine()
+ test_table = Table(_TABLE_NAME, meta,
+ Column('id', Integer, primary_key=True,
+ nullable=False),
+ Column('deleted', Integer, default=0),
+ Column('deleted_at', DateTime),
+ Column('updated_at', DateTime),
+ Column('created_at', DateTime),
+ Column('foo', Integer),
+ UniqueConstraint('foo', name='uniq_foo'))
+ test_table.create()
+
+ def tearDown(self):
+ super(SessionErrorWrapperTestCase, self).tearDown()
+ meta = MetaData()
+ meta.bind = session.get_engine()
+ test_table = Table(_TABLE_NAME, meta, autoload=True)
+ test_table.drop()
+
+ def test_flush_wrapper(self):
+ tbl = TmpTable()
+ tbl.update({'foo': 10})
+ tbl.save()
+
+ tbl2 = TmpTable()
+ tbl2.update({'foo': 10})
+ self.assertRaises(common_db.DBDuplicateEntry, tbl2.save)
+
+ def test_execute_wrapper(self):
+ _session = session.get_session()
+ with _session.begin():
+ for i in [10, 20]:
+ tbl = TmpTable()
+ tbl.update({'foo': i})
+ tbl.save(session=_session)
+
+ method = _session.query(TmpTable).\
+ filter_by(foo=10).\
+ update
+ self.assertRaises(common_db.DBDuplicateEntry,
+ method, {'foo': 20})