summaryrefslogtreecommitdiffstats
path: root/tests/unit/db/sqlalchemy/test_sqlalchemy.py
blob: ea01bc8b2f05605c84e19a7b24635680a6be22f1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (c) 2012 Rackspace Hosting
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

"""Unit tests for SQLAlchemy specific code."""

from eventlet import db_pool
try:
    import MySQLdb
    HAS_MYSQLDB = True
except ImportError:
    HAS_MYSQLDB = False

from sqlalchemy import Column, MetaData, Table, UniqueConstraint
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import DateTime, Integer

from openstack.common import context
from openstack.common import exception
from openstack.common.db.sqlalchemy import models
from openstack.common.db.sqlalchemy import session
from tests import utils as test_utils


class TestException(exception.OpenstackException):
    pass


class DbPoolTestCase(test_utils.BaseTestCase):
    def setUp(self):
        super(DbPoolTestCase, self).setUp()
        if not HAS_MYSQLDB:
            self.skipTest("Required module MySQLdb missing.")
        self.config(sql_dbpool_enable=True)
        self.user_id = 'fake'
        self.project_id = 'fake'
        self.context = context.RequestContext(self.user_id, self.project_id)

    def test_db_pool_option(self):
        self.config(sql_idle_timeout=11, sql_min_pool_size=21,
                    sql_max_pool_size=42)

        info = {}

        class FakeConnectionPool(db_pool.ConnectionPool):
            def __init__(self, mod_name, **kwargs):
                info['module'] = mod_name
                info['kwargs'] = kwargs
                super(FakeConnectionPool, self).__init__(mod_name,
                                                         **kwargs)

            def connect(self, *args, **kwargs):
                raise TestException()

        self.stubs.Set(db_pool, 'ConnectionPool',
                       FakeConnectionPool)

        sql_connection = 'mysql://user:pass@127.0.0.1/nova'
        self.assertRaises(TestException, session.create_engine,
                          sql_connection)

        self.assertEqual(info['module'], MySQLdb)
        self.assertEqual(info['kwargs']['max_idle'], 11)
        self.assertEqual(info['kwargs']['min_size'], 21)
        self.assertEqual(info['kwargs']['max_size'], 42)


BASE = declarative_base()
_TABLE_NAME = '__tmp__test__tmp__'


class TmpTable(BASE, models.ModelBase):
    __tablename__ = _TABLE_NAME
    id = Column(Integer, primary_key=True)
    foo = Column(Integer)


class SessionErrorWrapperTestCase(test_utils.BaseTestCase):
    def setUp(self):
        super(SessionErrorWrapperTestCase, self).setUp()
        meta = MetaData()
        meta.bind = session.get_engine()
        test_table = Table(_TABLE_NAME, meta,
                           Column('id', Integer, primary_key=True,
                                  nullable=False),
                           Column('deleted', Integer, default=0),
                           Column('deleted_at', DateTime),
                           Column('updated_at', DateTime),
                           Column('created_at', DateTime),
                           Column('foo', Integer),
                           UniqueConstraint('foo', name='uniq_foo'))
        test_table.create()

    def tearDown(self):
        super(SessionErrorWrapperTestCase, self).tearDown()
        meta = MetaData()
        meta.bind = session.get_engine()
        test_table = Table(_TABLE_NAME, meta, autoload=True)
        test_table.drop()

    def test_flush_wrapper(self):
        tbl = TmpTable()
        tbl.update({'foo': 10})
        tbl.save()

        tbl2 = TmpTable()
        tbl2.update({'foo': 10})
        self.assertRaises(session.DBDuplicateEntry, tbl2.save)

    def test_execute_wrapper(self):
        _session = session.get_session()
        with _session.begin():
            for i in [10, 20]:
                tbl = TmpTable()
                tbl.update({'foo': i})
                tbl.save(session=_session)

            method = _session.query(TmpTable).\
                filter_by(foo=10).\
                update
            self.assertRaises(session.DBDuplicateEntry,
                              method, {'foo': 20})