diff options
-rw-r--r-- | nova/db/sqlalchemy/api.py | 50 | ||||
-rw-r--r-- | nova/db/sqlalchemy/migrate_repo/versions/177_add_floating_ip_uc.py | 40 | ||||
-rw-r--r-- | nova/tests/test_db_api.py | 9 | ||||
-rw-r--r-- | nova/tests/test_migrations.py | 30 |
4 files changed, 96 insertions, 33 deletions
diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 12ecd0af1..72ade9857 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -685,21 +685,18 @@ def floating_ip_allocate_address(context, project_id, pool): @require_context def floating_ip_bulk_create(context, ips): - existing_ips = {} - for floating in _floating_ip_get_all(context).all(): - existing_ips[floating['address']] = floating - session = get_session() with session.begin(): for ip in ips: - addr = ip['address'] - if (addr in existing_ips and - ip.get('id') != existing_ips[addr]['id']): - raise exception.FloatingIpExists(**dict(existing_ips[addr])) - model = models.FloatingIp() model.update(ip) - session.add(model) + try: + # NOTE(boris-42): To get existing address we have to do each + # time session.flush().. + session.add(model) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.FloatingIpExists(address=ip['address']) def _ip_range_splitter(ips, block_size=256): @@ -731,25 +728,12 @@ def floating_ip_bulk_destroy(context, ips): @require_context def floating_ip_create(context, values, session=None): - if not session: - session = get_session() - floating_ip_ref = models.FloatingIp() floating_ip_ref.update(values) - - # check uniqueness for not deleted addresses - if not floating_ip_ref.deleted: - try: - floating_ip = _floating_ip_get_by_address(context, - floating_ip_ref.address, - session) - except exception.FloatingIpNotFoundForAddress: - pass - else: - if floating_ip.id != floating_ip_ref.id: - raise exception.FloatingIpExists(**dict(floating_ip_ref)) - - floating_ip_ref.save(session=session) + try: + floating_ip_ref.save() + except db_exc.DBDuplicateEntry: + raise exception.FloatingIpExists(address=values['address']) return floating_ip_ref @@ -916,12 +900,12 @@ def floating_ip_get_by_fixed_ip_id(context, fixed_ip_id, session=None): def floating_ip_update(context, address, values): session = get_session() with session.begin(): - floating_ip_ref = _floating_ip_get_by_address(context, - address, - session) - for (key, value) in values.iteritems(): - floating_ip_ref[key] = value - floating_ip_ref.save(session=session) + float_ip_ref = _floating_ip_get_by_address(context, address, session) + float_ip_ref.update(values) + try: + float_ip_ref.save(session=session) + except db_exc.DBDuplicateEntry: + raise exception.FloatingIpExists(address=values['address']) @require_context diff --git a/nova/db/sqlalchemy/migrate_repo/versions/177_add_floating_ip_uc.py b/nova/db/sqlalchemy/migrate_repo/versions/177_add_floating_ip_uc.py new file mode 100644 index 000000000..c0dd7c91d --- /dev/null +++ b/nova/db/sqlalchemy/migrate_repo/versions/177_add_floating_ip_uc.py @@ -0,0 +1,40 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (c) 2013 Boris Pavlovic (boris@pavlovic.me). +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from migrate.changeset import UniqueConstraint +from sqlalchemy import MetaData, Table + +from nova.db.sqlalchemy import utils + + +UC_NAME = "uniq_address_x_deleted" +COLUMNS = ('address', 'deleted') +TABLE_NAME = 'floating_ips' + + +def upgrade(migrate_engine): + meta = MetaData(bind=migrate_engine) + t = Table(TABLE_NAME, meta, autoload=True) + + utils.drop_old_duplicate_entries_from_table(migrate_engine, TABLE_NAME, + True, *COLUMNS) + uc = UniqueConstraint(*COLUMNS, table=t, name=UC_NAME) + uc.create() + + +def downgrade(migrate_engine): + utils.drop_unique_constraint(migrate_engine, TABLE_NAME, UC_NAME, *COLUMNS) diff --git a/nova/tests/test_db_api.py b/nova/tests/test_db_api.py index fcba2aefa..a3c281b49 100644 --- a/nova/tests/test_db_api.py +++ b/nova/tests/test_db_api.py @@ -2989,6 +2989,15 @@ class FloatingIpTestCase(test.TestCase, ModelsObjectComparatorMixin): 'deleted', 'fixed_ip_id', 'fixed_ip']) + def test_floating_ip_update_to_duplicate(self): + float_ip1 = self._create_floating_ip({'address': '1.1.1.1'}) + float_ip2 = self._create_floating_ip({'address': '1.1.1.2'}) + + self.assertRaises(exception.FloatingIpExists, + db.floating_ip_update, + self.ctxt, float_ip2['address'], + {'address': float_ip1['address']}) + class InstanceDestroyConstraints(test.TestCase): diff --git a/nova/tests/test_migrations.py b/nova/tests/test_migrations.py index 8c2f04f21..a0f71b25a 100644 --- a/nova/tests/test_migrations.py +++ b/nova/tests/test_migrations.py @@ -1285,6 +1285,36 @@ class TestNovaMigrations(BaseMigrationTestCase, CommonTestsMixIn): self.assertFalse('availability_zone' in rows[0]) + def _pre_upgrade_177(self, engine): + floating_ips = get_table(engine, 'floating_ips') + data = [ + {'address': '128.128.128.128', 'deleted': 0}, + {'address': '128.128.128.128', 'deleted': 0}, + {'address': '128.128.128.129', 'deleted': 0}, + ] + + for item in data: + floating_ips.insert().values(item).execute() + return data + + def _check_177(self, engine, data): + floating_ips = get_table(engine, 'floating_ips') + + def get_(address, deleted): + deleted_value = 0 if not deleted else floating_ips.c.id + return floating_ips.select().\ + where(floating_ips.c.address == address).\ + where(floating_ips.c.deleted == deleted_value).\ + execute().\ + fetchall() + + self.assertEqual(1, len(get_('128.128.128.128', False))) + self.assertEqual(1, len(get_('128.128.128.128', True))) + self.assertEqual(1, len(get_('128.128.128.129', False))) + self.assertRaises(sqlalchemy.exc.IntegrityError, + floating_ips.insert().execute, + dict(address='128.128.128.129', deleted=0)) + class TestBaremetalMigrations(BaseMigrationTestCase, CommonTestsMixIn): """Test sqlalchemy-migrate migrations.""" |