diff options
| author | Alex Meade <alex.meade@rackspace.com> | 2011-05-25 17:36:51 -0400 |
|---|---|---|
| committer | Alex Meade <alex.meade@rackspace.com> | 2011-05-25 17:36:51 -0400 |
| commit | 719dfcd62cec0f89b6e86b202f84ea79f448d4c7 (patch) | |
| tree | 99530b1b4ab7ae2c76196dd6cab286c30836438c | |
| parent | c440aecaaacf3caa8683234022bc10836d232971 (diff) | |
| parent | db18a792414240cbdb1221d0e79e8a63313f103e (diff) | |
| download | nova-719dfcd62cec0f89b6e86b202f84ea79f448d4c7.tar.gz nova-719dfcd62cec0f89b6e86b202f84ea79f448d4c7.tar.xz nova-719dfcd62cec0f89b6e86b202f84ea79f448d4c7.zip | |
Merged trunk
31 files changed, 1242 insertions, 868 deletions
@@ -17,6 +17,7 @@ Christian Berendt <berendt@b1-systems.de> Chuck Short <zulcss@ubuntu.com> Cory Wright <corywright@gmail.com> Dan Prince <dan.prince@rackspace.com> +Dave Walker <DaveWalker@ubuntu.com> David Pravec <David.Pravec@danix.org> Dean Troyer <dtroyer@gmail.com> Devin Carlen <devin.carlen@gmail.com> diff --git a/bin/nova-dhcpbridge b/bin/nova-dhcpbridge index f42dfd6b5..5926b97de 100755 --- a/bin/nova-dhcpbridge +++ b/bin/nova-dhcpbridge @@ -108,6 +108,13 @@ def main(): interface = os.environ.get('DNSMASQ_INTERFACE', FLAGS.dnsmasq_interface) if int(os.environ.get('TESTING', '0')): from nova.tests import fake_flags + + #if FLAGS.fake_rabbit: + # LOG.debug(_("leasing ip")) + # network_manager = utils.import_object(FLAGS.network_manager) + ## reload(fake_flags) + # from nova.tests import fake_flags + action = argv[1] if action in ['add', 'del', 'old']: mac = argv[2] diff --git a/bin/nova-manage b/bin/nova-manage index e3ed7b9d0..26c0d776c 100755 --- a/bin/nova-manage +++ b/bin/nova-manage @@ -97,7 +97,7 @@ flags.DECLARE('vlan_start', 'nova.network.manager') flags.DECLARE('vpn_start', 'nova.network.manager') flags.DECLARE('fixed_range_v6', 'nova.network.manager') flags.DECLARE('images_path', 'nova.image.local') -flags.DECLARE('libvirt_type', 'nova.virt.libvirt_conn') +flags.DECLARE('libvirt_type', 'nova.virt.libvirt.connection') flags.DEFINE_flag(flags.HelpFlag()) flags.DEFINE_flag(flags.HelpshortFlag()) flags.DEFINE_flag(flags.HelpXMLFlag()) @@ -417,12 +417,16 @@ class ProjectCommands(object): arguments: project_id [key] [value]""" ctxt = context.get_admin_context() if key: + if value.lower() == 'unlimited': + value = None try: db.quota_update(ctxt, project_id, key, value) except exception.ProjectQuotaNotFound: db.quota_create(ctxt, project_id, key, value) - project_quota = quota.get_quota(ctxt, project_id) + project_quota = quota.get_project_quotas(ctxt, project_id) for key, value in project_quota.iteritems(): + if value is None: + value = 'unlimited' print '%s: %s' % (key, value) def remove(self, project_id, user_id): diff --git a/doc/source/devref/index.rst b/doc/source/devref/index.rst index 9613ba990..0a5a7a4d6 100644 --- a/doc/source/devref/index.rst +++ b/doc/source/devref/index.rst @@ -35,6 +35,7 @@ Programming Concepts .. toctree:: :maxdepth: 3 + zone rabbit API Reference diff --git a/doc/source/devref/zone.rst b/doc/source/devref/zone.rst index 3dd9d37d3..263560ee2 100644 --- a/doc/source/devref/zone.rst +++ b/doc/source/devref/zone.rst @@ -17,7 +17,7 @@ Zones ===== -A Nova deployment is called a Zone. At the very least a Zone requires an API node, a Scheduler node, a database and RabbitMQ. Pushed further a Zone may contain many API nodes, many Scheduler, Volume, Network and Compute nodes as well as a cluster of databases and RabbitMQ servers. A Zone allows you to partition your deployments into logical groups for load balancing and instance distribution. +A Nova deployment is called a Zone. A Zone allows you to partition your deployments into logical groups for load balancing and instance distribution. At the very least a Zone requires an API node, a Scheduler node, a database and RabbitMQ. Pushed further a Zone may contain many API nodes, many Scheduler, Volume, Network and Compute nodes as well as a cluster of databases and RabbitMQ servers. The idea behind Zones is, if a particular deployment is not capable of servicing a particular request, the request may be forwarded to (child) Zones for possible processing. Zones may be nested in a tree fashion. @@ -34,7 +34,7 @@ Routing between Zones is based on the Capabilities of that Zone. Capabilities ar key=value;value;value, key=value;value;value -Zones have Capabilities which are general to the Zone and are set via `--zone-capabilities` flag. Zones also have dynamic per-service Capabilities. Services derived from `nova.manager.SchedulerDependentManager` (such as Compute, Volume and Network) can set these capabilities by calling the `update_service_capabilities()` method on their `Manager` base class. These capabilities will be periodically sent to the Scheduler service automatically. The rate at which these updates are sent is controlled by the `--periodic_interval` flag. +Zones have Capabilities which are general to the Zone and are set via `--zone_capabilities` flag. Zones also have dynamic per-service Capabilities. Services derived from `nova.manager.SchedulerDependentManager` (such as Compute, Volume and Network) can set these capabilities by calling the `update_service_capabilities()` method on their `Manager` base class. These capabilities will be periodically sent to the Scheduler service automatically. The rate at which these updates are sent is controlled by the `--periodic_interval` flag. Flow within a Zone ------------------ @@ -47,7 +47,7 @@ Inter-service communication within a Zone is done with RabbitMQ. Each class of S These capability messages are received by the Scheduler services and stored in the `ZoneManager` object. The SchedulerManager object has a reference to the `ZoneManager` it can use for load balancing. -The `ZoneManager` also polls the child Zones periodically to gather their capabilities to aid in decision making. This is done via the OpenStack API `/v1.0/zones/info` REST call. This also captures the name of each child Zone. The Zone name is set via the `--zone-name` flag (and defaults to "nova"). +The `ZoneManager` also polls the child Zones periodically to gather their capabilities to aid in decision making. This is done via the OpenStack API `/v1.0/zones/info` REST call. This also captures the name of each child Zone. The Zone name is set via the `--zone_name` flag (and defaults to "nova"). Zone administrative functions ----------------------------- diff --git a/doc/source/man/novamanage.rst b/doc/source/man/novamanage.rst index 9c54f3608..397cc8e80 100644 --- a/doc/source/man/novamanage.rst +++ b/doc/source/man/novamanage.rst @@ -6,7 +6,7 @@ nova-manage control and manage cloud computer instances and images ------------------------------------------------------ -:Author: nova@lists.launchpad.net +:Author: openstack@lists.launchpad.net :Date: 2010-11-16 :Copyright: OpenStack LLC :Version: 0.1 @@ -121,7 +121,7 @@ Nova Role nova-manage role <action> [<argument>] ``nova-manage role add <username> <rolename> <(optional) projectname>`` - Add a user to either a global or project-based role with the indicated <rolename> assigned to the named user. Role names can be one of the following five roles: admin, itsec, projectmanager, netadmin, developer. If you add the project name as the last argument then the role is assigned just for that project, otherwise the user is assigned the named role for all projects. + Add a user to either a global or project-based role with the indicated <rolename> assigned to the named user. Role names can be one of the following five roles: cloudadmin, itsec, sysadmin, netadmin, developer. If you add the project name as the last argument then the role is assigned just for that project, otherwise the user is assigned the named role for all projects. ``nova-manage role has <username> <projectname>`` Checks the user or project and responds with True if the user has a global role with a particular project. diff --git a/doc/source/runnova/managing.users.rst b/doc/source/runnova/managing.users.rst index 392142e86..d3442bed9 100644 --- a/doc/source/runnova/managing.users.rst +++ b/doc/source/runnova/managing.users.rst @@ -38,11 +38,11 @@ Role-based access control (RBAC) is an approach to restricting system access to Nova’s rights management system employs the RBAC model and currently supports the following five roles: -* **Cloud Administrator.** (admin) Users of this class enjoy complete system access. +* **Cloud Administrator.** (cloudadmin) Users of this class enjoy complete system access. * **IT Security.** (itsec) This role is limited to IT security personnel. It permits role holders to quarantine instances. -* **Project Manager.** (projectmanager)The default for project owners, this role affords users the ability to add other users to a project, interact with project images, and launch and terminate instances. +* **System Administrator.** (sysadmin) The default for project owners, this role affords users the ability to add other users to a project, interact with project images, and launch and terminate instances. * **Network Administrator.** (netadmin) Users with this role are permitted to allocate and assign publicly accessible IP addresses as well as create and modify firewall rules. -* **Developer.** This is a general purpose role that is assigned to users by default. +* **Developer.** (developer) This is a general purpose role that is assigned to users by default. RBAC management is exposed through the dashboard for simplified user management. diff --git a/nova/api/ec2/__init__.py b/nova/api/ec2/__init__.py index cd59340bd..c13993dd3 100644 --- a/nova/api/ec2/__init__.py +++ b/nova/api/ec2/__init__.py @@ -338,6 +338,10 @@ class Executor(wsgi.Application): else: return self._error(req, context, type(ex).__name__, unicode(ex)) + except exception.KeyPairExists as ex: + LOG.debug(_('KeyPairExists raised: %s'), unicode(ex), + context=context) + return self._error(req, context, type(ex).__name__, unicode(ex)) except Exception as ex: extra = {'environment': req.environ} LOG.exception(_('Unexpected error raised: %s'), unicode(ex), diff --git a/nova/api/openstack/limits.py b/nova/api/openstack/limits.py index 47bc238f1..bd0250a7f 100644 --- a/nova/api/openstack/limits.py +++ b/nova/api/openstack/limits.py @@ -30,6 +30,7 @@ from collections import defaultdict from webob.dec import wsgify +from nova import quota from nova import wsgi from nova.api.openstack import common from nova.api.openstack import faults @@ -64,7 +65,8 @@ class LimitsController(common.OpenstackController): """ Return all global and rate limit information. """ - abs_limits = {} + context = req.environ['nova.context'] + abs_limits = quota.get_project_quotas(context, context.project_id) rate_limits = req.environ.get("nova.limits", []) builder = self._get_view_builder(req) diff --git a/nova/api/openstack/servers.py b/nova/api/openstack/servers.py index 8f2de2afe..5c10fc916 100644 --- a/nova/api/openstack/servers.py +++ b/nova/api/openstack/servers.py @@ -180,7 +180,8 @@ class Controller(common.OpenstackController): key_name=key_name, key_data=key_data, metadata=env['server'].get('metadata', {}), - injected_files=injected_files) + injected_files=injected_files, + admin_password=password) except quota.QuotaError as error: self._handle_quota_error(error) @@ -190,8 +191,6 @@ class Controller(common.OpenstackController): builder = self._get_view_builder(req) server = builder.build(inst, is_detail=True) server['server']['adminPass'] = password - self.compute_api.set_admin_password(context, server['server']['id'], - password) return server def _deserialize_create(self, request): @@ -608,8 +607,8 @@ class ControllerV10(Controller): def _parse_update(self, context, server_id, inst_dict, update_dict): if 'adminPass' in inst_dict['server']: - update_dict['admin_pass'] = inst_dict['server']['adminPass'] - self.compute_api.set_admin_password(context, server_id) + self.compute_api.set_admin_password(context, server_id, + inst_dict['server']['adminPass']) def _action_rebuild(self, info, request, instance_id): context = request.environ['nova.context'] diff --git a/nova/api/openstack/views/limits.py b/nova/api/openstack/views/limits.py index 22d1c260d..e21c9f2fd 100644 --- a/nova/api/openstack/views/limits.py +++ b/nova/api/openstack/views/limits.py @@ -45,6 +45,34 @@ class ViewBuilder(object): return output + def _build_absolute_limits(self, absolute_limits): + """Builder for absolute limits + + absolute_limits should be given as a dict of limits. + For example: {"ram": 512, "gigabytes": 1024}. + + """ + limit_names = { + "ram": ["maxTotalRAMSize"], + "instances": ["maxTotalInstances"], + "cores": ["maxTotalCores"], + "metadata_items": ["maxServerMeta", "maxImageMeta"], + "injected_files": ["maxPersonality"], + "injected_file_content_bytes": ["maxPersonalitySize"], + } + limits = {} + for name, value in absolute_limits.iteritems(): + if name in limit_names and value is not None: + for name in limit_names[name]: + limits[name] = value + return limits + + def _build_rate_limits(self, rate_limits): + raise NotImplementedError() + + def _build_rate_limit(self, rate_limit): + raise NotImplementedError() + class ViewBuilderV10(ViewBuilder): """Openstack API v1.0 limits view builder.""" @@ -63,9 +91,6 @@ class ViewBuilderV10(ViewBuilder): "resetTime": rate_limit["resetTime"], } - def _build_absolute_limits(self, absolute_limit): - return {} - class ViewBuilderV11(ViewBuilder): """Openstack API v1.1 limits view builder.""" @@ -79,7 +104,7 @@ class ViewBuilderV11(ViewBuilder): # check for existing key for limit in limits: if limit["uri"] == rate_limit["URI"] and \ - limit["regex"] == limit["regex"]: + limit["regex"] == rate_limit["regex"]: _rate_limit_key = limit break @@ -104,6 +129,3 @@ class ViewBuilderV11(ViewBuilder): "unit": rate_limit["unit"], "next-available": rate_limit["resetTime"], } - - def _build_absolute_limits(self, absolute_limit): - return {} diff --git a/nova/compute/api.py b/nova/compute/api.py index 7e2494781..4f2363387 100644 --- a/nova/compute/api.py +++ b/nova/compute/api.py @@ -95,14 +95,15 @@ class API(base.Base): """ if injected_files is None: return - limit = quota.allowed_injected_files(context) + limit = quota.allowed_injected_files(context, len(injected_files)) if len(injected_files) > limit: raise quota.QuotaError(code="OnsetFileLimitExceeded") path_limit = quota.allowed_injected_file_path_bytes(context) - content_limit = quota.allowed_injected_file_content_bytes(context) for path, content in injected_files: if len(path) > path_limit: raise quota.QuotaError(code="OnsetFilePathLimitExceeded") + content_limit = quota.allowed_injected_file_content_bytes( + context, len(content)) if len(content) > content_limit: raise quota.QuotaError(code="OnsetFileContentLimitExceeded") @@ -134,7 +135,8 @@ class API(base.Base): display_name='', display_description='', key_name=None, key_data=None, security_group='default', availability_zone=None, user_data=None, metadata={}, - injected_files=None): + injected_files=None, + admin_password=None): """Create the number and type of instances requested. Verifies that quota and other arguments are valid. @@ -149,9 +151,13 @@ class API(base.Base): pid = context.project_id LOG.warn(_("Quota exceeeded for %(pid)s," " tried to run %(min_count)s instances") % locals()) - raise quota.QuotaError(_("Instance quota exceeded. You can only " - "run %s more instances of this type.") % - num_instances, "InstanceLimitExceeded") + if num_instances <= 0: + message = _("Instance quota exceeded. You cannot run any " + "more instances of this type.") + else: + message = _("Instance quota exceeded. You can only run %s " + "more instances of this type.") % num_instances + raise quota.QuotaError(message, "InstanceLimitExceeded") self._check_metadata_properties_quota(context, metadata) self._check_injected_file_quota(context, injected_files) @@ -264,7 +270,8 @@ class API(base.Base): "instance_id": instance_id, "instance_type": instance_type, "availability_zone": availability_zone, - "injected_files": injected_files}}) + "injected_files": injected_files, + "admin_password": admin_password}}) for group_id in security_groups: self.trigger_security_group_members_refresh(elevated, group_id) @@ -503,15 +510,6 @@ class API(base.Base): raise exception.Error(_("Unable to find host for Instance %s") % instance_id) - def _set_admin_password(self, context, instance_id, password): - """Set the root/admin password for the given instance.""" - host = self._find_host(context, instance_id) - - rpc.cast(context, - self.db.queue_get_for(context, FLAGS.compute_topic, host), - {"method": "set_admin_password", - "args": {"instance_id": instance_id, "new_pass": password}}) - def snapshot(self, context, instance_id, name): """Snapshot the given instance. @@ -665,8 +663,12 @@ class API(base.Base): def set_admin_password(self, context, instance_id, password=None): """Set the root/admin password for the given instance.""" - eventlet.spawn_n(self._set_admin_password(context, instance_id, - password)) + host = self._find_host(context, instance_id) + + rpc.cast(context, + self.db.queue_get_for(context, FLAGS.compute_topic, host), + {"method": "set_admin_password", + "args": {"instance_id": instance_id, "new_pass": password}}) def inject_file(self, context, instance_id): """Write a file to the given instance.""" diff --git a/nova/compute/manager.py b/nova/compute/manager.py index 11565c25e..d1e01f275 100644 --- a/nova/compute/manager.py +++ b/nova/compute/manager.py @@ -221,6 +221,7 @@ class ComputeManager(manager.SchedulerDependentManager): context = context.elevated() instance_ref = self.db.instance_get(context, instance_id) instance_ref.injected_files = kwargs.get('injected_files', []) + instance_ref.admin_pass = kwargs.get('admin_password', None) if instance_ref['name'] in self.driver.list_instances(): raise exception.Error(_("Instance has already been created")) LOG.audit(_("instance %s: starting..."), instance_id, @@ -405,22 +406,28 @@ class ComputeManager(manager.SchedulerDependentManager): @exception.wrap_exception @checks_instance_lock def set_admin_password(self, context, instance_id, new_pass=None): - """Set the root/admin password for an instance on this host.""" + """Set the root/admin password for an instance on this host. + + This is generally only called by API password resets after an + image has been built. + """ + context = context.elevated() if new_pass is None: # Generate a random password new_pass = utils.generate_password(FLAGS.password_length) - while True: + max_tries = 10 + + for i in xrange(max_tries): instance_ref = self.db.instance_get(context, instance_id) instance_id = instance_ref["id"] instance_state = instance_ref["state"] expected_state = power_state.RUNNING if instance_state != expected_state: - time.sleep(5) - continue + raise exception.Error(_('Instance is not running')) else: try: self.driver.set_admin_password(instance_ref, new_pass) @@ -436,6 +443,12 @@ class ComputeManager(manager.SchedulerDependentManager): except Exception, e: # Catch all here because this could be anything. LOG.exception(e) + if i == max_tries - 1: + # At some point this exception may make it back + # to the API caller, and we don't want to reveal + # too much. The real exception is logged above + raise exception.Error(_('Internal error')) + time.sleep(1) continue @exception.wrap_exception diff --git a/nova/db/api.py b/nova/db/api.py index ef8aa1143..310c0bb09 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -403,7 +403,7 @@ def instance_create(context, values): def instance_data_get_for_project(context, project_id): - """Get (instance_count, core_count) for project.""" + """Get (instance_count, total_cores, total_ram) for project.""" return IMPL.instance_data_get_for_project(context, project_id) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index b53e81053..e4dda5c12 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -803,12 +803,13 @@ def instance_create(context, values): def instance_data_get_for_project(context, project_id): session = get_session() result = session.query(func.count(models.Instance.id), - func.sum(models.Instance.vcpus)).\ + func.sum(models.Instance.vcpus), + func.sum(models.Instance.memory_mb)).\ filter_by(project_id=project_id).\ filter_by(deleted=False).\ first() # NOTE(vish): convert None to 0 - return (result[0] or 0, result[1] or 0) + return (result[0] or 0, result[1] or 0, result[2] or 0) @require_context @@ -1499,7 +1500,7 @@ def auth_token_create(_context, token): ################### -@require_admin_context +@require_context def quota_get(context, project_id, resource, session=None): if not session: session = get_session() @@ -1513,7 +1514,7 @@ def quota_get(context, project_id, resource, session=None): return result -@require_admin_context +@require_context def quota_get_all_by_project(context, project_id): session = get_session() result = {'project_id': project_id} diff --git a/nova/flags.py b/nova/flags.py index 32cb6efa8..9eaac5596 100644 --- a/nova/flags.py +++ b/nova/flags.py @@ -110,7 +110,7 @@ class FlagValues(gflags.FlagValues): return name in self.__dict__['__dirty'] def ClearDirty(self): - self.__dict__['__is_dirty'] = [] + self.__dict__['__dirty'] = [] def WasAlreadyParsed(self): return self.__dict__['__was_already_parsed'] @@ -119,11 +119,12 @@ class FlagValues(gflags.FlagValues): if '__stored_argv' not in self.__dict__: return new_flags = FlagValues(self) - for k in self.__dict__['__dirty']: + for k in self.FlagDict().iterkeys(): new_flags[k] = gflags.FlagValues.__getitem__(self, k) + new_flags.Reset() new_flags(self.__dict__['__stored_argv']) - for k in self.__dict__['__dirty']: + for k in new_flags.FlagDict().iterkeys(): setattr(self, k, getattr(new_flags, k)) self.ClearDirty() diff --git a/nova/quota.py b/nova/quota.py index a93cd0766..58766e846 100644 --- a/nova/quota.py +++ b/nova/quota.py @@ -28,6 +28,8 @@ flags.DEFINE_integer('quota_instances', 10, 'number of instances allowed per project') flags.DEFINE_integer('quota_cores', 20, 'number of instance cores allowed per project') +flags.DEFINE_integer('quota_ram', 50 * 1024, + 'megabytes of instance ram allowed per project') flags.DEFINE_integer('quota_volumes', 10, 'number of volumes allowed per project') flags.DEFINE_integer('quota_gigabytes', 1000, @@ -44,14 +46,28 @@ flags.DEFINE_integer('quota_max_injected_file_path_bytes', 255, 'number of bytes allowed per injected file path') -def get_quota(context, project_id): - rval = {'instances': FLAGS.quota_instances, - 'cores': FLAGS.quota_cores, - 'volumes': FLAGS.quota_volumes, - 'gigabytes': FLAGS.quota_gigabytes, - 'floating_ips': FLAGS.quota_floating_ips, - 'metadata_items': FLAGS.quota_metadata_items} - +def _get_default_quotas(): + defaults = { + 'instances': FLAGS.quota_instances, + 'cores': FLAGS.quota_cores, + 'ram': FLAGS.quota_ram, + 'volumes': FLAGS.quota_volumes, + 'gigabytes': FLAGS.quota_gigabytes, + 'floating_ips': FLAGS.quota_floating_ips, + 'metadata_items': FLAGS.quota_metadata_items, + 'injected_files': FLAGS.quota_max_injected_files, + 'injected_file_content_bytes': + FLAGS.quota_max_injected_file_content_bytes, + } + # -1 in the quota flags means unlimited + for key in defaults.keys(): + if defaults[key] == -1: + defaults[key] = None + return defaults + + +def get_project_quotas(context, project_id): + rval = _get_default_quotas() quota = db.quota_get_all_by_project(context, project_id) for key in rval.keys(): if key in quota: @@ -65,71 +81,81 @@ def _get_request_allotment(requested, used, quota): return quota - used -def allowed_instances(context, num_instances, instance_type): - """Check quota and return min(num_instances, allowed_instances).""" +def allowed_instances(context, requested_instances, instance_type): + """Check quota and return min(requested_instances, allowed_instances).""" project_id = context.project_id context = context.elevated() - num_cores = num_instances * instance_type['vcpus'] - used_instances, used_cores = db.instance_data_get_for_project(context, - project_id) - quota = get_quota(context, project_id) - allowed_instances = _get_request_allotment(num_instances, used_instances, + requested_cores = requested_instances * instance_type['vcpus'] + requested_ram = requested_instances * instance_type['memory_mb'] + usage = db.instance_data_get_for_project(context, project_id) + used_instances, used_cores, used_ram = usage + quota = get_project_quotas(context, project_id) + allowed_instances = _get_request_allotment(requested_instances, + used_instances, quota['instances']) - allowed_cores = _get_request_allotment(num_cores, used_cores, + allowed_cores = _get_request_allotment(requested_cores, used_cores, quota['cores']) + allowed_ram = _get_request_allotment(requested_ram, used_ram, quota['ram']) allowed_instances = min(allowed_instances, - int(allowed_cores // instance_type['vcpus'])) - return min(num_instances, allowed_instances) + allowed_cores // instance_type['vcpus'], + allowed_ram // instance_type['memory_mb']) + return min(requested_instances, allowed_instances) -def allowed_volumes(context, num_volumes, size): - """Check quota and return min(num_volumes, allowed_volumes).""" +def allowed_volumes(context, requested_volumes, size): + """Check quota and return min(requested_volumes, allowed_volumes).""" project_id = context.project_id context = context.elevated() size = int(size) - num_gigabytes = num_volumes * size + requested_gigabytes = requested_volumes * size used_volumes, used_gigabytes = db.volume_data_get_for_project(context, project_id) - quota = get_quota(context, project_id) - allowed_volumes = _get_request_allotment(num_volumes, used_volumes, + quota = get_project_quotas(context, project_id) + allowed_volumes = _get_request_allotment(requested_volumes, used_volumes, quota['volumes']) - allowed_gigabytes = _get_request_allotment(num_gigabytes, used_gigabytes, + allowed_gigabytes = _get_request_allotment(requested_gigabytes, + used_gigabytes, quota['gigabytes']) allowed_volumes = min(allowed_volumes, int(allowed_gigabytes // size)) - return min(num_volumes, allowed_volumes) + return min(requested_volumes, allowed_volumes) -def allowed_floating_ips(context, num_floating_ips): - """Check quota and return min(num_floating_ips, allowed_floating_ips).""" +def allowed_floating_ips(context, requested_floating_ips): + """Check quota and return min(requested, allowed) floating ips.""" project_id = context.project_id context = context.elevated() used_floating_ips = db.floating_ip_count_by_project(context, project_id) - quota = get_quota(context, project_id) - allowed_floating_ips = _get_request_allotment(num_floating_ips, + quota = get_project_quotas(context, project_id) + allowed_floating_ips = _get_request_allotment(requested_floating_ips, used_floating_ips, quota['floating_ips']) - return min(num_floating_ips, allowed_floating_ips) + return min(requested_floating_ips, allowed_floating_ips) -def allowed_metadata_items(context, num_metadata_items): - """Check quota; return min(num_metadata_items,allowed_metadata_items).""" - project_id = context.project_id - context = context.elevated() - quota = get_quota(context, project_id) - allowed_metadata_items = _get_request_allotment(num_metadata_items, 0, - quota['metadata_items']) - return min(num_metadata_items, allowed_metadata_items) +def _calculate_simple_quota(context, resource, requested): + """Check quota for resource; return min(requested, allowed).""" + quota = get_project_quotas(context, context.project_id) + allowed = _get_request_allotment(requested, 0, quota[resource]) + return min(requested, allowed) + + +def allowed_metadata_items(context, requested_metadata_items): + """Return the number of metadata items allowed.""" + return _calculate_simple_quota(context, 'metadata_items', + requested_metadata_items) -def allowed_injected_files(context): +def allowed_injected_files(context, requested_injected_files): """Return the number of injected files allowed.""" - return FLAGS.quota_max_injected_files + return _calculate_simple_quota(context, 'injected_files', + requested_injected_files) -def allowed_injected_file_content_bytes(context): +def allowed_injected_file_content_bytes(context, requested_bytes): """Return the number of bytes allowed per injected file content.""" - return FLAGS.quota_max_injected_file_content_bytes + resource = 'injected_file_content_bytes' + return _calculate_simple_quota(context, resource, requested_bytes) def allowed_injected_file_path_bytes(context): diff --git a/nova/tests/api/openstack/test_limits.py b/nova/tests/api/openstack/test_limits.py index 45bd4d501..70f59eda6 100644 --- a/nova/tests/api/openstack/test_limits.py +++ b/nova/tests/api/openstack/test_limits.py @@ -27,6 +27,7 @@ import webob from xml.dom.minidom import parseString +import nova.context from nova.api.openstack import limits @@ -47,6 +48,13 @@ class BaseLimitTestSuite(unittest.TestCase): self.time = 0.0 self.stubs = stubout.StubOutForTesting() self.stubs.Set(limits.Limit, "_get_time", self._get_time) + self.absolute_limits = {} + + def stub_get_project_quotas(context, project_id): + return self.absolute_limits + + self.stubs.Set(nova.quota, "get_project_quotas", + stub_get_project_quotas) def tearDown(self): """Run after each test.""" @@ -75,6 +83,8 @@ class LimitsControllerV10Test(BaseLimitTestSuite): "action": "index", "controller": "", }) + context = nova.context.RequestContext('testuser', 'testproject') + request.environ["nova.context"] = context return request def _populate_limits(self, request): @@ -86,6 +96,18 @@ class LimitsControllerV10Test(BaseLimitTestSuite): request.environ["nova.limits"] = _limits return request + def _setup_absolute_limits(self): + self.absolute_limits = { + 'instances': 5, + 'cores': 8, + 'ram': 2 ** 13, + 'volumes': 21, + 'gigabytes': 34, + 'metadata_items': 55, + 'injected_files': 89, + 'injected_file_content_bytes': 144, + } + def test_empty_index_json(self): """Test getting empty limit details in JSON.""" request = self._get_index_request() @@ -103,6 +125,7 @@ class LimitsControllerV10Test(BaseLimitTestSuite): """Test getting limit details in JSON.""" request = self._get_index_request() request = self._populate_limits(request) + self._setup_absolute_limits() response = request.get_response(self.controller) expected = { "limits": { @@ -124,7 +147,15 @@ class LimitsControllerV10Test(BaseLimitTestSuite): "remaining": 5, "unit": "HOUR", }], - "absolute": {}, + "absolute": { + "maxTotalInstances": 5, + "maxTotalCores": 8, + "maxTotalRAMSize": 2 ** 13, + "maxServerMeta": 55, + "maxImageMeta": 55, + "maxPersonality": 89, + "maxPersonalitySize": 144, + }, }, } body = json.loads(response.body) @@ -188,6 +219,8 @@ class LimitsControllerV11Test(BaseLimitTestSuite): "action": "index", "controller": "", }) + context = nova.context.RequestContext('testuser', 'testproject') + request.environ["nova.context"] = context return request def _populate_limits(self, request): @@ -218,6 +251,11 @@ class LimitsControllerV11Test(BaseLimitTestSuite): """Test getting limit details in JSON.""" request = self._get_index_request() request = self._populate_limits(request) + self.absolute_limits = { + 'ram': 512, + 'instances': 5, + 'cores': 21, + } response = request.get_response(self.controller) expected = { "limits": { @@ -257,12 +295,110 @@ class LimitsControllerV11Test(BaseLimitTestSuite): }, ], + "absolute": { + "maxTotalRAMSize": 512, + "maxTotalInstances": 5, + "maxTotalCores": 21, + }, + }, + } + body = json.loads(response.body) + self.assertEqual(expected, body) + + def _populate_limits_diff_regex(self, request): + """Put limit info into a request.""" + _limits = [ + limits.Limit("GET", "*", ".*", 10, 60).display(), + limits.Limit("GET", "*", "*.*", 10, 60).display(), + ] + request.environ["nova.limits"] = _limits + return request + + def test_index_diff_regex(self): + """Test getting limit details in JSON.""" + request = self._get_index_request() + request = self._populate_limits_diff_regex(request) + response = request.get_response(self.controller) + expected = { + "limits": { + "rate": [ + { + "regex": ".*", + "uri": "*", + "limit": [ + { + "verb": "GET", + "next-available": 0, + "unit": "MINUTE", + "value": 10, + "remaining": 10, + }, + ], + }, + { + "regex": "*.*", + "uri": "*", + "limit": [ + { + "verb": "GET", + "next-available": 0, + "unit": "MINUTE", + "value": 10, + "remaining": 10, + }, + ], + }, + + ], "absolute": {}, }, } body = json.loads(response.body) self.assertEqual(expected, body) + def _test_index_absolute_limits_json(self, expected): + request = self._get_index_request() + response = request.get_response(self.controller) + body = json.loads(response.body) + self.assertEqual(expected, body['limits']['absolute']) + + def test_index_ignores_extra_absolute_limits_json(self): + self.absolute_limits = {'unknown_limit': 9001} + self._test_index_absolute_limits_json({}) + + def test_index_absolute_ram_json(self): + self.absolute_limits = {'ram': 1024} + self._test_index_absolute_limits_json({'maxTotalRAMSize': 1024}) + + def test_index_absolute_cores_json(self): + self.absolute_limits = {'cores': 17} + self._test_index_absolute_limits_json({'maxTotalCores': 17}) + + def test_index_absolute_instances_json(self): + self.absolute_limits = {'instances': 19} + self._test_index_absolute_limits_json({'maxTotalInstances': 19}) + + def test_index_absolute_metadata_json(self): + # NOTE: both server metadata and image metadata are overloaded + # into metadata_items + self.absolute_limits = {'metadata_items': 23} + expected = { + 'maxServerMeta': 23, + 'maxImageMeta': 23, + } + self._test_index_absolute_limits_json(expected) + + def test_index_absolute_injected_files(self): + self.absolute_limits = { + 'injected_files': 17, + 'injected_file_content_bytes': 86753, + } + expected = { + 'maxPersonality': 17, + 'maxPersonalitySize': 86753, + } + self._test_index_absolute_limits_json(expected) + class LimitMiddlewareTest(BaseLimitTestSuite): """ diff --git a/nova/tests/api/openstack/test_servers.py b/nova/tests/api/openstack/test_servers.py index e8182b6a9..fbde5c9ce 100644 --- a/nova/tests/api/openstack/test_servers.py +++ b/nova/tests/api/openstack/test_servers.py @@ -138,6 +138,16 @@ def find_host(self, context, instance_id): return "nova" +class MockSetAdminPassword(object): + def __init__(self): + self.instance_id = None + self.password = None + + def __call__(self, context, instance_id, password): + self.instance_id = instance_id + self.password = password + + class ServersTest(test.TestCase): def setUp(self): @@ -764,8 +774,7 @@ class ServersTest(test.TestCase): def server_update(context, id, params): filtered_dict = dict( - display_name='server_test', - admin_pass='bacon', + display_name='server_test' ) self.assertEqual(params, filtered_dict) return filtered_dict @@ -773,6 +782,8 @@ class ServersTest(test.TestCase): self.stubs.Set(nova.db.api, 'instance_update', server_update) self.stubs.Set(nova.compute.api.API, "_find_host", find_host) + mock_method = MockSetAdminPassword() + self.stubs.Set(nova.compute.api.API, 'set_admin_password', mock_method) req = webob.Request.blank('/v1.0/servers/1') req.method = 'PUT' @@ -780,6 +791,8 @@ class ServersTest(test.TestCase): req.body = self.body res = req.get_response(fakes.wsgi_app()) self.assertEqual(res.status_int, 204) + self.assertEqual(mock_method.instance_id, '1') + self.assertEqual(mock_method.password, 'bacon') def test_update_server_adminPass_ignored_v1_1(self): inst_dict = dict(name='server_test', adminPass='bacon') @@ -996,16 +1009,6 @@ class ServersTest(test.TestCase): self.assertEqual(res.status_int, 501) def test_server_change_password_v1_1(self): - - class MockSetAdminPassword(object): - def __init__(self): - self.instance_id = None - self.password = None - - def __call__(self, context, instance_id, password): - self.instance_id = instance_id - self.password = password - mock_method = MockSetAdminPassword() self.stubs.Set(nova.compute.api.API, 'set_admin_password', mock_method) body = {'changePassword': {'adminPass': '1234pass'}} diff --git a/nova/tests/fake_flags.py b/nova/tests/fake_flags.py index 5d7ca98b5..ecefc464a 100644 --- a/nova/tests/fake_flags.py +++ b/nova/tests/fake_flags.py @@ -21,24 +21,24 @@ from nova import flags FLAGS = flags.FLAGS flags.DECLARE('volume_driver', 'nova.volume.manager') -FLAGS.volume_driver = 'nova.volume.driver.FakeISCSIDriver' -FLAGS.connection_type = 'fake' -FLAGS.fake_rabbit = True +FLAGS['volume_driver'].SetDefault('nova.volume.driver.FakeISCSIDriver') +FLAGS['connection_type'].SetDefault('fake') +FLAGS['fake_rabbit'].SetDefault(True) flags.DECLARE('auth_driver', 'nova.auth.manager') -FLAGS.auth_driver = 'nova.auth.dbdriver.DbDriver' +FLAGS['auth_driver'].SetDefault('nova.auth.dbdriver.DbDriver') flags.DECLARE('network_size', 'nova.network.manager') flags.DECLARE('num_networks', 'nova.network.manager') flags.DECLARE('fake_network', 'nova.network.manager') -FLAGS.network_size = 8 -FLAGS.num_networks = 2 -FLAGS.fake_network = True -FLAGS.image_service = 'nova.image.local.LocalImageService' +FLAGS['network_size'].SetDefault(8) +FLAGS['num_networks'].SetDefault(2) +FLAGS['fake_network'].SetDefault(True) +FLAGS['image_service'].SetDefault('nova.image.local.LocalImageService') flags.DECLARE('num_shelves', 'nova.volume.driver') flags.DECLARE('blades_per_shelf', 'nova.volume.driver') flags.DECLARE('iscsi_num_targets', 'nova.volume.driver') -FLAGS.num_shelves = 2 -FLAGS.blades_per_shelf = 4 -FLAGS.iscsi_num_targets = 8 -FLAGS.verbose = True -FLAGS.sqlite_db = "tests.sqlite" -FLAGS.use_ipv6 = True +FLAGS['num_shelves'].SetDefault(2) +FLAGS['blades_per_shelf'].SetDefault(4) +FLAGS['iscsi_num_targets'].SetDefault(8) +FLAGS['verbose'].SetDefault(True) +FLAGS['sqlite_db'].SetDefault("tests.sqlite") +FLAGS['use_ipv6'].SetDefault(True) diff --git a/nova/tests/real_flags.py b/nova/tests/real_flags.py deleted file mode 100644 index 71da04992..000000000 --- a/nova/tests/real_flags.py +++ /dev/null @@ -1,26 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 United States Government as represented by the -# Administrator of the National Aeronautics and Space Administration. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -from nova import flags - -FLAGS = flags.FLAGS - -FLAGS.connection_type = 'libvirt' -FLAGS.fake_rabbit = False -FLAGS.fake_network = False -FLAGS.verbose = False diff --git a/nova/tests/test_api.py b/nova/tests/test_api.py index 97f401b87..7c0331eff 100644 --- a/nova/tests/test_api.py +++ b/nova/tests/test_api.py @@ -224,6 +224,29 @@ class ApiEc2TestCase(test.TestCase): self.manager.delete_project(project) self.manager.delete_user(user) + def test_create_duplicate_key_pair(self): + """Test that, after successfully generating a keypair, + requesting a second keypair with the same name fails sanely""" + self.expect_http() + self.mox.ReplayAll() + keyname = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ + for x in range(random.randint(4, 8))) + user = self.manager.create_user('fake', 'fake', 'fake') + project = self.manager.create_project('fake', 'fake', 'fake') + # NOTE(vish): create depends on pool, so call helper directly + self.ec2.create_key_pair('test') + + try: + self.ec2.create_key_pair('test') + except EC2ResponseError, e: + if e.code == 'KeyPairExists': + pass + else: + self.fail("Unexpected EC2ResponseError: %s " + "(expected KeyPairExists)" % e.code) + else: + self.fail('Exception not raised.') + def test_get_all_security_groups(self): """Test that we can retrieve security groups""" self.expect_http() diff --git a/nova/tests/test_flags.py b/nova/tests/test_flags.py index 707300fcf..05319d91f 100644 --- a/nova/tests/test_flags.py +++ b/nova/tests/test_flags.py @@ -91,6 +91,20 @@ class FlagsTestCase(test.TestCase): self.assert_('runtime_answer' in self.global_FLAGS) self.assertEqual(self.global_FLAGS.runtime_answer, 60) + def test_long_vs_short_flags(self): + flags.DEFINE_string('duplicate_answer_long', 'val', 'desc', + flag_values=self.global_FLAGS) + argv = ['flags_test', '--duplicate_answer=60', 'extra_arg'] + args = self.global_FLAGS(argv) + + self.assert_('duplicate_answer' not in self.global_FLAGS) + self.assert_(self.global_FLAGS.duplicate_answer_long, 60) + + flags.DEFINE_integer('duplicate_answer', 60, 'desc', + flag_values=self.global_FLAGS) + self.assertEqual(self.global_FLAGS.duplicate_answer, 60) + self.assertEqual(self.global_FLAGS.duplicate_answer_long, 'val') + def test_flag_leak_left(self): self.assertEqual(FLAGS.flags_unittest, 'foo') FLAGS.flags_unittest = 'bar' diff --git a/nova/tests/test_virt.py b/nova/tests/test_libvirt.py index 1bec9caca..4efdd6ae9 100644 --- a/nova/tests/test_virt.py +++ b/nova/tests/test_libvirt.py @@ -32,7 +32,8 @@ from nova import utils from nova.api.ec2 import cloud from nova.auth import manager from nova.compute import power_state -from nova.virt import libvirt_conn +from nova.virt.libvirt import connection +from nova.virt.libvirt import firewall libvirt = None FLAGS = flags.FLAGS @@ -83,7 +84,7 @@ class CacheConcurrencyTestCase(test.TestCase): def test_same_fname_concurrency(self): """Ensures that the same fname cache runs at a sequentially""" - conn = libvirt_conn.LibvirtConnection + conn = connection.LibvirtConnection wait1 = eventlet.event.Event() done1 = eventlet.event.Event() eventlet.spawn(conn._cache_image, _concurrency, @@ -104,7 +105,7 @@ class CacheConcurrencyTestCase(test.TestCase): def test_different_fname_concurrency(self): """Ensures that two different fname caches are concurrent""" - conn = libvirt_conn.LibvirtConnection + conn = connection.LibvirtConnection wait1 = eventlet.event.Event() done1 = eventlet.event.Event() eventlet.spawn(conn._cache_image, _concurrency, @@ -125,7 +126,7 @@ class CacheConcurrencyTestCase(test.TestCase): class LibvirtConnTestCase(test.TestCase): def setUp(self): super(LibvirtConnTestCase, self).setUp() - libvirt_conn._late_load_cheetah() + connection._late_load_cheetah() self.flags(fake_call=True) self.manager = manager.AuthManager() @@ -171,8 +172,8 @@ class LibvirtConnTestCase(test.TestCase): return False global libvirt libvirt = __import__('libvirt') - libvirt_conn.libvirt = __import__('libvirt') - libvirt_conn.libxml2 = __import__('libxml2') + connection.libvirt = __import__('libvirt') + connection.libxml2 = __import__('libxml2') return True def create_fake_libvirt_mock(self, **kwargs): @@ -182,7 +183,7 @@ class LibvirtConnTestCase(test.TestCase): class FakeLibvirtConnection(object): pass - # A fake libvirt_conn.IptablesFirewallDriver + # A fake connection.IptablesFirewallDriver class FakeIptablesFirewallDriver(object): def __init__(self, **kwargs): @@ -198,11 +199,11 @@ class LibvirtConnTestCase(test.TestCase): for key, val in kwargs.items(): fake.__setattr__(key, val) - # Inevitable mocks for libvirt_conn.LibvirtConnection - self.mox.StubOutWithMock(libvirt_conn.utils, 'import_class') - libvirt_conn.utils.import_class(mox.IgnoreArg()).AndReturn(fakeip) - self.mox.StubOutWithMock(libvirt_conn.LibvirtConnection, '_conn') - libvirt_conn.LibvirtConnection._conn = fake + # Inevitable mocks for connection.LibvirtConnection + self.mox.StubOutWithMock(connection.utils, 'import_class') + connection.utils.import_class(mox.IgnoreArg()).AndReturn(fakeip) + self.mox.StubOutWithMock(connection.LibvirtConnection, '_conn') + connection.LibvirtConnection._conn = fake def create_service(self, **kwargs): service_ref = {'host': kwargs.get('host', 'dummy'), @@ -214,7 +215,7 @@ class LibvirtConnTestCase(test.TestCase): return db.service_create(context.get_admin_context(), service_ref) def test_preparing_xml_info(self): - conn = libvirt_conn.LibvirtConnection(True) + conn = connection.LibvirtConnection(True) instance_ref = db.instance_create(self.context, self.test_instance) result = conn._prepare_xml_info(instance_ref, False) @@ -229,7 +230,7 @@ class LibvirtConnTestCase(test.TestCase): self.assertTrue(len(result['nics']) == 2) def test_get_nic_for_xml_v4(self): - conn = libvirt_conn.LibvirtConnection(True) + conn = connection.LibvirtConnection(True) network, mapping = _create_network_info()[0] self.flags(use_ipv6=False) params = conn._get_nic_for_xml(network, mapping)['extra_params'] @@ -237,7 +238,7 @@ class LibvirtConnTestCase(test.TestCase): self.assertTrue(params.find('PROJMASKV6') == -1) def test_get_nic_for_xml_v6(self): - conn = libvirt_conn.LibvirtConnection(True) + conn = connection.LibvirtConnection(True) network, mapping = _create_network_info()[0] self.flags(use_ipv6=True) params = conn._get_nic_for_xml(network, mapping)['extra_params'] @@ -282,7 +283,7 @@ class LibvirtConnTestCase(test.TestCase): def test_multi_nic(self): instance_data = dict(self.test_instance) network_info = _create_network_info(2) - conn = libvirt_conn.LibvirtConnection(True) + conn = connection.LibvirtConnection(True) instance_ref = db.instance_create(self.context, instance_data) xml = conn.to_xml(instance_ref, False, network_info) tree = xml_to_tree(xml) @@ -313,7 +314,7 @@ class LibvirtConnTestCase(test.TestCase): 'instance_id': instance_ref['id']}) self.flags(libvirt_type='lxc') - conn = libvirt_conn.LibvirtConnection(True) + conn = connection.LibvirtConnection(True) uri = conn.get_uri() self.assertEquals(uri, 'lxc:///') @@ -419,7 +420,7 @@ class LibvirtConnTestCase(test.TestCase): for (libvirt_type, (expected_uri, checks)) in type_uri_map.iteritems(): FLAGS.libvirt_type = libvirt_type - conn = libvirt_conn.LibvirtConnection(True) + conn = connection.LibvirtConnection(True) uri = conn.get_uri() self.assertEquals(uri, expected_uri) @@ -446,7 +447,7 @@ class LibvirtConnTestCase(test.TestCase): FLAGS.libvirt_uri = testuri for (libvirt_type, (expected_uri, checks)) in type_uri_map.iteritems(): FLAGS.libvirt_type = libvirt_type - conn = libvirt_conn.LibvirtConnection(True) + conn = connection.LibvirtConnection(True) uri = conn.get_uri() self.assertEquals(uri, testuri) db.instance_destroy(user_context, instance_ref['id']) @@ -470,13 +471,13 @@ class LibvirtConnTestCase(test.TestCase): self.create_fake_libvirt_mock(getVersion=getVersion, getType=getType, listDomainsID=listDomainsID) - self.mox.StubOutWithMock(libvirt_conn.LibvirtConnection, + self.mox.StubOutWithMock(connection.LibvirtConnection, 'get_cpu_info') - libvirt_conn.LibvirtConnection.get_cpu_info().AndReturn('cpuinfo') + connection.LibvirtConnection.get_cpu_info().AndReturn('cpuinfo') # Start test self.mox.ReplayAll() - conn = libvirt_conn.LibvirtConnection(False) + conn = connection.LibvirtConnection(False) conn.update_available_resource(self.context, 'dummy') service_ref = db.service_get(self.context, service_ref['id']) compute_node = service_ref['compute_node'][0] @@ -510,7 +511,7 @@ class LibvirtConnTestCase(test.TestCase): self.create_fake_libvirt_mock() self.mox.ReplayAll() - conn = libvirt_conn.LibvirtConnection(False) + conn = connection.LibvirtConnection(False) self.assertRaises(exception.ComputeServiceUnavailable, conn.update_available_resource, self.context, 'dummy') @@ -545,7 +546,7 @@ class LibvirtConnTestCase(test.TestCase): # Start test self.mox.ReplayAll() try: - conn = libvirt_conn.LibvirtConnection(False) + conn = connection.LibvirtConnection(False) conn.firewall_driver.setattr('setup_basic_filtering', fake_none) conn.firewall_driver.setattr('prepare_instance_filter', fake_none) conn.firewall_driver.setattr('instance_filter_exists', fake_none) @@ -594,7 +595,7 @@ class LibvirtConnTestCase(test.TestCase): # Start test self.mox.ReplayAll() - conn = libvirt_conn.LibvirtConnection(False) + conn = connection.LibvirtConnection(False) self.assertRaises(libvirt.libvirtError, conn._live_migration, self.context, instance_ref, 'dest', '', @@ -623,7 +624,7 @@ class LibvirtConnTestCase(test.TestCase): # Start test self.mox.ReplayAll() - conn = libvirt_conn.LibvirtConnection(False) + conn = connection.LibvirtConnection(False) conn.firewall_driver.setattr('setup_basic_filtering', fake_none) conn.firewall_driver.setattr('prepare_instance_filter', fake_none) @@ -647,7 +648,7 @@ class LibvirtConnTestCase(test.TestCase): self.assertTrue(count) def test_get_host_ip_addr(self): - conn = libvirt_conn.LibvirtConnection(False) + conn = connection.LibvirtConnection(False) ip = conn.get_host_ip_addr() self.assertEquals(ip, FLAGS.my_ip) @@ -671,7 +672,7 @@ class IptablesFirewallTestCase(test.TestCase): class FakeLibvirtConnection(object): pass self.fake_libvirt_connection = FakeLibvirtConnection() - self.fw = libvirt_conn.IptablesFirewallDriver( + self.fw = firewall.IptablesFirewallDriver( get_connection=lambda: self.fake_libvirt_connection) def tearDown(self): @@ -895,7 +896,7 @@ class NWFilterTestCase(test.TestCase): self.fake_libvirt_connection = Mock() - self.fw = libvirt_conn.NWFilterFirewall( + self.fw = firewall.NWFilterFirewall( lambda: self.fake_libvirt_connection) def tearDown(self): diff --git a/nova/tests/test_quota.py b/nova/tests/test_quota.py index 7ace2ad7d..916fca55e 100644 --- a/nova/tests/test_quota.py +++ b/nova/tests/test_quota.py @@ -104,6 +104,10 @@ class QuotaTestCase(test.TestCase): num_instances = quota.allowed_instances(self.context, 100, self._get_instance_type('m1.small')) self.assertEqual(num_instances, 10) + db.quota_create(self.context, self.project.id, 'ram', 3 * 2048) + num_instances = quota.allowed_instances(self.context, 100, + self._get_instance_type('m1.small')) + self.assertEqual(num_instances, 3) # metadata_items too_many_items = FLAGS.quota_metadata_items + 1000 @@ -120,7 +124,8 @@ class QuotaTestCase(test.TestCase): def test_unlimited_instances(self): FLAGS.quota_instances = 2 - FLAGS.quota_cores = 1000 + FLAGS.quota_ram = -1 + FLAGS.quota_cores = -1 instance_type = self._get_instance_type('m1.small') num_instances = quota.allowed_instances(self.context, 100, instance_type) @@ -133,8 +138,25 @@ class QuotaTestCase(test.TestCase): instance_type) self.assertEqual(num_instances, 101) + def test_unlimited_ram(self): + FLAGS.quota_instances = -1 + FLAGS.quota_ram = 2 * 2048 + FLAGS.quota_cores = -1 + instance_type = self._get_instance_type('m1.small') + num_instances = quota.allowed_instances(self.context, 100, + instance_type) + self.assertEqual(num_instances, 2) + db.quota_create(self.context, self.project.id, 'ram', None) + num_instances = quota.allowed_instances(self.context, 100, + instance_type) + self.assertEqual(num_instances, 100) + num_instances = quota.allowed_instances(self.context, 101, + instance_type) + self.assertEqual(num_instances, 101) + def test_unlimited_cores(self): - FLAGS.quota_instances = 1000 + FLAGS.quota_instances = -1 + FLAGS.quota_ram = -1 FLAGS.quota_cores = 2 instance_type = self._get_instance_type('m1.small') num_instances = quota.allowed_instances(self.context, 100, @@ -150,7 +172,7 @@ class QuotaTestCase(test.TestCase): def test_unlimited_volumes(self): FLAGS.quota_volumes = 10 - FLAGS.quota_gigabytes = 1000 + FLAGS.quota_gigabytes = -1 volumes = quota.allowed_volumes(self.context, 100, 1) self.assertEqual(volumes, 10) db.quota_create(self.context, self.project.id, 'volumes', None) @@ -160,7 +182,7 @@ class QuotaTestCase(test.TestCase): self.assertEqual(volumes, 101) def test_unlimited_gigabytes(self): - FLAGS.quota_volumes = 1000 + FLAGS.quota_volumes = -1 FLAGS.quota_gigabytes = 10 volumes = quota.allowed_volumes(self.context, 100, 1) self.assertEqual(volumes, 10) @@ -274,10 +296,47 @@ class QuotaTestCase(test.TestCase): image_id='fake', metadata=metadata) - def test_allowed_injected_files(self): - self.assertEqual( - quota.allowed_injected_files(self.context), - FLAGS.quota_max_injected_files) + def test_default_allowed_injected_files(self): + FLAGS.quota_max_injected_files = 55 + self.assertEqual(quota.allowed_injected_files(self.context, 100), 55) + + def test_overridden_allowed_injected_files(self): + FLAGS.quota_max_injected_files = 5 + db.quota_create(self.context, self.project.id, 'injected_files', 77) + self.assertEqual(quota.allowed_injected_files(self.context, 100), 77) + + def test_unlimited_default_allowed_injected_files(self): + FLAGS.quota_max_injected_files = -1 + self.assertEqual(quota.allowed_injected_files(self.context, 100), 100) + + def test_unlimited_db_allowed_injected_files(self): + FLAGS.quota_max_injected_files = 5 + db.quota_create(self.context, self.project.id, 'injected_files', None) + self.assertEqual(quota.allowed_injected_files(self.context, 100), 100) + + def test_default_allowed_injected_file_content_bytes(self): + FLAGS.quota_max_injected_file_content_bytes = 12345 + limit = quota.allowed_injected_file_content_bytes(self.context, 23456) + self.assertEqual(limit, 12345) + + def test_overridden_allowed_injected_file_content_bytes(self): + FLAGS.quota_max_injected_file_content_bytes = 12345 + db.quota_create(self.context, self.project.id, + 'injected_file_content_bytes', 5678) + limit = quota.allowed_injected_file_content_bytes(self.context, 23456) + self.assertEqual(limit, 5678) + + def test_unlimited_default_allowed_injected_file_content_bytes(self): + FLAGS.quota_max_injected_file_content_bytes = -1 + limit = quota.allowed_injected_file_content_bytes(self.context, 23456) + self.assertEqual(limit, 23456) + + def test_unlimited_db_allowed_injected_file_content_bytes(self): + FLAGS.quota_max_injected_file_content_bytes = 12345 + db.quota_create(self.context, self.project.id, + 'injected_file_content_bytes', None) + limit = quota.allowed_injected_file_content_bytes(self.context, 23456) + self.assertEqual(limit, 23456) def _create_with_injected_files(self, files): api = compute.API(image_service=self.StubImageService()) @@ -304,11 +363,6 @@ class QuotaTestCase(test.TestCase): self.assertRaises(quota.QuotaError, self._create_with_injected_files, files) - def test_allowed_injected_file_content_bytes(self): - self.assertEqual( - quota.allowed_injected_file_content_bytes(self.context), - FLAGS.quota_max_injected_file_content_bytes) - def test_max_injected_file_content_bytes(self): max = FLAGS.quota_max_injected_file_content_bytes content = ''.join(['a' for i in xrange(max)]) diff --git a/nova/virt/connection.py b/nova/virt/connection.py index 99a8849f1..aeec17c98 100644 --- a/nova/virt/connection.py +++ b/nova/virt/connection.py @@ -27,9 +27,9 @@ from nova import utils from nova.virt import driver from nova.virt import fake from nova.virt import hyperv -from nova.virt import libvirt_conn from nova.virt import vmwareapi_conn from nova.virt import xenapi_conn +from nova.virt.libvirt import connection as libvirt_conn LOG = logging.getLogger("nova.virt.connection") diff --git a/nova/virt/libvirt/__init__.py b/nova/virt/libvirt/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/nova/virt/libvirt/__init__.py diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt/connection.py index fa918b0a3..94a703954 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt/connection.py @@ -57,7 +57,6 @@ from nova import context from nova import db from nova import exception from nova import flags -from nova import ipv6 from nova import log as logging from nova import utils from nova import vnc @@ -67,20 +66,23 @@ from nova.compute import power_state from nova.virt import disk from nova.virt import driver from nova.virt import images +from nova.virt.libvirt import netutils + libvirt = None libxml2 = None Template = None + LOG = logging.getLogger('nova.virt.libvirt_conn') + FLAGS = flags.FLAGS flags.DECLARE('live_migration_retry_count', 'nova.compute.manager') # TODO(vish): These flags should probably go into a shared location flags.DEFINE_string('rescue_image_id', 'ami-rescue', 'Rescue ami image') flags.DEFINE_string('rescue_kernel_id', 'aki-rescue', 'Rescue aki image') flags.DEFINE_string('rescue_ramdisk_id', 'ari-rescue', 'Rescue ari image') - flags.DEFINE_string('libvirt_xml_template', utils.abspath('virt/libvirt.xml.template'), 'Libvirt XML Template') @@ -102,7 +104,7 @@ flags.DEFINE_string('ajaxterm_portrange', '10000-12000', 'Range of ports that ajaxterm should randomly try to bind') flags.DEFINE_string('firewall_driver', - 'nova.virt.libvirt_conn.IptablesFirewallDriver', + 'nova.virt.libvirt.firewall.IptablesFirewallDriver', 'Firewall driver (defaults to iptables)') flags.DEFINE_string('cpuinfo_xml_template', utils.abspath('virt/cpuinfo.xml.template'), @@ -144,70 +146,6 @@ def _late_load_cheetah(): Template = t.Template -def _get_net_and_mask(cidr): - net = IPy.IP(cidr) - return str(net.net()), str(net.netmask()) - - -def _get_net_and_prefixlen(cidr): - net = IPy.IP(cidr) - return str(net.net()), str(net.prefixlen()) - - -def _get_ip_version(cidr): - net = IPy.IP(cidr) - return int(net.version()) - - -def _get_network_info(instance): - # TODO(adiantum) If we will keep this function - # we should cache network_info - admin_context = context.get_admin_context() - - ip_addresses = db.fixed_ip_get_all_by_instance(admin_context, - instance['id']) - networks = db.network_get_all_by_instance(admin_context, - instance['id']) - flavor = db.instance_type_get_by_id(admin_context, - instance['instance_type_id']) - network_info = [] - - for network in networks: - network_ips = [ip for ip in ip_addresses - if ip['network_id'] == network['id']] - - def ip_dict(ip): - return { - 'ip': ip['address'], - 'netmask': network['netmask'], - 'enabled': '1'} - - def ip6_dict(): - prefix = network['cidr_v6'] - mac = instance['mac_address'] - project_id = instance['project_id'] - return { - 'ip': ipv6.to_global(prefix, mac, project_id), - 'netmask': network['netmask_v6'], - 'enabled': '1'} - - mapping = { - 'label': network['label'], - 'gateway': network['gateway'], - 'broadcast': network['broadcast'], - 'mac': instance['mac_address'], - 'rxtx_cap': flavor['rxtx_cap'], - 'dns': [network['dns']], - 'ips': [ip_dict(ip) for ip in network_ips]} - - if FLAGS.use_ipv6: - mapping['ip6s'] = [ip6_dict()] - mapping['gateway6'] = network['gateway_v6'] - - network_info.append((network, mapping)) - return network_info - - class LibvirtConnection(driver.ComputeDriver): def __init__(self, read_only): @@ -807,7 +745,7 @@ class LibvirtConnection(driver.ComputeDriver): def _create_image(self, inst, libvirt_xml, suffix='', disk_images=None, network_info=None): if not network_info: - network_info = _get_network_info(inst) + network_info = netutils.get_network_info(inst) if not suffix: suffix = '' @@ -966,10 +904,10 @@ class LibvirtConnection(driver.ComputeDriver): if FLAGS.allow_project_net_traffic: template = "<parameter name=\"%s\"value=\"%s\" />\n" - net, mask = _get_net_and_mask(network['cidr']) + net, mask = netutils.get_net_and_mask(network['cidr']) values = [("PROJNET", net), ("PROJMASK", mask)] if FLAGS.use_ipv6: - net_v6, prefixlen_v6 = _get_net_and_prefixlen( + net_v6, prefixlen_v6 = netutils.get_net_and_prefixlen( network['cidr_v6']) values.extend([("PROJNETV6", net_v6), ("PROJMASKV6", prefixlen_v6)]) @@ -996,7 +934,7 @@ class LibvirtConnection(driver.ComputeDriver): # TODO(adiantum) remove network_info creation code # when multinics will be completed if not network_info: - network_info = _get_network_info(instance) + network_info = netutils.get_network_info(instance) nics = [] for (network, mapping) in network_info: @@ -1591,606 +1529,3 @@ class LibvirtConnection(driver.ComputeDriver): def get_host_stats(self, refresh=False): """See xenapi_conn.py implementation.""" pass - - -class FirewallDriver(object): - def prepare_instance_filter(self, instance, network_info=None): - """Prepare filters for the instance. - - At this point, the instance isn't running yet.""" - raise NotImplementedError() - - def unfilter_instance(self, instance): - """Stop filtering instance""" - raise NotImplementedError() - - def apply_instance_filter(self, instance): - """Apply instance filter. - - Once this method returns, the instance should be firewalled - appropriately. This method should as far as possible be a - no-op. It's vastly preferred to get everything set up in - prepare_instance_filter. - """ - raise NotImplementedError() - - def refresh_security_group_rules(self, - security_group_id, - network_info=None): - """Refresh security group rules from data store - - Gets called when a rule has been added to or removed from - the security group.""" - raise NotImplementedError() - - def refresh_security_group_members(self, security_group_id): - """Refresh security group members from data store - - Gets called when an instance gets added to or removed from - the security group.""" - raise NotImplementedError() - - def setup_basic_filtering(self, instance, network_info=None): - """Create rules to block spoofing and allow dhcp. - - This gets called when spawning an instance, before - :method:`prepare_instance_filter`. - - """ - raise NotImplementedError() - - def instance_filter_exists(self, instance): - """Check nova-instance-instance-xxx exists""" - raise NotImplementedError() - - -class NWFilterFirewall(FirewallDriver): - """ - This class implements a network filtering mechanism versatile - enough for EC2 style Security Group filtering by leveraging - libvirt's nwfilter. - - First, all instances get a filter ("nova-base-filter") applied. - This filter provides some basic security such as protection against - MAC spoofing, IP spoofing, and ARP spoofing. - - This filter drops all incoming ipv4 and ipv6 connections. - Outgoing connections are never blocked. - - Second, every security group maps to a nwfilter filter(*). - NWFilters can be updated at runtime and changes are applied - immediately, so changes to security groups can be applied at - runtime (as mandated by the spec). - - Security group rules are named "nova-secgroup-<id>" where <id> - is the internal id of the security group. They're applied only on - hosts that have instances in the security group in question. - - Updates to security groups are done by updating the data model - (in response to API calls) followed by a request sent to all - the nodes with instances in the security group to refresh the - security group. - - Each instance has its own NWFilter, which references the above - mentioned security group NWFilters. This was done because - interfaces can only reference one filter while filters can - reference multiple other filters. This has the added benefit of - actually being able to add and remove security groups from an - instance at run time. This functionality is not exposed anywhere, - though. - - Outstanding questions: - - The name is unique, so would there be any good reason to sync - the uuid across the nodes (by assigning it from the datamodel)? - - - (*) This sentence brought to you by the redundancy department of - redundancy. - - """ - - def __init__(self, get_connection, **kwargs): - self._libvirt_get_connection = get_connection - self.static_filters_configured = False - self.handle_security_groups = False - - def apply_instance_filter(self, instance): - """No-op. Everything is done in prepare_instance_filter""" - pass - - def _get_connection(self): - return self._libvirt_get_connection() - _conn = property(_get_connection) - - def nova_dhcp_filter(self): - """The standard allow-dhcp-server filter is an <ip> one, so it uses - ebtables to allow traffic through. Without a corresponding rule in - iptables, it'll get blocked anyway.""" - - return '''<filter name='nova-allow-dhcp-server' chain='ipv4'> - <uuid>891e4787-e5c0-d59b-cbd6-41bc3c6b36fc</uuid> - <rule action='accept' direction='out' - priority='100'> - <udp srcipaddr='0.0.0.0' - dstipaddr='255.255.255.255' - srcportstart='68' - dstportstart='67'/> - </rule> - <rule action='accept' direction='in' - priority='100'> - <udp srcipaddr='$DHCPSERVER' - srcportstart='67' - dstportstart='68'/> - </rule> - </filter>''' - - def nova_ra_filter(self): - return '''<filter name='nova-allow-ra-server' chain='root'> - <uuid>d707fa71-4fb5-4b27-9ab7-ba5ca19c8804</uuid> - <rule action='accept' direction='inout' - priority='100'> - <icmpv6 srcipaddr='$RASERVER'/> - </rule> - </filter>''' - - def setup_basic_filtering(self, instance, network_info=None): - """Set up basic filtering (MAC, IP, and ARP spoofing protection)""" - logging.info('called setup_basic_filtering in nwfilter') - - if not network_info: - network_info = _get_network_info(instance) - - if self.handle_security_groups: - # No point in setting up a filter set that we'll be overriding - # anyway. - return - - logging.info('ensuring static filters') - self._ensure_static_filters() - - if instance['image_id'] == str(FLAGS.vpn_image_id): - base_filter = 'nova-vpn' - else: - base_filter = 'nova-base' - - for (network, mapping) in network_info: - nic_id = mapping['mac'].replace(':', '') - instance_filter_name = self._instance_filter_name(instance, nic_id) - self._define_filter(self._filter_container(instance_filter_name, - [base_filter])) - - def _ensure_static_filters(self): - if self.static_filters_configured: - return - - self._define_filter(self._filter_container('nova-base', - ['no-mac-spoofing', - 'no-ip-spoofing', - 'no-arp-spoofing', - 'allow-dhcp-server'])) - self._define_filter(self._filter_container('nova-vpn', - ['allow-dhcp-server'])) - self._define_filter(self.nova_base_ipv4_filter) - self._define_filter(self.nova_base_ipv6_filter) - self._define_filter(self.nova_dhcp_filter) - self._define_filter(self.nova_ra_filter) - if FLAGS.allow_project_net_traffic: - self._define_filter(self.nova_project_filter) - if FLAGS.use_ipv6: - self._define_filter(self.nova_project_filter_v6) - - self.static_filters_configured = True - - def _filter_container(self, name, filters): - xml = '''<filter name='%s' chain='root'>%s</filter>''' % ( - name, - ''.join(["<filterref filter='%s'/>" % (f,) for f in filters])) - return xml - - def nova_base_ipv4_filter(self): - retval = "<filter name='nova-base-ipv4' chain='ipv4'>" - for protocol in ['tcp', 'udp', 'icmp']: - for direction, action, priority in [('out', 'accept', 399), - ('in', 'drop', 400)]: - retval += """<rule action='%s' direction='%s' priority='%d'> - <%s /> - </rule>""" % (action, direction, - priority, protocol) - retval += '</filter>' - return retval - - def nova_base_ipv6_filter(self): - retval = "<filter name='nova-base-ipv6' chain='ipv6'>" - for protocol in ['tcp-ipv6', 'udp-ipv6', 'icmpv6']: - for direction, action, priority in [('out', 'accept', 399), - ('in', 'drop', 400)]: - retval += """<rule action='%s' direction='%s' priority='%d'> - <%s /> - </rule>""" % (action, direction, - priority, protocol) - retval += '</filter>' - return retval - - def nova_project_filter(self): - retval = "<filter name='nova-project' chain='ipv4'>" - for protocol in ['tcp', 'udp', 'icmp']: - retval += """<rule action='accept' direction='in' priority='200'> - <%s srcipaddr='$PROJNET' srcipmask='$PROJMASK' /> - </rule>""" % protocol - retval += '</filter>' - return retval - - def nova_project_filter_v6(self): - retval = "<filter name='nova-project-v6' chain='ipv6'>" - for protocol in ['tcp-ipv6', 'udp-ipv6', 'icmpv6']: - retval += """<rule action='accept' direction='inout' - priority='200'> - <%s srcipaddr='$PROJNETV6' - srcipmask='$PROJMASKV6' /> - </rule>""" % (protocol) - retval += '</filter>' - return retval - - def _define_filter(self, xml): - if callable(xml): - xml = xml() - # execute in a native thread and block current greenthread until done - tpool.execute(self._conn.nwfilterDefineXML, xml) - - def unfilter_instance(self, instance): - # Nothing to do - pass - - def prepare_instance_filter(self, instance, network_info=None): - """ - Creates an NWFilter for the given instance. In the process, - it makes sure the filters for the security groups as well as - the base filter are all in place. - """ - if not network_info: - network_info = _get_network_info(instance) - - ctxt = context.get_admin_context() - - instance_secgroup_filter_name = \ - '%s-secgroup' % (self._instance_filter_name(instance)) - #% (instance_filter_name,) - - instance_secgroup_filter_children = ['nova-base-ipv4', - 'nova-base-ipv6', - 'nova-allow-dhcp-server'] - - if FLAGS.use_ipv6: - networks = [network for (network, _m) in network_info if - network['gateway_v6']] - - if networks: - instance_secgroup_filter_children.\ - append('nova-allow-ra-server') - - for security_group in \ - db.security_group_get_by_instance(ctxt, instance['id']): - - self.refresh_security_group_rules(security_group['id']) - - instance_secgroup_filter_children.append('nova-secgroup-%s' % - security_group['id']) - - self._define_filter( - self._filter_container(instance_secgroup_filter_name, - instance_secgroup_filter_children)) - - network_filters = self.\ - _create_network_filters(instance, network_info, - instance_secgroup_filter_name) - - for (name, children) in network_filters: - self._define_filters(name, children) - - def _create_network_filters(self, instance, network_info, - instance_secgroup_filter_name): - if instance['image_id'] == str(FLAGS.vpn_image_id): - base_filter = 'nova-vpn' - else: - base_filter = 'nova-base' - - result = [] - for (_n, mapping) in network_info: - nic_id = mapping['mac'].replace(':', '') - instance_filter_name = self._instance_filter_name(instance, nic_id) - instance_filter_children = [base_filter, - instance_secgroup_filter_name] - - if FLAGS.allow_project_net_traffic: - instance_filter_children.append('nova-project') - if FLAGS.use_ipv6: - instance_filter_children.append('nova-project-v6') - - result.append((instance_filter_name, instance_filter_children)) - - return result - - def _define_filters(self, filter_name, filter_children): - self._define_filter(self._filter_container(filter_name, - filter_children)) - - def refresh_security_group_rules(self, - security_group_id, - network_info=None): - return self._define_filter( - self.security_group_to_nwfilter_xml(security_group_id)) - - def security_group_to_nwfilter_xml(self, security_group_id): - security_group = db.security_group_get(context.get_admin_context(), - security_group_id) - rule_xml = "" - v6protocol = {'tcp': 'tcp-ipv6', 'udp': 'udp-ipv6', 'icmp': 'icmpv6'} - for rule in security_group.rules: - rule_xml += "<rule action='accept' direction='in' priority='300'>" - if rule.cidr: - version = _get_ip_version(rule.cidr) - if(FLAGS.use_ipv6 and version == 6): - net, prefixlen = _get_net_and_prefixlen(rule.cidr) - rule_xml += "<%s srcipaddr='%s' srcipmask='%s' " % \ - (v6protocol[rule.protocol], net, prefixlen) - else: - net, mask = _get_net_and_mask(rule.cidr) - rule_xml += "<%s srcipaddr='%s' srcipmask='%s' " % \ - (rule.protocol, net, mask) - if rule.protocol in ['tcp', 'udp']: - rule_xml += "dstportstart='%s' dstportend='%s' " % \ - (rule.from_port, rule.to_port) - elif rule.protocol == 'icmp': - LOG.info('rule.protocol: %r, rule.from_port: %r, ' - 'rule.to_port: %r', rule.protocol, - rule.from_port, rule.to_port) - if rule.from_port != -1: - rule_xml += "type='%s' " % rule.from_port - if rule.to_port != -1: - rule_xml += "code='%s' " % rule.to_port - - rule_xml += '/>\n' - rule_xml += "</rule>\n" - xml = "<filter name='nova-secgroup-%s' " % security_group_id - if(FLAGS.use_ipv6): - xml += "chain='root'>%s</filter>" % rule_xml - else: - xml += "chain='ipv4'>%s</filter>" % rule_xml - return xml - - def _instance_filter_name(self, instance, nic_id=None): - if not nic_id: - return 'nova-instance-%s' % (instance['name']) - return 'nova-instance-%s-%s' % (instance['name'], nic_id) - - def instance_filter_exists(self, instance): - """Check nova-instance-instance-xxx exists""" - network_info = _get_network_info(instance) - for (network, mapping) in network_info: - nic_id = mapping['mac'].replace(':', '') - instance_filter_name = self._instance_filter_name(instance, nic_id) - try: - self._conn.nwfilterLookupByName(instance_filter_name) - except libvirt.libvirtError: - name = instance.name - LOG.debug(_('The nwfilter(%(instance_filter_name)s) for' - '%(name)s is not found.') % locals()) - return False - return True - - -class IptablesFirewallDriver(FirewallDriver): - def __init__(self, execute=None, **kwargs): - from nova.network import linux_net - self.iptables = linux_net.iptables_manager - self.instances = {} - self.nwfilter = NWFilterFirewall(kwargs['get_connection']) - - self.iptables.ipv4['filter'].add_chain('sg-fallback') - self.iptables.ipv4['filter'].add_rule('sg-fallback', '-j DROP') - self.iptables.ipv6['filter'].add_chain('sg-fallback') - self.iptables.ipv6['filter'].add_rule('sg-fallback', '-j DROP') - - def setup_basic_filtering(self, instance, network_info=None): - """Use NWFilter from libvirt for this.""" - if not network_info: - network_info = _get_network_info(instance) - return self.nwfilter.setup_basic_filtering(instance, network_info) - - def apply_instance_filter(self, instance): - """No-op. Everything is done in prepare_instance_filter""" - pass - - def unfilter_instance(self, instance): - if self.instances.pop(instance['id'], None): - self.remove_filters_for_instance(instance) - self.iptables.apply() - else: - LOG.info(_('Attempted to unfilter instance %s which is not ' - 'filtered'), instance['id']) - - def prepare_instance_filter(self, instance, network_info=None): - if not network_info: - network_info = _get_network_info(instance) - self.instances[instance['id']] = instance - self.add_filters_for_instance(instance, network_info) - self.iptables.apply() - - def _create_filter(self, ips, chain_name): - return ['-d %s -j $%s' % (ip, chain_name) for ip in ips] - - def _filters_for_instance(self, chain_name, network_info): - ips_v4 = [ip['ip'] for (_n, mapping) in network_info - for ip in mapping['ips']] - ipv4_rules = self._create_filter(ips_v4, chain_name) - - ipv6_rules = [] - if FLAGS.use_ipv6: - ips_v6 = [ip['ip'] for (_n, mapping) in network_info - for ip in mapping['ip6s']] - ipv6_rules = self._create_filter(ips_v6, chain_name) - - return ipv4_rules, ipv6_rules - - def _add_filters(self, chain_name, ipv4_rules, ipv6_rules): - for rule in ipv4_rules: - self.iptables.ipv4['filter'].add_rule(chain_name, rule) - - if FLAGS.use_ipv6: - for rule in ipv6_rules: - self.iptables.ipv6['filter'].add_rule(chain_name, rule) - - def add_filters_for_instance(self, instance, network_info=None): - chain_name = self._instance_chain_name(instance) - if FLAGS.use_ipv6: - self.iptables.ipv6['filter'].add_chain(chain_name) - self.iptables.ipv4['filter'].add_chain(chain_name) - ipv4_rules, ipv6_rules = self._filters_for_instance(chain_name, - network_info) - self._add_filters('local', ipv4_rules, ipv6_rules) - ipv4_rules, ipv6_rules = self.instance_rules(instance, network_info) - self._add_filters(chain_name, ipv4_rules, ipv6_rules) - - def remove_filters_for_instance(self, instance): - chain_name = self._instance_chain_name(instance) - - self.iptables.ipv4['filter'].remove_chain(chain_name) - if FLAGS.use_ipv6: - self.iptables.ipv6['filter'].remove_chain(chain_name) - - def instance_rules(self, instance, network_info=None): - if not network_info: - network_info = _get_network_info(instance) - ctxt = context.get_admin_context() - - ipv4_rules = [] - ipv6_rules = [] - - # Always drop invalid packets - ipv4_rules += ['-m state --state ' 'INVALID -j DROP'] - ipv6_rules += ['-m state --state ' 'INVALID -j DROP'] - - # Allow established connections - ipv4_rules += ['-m state --state ESTABLISHED,RELATED -j ACCEPT'] - ipv6_rules += ['-m state --state ESTABLISHED,RELATED -j ACCEPT'] - - dhcp_servers = [network['gateway'] for (network, _m) in network_info] - - for dhcp_server in dhcp_servers: - ipv4_rules.append('-s %s -p udp --sport 67 --dport 68 ' - '-j ACCEPT' % (dhcp_server,)) - - #Allow project network traffic - if FLAGS.allow_project_net_traffic: - cidrs = [network['cidr'] for (network, _m) in network_info] - for cidr in cidrs: - ipv4_rules.append('-s %s -j ACCEPT' % (cidr,)) - - # We wrap these in FLAGS.use_ipv6 because they might cause - # a DB lookup. The other ones are just list operations, so - # they're not worth the clutter. - if FLAGS.use_ipv6: - # Allow RA responses - gateways_v6 = [network['gateway_v6'] for (network, _) in - network_info] - for gateway_v6 in gateways_v6: - ipv6_rules.append( - '-s %s/128 -p icmpv6 -j ACCEPT' % (gateway_v6,)) - - #Allow project network traffic - if FLAGS.allow_project_net_traffic: - cidrv6s = [network['cidr_v6'] for (network, _m) - in network_info] - - for cidrv6 in cidrv6s: - ipv6_rules.append('-s %s -j ACCEPT' % (cidrv6,)) - - security_groups = db.security_group_get_by_instance(ctxt, - instance['id']) - - # then, security group chains and rules - for security_group in security_groups: - rules = db.security_group_rule_get_by_security_group(ctxt, - security_group['id']) - - for rule in rules: - logging.info('%r', rule) - - if not rule.cidr: - # Eventually, a mechanism to grant access for security - # groups will turn up here. It'll use ipsets. - continue - - version = _get_ip_version(rule.cidr) - if version == 4: - rules = ipv4_rules - else: - rules = ipv6_rules - - protocol = rule.protocol - if version == 6 and rule.protocol == 'icmp': - protocol = 'icmpv6' - - args = ['-p', protocol, '-s', rule.cidr] - - if rule.protocol in ['udp', 'tcp']: - if rule.from_port == rule.to_port: - args += ['--dport', '%s' % (rule.from_port,)] - else: - args += ['-m', 'multiport', - '--dports', '%s:%s' % (rule.from_port, - rule.to_port)] - elif rule.protocol == 'icmp': - icmp_type = rule.from_port - icmp_code = rule.to_port - - if icmp_type == -1: - icmp_type_arg = None - else: - icmp_type_arg = '%s' % icmp_type - if not icmp_code == -1: - icmp_type_arg += '/%s' % icmp_code - - if icmp_type_arg: - if version == 4: - args += ['-m', 'icmp', '--icmp-type', - icmp_type_arg] - elif version == 6: - args += ['-m', 'icmp6', '--icmpv6-type', - icmp_type_arg] - - args += ['-j ACCEPT'] - rules += [' '.join(args)] - - ipv4_rules += ['-j $sg-fallback'] - ipv6_rules += ['-j $sg-fallback'] - - return ipv4_rules, ipv6_rules - - def instance_filter_exists(self, instance): - """Check nova-instance-instance-xxx exists""" - return self.nwfilter.instance_filter_exists(instance) - - def refresh_security_group_members(self, security_group): - pass - - def refresh_security_group_rules(self, security_group, network_info=None): - self.do_refresh_security_group_rules(security_group, network_info) - self.iptables.apply() - - @utils.synchronized('iptables', external=True) - def do_refresh_security_group_rules(self, - security_group, - network_info=None): - for instance in self.instances.values(): - self.remove_filters_for_instance(instance) - if not network_info: - network_info = _get_network_info(instance) - self.add_filters_for_instance(instance, network_info) - - def _security_group_chain_name(self, security_group_id): - return 'nova-sg-%s' % (security_group_id,) - - def _instance_chain_name(self, instance): - return 'inst-%s' % (instance['id'],) diff --git a/nova/virt/libvirt/firewall.py b/nova/virt/libvirt/firewall.py new file mode 100644 index 000000000..7e00662cd --- /dev/null +++ b/nova/virt/libvirt/firewall.py @@ -0,0 +1,642 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# Copyright (c) 2010 Citrix Systems, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +from eventlet import tpool + +from nova import context +from nova import db +from nova import flags +from nova import log as logging +from nova import utils +from nova.virt.libvirt import netutils + + +LOG = logging.getLogger("nova.virt.libvirt.firewall") +FLAGS = flags.FLAGS + + +try: + import libvirt +except ImportError: + LOG.warn(_("Libvirt module could not be loaded. NWFilterFirewall will " + "not work correctly.")) + + +class FirewallDriver(object): + def prepare_instance_filter(self, instance, network_info=None): + """Prepare filters for the instance. + + At this point, the instance isn't running yet.""" + raise NotImplementedError() + + def unfilter_instance(self, instance): + """Stop filtering instance""" + raise NotImplementedError() + + def apply_instance_filter(self, instance): + """Apply instance filter. + + Once this method returns, the instance should be firewalled + appropriately. This method should as far as possible be a + no-op. It's vastly preferred to get everything set up in + prepare_instance_filter. + """ + raise NotImplementedError() + + def refresh_security_group_rules(self, + security_group_id, + network_info=None): + """Refresh security group rules from data store + + Gets called when a rule has been added to or removed from + the security group.""" + raise NotImplementedError() + + def refresh_security_group_members(self, security_group_id): + """Refresh security group members from data store + + Gets called when an instance gets added to or removed from + the security group.""" + raise NotImplementedError() + + def setup_basic_filtering(self, instance, network_info=None): + """Create rules to block spoofing and allow dhcp. + + This gets called when spawning an instance, before + :method:`prepare_instance_filter`. + + """ + raise NotImplementedError() + + def instance_filter_exists(self, instance): + """Check nova-instance-instance-xxx exists""" + raise NotImplementedError() + + +class NWFilterFirewall(FirewallDriver): + """ + This class implements a network filtering mechanism versatile + enough for EC2 style Security Group filtering by leveraging + libvirt's nwfilter. + + First, all instances get a filter ("nova-base-filter") applied. + This filter provides some basic security such as protection against + MAC spoofing, IP spoofing, and ARP spoofing. + + This filter drops all incoming ipv4 and ipv6 connections. + Outgoing connections are never blocked. + + Second, every security group maps to a nwfilter filter(*). + NWFilters can be updated at runtime and changes are applied + immediately, so changes to security groups can be applied at + runtime (as mandated by the spec). + + Security group rules are named "nova-secgroup-<id>" where <id> + is the internal id of the security group. They're applied only on + hosts that have instances in the security group in question. + + Updates to security groups are done by updating the data model + (in response to API calls) followed by a request sent to all + the nodes with instances in the security group to refresh the + security group. + + Each instance has its own NWFilter, which references the above + mentioned security group NWFilters. This was done because + interfaces can only reference one filter while filters can + reference multiple other filters. This has the added benefit of + actually being able to add and remove security groups from an + instance at run time. This functionality is not exposed anywhere, + though. + + Outstanding questions: + + The name is unique, so would there be any good reason to sync + the uuid across the nodes (by assigning it from the datamodel)? + + + (*) This sentence brought to you by the redundancy department of + redundancy. + + """ + + def __init__(self, get_connection, **kwargs): + self._libvirt_get_connection = get_connection + self.static_filters_configured = False + self.handle_security_groups = False + + def apply_instance_filter(self, instance): + """No-op. Everything is done in prepare_instance_filter""" + pass + + def _get_connection(self): + return self._libvirt_get_connection() + _conn = property(_get_connection) + + def nova_dhcp_filter(self): + """The standard allow-dhcp-server filter is an <ip> one, so it uses + ebtables to allow traffic through. Without a corresponding rule in + iptables, it'll get blocked anyway.""" + + return '''<filter name='nova-allow-dhcp-server' chain='ipv4'> + <uuid>891e4787-e5c0-d59b-cbd6-41bc3c6b36fc</uuid> + <rule action='accept' direction='out' + priority='100'> + <udp srcipaddr='0.0.0.0' + dstipaddr='255.255.255.255' + srcportstart='68' + dstportstart='67'/> + </rule> + <rule action='accept' direction='in' + priority='100'> + <udp srcipaddr='$DHCPSERVER' + srcportstart='67' + dstportstart='68'/> + </rule> + </filter>''' + + def nova_ra_filter(self): + return '''<filter name='nova-allow-ra-server' chain='root'> + <uuid>d707fa71-4fb5-4b27-9ab7-ba5ca19c8804</uuid> + <rule action='accept' direction='inout' + priority='100'> + <icmpv6 srcipaddr='$RASERVER'/> + </rule> + </filter>''' + + def setup_basic_filtering(self, instance, network_info=None): + """Set up basic filtering (MAC, IP, and ARP spoofing protection)""" + logging.info('called setup_basic_filtering in nwfilter') + + if not network_info: + network_info = netutils.get_network_info(instance) + + if self.handle_security_groups: + # No point in setting up a filter set that we'll be overriding + # anyway. + return + + logging.info('ensuring static filters') + self._ensure_static_filters() + + if instance['image_id'] == str(FLAGS.vpn_image_id): + base_filter = 'nova-vpn' + else: + base_filter = 'nova-base' + + for (network, mapping) in network_info: + nic_id = mapping['mac'].replace(':', '') + instance_filter_name = self._instance_filter_name(instance, nic_id) + self._define_filter(self._filter_container(instance_filter_name, + [base_filter])) + + def _ensure_static_filters(self): + if self.static_filters_configured: + return + + self._define_filter(self._filter_container('nova-base', + ['no-mac-spoofing', + 'no-ip-spoofing', + 'no-arp-spoofing', + 'allow-dhcp-server'])) + self._define_filter(self._filter_container('nova-vpn', + ['allow-dhcp-server'])) + self._define_filter(self.nova_base_ipv4_filter) + self._define_filter(self.nova_base_ipv6_filter) + self._define_filter(self.nova_dhcp_filter) + self._define_filter(self.nova_ra_filter) + if FLAGS.allow_project_net_traffic: + self._define_filter(self.nova_project_filter) + if FLAGS.use_ipv6: + self._define_filter(self.nova_project_filter_v6) + + self.static_filters_configured = True + + def _filter_container(self, name, filters): + xml = '''<filter name='%s' chain='root'>%s</filter>''' % ( + name, + ''.join(["<filterref filter='%s'/>" % (f,) for f in filters])) + return xml + + def nova_base_ipv4_filter(self): + retval = "<filter name='nova-base-ipv4' chain='ipv4'>" + for protocol in ['tcp', 'udp', 'icmp']: + for direction, action, priority in [('out', 'accept', 399), + ('in', 'drop', 400)]: + retval += """<rule action='%s' direction='%s' priority='%d'> + <%s /> + </rule>""" % (action, direction, + priority, protocol) + retval += '</filter>' + return retval + + def nova_base_ipv6_filter(self): + retval = "<filter name='nova-base-ipv6' chain='ipv6'>" + for protocol in ['tcp-ipv6', 'udp-ipv6', 'icmpv6']: + for direction, action, priority in [('out', 'accept', 399), + ('in', 'drop', 400)]: + retval += """<rule action='%s' direction='%s' priority='%d'> + <%s /> + </rule>""" % (action, direction, + priority, protocol) + retval += '</filter>' + return retval + + def nova_project_filter(self): + retval = "<filter name='nova-project' chain='ipv4'>" + for protocol in ['tcp', 'udp', 'icmp']: + retval += """<rule action='accept' direction='in' priority='200'> + <%s srcipaddr='$PROJNET' srcipmask='$PROJMASK' /> + </rule>""" % protocol + retval += '</filter>' + return retval + + def nova_project_filter_v6(self): + retval = "<filter name='nova-project-v6' chain='ipv6'>" + for protocol in ['tcp-ipv6', 'udp-ipv6', 'icmpv6']: + retval += """<rule action='accept' direction='inout' + priority='200'> + <%s srcipaddr='$PROJNETV6' + srcipmask='$PROJMASKV6' /> + </rule>""" % (protocol) + retval += '</filter>' + return retval + + def _define_filter(self, xml): + if callable(xml): + xml = xml() + # execute in a native thread and block current greenthread until done + tpool.execute(self._conn.nwfilterDefineXML, xml) + + def unfilter_instance(self, instance): + # Nothing to do + pass + + def prepare_instance_filter(self, instance, network_info=None): + """ + Creates an NWFilter for the given instance. In the process, + it makes sure the filters for the security groups as well as + the base filter are all in place. + """ + if not network_info: + network_info = netutils.get_network_info(instance) + + ctxt = context.get_admin_context() + + instance_secgroup_filter_name = \ + '%s-secgroup' % (self._instance_filter_name(instance)) + #% (instance_filter_name,) + + instance_secgroup_filter_children = ['nova-base-ipv4', + 'nova-base-ipv6', + 'nova-allow-dhcp-server'] + + if FLAGS.use_ipv6: + networks = [network for (network, _m) in network_info if + network['gateway_v6']] + + if networks: + instance_secgroup_filter_children.\ + append('nova-allow-ra-server') + + for security_group in \ + db.security_group_get_by_instance(ctxt, instance['id']): + + self.refresh_security_group_rules(security_group['id']) + + instance_secgroup_filter_children.append('nova-secgroup-%s' % + security_group['id']) + + self._define_filter( + self._filter_container(instance_secgroup_filter_name, + instance_secgroup_filter_children)) + + network_filters = self.\ + _create_network_filters(instance, network_info, + instance_secgroup_filter_name) + + for (name, children) in network_filters: + self._define_filters(name, children) + + def _create_network_filters(self, instance, network_info, + instance_secgroup_filter_name): + if instance['image_id'] == str(FLAGS.vpn_image_id): + base_filter = 'nova-vpn' + else: + base_filter = 'nova-base' + + result = [] + for (_n, mapping) in network_info: + nic_id = mapping['mac'].replace(':', '') + instance_filter_name = self._instance_filter_name(instance, nic_id) + instance_filter_children = [base_filter, + instance_secgroup_filter_name] + + if FLAGS.allow_project_net_traffic: + instance_filter_children.append('nova-project') + if FLAGS.use_ipv6: + instance_filter_children.append('nova-project-v6') + + result.append((instance_filter_name, instance_filter_children)) + + return result + + def _define_filters(self, filter_name, filter_children): + self._define_filter(self._filter_container(filter_name, + filter_children)) + + def refresh_security_group_rules(self, + security_group_id, + network_info=None): + return self._define_filter( + self.security_group_to_nwfilter_xml(security_group_id)) + + def security_group_to_nwfilter_xml(self, security_group_id): + security_group = db.security_group_get(context.get_admin_context(), + security_group_id) + rule_xml = "" + v6protocol = {'tcp': 'tcp-ipv6', 'udp': 'udp-ipv6', 'icmp': 'icmpv6'} + for rule in security_group.rules: + rule_xml += "<rule action='accept' direction='in' priority='300'>" + if rule.cidr: + version = netutils.get_ip_version(rule.cidr) + if(FLAGS.use_ipv6 and version == 6): + net, prefixlen = netutils.get_net_and_prefixlen(rule.cidr) + rule_xml += "<%s srcipaddr='%s' srcipmask='%s' " % \ + (v6protocol[rule.protocol], net, prefixlen) + else: + net, mask = netutils.get_net_and_mask(rule.cidr) + rule_xml += "<%s srcipaddr='%s' srcipmask='%s' " % \ + (rule.protocol, net, mask) + if rule.protocol in ['tcp', 'udp']: + rule_xml += "dstportstart='%s' dstportend='%s' " % \ + (rule.from_port, rule.to_port) + elif rule.protocol == 'icmp': + LOG.info('rule.protocol: %r, rule.from_port: %r, ' + 'rule.to_port: %r', rule.protocol, + rule.from_port, rule.to_port) + if rule.from_port != -1: + rule_xml += "type='%s' " % rule.from_port + if rule.to_port != -1: + rule_xml += "code='%s' " % rule.to_port + + rule_xml += '/>\n' + rule_xml += "</rule>\n" + xml = "<filter name='nova-secgroup-%s' " % security_group_id + if(FLAGS.use_ipv6): + xml += "chain='root'>%s</filter>" % rule_xml + else: + xml += "chain='ipv4'>%s</filter>" % rule_xml + return xml + + def _instance_filter_name(self, instance, nic_id=None): + if not nic_id: + return 'nova-instance-%s' % (instance['name']) + return 'nova-instance-%s-%s' % (instance['name'], nic_id) + + def instance_filter_exists(self, instance): + """Check nova-instance-instance-xxx exists""" + network_info = netutils.get_network_info(instance) + for (network, mapping) in network_info: + nic_id = mapping['mac'].replace(':', '') + instance_filter_name = self._instance_filter_name(instance, nic_id) + try: + self._conn.nwfilterLookupByName(instance_filter_name) + except libvirt.libvirtError: + name = instance.name + LOG.debug(_('The nwfilter(%(instance_filter_name)s) for' + '%(name)s is not found.') % locals()) + return False + return True + + +class IptablesFirewallDriver(FirewallDriver): + def __init__(self, execute=None, **kwargs): + from nova.network import linux_net + self.iptables = linux_net.iptables_manager + self.instances = {} + self.nwfilter = NWFilterFirewall(kwargs['get_connection']) + + self.iptables.ipv4['filter'].add_chain('sg-fallback') + self.iptables.ipv4['filter'].add_rule('sg-fallback', '-j DROP') + self.iptables.ipv6['filter'].add_chain('sg-fallback') + self.iptables.ipv6['filter'].add_rule('sg-fallback', '-j DROP') + + def setup_basic_filtering(self, instance, network_info=None): + """Use NWFilter from libvirt for this.""" + if not network_info: + network_info = netutils.get_network_info(instance) + return self.nwfilter.setup_basic_filtering(instance, network_info) + + def apply_instance_filter(self, instance): + """No-op. Everything is done in prepare_instance_filter""" + pass + + def unfilter_instance(self, instance): + if self.instances.pop(instance['id'], None): + self.remove_filters_for_instance(instance) + self.iptables.apply() + else: + LOG.info(_('Attempted to unfilter instance %s which is not ' + 'filtered'), instance['id']) + + def prepare_instance_filter(self, instance, network_info=None): + if not network_info: + network_info = netutils.get_network_info(instance) + self.instances[instance['id']] = instance + self.add_filters_for_instance(instance, network_info) + self.iptables.apply() + + def _create_filter(self, ips, chain_name): + return ['-d %s -j $%s' % (ip, chain_name) for ip in ips] + + def _filters_for_instance(self, chain_name, network_info): + ips_v4 = [ip['ip'] for (_n, mapping) in network_info + for ip in mapping['ips']] + ipv4_rules = self._create_filter(ips_v4, chain_name) + + ipv6_rules = [] + if FLAGS.use_ipv6: + ips_v6 = [ip['ip'] for (_n, mapping) in network_info + for ip in mapping['ip6s']] + ipv6_rules = self._create_filter(ips_v6, chain_name) + + return ipv4_rules, ipv6_rules + + def _add_filters(self, chain_name, ipv4_rules, ipv6_rules): + for rule in ipv4_rules: + self.iptables.ipv4['filter'].add_rule(chain_name, rule) + + if FLAGS.use_ipv6: + for rule in ipv6_rules: + self.iptables.ipv6['filter'].add_rule(chain_name, rule) + + def add_filters_for_instance(self, instance, network_info=None): + chain_name = self._instance_chain_name(instance) + if FLAGS.use_ipv6: + self.iptables.ipv6['filter'].add_chain(chain_name) + self.iptables.ipv4['filter'].add_chain(chain_name) + ipv4_rules, ipv6_rules = self._filters_for_instance(chain_name, + network_info) + self._add_filters('local', ipv4_rules, ipv6_rules) + ipv4_rules, ipv6_rules = self.instance_rules(instance, network_info) + self._add_filters(chain_name, ipv4_rules, ipv6_rules) + + def remove_filters_for_instance(self, instance): + chain_name = self._instance_chain_name(instance) + + self.iptables.ipv4['filter'].remove_chain(chain_name) + if FLAGS.use_ipv6: + self.iptables.ipv6['filter'].remove_chain(chain_name) + + def instance_rules(self, instance, network_info=None): + if not network_info: + network_info = netutils.get_network_info(instance) + ctxt = context.get_admin_context() + + ipv4_rules = [] + ipv6_rules = [] + + # Always drop invalid packets + ipv4_rules += ['-m state --state ' 'INVALID -j DROP'] + ipv6_rules += ['-m state --state ' 'INVALID -j DROP'] + + # Allow established connections + ipv4_rules += ['-m state --state ESTABLISHED,RELATED -j ACCEPT'] + ipv6_rules += ['-m state --state ESTABLISHED,RELATED -j ACCEPT'] + + dhcp_servers = [network['gateway'] for (network, _m) in network_info] + + for dhcp_server in dhcp_servers: + ipv4_rules.append('-s %s -p udp --sport 67 --dport 68 ' + '-j ACCEPT' % (dhcp_server,)) + + #Allow project network traffic + if FLAGS.allow_project_net_traffic: + cidrs = [network['cidr'] for (network, _m) in network_info] + for cidr in cidrs: + ipv4_rules.append('-s %s -j ACCEPT' % (cidr,)) + + # We wrap these in FLAGS.use_ipv6 because they might cause + # a DB lookup. The other ones are just list operations, so + # they're not worth the clutter. + if FLAGS.use_ipv6: + # Allow RA responses + gateways_v6 = [network['gateway_v6'] for (network, _) in + network_info] + for gateway_v6 in gateways_v6: + ipv6_rules.append( + '-s %s/128 -p icmpv6 -j ACCEPT' % (gateway_v6,)) + + #Allow project network traffic + if FLAGS.allow_project_net_traffic: + cidrv6s = [network['cidr_v6'] for (network, _m) + in network_info] + + for cidrv6 in cidrv6s: + ipv6_rules.append('-s %s -j ACCEPT' % (cidrv6,)) + + security_groups = db.security_group_get_by_instance(ctxt, + instance['id']) + + # then, security group chains and rules + for security_group in security_groups: + rules = db.security_group_rule_get_by_security_group(ctxt, + security_group['id']) + + for rule in rules: + logging.info('%r', rule) + + if not rule.cidr: + # Eventually, a mechanism to grant access for security + # groups will turn up here. It'll use ipsets. + continue + + version = netutils.get_ip_version(rule.cidr) + if version == 4: + rules = ipv4_rules + else: + rules = ipv6_rules + + protocol = rule.protocol + if version == 6 and rule.protocol == 'icmp': + protocol = 'icmpv6' + + args = ['-p', protocol, '-s', rule.cidr] + + if rule.protocol in ['udp', 'tcp']: + if rule.from_port == rule.to_port: + args += ['--dport', '%s' % (rule.from_port,)] + else: + args += ['-m', 'multiport', + '--dports', '%s:%s' % (rule.from_port, + rule.to_port)] + elif rule.protocol == 'icmp': + icmp_type = rule.from_port + icmp_code = rule.to_port + + if icmp_type == -1: + icmp_type_arg = None + else: + icmp_type_arg = '%s' % icmp_type + if not icmp_code == -1: + icmp_type_arg += '/%s' % icmp_code + + if icmp_type_arg: + if version == 4: + args += ['-m', 'icmp', '--icmp-type', + icmp_type_arg] + elif version == 6: + args += ['-m', 'icmp6', '--icmpv6-type', + icmp_type_arg] + + args += ['-j ACCEPT'] + rules += [' '.join(args)] + + ipv4_rules += ['-j $sg-fallback'] + ipv6_rules += ['-j $sg-fallback'] + + return ipv4_rules, ipv6_rules + + def instance_filter_exists(self, instance): + """Check nova-instance-instance-xxx exists""" + return self.nwfilter.instance_filter_exists(instance) + + def refresh_security_group_members(self, security_group): + pass + + def refresh_security_group_rules(self, security_group, network_info=None): + self.do_refresh_security_group_rules(security_group, network_info) + self.iptables.apply() + + @utils.synchronized('iptables', external=True) + def do_refresh_security_group_rules(self, + security_group, + network_info=None): + for instance in self.instances.values(): + self.remove_filters_for_instance(instance) + if not network_info: + network_info = netutils.get_network_info(instance) + self.add_filters_for_instance(instance, network_info) + + def _security_group_chain_name(self, security_group_id): + return 'nova-sg-%s' % (security_group_id,) + + def _instance_chain_name(self, instance): + return 'inst-%s' % (instance['id'],) diff --git a/nova/virt/libvirt/netutils.py b/nova/virt/libvirt/netutils.py new file mode 100644 index 000000000..4d596078a --- /dev/null +++ b/nova/virt/libvirt/netutils.py @@ -0,0 +1,97 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# Copyright (c) 2010 Citrix Systems, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +"""Network-releated utilities for supporting libvirt connection code.""" + + +import IPy + +from nova import context +from nova import db +from nova import flags +from nova import ipv6 +from nova import utils + + +FLAGS = flags.FLAGS + + +def get_net_and_mask(cidr): + net = IPy.IP(cidr) + return str(net.net()), str(net.netmask()) + + +def get_net_and_prefixlen(cidr): + net = IPy.IP(cidr) + return str(net.net()), str(net.prefixlen()) + + +def get_ip_version(cidr): + net = IPy.IP(cidr) + return int(net.version()) + + +def get_network_info(instance): + # TODO(adiantum) If we will keep this function + # we should cache network_info + admin_context = context.get_admin_context() + + ip_addresses = db.fixed_ip_get_all_by_instance(admin_context, + instance['id']) + networks = db.network_get_all_by_instance(admin_context, + instance['id']) + flavor = db.instance_type_get_by_id(admin_context, + instance['instance_type_id']) + network_info = [] + + for network in networks: + network_ips = [ip for ip in ip_addresses + if ip['network_id'] == network['id']] + + def ip_dict(ip): + return { + 'ip': ip['address'], + 'netmask': network['netmask'], + 'enabled': '1'} + + def ip6_dict(): + prefix = network['cidr_v6'] + mac = instance['mac_address'] + project_id = instance['project_id'] + return { + 'ip': ipv6.to_global(prefix, mac, project_id), + 'netmask': network['netmask_v6'], + 'enabled': '1'} + + mapping = { + 'label': network['label'], + 'gateway': network['gateway'], + 'broadcast': network['broadcast'], + 'mac': instance['mac_address'], + 'rxtx_cap': flavor['rxtx_cap'], + 'dns': [network['dns']], + 'ips': [ip_dict(ip) for ip in network_ips]} + + if FLAGS.use_ipv6: + mapping['ip6s'] = [ip6_dict()] + mapping['gateway6'] = network['gateway_v6'] + + network_info.append((network, mapping)) + return network_info diff --git a/nova/virt/xenapi/vmops.py b/nova/virt/xenapi/vmops.py index 0074444f8..be6ef48ea 100644 --- a/nova/virt/xenapi/vmops.py +++ b/nova/virt/xenapi/vmops.py @@ -202,6 +202,13 @@ class VMOps(object): for path, contents in instance.injected_files: LOG.debug(_("Injecting file path: '%s'") % path) self.inject_file(instance, path, contents) + + def _set_admin_password(): + admin_password = instance.admin_pass + if admin_password: + LOG.debug(_("Setting admin password")) + self.set_admin_password(instance, admin_password) + # NOTE(armando): Do we really need to do this in virt? # NOTE(tr3buchet): not sure but wherever we do it, we need to call # reset_network afterwards @@ -214,6 +221,7 @@ class VMOps(object): LOG.debug(_('Instance %s: booted'), instance_name) timer.stop() _inject_files() + _set_admin_password() return True except Exception, exc: LOG.warn(exc) @@ -253,7 +261,8 @@ class VMOps(object): instance_name = instance_or_vm.name vm_ref = VMHelper.lookup(self._session, instance_name) if vm_ref is None: - raise exception.InstanceNotFound(instance_id=instance_obj.id) + raise exception.NotFound(_("No opaque_ref could be determined " + "for '%s'.") % instance_or_vm) return vm_ref def _acquire_bootlock(self, vm): @@ -457,6 +466,9 @@ class VMOps(object): # Successful return code from password is '0' if resp_dict['returncode'] != '0': raise RuntimeError(resp_dict['message']) + db.instance_update(context.get_admin_context(), + instance['id'], + dict(admin_pass=new_pass)) return resp_dict['message'] def inject_file(self, instance, path, contents): @@ -1171,13 +1183,13 @@ class SimpleDH(object): shared = self._shared cmd = base_cmd % locals() proc = _runproc(cmd) - proc.stdin.write(text) + proc.stdin.write(text + '\n') proc.stdin.close() proc.wait() err = proc.stderr.read() if err: raise RuntimeError(_('OpenSSL error: %s') % err) - return proc.stdout.read() + return proc.stdout.read().strip('\n') def encrypt(self, text): return self._run_ssl(text, 'enc') |
