summaryrefslogtreecommitdiffstats
path: root/keystone/common/sql/migration.py
diff options
context:
space:
mode:
Diffstat (limited to 'keystone/common/sql/migration.py')
-rw-r--r--keystone/common/sql/migration.py42
1 files changed, 27 insertions, 15 deletions
diff --git a/keystone/common/sql/migration.py b/keystone/common/sql/migration.py
index 86e0254c..3cb9cd63 100644
--- a/keystone/common/sql/migration.py
+++ b/keystone/common/sql/migration.py
@@ -39,39 +39,51 @@ except ImportError:
sys.exit('python-migrate is not installed. Exiting.')
-def db_sync(version=None):
+def migrate_repository(version, current_version, repo_path):
+ if version is None or version > current_version:
+ result = versioning_api.upgrade(CONF.sql.connection,
+ repo_path, version)
+ else:
+ result = versioning_api.downgrade(
+ CONF.sql.connection, repo_path, version)
+ return result
+
+
+def db_sync(version=None, repo_path=None):
if version is not None:
try:
version = int(version)
except ValueError:
raise Exception(_('version should be an integer'))
+ if repo_path is None:
+ repo_path = find_migrate_repo()
+ current_version = db_version(repo_path=repo_path)
+ return migrate_repository(version, current_version, repo_path)
- current_version = db_version()
- repo_path = _find_migrate_repo()
- if version is None or version > current_version:
- return versioning_api.upgrade(CONF.sql.connection, repo_path, version)
- else:
- return versioning_api.downgrade(
- CONF.sql.connection, repo_path, version)
-
-def db_version():
- repo_path = _find_migrate_repo()
+def db_version(repo_path=None):
+ if repo_path is None:
+ repo_path = find_migrate_repo()
try:
return versioning_api.db_version(CONF.sql.connection, repo_path)
except versioning_exceptions.DatabaseNotControlledError:
return db_version_control(0)
-def db_version_control(version=None):
- repo_path = _find_migrate_repo()
+def db_version_control(version=None, repo_path=None):
+ if repo_path is None:
+ repo_path = find_migrate_repo()
versioning_api.version_control(CONF.sql.connection, repo_path, version)
return version
-def _find_migrate_repo():
+def find_migrate_repo(package=None):
"""Get the path for the migrate repository."""
- path = os.path.join(os.path.abspath(os.path.dirname(__file__)),
+ if package is None:
+ file = __file__
+ else:
+ file = package.__file__
+ path = os.path.join(os.path.abspath(os.path.dirname(file)),
'migrate_repo')
assert os.path.exists(path)
return path