summaryrefslogtreecommitdiffstats
path: root/nova/db
diff options
context:
space:
mode:
authorJenkins <jenkins@review.openstack.org>2012-03-22 20:27:03 +0000
committerGerrit Code Review <review@openstack.org>2012-03-22 20:27:03 +0000
commit1bfa451b4c9f131bd0d30f522897361087dfaabe (patch)
tree2bf08c4c5f4e76d95414cdc1d4f3e50198457348 /nova/db
parent79807e11a1e066e79edd8460d9306824ce83b0e5 (diff)
parent155ef7daab08d7f3fb8f7838df1d715bf1dc2f3f (diff)
downloadnova-1bfa451b4c9f131bd0d30f522897361087dfaabe.tar.gz
nova-1bfa451b4c9f131bd0d30f522897361087dfaabe.tar.xz
nova-1bfa451b4c9f131bd0d30f522897361087dfaabe.zip
Merge "Make sqlite in-memory-db usable to unittest"
Diffstat (limited to 'nova/db')
-rw-r--r--nova/db/sqlalchemy/migration.py55
-rw-r--r--nova/db/sqlalchemy/session.py45
2 files changed, 74 insertions, 26 deletions
diff --git a/nova/db/sqlalchemy/migration.py b/nova/db/sqlalchemy/migration.py
index 16177cbcf..14a111f9c 100644
--- a/nova/db/sqlalchemy/migration.py
+++ b/nova/db/sqlalchemy/migration.py
@@ -16,14 +16,46 @@
# License for the specific language governing permissions and limitations
# under the License.
+import distutils.version as dist_version
import os
import sys
+from nova.db.sqlalchemy.session import get_engine
from nova import exception
from nova import flags
import sqlalchemy
+import migrate
+from migrate.versioning import util as migrate_util
+
+
+MIGRATE_PKG_VER = dist_version.StrictVersion(migrate.__version__)
+USE_MIGRATE_PATCH = MIGRATE_PKG_VER < dist_version.StrictVersion('0.7.3')
+
+
+@migrate_util.decorator
+def patched_with_engine(f, *a, **kw):
+ url = a[0]
+ engine = migrate_util.construct_engine(url, **kw)
+
+ try:
+ kw['engine'] = engine
+ return f(*a, **kw)
+ finally:
+ if isinstance(engine, migrate_util.Engine) and engine is not url:
+ migrate_util.log.debug('Disposing SQLAlchemy engine %s', engine)
+ engine.dispose()
+
+
+# TODO(jkoelker) When migrate 0.7.3 is released and nova depends
+# on that version or higher, this can be removed
+if USE_MIGRATE_PATCH:
+ migrate_util.with_engine = patched_with_engine
+
+
+# NOTE(jkoelker) Delay importing migrate until we are patched
from migrate.versioning import api as versioning_api
+from migrate.versioning.repository import Repository
try:
from migrate.versioning import exceptions as versioning_exceptions
@@ -37,6 +69,8 @@ except ImportError:
FLAGS = flags.FLAGS
+_REPOSITORY = None
+
def db_sync(version=None):
if version is not None:
@@ -46,24 +80,24 @@ def db_sync(version=None):
raise exception.Error(_("version should be an integer"))
current_version = db_version()
- repo_path = _find_migrate_repo()
+ repository = _find_migrate_repo()
if version is None or version > current_version:
- return versioning_api.upgrade(FLAGS.sql_connection, repo_path, version)
+ return versioning_api.upgrade(get_engine(), repository, version)
else:
- return versioning_api.downgrade(FLAGS.sql_connection, repo_path,
+ return versioning_api.downgrade(get_engine(), repository,
version)
def db_version():
- repo_path = _find_migrate_repo()
+ repository = _find_migrate_repo()
try:
- return versioning_api.db_version(FLAGS.sql_connection, repo_path)
+ return versioning_api.db_version(get_engine(), repository)
except versioning_exceptions.DatabaseNotControlledError:
# If we aren't version controlled we may already have the database
# in the state from before we started version control, check for that
# and set up version_control appropriately
meta = sqlalchemy.MetaData()
- engine = sqlalchemy.create_engine(FLAGS.sql_connection, echo=False)
+ engine = get_engine()
meta.reflect(bind=engine)
try:
for table in ('auth_tokens', 'zones', 'export_devices',
@@ -85,14 +119,17 @@ def db_version():
def db_version_control(version=None):
- repo_path = _find_migrate_repo()
- versioning_api.version_control(FLAGS.sql_connection, repo_path, version)
+ repository = _find_migrate_repo()
+ versioning_api.version_control(get_engine(), repository, version)
return version
def _find_migrate_repo():
"""Get the path for the migrate repository."""
+ global _REPOSITORY
path = os.path.join(os.path.abspath(os.path.dirname(__file__)),
'migrate_repo')
assert os.path.exists(path)
- return path
+ if _REPOSITORY is None:
+ _REPOSITORY = Repository(path)
+ return _REPOSITORY
diff --git a/nova/db/sqlalchemy/session.py b/nova/db/sqlalchemy/session.py
index 52983134a..fe1b44c41 100644
--- a/nova/db/sqlalchemy/session.py
+++ b/nova/db/sqlalchemy/session.py
@@ -23,6 +23,8 @@ import time
import sqlalchemy.interfaces
import sqlalchemy.orm
from sqlalchemy.exc import DisconnectionError
+from sqlalchemy.pool import NullPool, StaticPool
+import time
import nova.exception
import nova.flags as flags
@@ -38,11 +40,11 @@ _MAKER = None
def get_session(autocommit=True, expire_on_commit=False):
"""Return a SQLAlchemy session."""
- global _ENGINE, _MAKER
+ global _MAKER
- if _MAKER is None or _ENGINE is None:
- _ENGINE = get_engine()
- _MAKER = get_maker(_ENGINE, autocommit, expire_on_commit)
+ if _MAKER is None:
+ engine = get_engine()
+ _MAKER = get_maker(engine, autocommit, expire_on_commit)
session = _MAKER()
session.query = nova.exception.wrap_db_error(session.query)
@@ -81,23 +83,32 @@ class MySQLPingListener(object):
def get_engine():
"""Return a SQLAlchemy engine."""
- connection_dict = sqlalchemy.engine.url.make_url(FLAGS.sql_connection)
+ global _ENGINE
+ if _ENGINE is None:
+ connection_dict = sqlalchemy.engine.url.make_url(FLAGS.sql_connection)
+
+ engine_args = {
+ "pool_recycle": FLAGS.sql_idle_timeout,
+ "echo": False,
+ 'convert_unicode': True,
+ }
+
+ if "sqlite" in connection_dict.drivername:
+ engine_args["poolclass"] = NullPool
+
+ if FLAGS.sql_connection == "sqlite://":
+ engine_args["poolclass"] = StaticPool
+ engine_args["connect_args"] = {'check_same_thread': False}
- engine_args = {
- "pool_recycle": FLAGS.sql_idle_timeout,
- "echo": False,
- 'convert_unicode': True,
- }
+ if not FLAGS.sqlite_synchronous:
+ engine_args["listeners"] = [SynchronousSwitchListener()]
- if "sqlite" in connection_dict.drivername:
- engine_args["poolclass"] = sqlalchemy.pool.NullPool
- if not FLAGS.sqlite_synchronous:
- engine_args["listeners"] = [SynchronousSwitchListener()]
+ if 'mysql' in connection_dict.drivername:
+ engine_args['listeners'] = [MySQLPingListener()]
- if 'mysql' in connection_dict.drivername:
- engine_args['listeners'] = [MySQLPingListener()]
+ _ENGINE = sqlalchemy.create_engine(FLAGS.sql_connection, **engine_args)
- return sqlalchemy.create_engine(FLAGS.sql_connection, **engine_args)
+ return _ENGINE
def get_maker(engine, autocommit=True, expire_on_commit=False):