diff options
Diffstat (limited to 'tests/test_sql_core.py')
-rw-r--r-- | tests/test_sql_core.py | 142 |
1 files changed, 142 insertions, 0 deletions
diff --git a/tests/test_sql_core.py b/tests/test_sql_core.py index d8f2a4f7..e60005f5 100644 --- a/tests/test_sql_core.py +++ b/tests/test_sql_core.py @@ -17,6 +17,130 @@ from keystone.common import sql from keystone import test +class CallbackMonitor: + def __init__(self, expect_called=True, raise_=False): + self.expect_called = expect_called + self.called = False + self._complete = False + self._raise = raise_ + + def call_this(self): + if self._complete: + return + + if not self.expect_called: + raise Exception("Did not expect callback.") + + if self.called: + raise Exception("Callback already called.") + + self.called = True + + if self._raise: + raise Exception("When called, raises.") + + def check(self): + if self.expect_called: + if not self.called: + raise Exception("Expected function to be called.") + self._complete = True + + +class TestGlobalEngine(test.TestCase): + + def tearDown(self): + sql.set_global_engine(None) + super(TestGlobalEngine, self).tearDown() + + def test_notify_on_set(self): + # If call sql.set_global_engine(), notify callbacks get called. + + cb_mon = CallbackMonitor() + + sql.register_global_engine_callback(cb_mon.call_this) + fake_engine = object() + sql.set_global_engine(fake_engine) + + cb_mon.check() + + def test_multi_notify(self): + # You can also set multiple notify callbacks and they each get called. + + cb_mon1 = CallbackMonitor() + cb_mon2 = CallbackMonitor() + + sql.register_global_engine_callback(cb_mon1.call_this) + sql.register_global_engine_callback(cb_mon2.call_this) + + fake_engine = object() + sql.set_global_engine(fake_engine) + + cb_mon1.check() + cb_mon2.check() + + def test_notify_once(self): + # After a callback is called, it's not called again if set global + # engine again. + + cb_mon = CallbackMonitor() + + sql.register_global_engine_callback(cb_mon.call_this) + fake_engine = object() + sql.set_global_engine(fake_engine) + + fake_engine = object() + # Note that cb_mon.call_this would raise if it's called again. + sql.set_global_engine(fake_engine) + + cb_mon.check() + + def test_set_same_engine(self): + # If you set the global engine to the same engine, callbacks don't get + # called. + + fake_engine = object() + + sql.set_global_engine(fake_engine) + + cb_mon = CallbackMonitor(expect_called=False) + sql.register_global_engine_callback(cb_mon.call_this) + + # Note that cb_mon.call_this would raise if it's called. + sql.set_global_engine(fake_engine) + + cb_mon.check() + + def test_notify_register_same(self): + # If you register the same callback twice, only gets called once. + cb_mon = CallbackMonitor() + + sql.register_global_engine_callback(cb_mon.call_this) + sql.register_global_engine_callback(cb_mon.call_this) + + fake_engine = object() + # Note that cb_mon.call_this would raise if it's called twice. + sql.set_global_engine(fake_engine) + + cb_mon.check() + + def test_callback_throws(self): + # If a callback function raises, + # a) the caller doesn't know about it, + # b) other callbacks are still called + + cb_mon1 = CallbackMonitor(raise_=True) + cb_mon2 = CallbackMonitor() + + sql.register_global_engine_callback(cb_mon1.call_this) + sql.register_global_engine_callback(cb_mon2.call_this) + + fake_engine = object() + sql.set_global_engine(fake_engine) + + cb_mon1.check() + cb_mon2.check() + + class TestBase(test.TestCase): def tearDown(self): @@ -38,3 +162,21 @@ class TestBase(test.TestCase): engine1 = base.get_engine() engine2 = base.get_engine(allow_global_engine=False) self.assertIsNot(engine1, engine2) + + def test_get_session(self): + # autocommit and expire_on_commit flags to get_session() are passed on + # to the session created. + + base = sql.Base() + session = base.get_session(autocommit=False, expire_on_commit=True) + + self.assertFalse(session.autocommit) + self.assertTrue(session.expire_on_commit) + + def test_get_session_invalidated(self): + # If clear the global engine, a new engine is used for get_session(). + base = sql.Base() + session1 = base.get_session() + sql.set_global_engine(None) + session2 = base.get_session() + self.assertIsNot(session1.bind, session2.bind) |