summaryrefslogtreecommitdiffstats
path: root/nova
diff options
context:
space:
mode:
authorJustin Santa Barbara <justinsb@justinsb-desktop>2010-10-14 12:59:36 -0700
committerJustin Santa Barbara <justinsb@justinsb-desktop>2010-10-14 12:59:36 -0700
commitd8643f1e15f241db96893d1ea41083a2bee65dbd (patch)
tree12e9e85733306f97b12b99339edbe49ef4031418 /nova
parent759bab6059ef2e4c463a73e12fe85fe4b147eba7 (diff)
parent3363b133a927509432cb42d77abf18d3d5248abf (diff)
downloadnova-d8643f1e15f241db96893d1ea41083a2bee65dbd.tar.gz
nova-d8643f1e15f241db96893d1ea41083a2bee65dbd.tar.xz
nova-d8643f1e15f241db96893d1ea41083a2bee65dbd.zip
Merged with trunk, fixed broken stuff
Diffstat (limited to 'nova')
-rw-r--r--nova/adminclient.py73
-rw-r--r--nova/api/__init__.py81
-rw-r--r--nova/api/cloud.py42
-rw-r--r--nova/api/cloudpipe/__init__.py69
-rw-r--r--nova/api/context.py46
-rw-r--r--nova/api/ec2/__init__.py230
-rw-r--r--nova/api/ec2/admin.py (renamed from nova/endpoint/admin.py)42
-rw-r--r--nova/api/ec2/apirequest.py131
-rw-r--r--nova/api/ec2/cloud.py1027
-rw-r--r--nova/api/ec2/images.py (renamed from nova/endpoint/images.py)23
-rw-r--r--nova/api/ec2/metadatarequesthandler.py73
-rw-r--r--nova/api/openstack/__init__.py190
-rw-r--r--nova/api/openstack/auth.py101
-rw-r--r--nova/api/openstack/backup_schedules.py38
-rw-r--r--nova/api/openstack/base.py (renamed from nova/api/rackspace/base.py)0
-rw-r--r--nova/api/openstack/context.py33
-rw-r--r--nova/api/openstack/faults.py62
-rw-r--r--nova/api/openstack/flavors.py58
-rw-r--r--nova/api/openstack/images.py71
-rw-r--r--nova/api/openstack/notes.txt23
-rw-r--r--nova/api/openstack/ratelimiting/__init__.py122
-rw-r--r--nova/api/openstack/servers.py273
-rw-r--r--nova/api/openstack/sharedipgroups.py (renamed from nova/api/rackspace/sharedipgroups.py)4
-rw-r--r--nova/api/rackspace/__init__.py81
-rw-r--r--nova/api/rackspace/servers.py83
-rw-r--r--nova/auth/dbdriver.py236
-rw-r--r--nova/auth/fakeldap.py42
-rw-r--r--nova/auth/ldapdriver.py150
-rw-r--r--nova/auth/manager.py236
-rw-r--r--nova/auth/rbac.py55
-rw-r--r--nova/auth/signer.py51
-rw-r--r--nova/cloudpipe/api.py59
-rw-r--r--nova/cloudpipe/pipelib.py14
-rw-r--r--nova/compute/instance_types.py14
-rw-r--r--nova/compute/manager.py180
-rw-r--r--nova/compute/model.py314
-rw-r--r--nova/compute/service.py367
-rw-r--r--nova/crypto.py2
-rw-r--r--nova/datastore.py209
-rw-r--r--nova/db/__init__.py (renamed from nova/api/rackspace/images.py)9
-rw-r--r--nova/db/api.py770
-rw-r--r--nova/db/sqlalchemy/__init__.py (renamed from nova/api/rackspace/flavors.py)10
-rw-r--r--nova/db/sqlalchemy/api.py1680
-rw-r--r--nova/db/sqlalchemy/models.py513
-rw-r--r--nova/db/sqlalchemy/session.py43
-rwxr-xr-xnova/endpoint/api.py344
-rw-r--r--nova/endpoint/cloud.py745
-rw-r--r--nova/exception.py15
-rw-r--r--nova/fakerabbit.py15
-rw-r--r--nova/flags.py46
-rw-r--r--nova/image/__init__.py (renamed from nova/endpoint/__init__.py)0
-rw-r--r--nova/image/service.py285
-rw-r--r--nova/manager.py52
-rw-r--r--nova/network/linux_net.py326
-rw-r--r--nova/network/manager.py428
-rw-r--r--nova/network/model.py634
-rw-r--r--nova/network/service.py257
-rw-r--r--nova/network/vpn.py126
-rw-r--r--nova/objectstore/__init__.py2
-rw-r--r--nova/objectstore/handler.py160
-rw-r--r--nova/objectstore/image.py20
-rw-r--r--nova/process.py99
-rw-r--r--nova/quota.py92
-rw-r--r--nova/rpc.py71
-rw-r--r--nova/scheduler/__init__.py25
-rw-r--r--nova/scheduler/chance.py (renamed from nova/network/exception.py)34
-rw-r--r--nova/scheduler/driver.py59
-rw-r--r--nova/scheduler/manager.py66
-rw-r--r--nova/scheduler/simple.py90
-rw-r--r--nova/server.py12
-rw-r--r--nova/service.py154
-rw-r--r--nova/test.py70
-rw-r--r--nova/tests/access_unittest.py115
-rw-r--r--nova/tests/api/__init__.py83
-rw-r--r--nova/tests/api/fakes.py8
-rw-r--r--nova/tests/api/openstack/__init__.py108
-rw-r--r--nova/tests/api/openstack/fakes.py210
-rw-r--r--nova/tests/api/openstack/test_auth.py110
-rw-r--r--nova/tests/api/openstack/test_faults.py40
-rw-r--r--nova/tests/api/openstack/test_flavors.py48
-rw-r--r--nova/tests/api/openstack/test_images.py141
-rw-r--r--nova/tests/api/openstack/test_ratelimiting.py237
-rw-r--r--nova/tests/api/openstack/test_servers.py249
-rw-r--r--nova/tests/api/openstack/test_sharedipgroups.py39
-rw-r--r--nova/tests/api/test_wsgi.py147
-rw-r--r--nova/tests/api_unittest.py340
-rw-r--r--nova/tests/auth_unittest.py447
-rw-r--r--nova/tests/bundle/1mb.manifest.xml2
-rw-r--r--nova/tests/bundle/1mb.no_kernel_or_ramdisk.manifest.xml1
-rw-r--r--nova/tests/cloud_unittest.py180
-rw-r--r--nova/tests/compute_unittest.py154
-rw-r--r--nova/tests/fake_flags.py15
-rw-r--r--nova/tests/model_unittest.py292
-rw-r--r--nova/tests/network_unittest.py367
-rw-r--r--nova/tests/objectstore_unittest.py39
-rw-r--r--nova/tests/quota_unittest.py152
-rw-r--r--nova/tests/real_flags.py1
-rw-r--r--nova/tests/rpc_unittest.py17
-rw-r--r--nova/tests/scheduler_unittest.py242
-rw-r--r--nova/tests/service_unittest.py190
-rw-r--r--nova/tests/storage_unittest.py115
-rw-r--r--nova/tests/virt_unittest.py231
-rw-r--r--nova/tests/volume_unittest.py188
-rw-r--r--nova/twistd.py16
-rw-r--r--nova/utils.py73
-rw-r--r--nova/virt/connection.py6
-rw-r--r--nova/virt/fake.py38
-rw-r--r--nova/virt/images.py1
-rw-r--r--nova/virt/interfaces.template1
-rw-r--r--nova/virt/libvirt.qemu.xml.template5
-rw-r--r--nova/virt/libvirt.uml.xml.template27
-rw-r--r--nova/virt/libvirt.xen.xml.template35
-rw-r--r--nova/virt/libvirt_conn.py492
-rw-r--r--nova/virt/xenapi.py241
-rw-r--r--nova/volume/driver.py136
-rw-r--r--nova/volume/manager.py132
-rw-r--r--nova/volume/service.py322
-rw-r--r--nova/wsgi.py135
118 files changed, 12690 insertions, 5635 deletions
diff --git a/nova/adminclient.py b/nova/adminclient.py
index 0ca32b1e5..fc9fcfde0 100644
--- a/nova/adminclient.py
+++ b/nova/adminclient.py
@@ -20,11 +20,17 @@ Nova User API client library.
"""
import base64
-
import boto
+import httplib
from boto.ec2.regioninfo import RegionInfo
+DEFAULT_CLC_URL='http://127.0.0.1:8773'
+DEFAULT_REGION='nova'
+DEFAULT_ACCESS_KEY='admin'
+DEFAULT_SECRET_KEY='admin'
+
+
class UserInfo(object):
"""
Information about a Nova user, as parsed through SAX
@@ -68,13 +74,13 @@ class UserRole(object):
def __init__(self, connection=None):
self.connection = connection
self.role = None
-
+
def __repr__(self):
return 'UserRole:%s' % self.role
def startElement(self, name, attrs, connection):
return None
-
+
def endElement(self, name, value, connection):
if name == 'role':
self.role = value
@@ -128,20 +134,20 @@ class ProjectMember(object):
def __init__(self, connection=None):
self.connection = connection
self.memberId = None
-
+
def __repr__(self):
return 'ProjectMember:%s' % self.memberId
def startElement(self, name, attrs, connection):
return None
-
+
def endElement(self, name, value, connection):
if name == 'member':
self.memberId = value
else:
setattr(self, name, str(value))
-
+
class HostInfo(object):
"""
Information about a Nova Host, as parsed through SAX:
@@ -171,35 +177,56 @@ class HostInfo(object):
class NovaAdminClient(object):
- def __init__(self, clc_ip='127.0.0.1', region='nova', access_key='admin',
- secret_key='admin', **kwargs):
- self.clc_ip = clc_ip
+ def __init__(self, clc_url=DEFAULT_CLC_URL, region=DEFAULT_REGION,
+ access_key=DEFAULT_ACCESS_KEY, secret_key=DEFAULT_SECRET_KEY,
+ **kwargs):
+ parts = self.split_clc_url(clc_url)
+
+ self.clc_url = clc_url
self.region = region
self.access = access_key
self.secret = secret_key
self.apiconn = boto.connect_ec2(aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
- is_secure=False,
- region=RegionInfo(None, region, clc_ip),
- port=8773,
+ is_secure=parts['is_secure'],
+ region=RegionInfo(None,
+ region,
+ parts['ip']),
+ port=parts['port'],
path='/services/Admin',
**kwargs)
self.apiconn.APIVersion = 'nova'
- def connection_for(self, username, project, **kwargs):
+ def connection_for(self, username, project, clc_url=None, region=None,
+ **kwargs):
"""
Returns a boto ec2 connection for the given username.
"""
+ if not clc_url:
+ clc_url = self.clc_url
+ if not region:
+ region = self.region
+ parts = self.split_clc_url(clc_url)
user = self.get_user(username)
access_key = '%s:%s' % (user.accesskey, project)
- return boto.connect_ec2(
- aws_access_key_id=access_key,
- aws_secret_access_key=user.secretkey,
- is_secure=False,
- region=RegionInfo(None, self.region, self.clc_ip),
- port=8773,
- path='/services/Cloud',
- **kwargs)
+ return boto.connect_ec2(aws_access_key_id=access_key,
+ aws_secret_access_key=user.secretkey,
+ is_secure=parts['is_secure'],
+ region=RegionInfo(None,
+ self.region,
+ parts['ip']),
+ port=parts['port'],
+ path='/services/Cloud',
+ **kwargs)
+
+ def split_clc_url(self, clc_url):
+ """
+ Splits a cloud controller endpoint url.
+ """
+ parts = httplib.urlsplit(clc_url)
+ is_secure = parts.scheme == 'https'
+ ip, port = parts.netloc.split(':')
+ return {'ip': ip, 'port': int(port), 'is_secure': is_secure}
def get_users(self):
""" grabs the list of all users """
@@ -289,7 +316,7 @@ class NovaAdminClient(object):
if project.projectname != None:
return project
-
+
def create_project(self, projectname, manager_user, description=None,
member_users=None):
"""
@@ -322,7 +349,7 @@ class NovaAdminClient(object):
Adds a user to a project.
"""
return self.modify_project_member(user, project, operation='add')
-
+
def remove_project_member(self, user, project):
"""
Removes a user from a project.
diff --git a/nova/api/__init__.py b/nova/api/__init__.py
index a6bb93348..8ec7094d7 100644
--- a/nova/api/__init__.py
+++ b/nova/api/__init__.py
@@ -21,18 +21,91 @@ Root WSGI middleware for all API controllers.
"""
import routes
+import webob.dec
+from nova import flags
from nova import wsgi
+from nova.api import cloudpipe
from nova.api import ec2
-from nova.api import rackspace
+from nova.api import openstack
+from nova.api.ec2 import metadatarequesthandler
+
+
+flags.DEFINE_string('osapi_subdomain', 'api',
+ 'subdomain running the OpenStack API')
+flags.DEFINE_string('ec2api_subdomain', 'ec2',
+ 'subdomain running the EC2 API')
+flags.DEFINE_string('FAKE_subdomain', None,
+ 'set to api or ec2 to fake the subdomain of the host for testing')
+FLAGS = flags.FLAGS
class API(wsgi.Router):
"""Routes top-level requests to the appropriate controller."""
def __init__(self):
+ osapidomain = {'sub_domain': [FLAGS.osapi_subdomain]}
+ ec2domain = {'sub_domain': [FLAGS.ec2api_subdomain]}
+ # If someone wants to pretend they're hitting the OSAPI subdomain
+ # on their local box, they can set FAKE_subdomain to 'api', which
+ # removes subdomain restrictions from the OpenStack API routes below.
+ if FLAGS.FAKE_subdomain == 'api':
+ osapidomain = {}
+ elif FLAGS.FAKE_subdomain == 'ec2':
+ ec2domain = {}
mapper = routes.Mapper()
- mapper.connect(None, "/v1.0/{path_info:.*}",
- controller=rackspace.API())
- mapper.connect(None, "/ec2/{path_info:.*}", controller=ec2.API())
+ mapper.sub_domains = True
+ mapper.connect("/", controller=self.osapi_versions,
+ conditions=osapidomain)
+ mapper.connect("/v1.0/{path_info:.*}", controller=openstack.API(),
+ conditions=osapidomain)
+
+ mapper.connect("/", controller=self.ec2api_versions,
+ conditions=ec2domain)
+ mapper.connect("/services/{path_info:.*}", controller=ec2.API(),
+ conditions=ec2domain)
+ mapper.connect("/cloudpipe/{path_info:.*}", controller=cloudpipe.API())
+ mrh = metadatarequesthandler.MetadataRequestHandler()
+ for s in ['/latest',
+ '/2009-04-04',
+ '/2008-09-01',
+ '/2008-02-01',
+ '/2007-12-15',
+ '/2007-10-10',
+ '/2007-08-29',
+ '/2007-03-01',
+ '/2007-01-19',
+ '/1.0']:
+ mapper.connect('%s/{path_info:.*}' % s, controller=mrh,
+ conditions=ec2domain)
super(API, self).__init__(mapper)
+
+ @webob.dec.wsgify
+ def osapi_versions(self, req):
+ """Respond to a request for all OpenStack API versions."""
+ response = {
+ "versions": [
+ dict(status="CURRENT", id="v1.0")]}
+ metadata = {
+ "application/xml": {
+ "attributes": dict(version=["status", "id"])}}
+ return wsgi.Serializer(req.environ, metadata).to_content_type(response)
+
+ @webob.dec.wsgify
+ def ec2api_versions(self, req):
+ """Respond to a request for all EC2 versions."""
+ # available api versions
+ versions = [
+ '1.0',
+ '2007-01-19',
+ '2007-03-01',
+ '2007-08-29',
+ '2007-10-10',
+ '2007-12-15',
+ '2008-02-01',
+ '2008-09-01',
+ '2009-04-04',
+ ]
+ return ''.join('%s\n' % v for v in versions)
+
+
diff --git a/nova/api/cloud.py b/nova/api/cloud.py
new file mode 100644
index 000000000..57e94a17a
--- /dev/null
+++ b/nova/api/cloud.py
@@ -0,0 +1,42 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Methods for API calls to control instances via AMQP.
+"""
+
+
+from nova import db
+from nova import flags
+from nova import rpc
+
+FLAGS = flags.FLAGS
+
+
+def reboot(instance_id, context=None):
+ """Reboot the given instance.
+
+ #TODO(gundlach) not actually sure what context is used for by ec2 here
+ -- I think we can just remove it and use None all the time.
+ """
+ instance_ref = db.instance_get_by_internal_id(None, instance_id)
+ host = instance_ref['host']
+ rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host),
+ {"method": "reboot_instance",
+ "args": {"context": None,
+ "instance_id": instance_ref['id']}})
diff --git a/nova/api/cloudpipe/__init__.py b/nova/api/cloudpipe/__init__.py
new file mode 100644
index 000000000..6d40990a8
--- /dev/null
+++ b/nova/api/cloudpipe/__init__.py
@@ -0,0 +1,69 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+REST API Request Handlers for CloudPipe
+"""
+
+import logging
+import urllib
+import webob
+import webob.dec
+import webob.exc
+
+from nova import crypto
+from nova import wsgi
+from nova.auth import manager
+from nova.api.ec2 import cloud
+
+
+_log = logging.getLogger("api")
+_log.setLevel(logging.DEBUG)
+
+
+class API(wsgi.Application):
+
+ def __init__(self):
+ self.controller = cloud.CloudController()
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ if req.method == 'POST':
+ return self.sign_csr(req)
+ _log.debug("Cloudpipe path is %s" % req.path_info)
+ if req.path_info.endswith("/getca/"):
+ return self.send_root_ca(req)
+ return webob.exc.HTTPNotFound()
+
+ def get_project_id_from_ip(self, ip):
+ # TODO(eday): This was removed with the ORM branch, fix!
+ instance = self.controller.get_instance_by_ip(ip)
+ return instance['project_id']
+
+ def send_root_ca(self, req):
+ _log.debug("Getting root ca")
+ project_id = self.get_project_id_from_ip(req.remote_addr)
+ res = webob.Response()
+ res.headers["Content-Type"] = "text/plain"
+ res.body = crypto.fetch_ca(project_id)
+ return res
+
+ def sign_csr(self, req):
+ project_id = self.get_project_id_from_ip(req.remote_addr)
+ cert = self.str_params['cert']
+ return crypto.sign_csr(urllib.unquote(cert), project_id)
diff --git a/nova/api/context.py b/nova/api/context.py
new file mode 100644
index 000000000..b66cfe468
--- /dev/null
+++ b/nova/api/context.py
@@ -0,0 +1,46 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+APIRequestContext
+"""
+
+import random
+
+
+class APIRequestContext(object):
+ def __init__(self, user, project):
+ self.user = user
+ self.project = project
+ self.request_id = ''.join(
+ [random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-')
+ for x in xrange(20)]
+ )
+ if user:
+ self.is_admin = user.is_admin()
+ else:
+ self.is_admin = False
+ self.read_deleted = False
+
+
+def get_admin_context(user=None, read_deleted=False):
+ context_ref = APIRequestContext(user=user, project=None)
+ context_ref.is_admin = True
+ context_ref.read_deleted = read_deleted
+ return context_ref
+
diff --git a/nova/api/ec2/__init__.py b/nova/api/ec2/__init__.py
index 6eec0abf7..6e771f064 100644
--- a/nova/api/ec2/__init__.py
+++ b/nova/api/ec2/__init__.py
@@ -1,6 +1,7 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
-# Copyright 2010 OpenStack LLC.
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
@@ -15,28 +16,227 @@
# License for the specific language governing permissions and limitations
# under the License.
-"""
-WSGI middleware for EC2 API controllers.
-"""
+"""Starting point for routing EC2 requests"""
+import logging
import routes
+import webob
import webob.dec
+import webob.exc
+from nova import exception
+from nova import flags
from nova import wsgi
+from nova.api import context
+from nova.api.ec2 import apirequest
+from nova.api.ec2 import admin
+from nova.api.ec2 import cloud
+from nova.auth import manager
-class API(wsgi.Router):
- """Routes EC2 requests to the appropriate controller."""
+FLAGS = flags.FLAGS
+_log = logging.getLogger("api")
+_log.setLevel(logging.DEBUG)
+
+
+class API(wsgi.Middleware):
+
+ """Routing for all EC2 API requests."""
def __init__(self):
- mapper = routes.Mapper()
- mapper.connect(None, "{all:.*}", controller=self.dummy)
- super(API, self).__init__(mapper)
+ self.application = Authenticate(Router(Authorizer(Executor())))
+
+
+class Authenticate(wsgi.Middleware):
+
+ """Authenticate an EC2 request and add 'ec2.context' to WSGI environ."""
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ # Read request signature and access id.
+ try:
+ signature = req.params['Signature']
+ access = req.params['AWSAccessKeyId']
+ except:
+ raise webob.exc.HTTPBadRequest()
+
+ # Make a copy of args for authentication and signature verification.
+ auth_params = dict(req.params)
+ auth_params.pop('Signature') # not part of authentication args
+
+ # Authenticate the request.
+ try:
+ (user, project) = manager.AuthManager().authenticate(
+ access,
+ signature,
+ auth_params,
+ req.method,
+ req.host,
+ req.path)
+ except exception.Error, ex:
+ logging.debug("Authentication Failure: %s" % ex)
+ raise webob.exc.HTTPForbidden()
+
+ # Authenticated!
+ req.environ['ec2.context'] = context.APIRequestContext(user, project)
+ return self.application
+
+
+class Router(wsgi.Middleware):
+
+ """Add ec2.'controller', .'action', and .'action_args' to WSGI environ."""
+
+ def __init__(self, application):
+ super(Router, self).__init__(application)
+ self.map = routes.Mapper()
+ self.map.connect("/{controller_name}/")
+ self.controllers = dict(Cloud=cloud.CloudController(),
+ Admin=admin.AdminController())
- @staticmethod
@webob.dec.wsgify
- def dummy(req):
- """Temporary dummy controller."""
- msg = "dummy response -- please hook up __init__() to cloud.py instead"
- return repr({'dummy': msg,
- 'kwargs': repr(req.environ['wsgiorg.routing_args'][1])})
+ def __call__(self, req):
+ # Obtain the appropriate controller and action for this request.
+ try:
+ match = self.map.match(req.path_info)
+ controller_name = match['controller_name']
+ controller = self.controllers[controller_name]
+ except:
+ raise webob.exc.HTTPNotFound()
+ non_args = ['Action', 'Signature', 'AWSAccessKeyId', 'SignatureMethod',
+ 'SignatureVersion', 'Version', 'Timestamp']
+ args = dict(req.params)
+ try:
+ action = req.params['Action'] # raise KeyError if omitted
+ for non_arg in non_args:
+ args.pop(non_arg) # remove, but raise KeyError if omitted
+ except:
+ raise webob.exc.HTTPBadRequest()
+
+ _log.debug('action: %s' % action)
+ for key, value in args.items():
+ _log.debug('arg: %s\t\tval: %s' % (key, value))
+
+ # Success!
+ req.environ['ec2.controller'] = controller
+ req.environ['ec2.action'] = action
+ req.environ['ec2.action_args'] = args
+ return self.application
+
+
+class Authorizer(wsgi.Middleware):
+
+ """Authorize an EC2 API request.
+
+ Return a 401 if ec2.controller and ec2.action in WSGI environ may not be
+ executed in ec2.context.
+ """
+
+ def __init__(self, application):
+ super(Authorizer, self).__init__(application)
+ self.action_roles = {
+ 'CloudController': {
+ 'DescribeAvailabilityzones': ['all'],
+ 'DescribeRegions': ['all'],
+ 'DescribeSnapshots': ['all'],
+ 'DescribeKeyPairs': ['all'],
+ 'CreateKeyPair': ['all'],
+ 'DeleteKeyPair': ['all'],
+ 'DescribeSecurityGroups': ['all'],
+ 'AuthorizeSecurityGroupIngress': ['netadmin'],
+ 'RevokeSecurityGroupIngress': ['netadmin'],
+ 'CreateSecurityGroup': ['netadmin'],
+ 'DeleteSecurityGroup': ['netadmin'],
+ 'GetConsoleOutput': ['projectmanager', 'sysadmin'],
+ 'DescribeVolumes': ['projectmanager', 'sysadmin'],
+ 'CreateVolume': ['projectmanager', 'sysadmin'],
+ 'AttachVolume': ['projectmanager', 'sysadmin'],
+ 'DetachVolume': ['projectmanager', 'sysadmin'],
+ 'DescribeInstances': ['all'],
+ 'DescribeAddresses': ['all'],
+ 'AllocateAddress': ['netadmin'],
+ 'ReleaseAddress': ['netadmin'],
+ 'AssociateAddress': ['netadmin'],
+ 'DisassociateAddress': ['netadmin'],
+ 'RunInstances': ['projectmanager', 'sysadmin'],
+ 'TerminateInstances': ['projectmanager', 'sysadmin'],
+ 'RebootInstances': ['projectmanager', 'sysadmin'],
+ 'UpdateInstance': ['projectmanager', 'sysadmin'],
+ 'DeleteVolume': ['projectmanager', 'sysadmin'],
+ 'DescribeImages': ['all'],
+ 'DeregisterImage': ['projectmanager', 'sysadmin'],
+ 'RegisterImage': ['projectmanager', 'sysadmin'],
+ 'DescribeImageAttribute': ['all'],
+ 'ModifyImageAttribute': ['projectmanager', 'sysadmin'],
+ 'UpdateImage': ['projectmanager', 'sysadmin'],
+ },
+ 'AdminController': {
+ # All actions have the same permission: ['none'] (the default)
+ # superusers will be allowed to run them
+ # all others will get HTTPUnauthorized.
+ },
+ }
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ context = req.environ['ec2.context']
+ controller_name = req.environ['ec2.controller'].__class__.__name__
+ action = req.environ['ec2.action']
+ allowed_roles = self.action_roles[controller_name].get(action, ['none'])
+ if self._matches_any_role(context, allowed_roles):
+ return self.application
+ else:
+ raise webob.exc.HTTPUnauthorized()
+
+ def _matches_any_role(self, context, roles):
+ """Return True if any role in roles is allowed in context."""
+ if context.user.is_superuser():
+ return True
+ if 'all' in roles:
+ return True
+ if 'none' in roles:
+ return False
+ return any(context.project.has_role(context.user.id, role)
+ for role in roles)
+
+
+class Executor(wsgi.Application):
+
+ """Execute an EC2 API request.
+
+ Executes 'ec2.action' upon 'ec2.controller', passing 'ec2.context' and
+ 'ec2.action_args' (all variables in WSGI environ.) Returns an XML
+ response, or a 400 upon failure.
+ """
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ context = req.environ['ec2.context']
+ controller = req.environ['ec2.controller']
+ action = req.environ['ec2.action']
+ args = req.environ['ec2.action_args']
+
+ api_request = apirequest.APIRequest(controller, action)
+ try:
+ result = api_request.send(context, **args)
+ req.headers['Content-Type'] = 'text/xml'
+ return result
+ except exception.ApiError as ex:
+
+ if ex.code:
+ return self._error(req, ex.code, ex.message)
+ else:
+ return self._error(req, type(ex).__name__, ex.message)
+ # TODO(vish): do something more useful with unknown exceptions
+ except Exception as ex:
+ return self._error(req, type(ex).__name__, str(ex))
+
+ def _error(self, req, code, message):
+ resp = webob.Response()
+ resp.status = 400
+ resp.headers['Content-Type'] = 'text/xml'
+ resp.body = ('<?xml version="1.0"?>\n'
+ '<Response><Errors><Error><Code>%s</Code>'
+ '<Message>%s</Message></Error></Errors>'
+ '<RequestID>?</RequestID></Response>') % (code, message)
+ return resp
+
diff --git a/nova/endpoint/admin.py b/nova/api/ec2/admin.py
index d6f622755..36feae451 100644
--- a/nova/endpoint/admin.py
+++ b/nova/api/ec2/admin.py
@@ -22,8 +22,9 @@ Admin API controller, exposed through http via the api worker.
import base64
+from nova import db
+from nova import exception
from nova.auth import manager
-from nova.compute import model
def user_dict(user, base64_file=None):
@@ -57,46 +58,27 @@ def host_dict(host):
return {}
-def admin_only(target):
- """Decorator for admin-only API calls"""
- def wrapper(*args, **kwargs):
- """Internal wrapper method for admin-only API calls"""
- context = args[1]
- if context.user.is_admin():
- return target(*args, **kwargs)
- else:
- return {}
-
- return wrapper
-
-
class AdminController(object):
"""
API Controller for users, hosts, nodes, and workers.
- Trivial admin_only wrapper will be replaced with RBAC,
- allowing project managers to administer project users.
"""
def __str__(self):
return 'AdminController'
- @admin_only
def describe_user(self, _context, name, **_kwargs):
"""Returns user data, including access and secret keys."""
return user_dict(manager.AuthManager().get_user(name))
- @admin_only
def describe_users(self, _context, **_kwargs):
"""Returns all users - should be changed to deal with a list."""
return {'userSet':
[user_dict(u) for u in manager.AuthManager().get_users()] }
- @admin_only
def register_user(self, _context, name, **_kwargs):
"""Creates a new user, and returns generated credentials."""
return user_dict(manager.AuthManager().create_user(name))
- @admin_only
def deregister_user(self, _context, name, **_kwargs):
"""Deletes a single user (NOT undoable.)
Should throw an exception if the user has instances,
@@ -106,13 +88,11 @@ class AdminController(object):
return True
- @admin_only
def describe_roles(self, context, project_roles=True, **kwargs):
"""Returns a list of allowed roles."""
roles = manager.AuthManager().get_roles(project_roles)
return { 'roles': [{'role': r} for r in roles]}
- @admin_only
def describe_user_roles(self, context, user, project=None, **kwargs):
"""Returns a list of roles for the given user.
Omitting project will return any global roles that the user has.
@@ -121,7 +101,6 @@ class AdminController(object):
roles = manager.AuthManager().get_user_roles(user, project=project)
return { 'roles': [{'role': r} for r in roles]}
- @admin_only
def modify_user_role(self, context, user, role, project=None,
operation='add', **kwargs):
"""Add or remove a role for a user and project."""
@@ -134,7 +113,6 @@ class AdminController(object):
return True
- @admin_only
def generate_x509_for_user(self, _context, name, project=None, **kwargs):
"""Generates and returns an x509 certificate for a single user.
Is usually called from a client that will wrap this with
@@ -146,19 +124,16 @@ class AdminController(object):
user = manager.AuthManager().get_user(name)
return user_dict(user, base64.b64encode(project.get_credentials(user)))
- @admin_only
def describe_project(self, context, name, **kwargs):
"""Returns project data, including member ids."""
return project_dict(manager.AuthManager().get_project(name))
- @admin_only
def describe_projects(self, context, user=None, **kwargs):
"""Returns all projects - should be changed to deal with a list."""
return {'projectSet':
[project_dict(u) for u in
manager.AuthManager().get_projects(user=user)]}
- @admin_only
def register_project(self, context, name, manager_user, description=None,
member_users=None, **kwargs):
"""Creates a new project"""
@@ -169,20 +144,17 @@ class AdminController(object):
description=None,
member_users=None))
- @admin_only
def deregister_project(self, context, name):
"""Permanently deletes a project."""
manager.AuthManager().delete_project(name)
return True
- @admin_only
def describe_project_members(self, context, name, **kwargs):
project = manager.AuthManager().get_project(name)
result = {
'members': [{'member': m} for m in project.member_ids]}
return result
-
- @admin_only
+
def modify_project_member(self, context, user, project, operation, **kwargs):
"""Add or remove a user from a project."""
if operation =='add':
@@ -193,7 +165,8 @@ class AdminController(object):
raise exception.ApiError('operation must be add or remove')
return True
- @admin_only
+ # FIXME(vish): these host commands don't work yet, perhaps some of the
+ # required data can be retrieved from service objects?
def describe_hosts(self, _context, **_kwargs):
"""Returns status info for all nodes. Includes:
* Disk Space
@@ -203,9 +176,8 @@ class AdminController(object):
* DHCP servers running
* Iptables / bridges
"""
- return {'hostSet': [host_dict(h) for h in model.Host.all()]}
+ return {'hostSet': [host_dict(h) for h in db.host_get_all()]}
- @admin_only
def describe_host(self, _context, name, **_kwargs):
"""Returns status info for single node."""
- return host_dict(model.Host.lookup(name))
+ return host_dict(db.host_get(name))
diff --git a/nova/api/ec2/apirequest.py b/nova/api/ec2/apirequest.py
new file mode 100644
index 000000000..a87c21fb3
--- /dev/null
+++ b/nova/api/ec2/apirequest.py
@@ -0,0 +1,131 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+APIRequest class
+"""
+
+import logging
+import re
+# TODO(termie): replace minidom with etree
+from xml.dom import minidom
+
+_log = logging.getLogger("api")
+_log.setLevel(logging.DEBUG)
+
+
+_c2u = re.compile('(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))')
+
+
+def _camelcase_to_underscore(str):
+ return _c2u.sub(r'_\1', str).lower().strip('_')
+
+
+def _underscore_to_camelcase(str):
+ return ''.join([x[:1].upper() + x[1:] for x in str.split('_')])
+
+
+def _underscore_to_xmlcase(str):
+ res = _underscore_to_camelcase(str)
+ return res[:1].lower() + res[1:]
+
+
+class APIRequest(object):
+ def __init__(self, controller, action):
+ self.controller = controller
+ self.action = action
+
+ def send(self, context, **kwargs):
+ try:
+ method = getattr(self.controller,
+ _camelcase_to_underscore(self.action))
+ except AttributeError:
+ _error = ('Unsupported API request: controller = %s,'
+ 'action = %s') % (self.controller, self.action)
+ _log.warning(_error)
+ # TODO: Raise custom exception, trap in apiserver,
+ # and reraise as 400 error.
+ raise Exception(_error)
+
+ args = {}
+ for key, value in kwargs.items():
+ parts = key.split(".")
+ key = _camelcase_to_underscore(parts[0])
+ if len(parts) > 1:
+ d = args.get(key, {})
+ d[parts[1]] = value
+ value = d
+ args[key] = value
+
+ for key in args.keys():
+ if isinstance(args[key], dict):
+ if args[key] != {} and args[key].keys()[0].isdigit():
+ s = args[key].items()
+ s.sort()
+ args[key] = [v for k, v in s]
+
+ result = method(context, **args)
+ return self._render_response(result, context.request_id)
+
+ def _render_response(self, response_data, request_id):
+ xml = minidom.Document()
+
+ response_el = xml.createElement(self.action + 'Response')
+ response_el.setAttribute('xmlns',
+ 'http://ec2.amazonaws.com/doc/2009-11-30/')
+ request_id_el = xml.createElement('requestId')
+ request_id_el.appendChild(xml.createTextNode(request_id))
+ response_el.appendChild(request_id_el)
+ if(response_data == True):
+ self._render_dict(xml, response_el, {'return': 'true'})
+ else:
+ self._render_dict(xml, response_el, response_data)
+
+ xml.appendChild(response_el)
+
+ response = xml.toxml()
+ xml.unlink()
+ _log.debug(response)
+ return response
+
+ def _render_dict(self, xml, el, data):
+ try:
+ for key in data.keys():
+ val = data[key]
+ el.appendChild(self._render_data(xml, key, val))
+ except:
+ _log.debug(data)
+ raise
+
+ def _render_data(self, xml, el_name, data):
+ el_name = _underscore_to_xmlcase(el_name)
+ data_el = xml.createElement(el_name)
+
+ if isinstance(data, list):
+ for item in data:
+ data_el.appendChild(self._render_data(xml, 'item', item))
+ elif isinstance(data, dict):
+ self._render_dict(xml, data_el, data)
+ elif hasattr(data, '__dict__'):
+ self._render_dict(xml, data_el, data.__dict__)
+ elif isinstance(data, bool):
+ data_el.appendChild(xml.createTextNode(str(data).lower()))
+ elif data != None:
+ data_el.appendChild(xml.createTextNode(str(data)))
+
+ return data_el
diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py
new file mode 100644
index 000000000..ee45374b2
--- /dev/null
+++ b/nova/api/ec2/cloud.py
@@ -0,0 +1,1027 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Cloud Controller: Implementation of EC2 REST API calls, which are
+dispatched to other nodes via AMQP RPC. State is via distributed
+datastore.
+"""
+
+import base64
+import datetime
+import logging
+import os
+import time
+
+import IPy
+
+from nova import crypto
+from nova import db
+from nova import exception
+from nova import flags
+from nova import quota
+from nova import rpc
+from nova import utils
+from nova.compute.instance_types import INSTANCE_TYPES
+from nova.api import cloud
+from nova.api.ec2 import images
+
+
+FLAGS = flags.FLAGS
+flags.DECLARE('storage_availability_zone', 'nova.volume.manager')
+
+InvalidInputException = exception.InvalidInputException
+
+class QuotaError(exception.ApiError):
+ """Quota Exceeeded"""
+ pass
+
+
+def _gen_key(context, user_id, key_name):
+ """Generate a key
+
+ This is a module level method because it is slow and we need to defer
+ it into a process pool."""
+ # NOTE(vish): generating key pair is slow so check for legal
+ # creation before creating key_pair
+ try:
+ db.key_pair_get(context, user_id, key_name)
+ raise exception.Duplicate("The key_pair %s already exists"
+ % key_name)
+ except exception.NotFound:
+ pass
+ private_key, public_key, fingerprint = crypto.generate_key_pair()
+ key = {}
+ key['user_id'] = user_id
+ key['name'] = key_name
+ key['public_key'] = public_key
+ key['fingerprint'] = fingerprint
+ db.key_pair_create(context, key)
+ return {'private_key': private_key, 'fingerprint': fingerprint}
+
+
+def ec2_id_to_internal_id(ec2_id):
+ """Convert an ec2 ID (i-[base 36 number]) to an internal id (int)"""
+ return int(ec2_id[2:], 36)
+
+
+def internal_id_to_ec2_id(internal_id):
+ """Convert an internal ID (int) to an ec2 ID (i-[base 36 number])"""
+ digits = []
+ while internal_id != 0:
+ internal_id, remainder = divmod(internal_id, 36)
+ digits.append('0123456789abcdefghijklmnopqrstuvwxyz'[remainder])
+ return "i-%s" % ''.join(reversed(digits))
+
+
+class CloudController(object):
+ """ CloudController provides the critical dispatch between
+ inbound API calls through the endpoint and messages
+ sent to the other nodes.
+"""
+ def __init__(self):
+ self.network_manager = utils.import_object(FLAGS.network_manager)
+ self.setup()
+
+ def __str__(self):
+ return 'CloudController'
+
+ def setup(self):
+ """ Ensure the keychains and folders exist. """
+ # FIXME(ja): this should be moved to a nova-manage command,
+ # if not setup throw exceptions instead of running
+ # Create keys folder, if it doesn't exist
+ if not os.path.exists(FLAGS.keys_path):
+ os.makedirs(FLAGS.keys_path)
+ # Gen root CA, if we don't have one
+ root_ca_path = os.path.join(FLAGS.ca_path, FLAGS.ca_file)
+ if not os.path.exists(root_ca_path):
+ start = os.getcwd()
+ os.chdir(FLAGS.ca_path)
+ # TODO(vish): Do this with M2Crypto instead
+ utils.runthis("Generating root CA: %s", "sh genrootca.sh")
+ os.chdir(start)
+
+ def _get_mpi_data(self, project_id):
+ result = {}
+ for instance in db.instance_get_all_by_project(None, project_id):
+ if instance['fixed_ip']:
+ line = '%s slots=%d' % (instance['fixed_ip']['address'],
+ INSTANCE_TYPES[instance['instance_type']]['vcpus'])
+ key = str(instance['key_name'])
+ if key in result:
+ result[key].append(line)
+ else:
+ result[key] = [line]
+ return result
+
+ def _trigger_refresh_security_group(self, security_group):
+ nodes = set([instance['host'] for instance in security_group.instances
+ if instance['host'] is not None])
+ for node in nodes:
+ rpc.call('%s.%s' % (FLAGS.compute_topic, node),
+ { "method": "refresh_security_group",
+ "args": { "context": None,
+ "security_group_id": security_group.id}})
+
+ def get_metadata(self, address):
+ instance_ref = db.fixed_ip_get_instance(None, address)
+ if instance_ref is None:
+ return None
+ mpi = self._get_mpi_data(instance_ref['project_id'])
+ if instance_ref['key_name']:
+ keys = {
+ '0': {
+ '_name': instance_ref['key_name'],
+ 'openssh-key': instance_ref['key_data']
+ }
+ }
+ else:
+ keys = ''
+ hostname = instance_ref['hostname']
+ floating_ip = db.instance_get_floating_address(None,
+ instance_ref['id'])
+ data = {
+ 'user-data': base64.b64decode(instance_ref['user_data']),
+ 'meta-data': {
+ 'ami-id': instance_ref['image_id'],
+ 'ami-launch-index': instance_ref['launch_index'],
+ 'ami-manifest-path': 'FIXME',
+ 'block-device-mapping': { # TODO(vish): replace with real data
+ 'ami': 'sda1',
+ 'ephemeral0': 'sda2',
+ 'root': '/dev/sda1',
+ 'swap': 'sda3'
+ },
+ 'hostname': hostname,
+ 'instance-action': 'none',
+ 'instance-id': internal_id_to_ec2_id(instance_ref['internal_id']),
+ 'instance-type': instance_ref['instance_type'],
+ 'local-hostname': hostname,
+ 'local-ipv4': address,
+ 'kernel-id': instance_ref['kernel_id'],
+ 'placement': {
+ 'availability-zone': 'nova' # TODO(vish): real zone
+ },
+ 'public-hostname': hostname,
+ 'public-ipv4': floating_ip or '',
+ 'public-keys': keys,
+ 'ramdisk-id': instance_ref['ramdisk_id'],
+ 'reservation-id': instance_ref['reservation_id'],
+ 'security-groups': '',
+ 'mpi': mpi
+ }
+ }
+ if False: # TODO(vish): store ancestor ids
+ data['ancestor-ami-ids'] = []
+ if False: # TODO(vish): store product codes
+ data['product-codes'] = []
+ return data
+
+ def describe_availability_zones(self, context, **kwargs):
+ return {'availabilityZoneInfo': [{'zoneName': 'nova',
+ 'zoneState': 'available'}]}
+
+ def describe_regions(self, context, region_name=None, **kwargs):
+ if FLAGS.region_list:
+ regions = []
+ for region in FLAGS.region_list:
+ name, _sep, url = region.partition('=')
+ regions.append({'regionName': name,
+ 'regionEndpoint': url})
+ else:
+ regions = [{'regionName': 'nova',
+ 'regionEndpoint': FLAGS.ec2_url}]
+ if region_name:
+ regions = [r for r in regions if r['regionName'] in region_name]
+ return {'regionInfo': regions }
+
+ def describe_snapshots(self,
+ context,
+ snapshot_id=None,
+ owner=None,
+ restorable_by=None,
+ **kwargs):
+ return {'snapshotSet': [{'snapshotId': 'fixme',
+ 'volumeId': 'fixme',
+ 'status': 'fixme',
+ 'startTime': 'fixme',
+ 'progress': 'fixme',
+ 'ownerId': 'fixme',
+ 'volumeSize': 0,
+ 'description': 'fixme'}]}
+
+ def describe_key_pairs(self, context, key_name=None, **kwargs):
+ key_pairs = db.key_pair_get_all_by_user(context, context.user.id)
+ if not key_name is None:
+ key_pairs = [x for x in key_pairs if x['name'] in key_name]
+
+ result = []
+ for key_pair in key_pairs:
+ # filter out the vpn keys
+ suffix = FLAGS.vpn_key_suffix
+ if context.user.is_admin() or not key_pair['name'].endswith(suffix):
+ result.append({
+ 'keyName': key_pair['name'],
+ 'keyFingerprint': key_pair['fingerprint'],
+ })
+
+ return {'keypairsSet': result}
+
+ def create_key_pair(self, context, key_name, **kwargs):
+ data = _gen_key(None, context.user.id, key_name)
+ return {'keyName': key_name,
+ 'keyFingerprint': data['fingerprint'],
+ 'keyMaterial': data['private_key']}
+ # TODO(vish): when context is no longer an object, pass it here
+
+ def delete_key_pair(self, context, key_name, **kwargs):
+ try:
+ db.key_pair_destroy(context, context.user.id, key_name)
+ except exception.NotFound:
+ # aws returns true even if the key doesn't exist
+ pass
+ return True
+
+ def describe_security_groups(self, context, group_name=None, **kwargs):
+ self._ensure_default_security_group(context)
+ if context.user.is_admin():
+ groups = db.security_group_get_all(context)
+ else:
+ groups = db.security_group_get_by_project(context,
+ context.project.id)
+ groups = [self._format_security_group(context, g) for g in groups]
+ if not group_name is None:
+ groups = [g for g in groups if g.name in group_name]
+
+ return {'securityGroupInfo': groups }
+
+ def _format_security_group(self, context, group):
+ g = {}
+ g['groupDescription'] = group.description
+ g['groupName'] = group.name
+ g['ownerId'] = group.project_id
+ g['ipPermissions'] = []
+ for rule in group.rules:
+ r = {}
+ r['ipProtocol'] = rule.protocol
+ r['fromPort'] = rule.from_port
+ r['toPort'] = rule.to_port
+ r['groups'] = []
+ r['ipRanges'] = []
+ if rule.group_id:
+ source_group = db.security_group_get(context, rule.group_id)
+ r['groups'] += [{'groupName': source_group.name,
+ 'userId': source_group.project_id}]
+ else:
+ r['ipRanges'] += [{'cidrIp': rule.cidr}]
+ g['ipPermissions'] += [r]
+ return g
+
+
+ def _authorize_revoke_rule_args_to_dict(self, context,
+ to_port=None, from_port=None,
+ ip_protocol=None, cidr_ip=None,
+ user_id=None,
+ source_security_group_name=None,
+ source_security_group_owner_id=None):
+
+ values = {}
+
+ if source_security_group_name:
+ source_project_id = self._get_source_project_id(context,
+ source_security_group_owner_id)
+
+ source_security_group = \
+ db.security_group_get_by_name(context,
+ source_project_id,
+ source_security_group_name)
+ values['group_id'] = source_security_group['id']
+ elif cidr_ip:
+ # If this fails, it throws an exception. This is what we want.
+ IPy.IP(cidr_ip)
+ values['cidr'] = cidr_ip
+ else:
+ values['cidr'] = '0.0.0.0/0'
+
+ if ip_protocol and from_port and to_port:
+ from_port = int(from_port)
+ to_port = int(to_port)
+ ip_protocol = str(ip_protocol)
+
+ if ip_protocol.upper() not in ['TCP','UDP','ICMP']:
+ raise InvalidInputException('%s is not a valid ipProtocol' %
+ (ip_protocol,))
+ if ((min(from_port, to_port) < -1) or
+ (max(from_port, to_port) > 65535)):
+ raise InvalidInputException('Invalid port range')
+
+ values['protocol'] = ip_protocol
+ values['from_port'] = from_port
+ values['to_port'] = to_port
+ else:
+ # If cidr based filtering, protocol and ports are mandatory
+ if 'cidr' in values:
+ return None
+
+ return values
+
+
+ def _security_group_rule_exists(self, security_group, values):
+ """Indicates whether the specified rule values are already
+ defined in the given security group.
+ """
+ for rule in security_group.rules:
+ if 'group_id' in values:
+ if rule['group_id'] == values['group_id']:
+ return True
+ else:
+ is_duplicate = True
+ for key in ('cidr', 'from_port', 'to_port', 'protocol'):
+ if rule[key] != values[key]:
+ is_duplicate = False
+ break
+ if is_duplicate:
+ return True
+ return False
+
+
+ def revoke_security_group_ingress(self, context, group_name, **kwargs):
+ self._ensure_default_security_group(context)
+ security_group = db.security_group_get_by_name(context,
+ context.project.id,
+ group_name)
+
+ criteria = self._authorize_revoke_rule_args_to_dict(context, **kwargs)
+ if criteria == None:
+ raise exception.ApiError("No rule for the specified parameters.")
+
+ for rule in security_group.rules:
+ match = True
+ for (k,v) in criteria.iteritems():
+ if getattr(rule, k, False) != v:
+ match = False
+ if match:
+ db.security_group_rule_destroy(context, rule['id'])
+ self._trigger_refresh_security_group(security_group)
+ return True
+ raise exception.ApiError("No rule for the specified parameters.")
+
+ # TODO(soren): This has only been tested with Boto as the client.
+ # Unfortunately, it seems Boto is using an old API
+ # for these operations, so support for newer API versions
+ # is sketchy.
+ def authorize_security_group_ingress(self, context, group_name, **kwargs):
+ self._ensure_default_security_group(context)
+ security_group = db.security_group_get_by_name(context,
+ context.project.id,
+ group_name)
+
+ values = self._authorize_revoke_rule_args_to_dict(context, **kwargs)
+ values['parent_group_id'] = security_group.id
+
+ if self._security_group_rule_exists(security_group, values):
+ raise exception.ApiError('This rule already exists in group %s' %
+ group_name)
+
+ security_group_rule = db.security_group_rule_create(context, values)
+
+ self._trigger_refresh_security_group(security_group)
+
+ return True
+
+
+ def _get_source_project_id(self, context, source_security_group_owner_id):
+ if source_security_group_owner_id:
+ # Parse user:project for source group.
+ source_parts = source_security_group_owner_id.split(':')
+
+ # If no project name specified, assume it's same as user name.
+ # Since we're looking up by project name, the user name is not
+ # used here. It's only read for EC2 API compatibility.
+ if len(source_parts) == 2:
+ source_project_id = source_parts[1]
+ else:
+ source_project_id = source_parts[0]
+ else:
+ source_project_id = context.project.id
+
+ return source_project_id
+
+
+ def create_security_group(self, context, group_name, group_description):
+ self._ensure_default_security_group(context)
+ if db.security_group_exists(context, context.project.id, group_name):
+ raise exception.ApiError('group %s already exists' % group_name)
+
+ group = {'user_id' : context.user.id,
+ 'project_id': context.project.id,
+ 'name': group_name,
+ 'description': group_description}
+ group_ref = db.security_group_create(context, group)
+
+ return {'securityGroupSet': [self._format_security_group(context,
+ group_ref)]}
+
+
+ def delete_security_group(self, context, group_name, **kwargs):
+ security_group = db.security_group_get_by_name(context,
+ context.project.id,
+ group_name)
+ db.security_group_destroy(context, security_group.id)
+ return True
+
+
+ def get_console_output(self, context, instance_id, **kwargs):
+ # instance_id is passed in as a list of instances
+ ec2_id = instance_id[0]
+ internal_id = ec2_id_to_internal_id(ec2_id)
+ instance_ref = db.instance_get_by_internal_id(context, internal_id)
+ output = rpc.call('%s.%s' % (FLAGS.compute_topic,
+ instance_ref['host']),
+ { "method" : "get_console_output",
+ "args" : { "context": None,
+ "instance_id": instance_ref['id']}})
+
+ now = datetime.datetime.utcnow()
+ return { "InstanceId" : ec2_id,
+ "Timestamp" : now,
+ "output" : base64.b64encode(output) }
+
+ def describe_volumes(self, context, **kwargs):
+ if context.user.is_admin():
+ volumes = db.volume_get_all(context)
+ else:
+ volumes = db.volume_get_all_by_project(context, context.project.id)
+
+ volumes = [self._format_volume(context, v) for v in volumes]
+
+ return {'volumeSet': volumes}
+
+ def _format_volume(self, context, volume):
+ v = {}
+ v['volumeId'] = volume['ec2_id']
+ v['status'] = volume['status']
+ v['size'] = volume['size']
+ v['availabilityZone'] = volume['availability_zone']
+ v['createTime'] = volume['created_at']
+ if context.user.is_admin():
+ v['status'] = '%s (%s, %s, %s, %s)' % (
+ volume['status'],
+ volume['user_id'],
+ volume['host'],
+ volume['instance_id'],
+ volume['mountpoint'])
+ if volume['attach_status'] == 'attached':
+ v['attachmentSet'] = [{'attachTime': volume['attach_time'],
+ 'deleteOnTermination': False,
+ 'device': volume['mountpoint'],
+ 'instanceId': volume['instance_id'],
+ 'status': 'attached',
+ 'volume_id': volume['ec2_id']}]
+ else:
+ v['attachmentSet'] = [{}]
+
+ v['display_name'] = volume['display_name']
+ v['display_description'] = volume['display_description']
+ return v
+
+ def create_volume(self, context, size, **kwargs):
+ # check quota
+ if quota.allowed_volumes(context, 1, size) < 1:
+ logging.warn("Quota exceeeded for %s, tried to create %sG volume",
+ context.project.id, size)
+ raise QuotaError("Volume quota exceeded. You cannot "
+ "create a volume of size %s" %
+ size)
+ vol = {}
+ vol['size'] = size
+ vol['user_id'] = context.user.id
+ vol['project_id'] = context.project.id
+ vol['availability_zone'] = FLAGS.storage_availability_zone
+ vol['status'] = "creating"
+ vol['attach_status'] = "detached"
+ vol['display_name'] = kwargs.get('display_name')
+ vol['display_description'] = kwargs.get('display_description')
+ volume_ref = db.volume_create(context, vol)
+
+ rpc.cast(FLAGS.scheduler_topic,
+ {"method": "create_volume",
+ "args": {"context": None,
+ "topic": FLAGS.volume_topic,
+ "volume_id": volume_ref['id']}})
+
+ return {'volumeSet': [self._format_volume(context, volume_ref)]}
+
+
+ def attach_volume(self, context, volume_id, instance_id, device, **kwargs):
+ volume_ref = db.volume_get_by_ec2_id(context, volume_id)
+ # TODO(vish): abstract status checking?
+ if volume_ref['status'] != "available":
+ raise exception.ApiError("Volume status must be available")
+ if volume_ref['attach_status'] == "attached":
+ raise exception.ApiError("Volume is already attached")
+ internal_id = ec2_id_to_internal_id(instance_id)
+ instance_ref = db.instance_get_by_internal_id(context, internal_id)
+ host = instance_ref['host']
+ rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host),
+ {"method": "attach_volume",
+ "args": {"context": None,
+ "volume_id": volume_ref['id'],
+ "instance_id": instance_ref['id'],
+ "mountpoint": device}})
+ return {'attachTime': volume_ref['attach_time'],
+ 'device': volume_ref['mountpoint'],
+ 'instanceId': instance_ref['id'],
+ 'requestId': context.request_id,
+ 'status': volume_ref['attach_status'],
+ 'volumeId': volume_ref['id']}
+
+ def detach_volume(self, context, volume_id, **kwargs):
+ volume_ref = db.volume_get_by_ec2_id(context, volume_id)
+ instance_ref = db.volume_get_instance(context, volume_ref['id'])
+ if not instance_ref:
+ raise exception.ApiError("Volume isn't attached to anything!")
+ # TODO(vish): abstract status checking?
+ if volume_ref['status'] == "available":
+ raise exception.ApiError("Volume is already detached")
+ try:
+ host = instance_ref['host']
+ rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host),
+ {"method": "detach_volume",
+ "args": {"context": None,
+ "instance_id": instance_ref['id'],
+ "volume_id": volume_ref['id']}})
+ except exception.NotFound:
+ # If the instance doesn't exist anymore,
+ # then we need to call detach blind
+ db.volume_detached(context)
+ internal_id = instance_ref['internal_id']
+ ec2_id = internal_id_to_ec2_id(internal_id)
+ return {'attachTime': volume_ref['attach_time'],
+ 'device': volume_ref['mountpoint'],
+ 'instanceId': internal_id,
+ 'requestId': context.request_id,
+ 'status': volume_ref['attach_status'],
+ 'volumeId': volume_ref['id']}
+
+ def _convert_to_set(self, lst, label):
+ if lst == None or lst == []:
+ return None
+ if not isinstance(lst, list):
+ lst = [lst]
+ return [{label: x} for x in lst]
+
+ def update_volume(self, context, volume_id, **kwargs):
+ updatable_fields = ['display_name', 'display_description']
+ changes = {}
+ for field in updatable_fields:
+ if field in kwargs:
+ changes[field] = kwargs[field]
+ if changes:
+ db.volume_update(context, volume_id, kwargs)
+ return True
+
+ def describe_instances(self, context, **kwargs):
+ return self._format_describe_instances(context)
+
+ def _format_describe_instances(self, context):
+ return { 'reservationSet': self._format_instances(context) }
+
+ def _format_run_instances(self, context, reservation_id):
+ i = self._format_instances(context, reservation_id)
+ assert len(i) == 1
+ return i[0]
+
+ def _format_instances(self, context, reservation_id=None):
+ reservations = {}
+ if reservation_id:
+ instances = db.instance_get_all_by_reservation(context,
+ reservation_id)
+ else:
+ if context.user.is_admin():
+ instances = db.instance_get_all(context)
+ else:
+ instances = db.instance_get_all_by_project(context,
+ context.project.id)
+ for instance in instances:
+ if not context.user.is_admin():
+ if instance['image_id'] == FLAGS.vpn_image_id:
+ continue
+ i = {}
+ internal_id = instance['internal_id']
+ ec2_id = internal_id_to_ec2_id(internal_id)
+ i['instanceId'] = ec2_id
+ i['imageId'] = instance['image_id']
+ i['instanceState'] = {
+ 'code': instance['state'],
+ 'name': instance['state_description']
+ }
+ fixed_addr = None
+ floating_addr = None
+ if instance['fixed_ip']:
+ fixed_addr = instance['fixed_ip']['address']
+ if instance['fixed_ip']['floating_ips']:
+ fixed = instance['fixed_ip']
+ floating_addr = fixed['floating_ips'][0]['address']
+ i['privateDnsName'] = fixed_addr
+ i['publicDnsName'] = floating_addr
+ i['dnsName'] = i['publicDnsName'] or i['privateDnsName']
+ i['keyName'] = instance['key_name']
+ if context.user.is_admin():
+ i['keyName'] = '%s (%s, %s)' % (i['keyName'],
+ instance['project_id'],
+ instance['host'])
+ i['productCodesSet'] = self._convert_to_set([], 'product_codes')
+ i['instanceType'] = instance['instance_type']
+ i['launchTime'] = instance['created_at']
+ i['amiLaunchIndex'] = instance['launch_index']
+ i['displayName'] = instance['display_name']
+ i['displayDescription'] = instance['display_description']
+ if not reservations.has_key(instance['reservation_id']):
+ r = {}
+ r['reservationId'] = instance['reservation_id']
+ r['ownerId'] = instance['project_id']
+ r['groupSet'] = self._convert_to_set([], 'groups')
+ r['instancesSet'] = []
+ reservations[instance['reservation_id']] = r
+ reservations[instance['reservation_id']]['instancesSet'].append(i)
+
+ return list(reservations.values())
+
+ def describe_addresses(self, context, **kwargs):
+ return self.format_addresses(context)
+
+ def format_addresses(self, context):
+ addresses = []
+ if context.user.is_admin():
+ iterator = db.floating_ip_get_all(context)
+ else:
+ iterator = db.floating_ip_get_all_by_project(context,
+ context.project.id)
+ for floating_ip_ref in iterator:
+ address = floating_ip_ref['address']
+ instance_id = None
+ if (floating_ip_ref['fixed_ip']
+ and floating_ip_ref['fixed_ip']['instance']):
+ internal_id = floating_ip_ref['fixed_ip']['instance']['ec2_id']
+ ec2_id = internal_id_to_ec2_id(internal_id)
+ address_rv = {'public_ip': address,
+ 'instance_id': ec2_id}
+ if context.user.is_admin():
+ details = "%s (%s)" % (address_rv['instance_id'],
+ floating_ip_ref['project_id'])
+ address_rv['instance_id'] = details
+ addresses.append(address_rv)
+ return {'addressesSet': addresses}
+
+ def allocate_address(self, context, **kwargs):
+ # check quota
+ if quota.allowed_floating_ips(context, 1) < 1:
+ logging.warn("Quota exceeeded for %s, tried to allocate address",
+ context.project.id)
+ raise QuotaError("Address quota exceeded. You cannot "
+ "allocate any more addresses")
+ network_topic = self._get_network_topic(context)
+ public_ip = rpc.call(network_topic,
+ {"method": "allocate_floating_ip",
+ "args": {"context": None,
+ "project_id": context.project.id}})
+ return {'addressSet': [{'publicIp': public_ip}]}
+
+ def release_address(self, context, public_ip, **kwargs):
+ # NOTE(vish): Should we make sure this works?
+ floating_ip_ref = db.floating_ip_get_by_address(context, public_ip)
+ network_topic = self._get_network_topic(context)
+ rpc.cast(network_topic,
+ {"method": "deallocate_floating_ip",
+ "args": {"context": None,
+ "floating_address": floating_ip_ref['address']}})
+ return {'releaseResponse': ["Address released."]}
+
+ def associate_address(self, context, ec2_id, public_ip, **kwargs):
+ internal_id = ec2_id_to_internal_id(ec2_id)
+ instance_ref = db.instance_get_by_internal_id(context, internal_id)
+ fixed_address = db.instance_get_fixed_address(context,
+ instance_ref['id'])
+ floating_ip_ref = db.floating_ip_get_by_address(context, public_ip)
+ network_topic = self._get_network_topic(context)
+ rpc.cast(network_topic,
+ {"method": "associate_floating_ip",
+ "args": {"context": None,
+ "floating_address": floating_ip_ref['address'],
+ "fixed_address": fixed_address}})
+ return {'associateResponse': ["Address associated."]}
+
+ def disassociate_address(self, context, public_ip, **kwargs):
+ floating_ip_ref = db.floating_ip_get_by_address(context, public_ip)
+ network_topic = self._get_network_topic(context)
+ rpc.cast(network_topic,
+ {"method": "disassociate_floating_ip",
+ "args": {"context": None,
+ "floating_address": floating_ip_ref['address']}})
+ return {'disassociateResponse': ["Address disassociated."]}
+
+ def _get_network_topic(self, context):
+ """Retrieves the network host for a project"""
+ network_ref = self.network_manager.get_network(context)
+ host = network_ref['host']
+ if not host:
+ host = rpc.call(FLAGS.network_topic,
+ {"method": "set_network_host",
+ "args": {"context": None,
+ "network_id": network_ref['id']}})
+ return db.queue_get_for(context, FLAGS.network_topic, host)
+
+ def _ensure_default_security_group(self, context):
+ try:
+ db.security_group_get_by_name(context,
+ context.project.id,
+ 'default')
+ except exception.NotFound:
+ values = { 'name' : 'default',
+ 'description' : 'default',
+ 'user_id' : context.user.id,
+ 'project_id' : context.project.id }
+ group = db.security_group_create(context, values)
+
+ def run_instances(self, context, **kwargs):
+ instance_type = kwargs.get('instance_type', 'm1.small')
+ if instance_type not in INSTANCE_TYPES:
+ raise exception.ApiError("Unknown instance type: %s",
+ instance_type)
+ # check quota
+ max_instances = int(kwargs.get('max_count', 1))
+ min_instances = int(kwargs.get('min_count', max_instances))
+ num_instances = quota.allowed_instances(context,
+ max_instances,
+ instance_type)
+ if num_instances < min_instances:
+ logging.warn("Quota exceeeded for %s, tried to run %s instances",
+ context.project.id, min_instances)
+ raise QuotaError("Instance quota exceeded. You can only "
+ "run %s more instances of this type." %
+ num_instances, "InstanceLimitExceeded")
+ # make sure user can access the image
+ # vpn image is private so it doesn't show up on lists
+ vpn = kwargs['image_id'] == FLAGS.vpn_image_id
+
+ if not vpn:
+ image = images.get(context, kwargs['image_id'])
+
+ # FIXME(ja): if image is vpn, this breaks
+ # get defaults from imagestore
+ image_id = image['imageId']
+ kernel_id = image.get('kernelId', FLAGS.default_kernel)
+ ramdisk_id = image.get('ramdiskId', FLAGS.default_ramdisk)
+
+ # API parameters overrides of defaults
+ kernel_id = kwargs.get('kernel_id', kernel_id)
+ ramdisk_id = kwargs.get('ramdisk_id', ramdisk_id)
+
+ if kernel_id == str(FLAGS.null_kernel):
+ kernel_id = None
+ ramdisk_id = None
+
+ # make sure we have access to kernel and ramdisk
+ if kernel_id:
+ images.get(context, kernel_id)
+ if ramdisk_id:
+ images.get(context, ramdisk_id)
+
+ logging.debug("Going to run %s instances...", num_instances)
+ launch_time = time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime())
+ key_data = None
+ if kwargs.has_key('key_name'):
+ key_pair_ref = db.key_pair_get(context,
+ context.user.id,
+ kwargs['key_name'])
+ key_data = key_pair_ref['public_key']
+
+ security_group_arg = kwargs.get('security_group', ["default"])
+ if not type(security_group_arg) is list:
+ security_group_arg = [security_group_arg]
+
+ security_groups = []
+ self._ensure_default_security_group(context)
+ for security_group_name in security_group_arg:
+ group = db.security_group_get_by_name(context,
+ context.project.id,
+ security_group_name)
+ security_groups.append(group['id'])
+
+ reservation_id = utils.generate_uid('r')
+ base_options = {}
+ base_options['state_description'] = 'scheduling'
+ base_options['image_id'] = image_id
+ base_options['kernel_id'] = kernel_id or ''
+ base_options['ramdisk_id'] = ramdisk_id or ''
+ base_options['reservation_id'] = reservation_id
+ base_options['key_data'] = key_data
+ base_options['key_name'] = kwargs.get('key_name', None)
+ base_options['user_id'] = context.user.id
+ base_options['project_id'] = context.project.id
+ base_options['user_data'] = kwargs.get('user_data', '')
+
+ base_options['display_name'] = kwargs.get('display_name')
+ base_options['display_description'] = kwargs.get('display_description')
+
+ type_data = INSTANCE_TYPES[instance_type]
+ base_options['instance_type'] = instance_type
+ base_options['memory_mb'] = type_data['memory_mb']
+ base_options['vcpus'] = type_data['vcpus']
+ base_options['local_gb'] = type_data['local_gb']
+
+ for num in range(num_instances):
+ instance_ref = db.instance_create(context, base_options)
+ inst_id = instance_ref['id']
+
+ for security_group_id in security_groups:
+ db.instance_add_security_group(context, inst_id,
+ security_group_id)
+
+ inst = {}
+ inst['mac_address'] = utils.generate_mac()
+ inst['launch_index'] = num
+ internal_id = instance_ref['internal_id']
+ ec2_id = internal_id_to_ec2_id(internal_id)
+ inst['hostname'] = ec2_id
+ db.instance_update(context, inst_id, inst)
+ # TODO(vish): This probably should be done in the scheduler
+ # or in compute as a call. The network should be
+ # allocated after the host is assigned and setup
+ # can happen at the same time.
+ address = self.network_manager.allocate_fixed_ip(context,
+ inst_id,
+ vpn)
+ network_topic = self._get_network_topic(context)
+ rpc.call(network_topic,
+ {"method": "setup_fixed_ip",
+ "args": {"context": None,
+ "address": address}})
+
+ rpc.cast(FLAGS.scheduler_topic,
+ {"method": "run_instance",
+ "args": {"context": None,
+ "topic": FLAGS.compute_topic,
+ "instance_id": inst_id}})
+ logging.debug("Casting to scheduler for %s/%s's instance %s" %
+ (context.project.name, context.user.name, inst_id))
+ return self._format_run_instances(context, reservation_id)
+
+
+ def terminate_instances(self, context, instance_id, **kwargs):
+ """Terminate each instance in instance_id, which is a list of ec2 ids.
+
+ instance_id is a kwarg so its name cannot be modified.
+ """
+ ec2_id_list = instance_id
+ logging.debug("Going to start terminating instances")
+ for id_str in ec2_id_list:
+ internal_id = ec2_id_to_internal_id(id_str)
+ logging.debug("Going to try and terminate %s" % id_str)
+ try:
+ instance_ref = db.instance_get_by_internal_id(context,
+ internal_id)
+ except exception.NotFound:
+ logging.warning("Instance %s was not found during terminate",
+ id_str)
+ continue
+
+ if (instance_ref['state_description'] == 'terminating'):
+ logging.warning("Instance %s is already being terminated",
+ id_str)
+ continue
+ now = datetime.datetime.utcnow()
+ db.instance_update(context,
+ instance_ref['id'],
+ {'state_description': 'terminating',
+ 'state': 0,
+ 'terminated_at': now})
+ # FIXME(ja): where should network deallocate occur?
+ address = db.instance_get_floating_address(context,
+ instance_ref['id'])
+ if address:
+ logging.debug("Disassociating address %s" % address)
+ # NOTE(vish): Right now we don't really care if the ip is
+ # disassociated. We may need to worry about
+ # checking this later. Perhaps in the scheduler?
+ network_topic = self._get_network_topic(context)
+ rpc.cast(network_topic,
+ {"method": "disassociate_floating_ip",
+ "args": {"context": None,
+ "floating_address": address}})
+
+ address = db.instance_get_fixed_address(context,
+ instance_ref['id'])
+ if address:
+ logging.debug("Deallocating address %s" % address)
+ # NOTE(vish): Currently, nothing needs to be done on the
+ # network node until release. If this changes,
+ # we will need to cast here.
+ self.network_manager.deallocate_fixed_ip(context, address)
+
+ host = instance_ref['host']
+ if host:
+ rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host),
+ {"method": "terminate_instance",
+ "args": {"context": None,
+ "instance_id": instance_ref['id']}})
+ else:
+ db.instance_destroy(context, instance_ref['id'])
+ return True
+
+ def reboot_instances(self, context, instance_id, **kwargs):
+ """instance_id is a list of instance ids"""
+ for id_str in instance_id:
+ cloud.reboot(id_str, context=context)
+ return True
+
+ def update_instance(self, context, ec2_id, **kwargs):
+ updatable_fields = ['display_name', 'display_description']
+ changes = {}
+ for field in updatable_fields:
+ if field in kwargs:
+ changes[field] = kwargs[field]
+ if changes:
+ db_context = {}
+ internal_id = ec2_id_to_internal_id(ec2_id)
+ inst = db.instance_get_by_internal_id(db_context, internal_id)
+ db.instance_update(db_context, inst['id'], kwargs)
+ return True
+
+ def delete_volume(self, context, volume_id, **kwargs):
+ # TODO: return error if not authorized
+ volume_ref = db.volume_get_by_ec2_id(context, volume_id)
+ if volume_ref['status'] != "available":
+ raise exception.ApiError("Volume status must be available")
+ now = datetime.datetime.utcnow()
+ db.volume_update(context, volume_ref['id'], {'status': 'deleting',
+ 'terminated_at': now})
+ host = volume_ref['host']
+ rpc.cast(db.queue_get_for(context, FLAGS.volume_topic, host),
+ {"method": "delete_volume",
+ "args": {"context": None,
+ "volume_id": volume_ref['id']}})
+ return True
+
+ def describe_images(self, context, image_id=None, **kwargs):
+ # The objectstore does its own authorization for describe
+ imageSet = images.list(context, image_id)
+ return {'imagesSet': imageSet}
+
+ def deregister_image(self, context, image_id, **kwargs):
+ # FIXME: should the objectstore be doing these authorization checks?
+ images.deregister(context, image_id)
+ return {'imageId': image_id}
+
+ def register_image(self, context, image_location=None, **kwargs):
+ # FIXME: should the objectstore be doing these authorization checks?
+ if image_location is None and kwargs.has_key('name'):
+ image_location = kwargs['name']
+ image_id = images.register(context, image_location)
+ logging.debug("Registered %s as %s" % (image_location, image_id))
+ return {'imageId': image_id}
+
+ def describe_image_attribute(self, context, image_id, attribute, **kwargs):
+ if attribute != 'launchPermission':
+ raise exception.ApiError('attribute not supported: %s' % attribute)
+ try:
+ image = images.list(context, image_id)[0]
+ except IndexError:
+ raise exception.ApiError('invalid id: %s' % image_id)
+ result = {'image_id': image_id, 'launchPermission': []}
+ if image['isPublic']:
+ result['launchPermission'].append({'group': 'all'})
+ return result
+
+ def modify_image_attribute(self, context, image_id, attribute, operation_type, **kwargs):
+ # TODO(devcamcar): Support users and groups other than 'all'.
+ if attribute != 'launchPermission':
+ raise exception.ApiError('attribute not supported: %s' % attribute)
+ if not 'user_group' in kwargs:
+ raise exception.ApiError('user or group not specified')
+ if len(kwargs['user_group']) != 1 and kwargs['user_group'][0] != 'all':
+ raise exception.ApiError('only group "all" is supported')
+ if not operation_type in ['add', 'remove']:
+ raise exception.ApiError('operation_type must be add or remove')
+ return images.modify(context, image_id, operation_type)
+
+ def update_image(self, context, image_id, **kwargs):
+ result = images.update(context, image_id, dict(kwargs))
+ return result
diff --git a/nova/endpoint/images.py b/nova/api/ec2/images.py
index 2a88d66af..f0a43dad6 100644
--- a/nova/endpoint/images.py
+++ b/nova/api/ec2/images.py
@@ -18,7 +18,7 @@
"""
Proxy AMI-related calls from the cloud controller, to the running
-objectstore daemon.
+objectstore service.
"""
import json
@@ -26,6 +26,7 @@ import urllib
import boto.s3.connection
+from nova import exception
from nova import flags
from nova import utils
from nova.auth import manager
@@ -42,6 +43,14 @@ def modify(context, image_id, operation):
return True
+def update(context, image_id, attributes):
+ """update an image's attributes / info.json"""
+ attributes.update({"image_id": image_id})
+ conn(context).make_request(
+ method='POST',
+ bucket='_images',
+ query_args=qs(attributes))
+ return True
def register(context, image_location):
""" rpc call to register a new image based from a manifest """
@@ -55,12 +64,14 @@ def register(context, image_location):
return image_id
-
def list(context, filter_list=[]):
""" return a list of all images that a user can see
optionally filtered by a list of image_id """
+ if FLAGS.connection_type == 'fake':
+ return [{ 'imageId' : 'bar'}]
+
# FIXME: send along the list of only_images to check for
response = conn(context).make_request(
method='GET',
@@ -71,6 +82,14 @@ def list(context, filter_list=[]):
return [i for i in result if i['imageId'] in filter_list]
return result
+def get(context, image_id):
+ """return a image object if the context has permissions"""
+ result = list(context, [image_id])
+ if not result:
+ raise exception.NotFound('Image %s could not be found' % image_id)
+ image = result[0]
+ return image
+
def deregister(context, image_id):
""" unregister an image """
diff --git a/nova/api/ec2/metadatarequesthandler.py b/nova/api/ec2/metadatarequesthandler.py
new file mode 100644
index 000000000..08a8040ca
--- /dev/null
+++ b/nova/api/ec2/metadatarequesthandler.py
@@ -0,0 +1,73 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Metadata request handler."""
+
+import logging
+
+import webob.dec
+import webob.exc
+
+from nova.api.ec2 import cloud
+
+
+class MetadataRequestHandler(object):
+
+ """Serve metadata from the EC2 API."""
+
+ def print_data(self, data):
+ if isinstance(data, dict):
+ output = ''
+ for key in data:
+ if key == '_name':
+ continue
+ output += key
+ if isinstance(data[key], dict):
+ if '_name' in data[key]:
+ output += '=' + str(data[key]['_name'])
+ else:
+ output += '/'
+ output += '\n'
+ return output[:-1] # cut off last \n
+ elif isinstance(data, list):
+ return '\n'.join(data)
+ else:
+ return str(data)
+
+ def lookup(self, path, data):
+ items = path.split('/')
+ for item in items:
+ if item:
+ if not isinstance(data, dict):
+ return data
+ if not item in data:
+ return None
+ data = data[item]
+ return data
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ cc = cloud.CloudController()
+ meta_data = cc.get_metadata(req.remote_addr)
+ if meta_data is None:
+ logging.error('Failed to get metadata for ip: %s' % req.remote_addr)
+ raise webob.exc.HTTPNotFound()
+ data = self.lookup(req.path_info, meta_data)
+ if data is None:
+ raise webob.exc.HTTPNotFound()
+ return self.print_data(data)
diff --git a/nova/api/openstack/__init__.py b/nova/api/openstack/__init__.py
new file mode 100644
index 000000000..5e81ba2bd
--- /dev/null
+++ b/nova/api/openstack/__init__.py
@@ -0,0 +1,190 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+WSGI middleware for OpenStack API controllers.
+"""
+
+import json
+import time
+
+import routes
+import webob.dec
+import webob.exc
+import webob
+
+from nova import flags
+from nova import utils
+from nova import wsgi
+from nova.api.openstack import faults
+from nova.api.openstack import backup_schedules
+from nova.api.openstack import flavors
+from nova.api.openstack import images
+from nova.api.openstack import ratelimiting
+from nova.api.openstack import servers
+from nova.api.openstack import sharedipgroups
+from nova.auth import manager
+
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string('nova_api_auth',
+ 'nova.api.openstack.auth.BasicApiAuthManager',
+ 'The auth mechanism to use for the OpenStack API implemenation')
+
+class API(wsgi.Middleware):
+ """WSGI entry point for all OpenStack API requests."""
+
+ def __init__(self):
+ app = AuthMiddleware(RateLimitingMiddleware(APIRouter()))
+ super(API, self).__init__(app)
+
+class AuthMiddleware(wsgi.Middleware):
+ """Authorize the openstack API request or return an HTTP Forbidden."""
+
+ def __init__(self, application):
+ self.auth_driver = utils.import_class(FLAGS.nova_api_auth)()
+ super(AuthMiddleware, self).__init__(application)
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ if not req.headers.has_key("X-Auth-Token"):
+ return self.auth_driver.authenticate(req)
+
+ user = self.auth_driver.authorize_token(req.headers["X-Auth-Token"])
+
+ if not user:
+ return faults.Fault(webob.exc.HTTPUnauthorized())
+
+ if not req.environ.has_key('nova.context'):
+ req.environ['nova.context'] = {}
+ req.environ['nova.context']['user'] = user
+ return self.application
+
+class RateLimitingMiddleware(wsgi.Middleware):
+ """Rate limit incoming requests according to the OpenStack rate limits."""
+
+ def __init__(self, application, service_host=None):
+ """Create a rate limiting middleware that wraps the given application.
+
+ By default, rate counters are stored in memory. If service_host is
+ specified, the middleware instead relies on the ratelimiting.WSGIApp
+ at the given host+port to keep rate counters.
+ """
+ super(RateLimitingMiddleware, self).__init__(application)
+ if not service_host:
+ #TODO(gundlach): These limits were based on limitations of Cloud
+ #Servers. We should revisit them in Nova.
+ self.limiter = ratelimiting.Limiter(limits={
+ 'DELETE': (100, ratelimiting.PER_MINUTE),
+ 'PUT': (10, ratelimiting.PER_MINUTE),
+ 'POST': (10, ratelimiting.PER_MINUTE),
+ 'POST servers': (50, ratelimiting.PER_DAY),
+ 'GET changes-since': (3, ratelimiting.PER_MINUTE),
+ })
+ else:
+ self.limiter = ratelimiting.WSGIAppProxy(service_host)
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ """Rate limit the request.
+
+ If the request should be rate limited, return a 413 status with a
+ Retry-After header giving the time when the request would succeed.
+ """
+ username = req.headers['X-Auth-User']
+ action_name = self.get_action_name(req)
+ if not action_name: # not rate limited
+ return self.application
+ delay = self.get_delay(action_name, username)
+ if delay:
+ # TODO(gundlach): Get the retry-after format correct.
+ exc = webob.exc.HTTPRequestEntityTooLarge(
+ explanation='Too many requests.',
+ headers={'Retry-After': time.time() + delay})
+ raise faults.Fault(exc)
+ return self.application
+
+ def get_delay(self, action_name, username):
+ """Return the delay for the given action and username, or None if
+ the action would not be rate limited.
+ """
+ if action_name == 'POST servers':
+ # "POST servers" is a POST, so it counts against "POST" too.
+ # Attempt the "POST" first, lest we are rate limited by "POST" but
+ # use up a precious "POST servers" call.
+ delay = self.limiter.perform("POST", username=username)
+ if delay:
+ return delay
+ return self.limiter.perform(action_name, username=username)
+
+ def get_action_name(self, req):
+ """Return the action name for this request."""
+ if req.method == 'GET' and 'changes-since' in req.GET:
+ return 'GET changes-since'
+ if req.method == 'POST' and req.path_info.startswith('/servers'):
+ return 'POST servers'
+ if req.method in ['PUT', 'POST', 'DELETE']:
+ return req.method
+ return None
+
+
+class APIRouter(wsgi.Router):
+ """
+ Routes requests on the OpenStack API to the appropriate controller
+ and method.
+ """
+
+ def __init__(self):
+ mapper = routes.Mapper()
+ mapper.resource("server", "servers", controller=servers.Controller(),
+ collection={ 'detail': 'GET'},
+ member={'action':'POST'})
+
+ mapper.resource("backup_schedule", "backup_schedules",
+ controller=backup_schedules.Controller(),
+ parent_resource=dict(member_name='server',
+ collection_name = 'servers'))
+
+ mapper.resource("image", "images", controller=images.Controller(),
+ collection={'detail': 'GET'})
+ mapper.resource("flavor", "flavors", controller=flavors.Controller(),
+ collection={'detail': 'GET'})
+ mapper.resource("sharedipgroup", "sharedipgroups",
+ controller=sharedipgroups.Controller())
+
+ super(APIRouter, self).__init__(mapper)
+
+
+def limited(items, req):
+ """Return a slice of items according to requested offset and limit.
+
+ items - a sliceable
+ req - wobob.Request possibly containing offset and limit GET variables.
+ offset is where to start in the list, and limit is the maximum number
+ of items to return.
+
+ If limit is not specified, 0, or > 1000, defaults to 1000.
+ """
+ offset = int(req.GET.get('offset', 0))
+ limit = int(req.GET.get('limit', 0))
+ if not limit:
+ limit = 1000
+ limit = min(1000, limit)
+ range_end = offset + limit
+ return items[offset:range_end]
+
diff --git a/nova/api/openstack/auth.py b/nova/api/openstack/auth.py
new file mode 100644
index 000000000..7aba55728
--- /dev/null
+++ b/nova/api/openstack/auth.py
@@ -0,0 +1,101 @@
+import datetime
+import hashlib
+import json
+import time
+
+import webob.exc
+import webob.dec
+
+from nova import auth
+from nova import db
+from nova import flags
+from nova import manager
+from nova import utils
+from nova.api.openstack import faults
+
+FLAGS = flags.FLAGS
+
+class Context(object):
+ pass
+
+class BasicApiAuthManager(object):
+ """ Implements a somewhat rudimentary version of OpenStack Auth"""
+
+ def __init__(self, host=None, db_driver=None):
+ if not host:
+ host = FLAGS.host
+ self.host = host
+ if not db_driver:
+ db_driver = FLAGS.db_driver
+ self.db = utils.import_object(db_driver)
+ self.auth = auth.manager.AuthManager()
+ self.context = Context()
+ super(BasicApiAuthManager, self).__init__()
+
+ def authenticate(self, req):
+ # Unless the request is explicitly made against /<version>/ don't
+ # honor it
+ path_info = req.path_info
+ if len(path_info) > 1:
+ return faults.Fault(webob.exc.HTTPUnauthorized())
+
+ try:
+ username = req.headers['X-Auth-User']
+ key = req.headers['X-Auth-Key']
+ except KeyError:
+ return faults.Fault(webob.exc.HTTPUnauthorized())
+
+ token, user = self._authorize_user(username, key)
+ if user and token:
+ res = webob.Response()
+ res.headers['X-Auth-Token'] = token.token_hash
+ res.headers['X-Server-Management-Url'] = \
+ token.server_management_url
+ res.headers['X-Storage-Url'] = token.storage_url
+ res.headers['X-CDN-Management-Url'] = token.cdn_management_url
+ res.content_type = 'text/plain'
+ res.status = '204'
+ return res
+ else:
+ return faults.Fault(webob.exc.HTTPUnauthorized())
+
+ def authorize_token(self, token_hash):
+ """ retrieves user information from the datastore given a token
+
+ If the token has expired, returns None
+ If the token is not found, returns None
+ Otherwise returns dict(id=(the authorized user's id))
+
+ This method will also remove the token if the timestamp is older than
+ 2 days ago.
+ """
+ token = self.db.auth_get_token(self.context, token_hash)
+ if token:
+ delta = datetime.datetime.now() - token.created_at
+ if delta.days >= 2:
+ self.db.auth_destroy_token(self.context, token)
+ else:
+ #TODO(gundlach): Why not just return dict(id=token.user_id)?
+ user = self.auth.get_user(token.user_id)
+ return {'id': user.id}
+ return None
+
+ def _authorize_user(self, username, key):
+ """ Generates a new token and assigns it to a user """
+ user = self.auth.get_user_from_access_key(key)
+ if user and user.name == username:
+ token_hash = hashlib.sha1('%s%s%f' % (username, key,
+ time.time())).hexdigest()
+ token_dict = {}
+ token_dict['token_hash'] = token_hash
+ token_dict['cdn_management_url'] = ''
+ token_dict['server_management_url'] = self._get_server_mgmt_url()
+ token_dict['storage_url'] = ''
+ token_dict['user_id'] = user.id
+ token = self.db.auth_create_token(self.context, token_dict)
+ return token, user
+ return None, None
+
+ def _get_server_mgmt_url(self):
+ return 'https://%s/v1.0/' % self.host
+
diff --git a/nova/api/openstack/backup_schedules.py b/nova/api/openstack/backup_schedules.py
new file mode 100644
index 000000000..76ad6ef87
--- /dev/null
+++ b/nova/api/openstack/backup_schedules.py
@@ -0,0 +1,38 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import time
+from webob import exc
+
+from nova import wsgi
+from nova.api.openstack import faults
+import nova.image.service
+
+class Controller(wsgi.Controller):
+ def __init__(self):
+ pass
+
+ def index(self, req, server_id):
+ return faults.Fault(exc.HTTPNotFound())
+
+ def create(self, req, server_id):
+ """ No actual update method required, since the existing API allows
+ both create and update through a POST """
+ return faults.Fault(exc.HTTPNotFound())
+
+ def delete(self, req, server_id):
+ return faults.Fault(exc.HTTPNotFound())
diff --git a/nova/api/rackspace/base.py b/nova/api/openstack/base.py
index dd2c6543c..dd2c6543c 100644
--- a/nova/api/rackspace/base.py
+++ b/nova/api/openstack/base.py
diff --git a/nova/api/openstack/context.py b/nova/api/openstack/context.py
new file mode 100644
index 000000000..77394615b
--- /dev/null
+++ b/nova/api/openstack/context.py
@@ -0,0 +1,33 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+APIRequestContext
+"""
+
+import random
+
+class Project(object):
+ def __init__(self, user_id):
+ self.id = user_id
+
+class APIRequestContext(object):
+ """ This is an adapter class to get around all of the assumptions made in
+ the FlatNetworking """
+ def __init__(self, user_id):
+ self.user_id = user_id
+ self.project = Project(user_id)
diff --git a/nova/api/openstack/faults.py b/nova/api/openstack/faults.py
new file mode 100644
index 000000000..32e5c866f
--- /dev/null
+++ b/nova/api/openstack/faults.py
@@ -0,0 +1,62 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+
+import webob.dec
+import webob.exc
+
+from nova import wsgi
+
+
+class Fault(webob.exc.HTTPException):
+
+ """An RS API fault response."""
+
+ _fault_names = {
+ 400: "badRequest",
+ 401: "unauthorized",
+ 403: "resizeNotAllowed",
+ 404: "itemNotFound",
+ 405: "badMethod",
+ 409: "inProgress",
+ 413: "overLimit",
+ 415: "badMediaType",
+ 501: "notImplemented",
+ 503: "serviceUnavailable"}
+
+ def __init__(self, exception):
+ """Create a Fault for the given webob.exc.exception."""
+ self.wrapped_exc = exception
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ """Generate a WSGI response based on the exception passed to ctor."""
+ # Replace the body with fault details.
+ code = self.wrapped_exc.status_int
+ fault_name = self._fault_names.get(code, "cloudServersFault")
+ fault_data = {
+ fault_name: {
+ 'code': code,
+ 'message': self.wrapped_exc.explanation}}
+ if code == 413:
+ retry = self.wrapped_exc.headers['Retry-After']
+ fault_data[fault_name]['retryAfter'] = retry
+ # 'code' is an attribute on the fault tag itself
+ metadata = {'application/xml': {'attributes': {fault_name: 'code'}}}
+ serializer = wsgi.Serializer(req.environ, metadata)
+ self.wrapped_exc.body = serializer.to_content_type(fault_data)
+ return self.wrapped_exc
diff --git a/nova/api/openstack/flavors.py b/nova/api/openstack/flavors.py
new file mode 100644
index 000000000..793984a5d
--- /dev/null
+++ b/nova/api/openstack/flavors.py
@@ -0,0 +1,58 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from webob import exc
+
+from nova.api.openstack import faults
+from nova.compute import instance_types
+from nova import wsgi
+import nova.api.openstack
+
+class Controller(wsgi.Controller):
+ """Flavor controller for the OpenStack API."""
+
+ _serialization_metadata = {
+ 'application/xml': {
+ "attributes": {
+ "flavor": [ "id", "name", "ram", "disk" ]
+ }
+ }
+ }
+
+ def index(self, req):
+ """Return all flavors in brief."""
+ return dict(flavors=[dict(id=flavor['id'], name=flavor['name'])
+ for flavor in self.detail(req)['flavors']])
+
+ def detail(self, req):
+ """Return all flavors in detail."""
+ items = [self.show(req, id)['flavor'] for id in self._all_ids()]
+ items = nova.api.openstack.limited(items, req)
+ return dict(flavors=items)
+
+ def show(self, req, id):
+ """Return data about the given flavor id."""
+ for name, val in instance_types.INSTANCE_TYPES.iteritems():
+ if val['flavorid'] == int(id):
+ item = dict(ram=val['memory_mb'], disk=val['local_gb'],
+ id=val['flavorid'], name=name)
+ return dict(flavor=item)
+ raise faults.Fault(exc.HTTPNotFound())
+
+ def _all_ids(self):
+ """Return the list of all flavorids."""
+ return [i['flavorid'] for i in instance_types.INSTANCE_TYPES.values()]
diff --git a/nova/api/openstack/images.py b/nova/api/openstack/images.py
new file mode 100644
index 000000000..aa438739c
--- /dev/null
+++ b/nova/api/openstack/images.py
@@ -0,0 +1,71 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from webob import exc
+
+from nova import flags
+from nova import utils
+from nova import wsgi
+import nova.api.openstack
+import nova.image.service
+from nova.api.openstack import faults
+
+
+FLAGS = flags.FLAGS
+
+class Controller(wsgi.Controller):
+
+ _serialization_metadata = {
+ 'application/xml': {
+ "attributes": {
+ "image": [ "id", "name", "updated", "created", "status",
+ "serverId", "progress" ]
+ }
+ }
+ }
+
+ def __init__(self):
+ self._service = utils.import_object(FLAGS.image_service)
+
+ def index(self, req):
+ """Return all public images in brief."""
+ return dict(images=[dict(id=img['id'], name=img['name'])
+ for img in self.detail(req)['images']])
+
+ def detail(self, req):
+ """Return all public images in detail."""
+ data = self._service.index()
+ data = nova.api.openstack.limited(data, req)
+ return dict(images=data)
+
+ def show(self, req, id):
+ """Return data about the given image id."""
+ return dict(image=self._service.show(id))
+
+ def delete(self, req, id):
+ # Only public images are supported for now.
+ raise faults.Fault(exc.HTTPNotFound())
+
+ def create(self, req):
+ # Only public images are supported for now, so a request to
+ # make a backup of a server cannot be supproted.
+ raise faults.Fault(exc.HTTPNotFound())
+
+ def update(self, req, id):
+ # Users may not modify public images, and that's all that
+ # we support for now.
+ raise faults.Fault(exc.HTTPNotFound())
diff --git a/nova/api/openstack/notes.txt b/nova/api/openstack/notes.txt
new file mode 100644
index 000000000..2330f1002
--- /dev/null
+++ b/nova/api/openstack/notes.txt
@@ -0,0 +1,23 @@
+We will need:
+
+ImageService
+a service that can do crud on image information. not user-specific. opaque
+image ids.
+
+GlanceImageService(ImageService):
+image ids are URIs.
+
+LocalImageService(ImageService):
+image ids are random strings.
+
+OpenstackAPITranslationStore:
+translates RS server/images/flavor/etc ids into formats required
+by a given ImageService strategy.
+
+api.openstack.images.Controller:
+uses an ImageService strategy behind the scenes to do its fetching; it just
+converts int image id into a strategy-specific image id.
+
+who maintains the mapping from user to [images he owns]? nobody, because
+we have no way of enforcing access to his images, without kryptex which
+won't be in Austin.
diff --git a/nova/api/openstack/ratelimiting/__init__.py b/nova/api/openstack/ratelimiting/__init__.py
new file mode 100644
index 000000000..f843bac0f
--- /dev/null
+++ b/nova/api/openstack/ratelimiting/__init__.py
@@ -0,0 +1,122 @@
+"""Rate limiting of arbitrary actions."""
+
+import httplib
+import time
+import urllib
+import webob.dec
+import webob.exc
+
+
+# Convenience constants for the limits dictionary passed to Limiter().
+PER_SECOND = 1
+PER_MINUTE = 60
+PER_HOUR = 60 * 60
+PER_DAY = 60 * 60 * 24
+
+class Limiter(object):
+
+ """Class providing rate limiting of arbitrary actions."""
+
+ def __init__(self, limits):
+ """Create a rate limiter.
+
+ limits: a dict mapping from action name to a tuple. The tuple contains
+ the number of times the action may be performed, and the time period
+ (in seconds) during which the number must not be exceeded for this
+ action. Example: dict(reboot=(10, ratelimiting.PER_MINUTE)) would
+ allow 10 'reboot' actions per minute.
+ """
+ self.limits = limits
+ self._levels = {}
+
+ def perform(self, action_name, username='nobody'):
+ """Attempt to perform an action by the given username.
+
+ action_name: the string name of the action to perform. This must
+ be a key in the limits dict passed to the ctor.
+
+ username: an optional string name of the user performing the action.
+ Each user has her own set of rate limiting counters. Defaults to
+ 'nobody' (so that if you never specify a username when calling
+ perform(), a single set of counters will be used.)
+
+ Return None if the action may proceed. If the action may not proceed
+ because it has been rate limited, return the float number of seconds
+ until the action would succeed.
+ """
+ # Think of rate limiting as a bucket leaking water at 1cc/second. The
+ # bucket can hold as many ccs as there are seconds in the rate
+ # limiting period (e.g. 3600 for per-hour ratelimits), and if you can
+ # perform N actions in that time, each action fills the bucket by
+ # 1/Nth of its volume. You may only perform an action if the bucket
+ # would not overflow.
+ now = time.time()
+ key = '%s:%s' % (username, action_name)
+ last_time_performed, water_level = self._levels.get(key, (now, 0))
+ # The bucket leaks 1cc/second.
+ water_level -= (now - last_time_performed)
+ if water_level < 0:
+ water_level = 0
+ num_allowed_per_period, period_in_secs = self.limits[action_name]
+ # Fill the bucket by 1/Nth its capacity, and hope it doesn't overflow.
+ capacity = period_in_secs
+ new_level = water_level + (capacity * 1.0 / num_allowed_per_period)
+ if new_level > capacity:
+ # Delay this many seconds.
+ return new_level - capacity
+ self._levels[key] = (now, new_level)
+ return None
+
+
+# If one instance of this WSGIApps is unable to handle your load, put a
+# sharding app in front that shards by username to one of many backends.
+
+class WSGIApp(object):
+
+ """Application that tracks rate limits in memory. Send requests to it of
+ this form:
+
+ POST /limiter/<username>/<urlencoded action>
+
+ and receive a 200 OK, or a 403 Forbidden with an X-Wait-Seconds header
+ containing the number of seconds to wait before the action would succeed.
+ """
+
+ def __init__(self, limiter):
+ """Create the WSGI application using the given Limiter instance."""
+ self.limiter = limiter
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ parts = req.path_info.split('/')
+ # format: /limiter/<username>/<urlencoded action>
+ if req.method != 'POST':
+ raise webob.exc.HTTPMethodNotAllowed()
+ if len(parts) != 4 or parts[1] != 'limiter':
+ raise webob.exc.HTTPNotFound()
+ username = parts[2]
+ action_name = urllib.unquote(parts[3])
+ delay = self.limiter.perform(action_name, username)
+ if delay:
+ return webob.exc.HTTPForbidden(
+ headers={'X-Wait-Seconds': "%.2f" % delay})
+ else:
+ return '' # 200 OK
+
+
+class WSGIAppProxy(object):
+
+ """Limiter lookalike that proxies to a ratelimiting.WSGIApp."""
+
+ def __init__(self, service_host):
+ """Creates a proxy pointing to a ratelimiting.WSGIApp at the given
+ host."""
+ self.service_host = service_host
+
+ def perform(self, action, username='nobody'):
+ conn = httplib.HTTPConnection(self.service_host)
+ conn.request('POST', '/limiter/%s/%s' % (username, action))
+ resp = conn.getresponse()
+ if resp.status == 200:
+ return None # no delay
+ return float(resp.getheader('X-Wait-Seconds'))
diff --git a/nova/api/openstack/servers.py b/nova/api/openstack/servers.py
new file mode 100644
index 000000000..cb5132635
--- /dev/null
+++ b/nova/api/openstack/servers.py
@@ -0,0 +1,273 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import time
+
+import webob
+from webob import exc
+
+from nova import flags
+from nova import rpc
+from nova import utils
+from nova import wsgi
+from nova.api import cloud
+from nova.api.openstack import context
+from nova.api.openstack import faults
+from nova.compute import instance_types
+from nova.compute import power_state
+import nova.api.openstack
+import nova.image.service
+
+FLAGS = flags.FLAGS
+
+def _filter_params(inst_dict):
+ """ Extracts all updatable parameters for a server update request """
+ keys = dict(name='name', admin_pass='adminPass')
+ new_attrs = {}
+ for k, v in keys.items():
+ if inst_dict.has_key(v):
+ new_attrs[k] = inst_dict[v]
+ return new_attrs
+
+def _entity_list(entities):
+ """ Coerces a list of servers into proper dictionary format """
+ return dict(servers=entities)
+
+def _entity_detail(inst):
+ """ Maps everything to Rackspace-like attributes for return"""
+ power_mapping = {
+ power_state.NOSTATE: 'build',
+ power_state.RUNNING: 'active',
+ power_state.BLOCKED: 'active',
+ power_state.PAUSED: 'suspended',
+ power_state.SHUTDOWN: 'active',
+ power_state.SHUTOFF: 'active',
+ power_state.CRASHED: 'error'
+ }
+ inst_dict = {}
+
+ mapped_keys = dict(status='state', imageId='image_id',
+ flavorId='instance_type', name='server_name', id='id')
+
+ for k, v in mapped_keys.iteritems():
+ inst_dict[k] = inst[v]
+
+ inst_dict['status'] = power_mapping[inst_dict['status']]
+ inst_dict['addresses'] = dict(public=[], private=[])
+ inst_dict['metadata'] = {}
+ inst_dict['hostId'] = ''
+
+ return dict(server=inst_dict)
+
+def _entity_inst(inst):
+ """ Filters all model attributes save for id and name """
+ return dict(server=dict(id=inst['id'], name=inst['server_name']))
+
+class Controller(wsgi.Controller):
+ """ The Server API controller for the OpenStack API """
+
+ _serialization_metadata = {
+ 'application/xml': {
+ "attributes": {
+ "server": [ "id", "imageId", "name", "flavorId", "hostId",
+ "status", "progress", "progress" ]
+ }
+ }
+ }
+
+ def __init__(self, db_driver=None):
+ if not db_driver:
+ db_driver = FLAGS.db_driver
+ self.db_driver = utils.import_object(db_driver)
+ super(Controller, self).__init__()
+
+ def index(self, req):
+ """ Returns a list of server names and ids for a given user """
+ return self._items(req, entity_maker=_entity_inst)
+
+ def detail(self, req):
+ """ Returns a list of server details for a given user """
+ return self._items(req, entity_maker=_entity_detail)
+
+ def _items(self, req, entity_maker):
+ """Returns a list of servers for a given user.
+
+ entity_maker - either _entity_detail or _entity_inst
+ """
+ user_id = req.environ['nova.context']['user']['id']
+ instance_list = self.db_driver.instance_get_all_by_user(None, user_id)
+ limited_list = nova.api.openstack.limited(instance_list, req)
+ res = [entity_maker(inst)['server'] for inst in limited_list]
+ return _entity_list(res)
+
+ def show(self, req, id):
+ """ Returns server details by server id """
+ user_id = req.environ['nova.context']['user']['id']
+ inst = self.db_driver.instance_get_by_internal_id(None, int(id))
+ if inst:
+ if inst.user_id == user_id:
+ return _entity_detail(inst)
+ raise faults.Fault(exc.HTTPNotFound())
+
+ def delete(self, req, id):
+ """ Destroys a server """
+ user_id = req.environ['nova.context']['user']['id']
+ instance = self.db_driver.instance_get_by_internal_id(None, int(id))
+ if instance and instance['user_id'] == user_id:
+ self.db_driver.instance_destroy(None, id)
+ return faults.Fault(exc.HTTPAccepted())
+ return faults.Fault(exc.HTTPNotFound())
+
+ def create(self, req):
+ """ Creates a new server for a given user """
+
+ env = self._deserialize(req.body, req)
+ if not env:
+ return faults.Fault(exc.HTTPUnprocessableEntity())
+
+ #try:
+ inst = self._build_server_instance(req, env)
+ #except Exception, e:
+ # return faults.Fault(exc.HTTPUnprocessableEntity())
+
+ rpc.cast(
+ FLAGS.compute_topic, {
+ "method": "run_instance",
+ "args": {"instance_id": inst['id']}})
+ return _entity_inst(inst)
+
+ def update(self, req, id):
+ """ Updates the server name or password """
+ user_id = req.environ['nova.context']['user']['id']
+
+ inst_dict = self._deserialize(req.body, req)
+
+ if not inst_dict:
+ return faults.Fault(exc.HTTPUnprocessableEntity())
+
+ instance = self.db_driver.instance_get_by_internal_id(None, int(id))
+ if not instance or instance.user_id != user_id:
+ return faults.Fault(exc.HTTPNotFound())
+
+ self.db_driver.instance_update(None, int(id),
+ _filter_params(inst_dict['server']))
+ return faults.Fault(exc.HTTPNoContent())
+
+ def action(self, req, id):
+ """ multi-purpose method used to reboot, rebuild, and
+ resize a server """
+ user_id = req.environ['nova.context']['user']['id']
+ input_dict = self._deserialize(req.body, req)
+ try:
+ reboot_type = input_dict['reboot']['type']
+ except Exception:
+ raise faults.Fault(webob.exc.HTTPNotImplemented())
+ inst_ref = self.db.instance_get_by_internal_id(None, int(id))
+ if not inst_ref or (inst_ref and not inst_ref.user_id == user_id):
+ return faults.Fault(exc.HTTPUnprocessableEntity())
+ cloud.reboot(id)
+
+ def _build_server_instance(self, req, env):
+ """Build instance data structure and save it to the data store."""
+ ltime = time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime())
+ inst = {}
+
+ user_id = req.environ['nova.context']['user']['id']
+
+ flavor_id = env['server']['flavorId']
+
+ instance_type, flavor = [(k, v) for k, v in
+ instance_types.INSTANCE_TYPES.iteritems()
+ if v['flavorid'] == flavor_id][0]
+
+ image_id = env['server']['imageId']
+ img_service = utils.import_object(FLAGS.image_service)
+
+ image = img_service.show(image_id)
+
+ if not image:
+ raise Exception, "Image not found"
+
+ inst['server_name'] = env['server']['name']
+ inst['image_id'] = image_id
+ inst['user_id'] = user_id
+ inst['launch_time'] = ltime
+ inst['mac_address'] = utils.generate_mac()
+ inst['project_id'] = user_id
+
+ inst['state_description'] = 'scheduling'
+ inst['kernel_id'] = image.get('kernelId', FLAGS.default_kernel)
+ inst['ramdisk_id'] = image.get('ramdiskId', FLAGS.default_ramdisk)
+ inst['reservation_id'] = utils.generate_uid('r')
+
+ inst['display_name'] = env['server']['name']
+ inst['display_description'] = env['server']['name']
+
+ #TODO(dietz) this may be ill advised
+ key_pair_ref = self.db_driver.key_pair_get_all_by_user(
+ None, user_id)[0]
+
+ inst['key_data'] = key_pair_ref['public_key']
+ inst['key_name'] = key_pair_ref['name']
+
+ #TODO(dietz) stolen from ec2 api, see TODO there
+ inst['security_group'] = 'default'
+
+ # Flavor related attributes
+ inst['instance_type'] = instance_type
+ inst['memory_mb'] = flavor['memory_mb']
+ inst['vcpus'] = flavor['vcpus']
+ inst['local_gb'] = flavor['local_gb']
+
+ ref = self.db_driver.instance_create(None, inst)
+ inst['id'] = ref.internal_id
+ # TODO(dietz): this isn't explicitly necessary, but the networking
+ # calls depend on an object with a project_id property, and therefore
+ # should be cleaned up later
+ api_context = context.APIRequestContext(user_id)
+
+ inst['mac_address'] = utils.generate_mac()
+
+ #TODO(dietz) is this necessary?
+ inst['launch_index'] = 0
+
+ inst['hostname'] = str(ref.internal_id)
+ self.db_driver.instance_update(api_context, inst['id'], inst)
+
+ network_manager = utils.import_object(FLAGS.network_manager)
+ address = network_manager.allocate_fixed_ip(api_context,
+ inst['id'])
+
+ # TODO(vish): This probably should be done in the scheduler
+ # network is setup when host is assigned
+ network_topic = self._get_network_topic(api_context, network_manager)
+ rpc.call(network_topic,
+ {"method": "setup_fixed_ip",
+ "args": {"context": api_context,
+ "address": address}})
+ return inst
+
+ def _get_network_topic(self, context, network_manager):
+ """Retrieves the network host for a project"""
+ network_ref = network_manager.get_network(context)
+ host = network_ref['host']
+ if not host:
+ host = rpc.call(FLAGS.network_topic,
+ {"method": "set_network_host",
+ "args": {"context": context,
+ "network_id": network_ref['id']}})
+ return self.db_driver.queue_get_for(None, FLAGS.network_topic, host)
diff --git a/nova/api/rackspace/sharedipgroups.py b/nova/api/openstack/sharedipgroups.py
index 986f11434..4d2d0ede1 100644
--- a/nova/api/rackspace/sharedipgroups.py
+++ b/nova/api/openstack/sharedipgroups.py
@@ -15,4 +15,6 @@
# License for the specific language governing permissions and limitations
# under the License.
-class Controller(object): pass
+from nova import wsgi
+
+class Controller(wsgi.Controller): pass
diff --git a/nova/api/rackspace/__init__.py b/nova/api/rackspace/__init__.py
deleted file mode 100644
index 27e78f801..000000000
--- a/nova/api/rackspace/__init__.py
+++ /dev/null
@@ -1,81 +0,0 @@
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2010 United States Government as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""
-WSGI middleware for Rackspace API controllers.
-"""
-
-import json
-import time
-
-import routes
-import webob.dec
-import webob.exc
-
-from nova import flags
-from nova import wsgi
-from nova.api.rackspace import flavors
-from nova.api.rackspace import images
-from nova.api.rackspace import servers
-from nova.api.rackspace import sharedipgroups
-from nova.auth import manager
-
-
-class API(wsgi.Middleware):
- """WSGI entry point for all Rackspace API requests."""
-
- def __init__(self):
- app = AuthMiddleware(APIRouter())
- super(API, self).__init__(app)
-
-
-class AuthMiddleware(wsgi.Middleware):
- """Authorize the rackspace API request or return an HTTP Forbidden."""
-
- #TODO(gundlach): isn't this the old Nova API's auth? Should it be replaced
- #with correct RS API auth?
-
- @webob.dec.wsgify
- def __call__(self, req):
- context = {}
- if "HTTP_X_AUTH_TOKEN" in req.environ:
- context['user'] = manager.AuthManager().get_user_from_access_key(
- req.environ['HTTP_X_AUTH_TOKEN'])
- if context['user']:
- context['project'] = manager.AuthManager().get_project(
- context['user'].name)
- if "user" not in context:
- return webob.exc.HTTPForbidden()
- req.environ['nova.context'] = context
- return self.application
-
-
-class APIRouter(wsgi.Router):
- """
- Routes requests on the Rackspace API to the appropriate controller
- and method.
- """
-
- def __init__(self):
- mapper = routes.Mapper()
- mapper.resource("server", "servers", controller=servers.Controller())
- mapper.resource("image", "images", controller=images.Controller())
- mapper.resource("flavor", "flavors", controller=flavors.Controller())
- mapper.resource("sharedipgroup", "sharedipgroups",
- controller=sharedipgroups.Controller())
- super(APIRouter, self).__init__(mapper)
diff --git a/nova/api/rackspace/servers.py b/nova/api/rackspace/servers.py
deleted file mode 100644
index 25d1fe9c8..000000000
--- a/nova/api/rackspace/servers.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2010 OpenStack LLC.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-from nova import rpc
-from nova.compute import model as compute
-from nova.api.rackspace import base
-
-
-class Controller(base.Controller):
- entity_name = 'servers'
-
- def index(self, **kwargs):
- instances = []
- for inst in compute.InstanceDirectory().all:
- instances.append(instance_details(inst))
-
- def show(self, **kwargs):
- instance_id = kwargs['id']
- return compute.InstanceDirectory().get(instance_id)
-
- def delete(self, **kwargs):
- instance_id = kwargs['id']
- instance = compute.InstanceDirectory().get(instance_id)
- if not instance:
- raise ServerNotFound("The requested server was not found")
- instance.destroy()
- return True
-
- def create(self, **kwargs):
- inst = self.build_server_instance(kwargs['server'])
- rpc.cast(
- FLAGS.compute_topic, {
- "method": "run_instance",
- "args": {"instance_id": inst.instance_id}})
-
- def update(self, **kwargs):
- instance_id = kwargs['id']
- instance = compute.InstanceDirectory().get(instance_id)
- if not instance:
- raise ServerNotFound("The requested server was not found")
- instance.update(kwargs['server'])
- instance.save()
-
- def build_server_instance(self, env):
- """Build instance data structure and save it to the data store."""
- reservation = utils.generate_uid('r')
- ltime = time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime())
- inst = self.instdir.new()
- inst['name'] = env['server']['name']
- inst['image_id'] = env['server']['imageId']
- inst['instance_type'] = env['server']['flavorId']
- inst['user_id'] = env['user']['id']
- inst['project_id'] = env['project']['id']
- inst['reservation_id'] = reservation
- inst['launch_time'] = ltime
- inst['mac_address'] = utils.generate_mac()
- address = self.network.allocate_ip(
- inst['user_id'],
- inst['project_id'],
- mac=inst['mac_address'])
- inst['private_dns_name'] = str(address)
- inst['bridge_name'] = network.BridgedNetwork.get_network_for_project(
- inst['user_id'],
- inst['project_id'],
- 'default')['bridge_name']
- # key_data, key_name, ami_launch_index
- # TODO(todd): key data or root password
- inst.save()
- return inst
diff --git a/nova/auth/dbdriver.py b/nova/auth/dbdriver.py
new file mode 100644
index 000000000..09d15018b
--- /dev/null
+++ b/nova/auth/dbdriver.py
@@ -0,0 +1,236 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Auth driver using the DB as its backend.
+"""
+
+import logging
+import sys
+
+from nova import exception
+from nova import db
+
+
+class DbDriver(object):
+ """DB Auth driver
+
+ Defines enter and exit and therefore supports the with/as syntax.
+ """
+
+ def __init__(self):
+ """Imports the LDAP module"""
+ pass
+ db
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ pass
+
+ def get_user(self, uid):
+ """Retrieve user by id"""
+ return self._db_user_to_auth_user(db.user_get({}, uid))
+
+ def get_user_from_access_key(self, access):
+ """Retrieve user by access key"""
+ return self._db_user_to_auth_user(db.user_get_by_access_key({}, access))
+
+ def get_project(self, pid):
+ """Retrieve project by id"""
+ return self._db_project_to_auth_projectuser(db.project_get({}, pid))
+
+ def get_users(self):
+ """Retrieve list of users"""
+ return [self._db_user_to_auth_user(user) for user in db.user_get_all({})]
+
+ def get_projects(self, uid=None):
+ """Retrieve list of projects"""
+ if uid:
+ result = db.project_get_by_user({}, uid)
+ else:
+ result = db.project_get_all({})
+ return [self._db_project_to_auth_projectuser(proj) for proj in result]
+
+ def create_user(self, name, access_key, secret_key, is_admin):
+ """Create a user"""
+ values = { 'id' : name,
+ 'access_key' : access_key,
+ 'secret_key' : secret_key,
+ 'is_admin' : is_admin
+ }
+ try:
+ user_ref = db.user_create({}, values)
+ return self._db_user_to_auth_user(user_ref)
+ except exception.Duplicate, e:
+ raise exception.Duplicate('User %s already exists' % name)
+
+ def _db_user_to_auth_user(self, user_ref):
+ return { 'id' : user_ref['id'],
+ 'name' : user_ref['id'],
+ 'access' : user_ref['access_key'],
+ 'secret' : user_ref['secret_key'],
+ 'admin' : user_ref['is_admin'] }
+
+ def _db_project_to_auth_projectuser(self, project_ref):
+ return { 'id' : project_ref['id'],
+ 'name' : project_ref['name'],
+ 'project_manager_id' : project_ref['project_manager'],
+ 'description' : project_ref['description'],
+ 'member_ids' : [member['id'] for member in project_ref['members']] }
+
+ def create_project(self, name, manager_uid,
+ description=None, member_uids=None):
+ """Create a project"""
+ manager = db.user_get({}, manager_uid)
+ if not manager:
+ raise exception.NotFound("Project can't be created because "
+ "manager %s doesn't exist" % manager_uid)
+
+ # description is a required attribute
+ if description is None:
+ description = name
+
+ # First, we ensure that all the given users exist before we go
+ # on to create the project. This way we won't have to destroy
+ # the project again because a user turns out to be invalid.
+ members = set([manager])
+ if member_uids != None:
+ for member_uid in member_uids:
+ member = db.user_get({}, member_uid)
+ if not member:
+ raise exception.NotFound("Project can't be created "
+ "because user %s doesn't exist"
+ % member_uid)
+ members.add(member)
+
+ values = { 'id' : name,
+ 'name' : name,
+ 'project_manager' : manager['id'],
+ 'description': description }
+
+ try:
+ project = db.project_create({}, values)
+ except exception.Duplicate:
+ raise exception.Duplicate("Project can't be created because "
+ "project %s already exists" % name)
+
+ for member in members:
+ db.project_add_member({}, project['id'], member['id'])
+
+ # This looks silly, but ensures that the members element has been
+ # correctly populated
+ project_ref = db.project_get({}, project['id'])
+ return self._db_project_to_auth_projectuser(project_ref)
+
+ def modify_project(self, project_id, manager_uid=None, description=None):
+ """Modify an existing project"""
+ if not manager_uid and not description:
+ return
+ values = {}
+ if manager_uid:
+ manager = db.user_get({}, manager_uid)
+ if not manager:
+ raise exception.NotFound("Project can't be modified because "
+ "manager %s doesn't exist" %
+ manager_uid)
+ values['project_manager'] = manager['id']
+ if description:
+ values['description'] = description
+
+ db.project_update({}, project_id, values)
+
+ def add_to_project(self, uid, project_id):
+ """Add user to project"""
+ user, project = self._validate_user_and_project(uid, project_id)
+ db.project_add_member({}, project['id'], user['id'])
+
+ def remove_from_project(self, uid, project_id):
+ """Remove user from project"""
+ user, project = self._validate_user_and_project(uid, project_id)
+ db.project_remove_member({}, project['id'], user['id'])
+
+ def is_in_project(self, uid, project_id):
+ """Check if user is in project"""
+ user, project = self._validate_user_and_project(uid, project_id)
+ return user in project.members
+
+ def has_role(self, uid, role, project_id=None):
+ """Check if user has role
+
+ If project is specified, it checks for local role, otherwise it
+ checks for global role
+ """
+
+ return role in self.get_user_roles(uid, project_id)
+
+ def add_role(self, uid, role, project_id=None):
+ """Add role for user (or user and project)"""
+ if not project_id:
+ db.user_add_role({}, uid, role)
+ return
+ db.user_add_project_role({}, uid, project_id, role)
+
+ def remove_role(self, uid, role, project_id=None):
+ """Remove role for user (or user and project)"""
+ if not project_id:
+ db.user_remove_role({}, uid, role)
+ return
+ db.user_remove_project_role({}, uid, project_id, role)
+
+ def get_user_roles(self, uid, project_id=None):
+ """Retrieve list of roles for user (or user and project)"""
+ if project_id is None:
+ roles = db.user_get_roles({}, uid)
+ return roles
+ else:
+ roles = db.user_get_roles_for_project({}, uid, project_id)
+ return roles
+
+ def delete_user(self, id):
+ """Delete a user"""
+ user = db.user_get({}, id)
+ db.user_delete({}, user['id'])
+
+ def delete_project(self, project_id):
+ """Delete a project"""
+ db.project_delete({}, project_id)
+
+ def modify_user(self, uid, access_key=None, secret_key=None, admin=None):
+ """Modify an existing user"""
+ if not access_key and not secret_key and admin is None:
+ return
+ values = {}
+ if access_key:
+ values['access_key'] = access_key
+ if secret_key:
+ values['secret_key'] = secret_key
+ if admin is not None:
+ values['is_admin'] = admin
+ db.user_update({}, uid, values)
+
+ def _validate_user_and_project(self, user_id, project_id):
+ user = db.user_get({}, user_id)
+ if not user:
+ raise exception.NotFound('User "%s" not found' % user_id)
+ project = db.project_get({}, project_id)
+ if not project:
+ raise exception.NotFound('Project "%s" not found' % project_id)
+ return user, project
+
diff --git a/nova/auth/fakeldap.py b/nova/auth/fakeldap.py
index bc744fa01..2791dfde6 100644
--- a/nova/auth/fakeldap.py
+++ b/nova/auth/fakeldap.py
@@ -30,20 +30,24 @@ from nova import datastore
SCOPE_BASE = 0
SCOPE_ONELEVEL = 1 # not implemented
-SCOPE_SUBTREE = 2
+SCOPE_SUBTREE = 2
MOD_ADD = 0
MOD_DELETE = 1
+MOD_REPLACE = 2
-class NO_SUCH_OBJECT(Exception):
+class NO_SUCH_OBJECT(Exception): # pylint: disable-msg=C0103
+ """Duplicate exception class from real LDAP module."""
pass
-class OBJECT_CLASS_VIOLATION(Exception):
+class OBJECT_CLASS_VIOLATION(Exception): # pylint: disable-msg=C0103
+ """Duplicate exception class from real LDAP module."""
pass
-def initialize(uri):
+def initialize(_uri):
+ """Opens a fake connection with an LDAP server."""
return FakeLDAP()
@@ -68,7 +72,7 @@ def _match_query(query, attrs):
# cut off the ! and the nested parentheses
return not _match_query(query[2:-1], attrs)
- (k, sep, v) = inner.partition('=')
+ (k, _sep, v) = inner.partition('=')
return _match(k, v, attrs)
@@ -85,20 +89,20 @@ def _paren_groups(source):
if source[pos] == ')':
count -= 1
if count == 0:
- result.append(source[start:pos+1])
+ result.append(source[start:pos + 1])
return result
-def _match(k, v, attrs):
+def _match(key, value, attrs):
"""Match a given key and value against an attribute list."""
- if k not in attrs:
+ if key not in attrs:
return False
- if k != "objectclass":
- return v in attrs[k]
+ if key != "objectclass":
+ return value in attrs[key]
# it is an objectclass check, so check subclasses
- values = _subs(v)
- for value in values:
- if value in attrs[k]:
+ values = _subs(value)
+ for v in values:
+ if v in attrs[key]:
return True
return False
@@ -145,6 +149,7 @@ def _to_json(unencoded):
class FakeLDAP(object):
#TODO(vish): refactor this class to use a wrapper instead of accessing
# redis directly
+ """Fake LDAP connection."""
def simple_bind_s(self, dn, password):
"""This method is ignored, but provided for compatibility."""
@@ -171,7 +176,7 @@ class FakeLDAP(object):
Args:
dn -- a dn
attrs -- a list of tuples in the following form:
- ([MOD_ADD | MOD_DELETE], attribute, value)
+ ([MOD_ADD | MOD_DELETE | MOD_REPACE], attribute, value)
"""
redis = datastore.Redis.instance()
@@ -181,6 +186,8 @@ class FakeLDAP(object):
values = _from_json(redis.hget(key, k))
if cmd == MOD_ADD:
values.append(v)
+ elif cmd == MOD_REPLACE:
+ values = [v]
else:
values.remove(v)
values = redis.hset(key, k, _to_json(values))
@@ -207,6 +214,7 @@ class FakeLDAP(object):
# get the attributes from redis
attrs = redis.hgetall(key)
# turn the values from redis into lists
+ # pylint: disable-msg=E1103
attrs = dict([(k, _from_json(v))
for k, v in attrs.iteritems()])
# filter the objects by query
@@ -215,12 +223,12 @@ class FakeLDAP(object):
attrs = dict([(k, v) for k, v in attrs.iteritems()
if not fields or k in fields])
objects.append((key[len(self.__redis_prefix):], attrs))
+ # pylint: enable-msg=E1103
if objects == []:
raise NO_SUCH_OBJECT()
return objects
@property
- def __redis_prefix(self):
+ def __redis_prefix(self): # pylint: disable-msg=R0201
+ """Get the prefix to use for all redis keys."""
return 'ldap:'
-
-
diff --git a/nova/auth/ldapdriver.py b/nova/auth/ldapdriver.py
index 6bf7fcd1e..640ea169e 100644
--- a/nova/auth/ldapdriver.py
+++ b/nova/auth/ldapdriver.py
@@ -34,7 +34,7 @@ from nova import flags
FLAGS = flags.FLAGS
flags.DEFINE_string('ldap_url', 'ldap://localhost',
'Point this at your ldap server')
-flags.DEFINE_string('ldap_password', 'changeme', 'LDAP password')
+flags.DEFINE_string('ldap_password', 'changeme', 'LDAP password')
flags.DEFINE_string('ldap_user_dn', 'cn=Manager,dc=example,dc=com',
'DN of admin user')
flags.DEFINE_string('ldap_user_unit', 'Users', 'OID for Users')
@@ -63,14 +63,18 @@ flags.DEFINE_string('ldap_developer',
# to define a set interface for AuthDrivers. I'm delaying
# creating this now because I'm expecting an auth refactor
# in which we may want to change the interface a bit more.
+
+
class LdapDriver(object):
"""Ldap Auth driver
Defines enter and exit and therefore supports the with/as syntax.
"""
+
def __init__(self):
"""Imports the LDAP module"""
self.ldap = __import__('ldap')
+ self.conn = None
def __enter__(self):
"""Creates the connection to LDAP"""
@@ -78,7 +82,7 @@ class LdapDriver(object):
self.conn.simple_bind_s(FLAGS.ldap_user_dn, FLAGS.ldap_password)
return self
- def __exit__(self, type, value, traceback):
+ def __exit__(self, exc_type, exc_value, traceback):
"""Destroys the connection to LDAP"""
self.conn.unbind_s()
return False
@@ -95,13 +99,6 @@ class LdapDriver(object):
dn = FLAGS.ldap_user_subtree
return self.__to_user(self.__find_object(dn, query))
- def get_key_pair(self, uid, key_name):
- """Retrieve key pair by uid and key name"""
- dn = 'cn=%s,%s' % (key_name,
- self.__uid_to_dn(uid))
- attr = self.__find_object(dn, '(objectclass=novaKeyPair)')
- return self.__to_key_pair(uid, attr)
-
def get_project(self, pid):
"""Retrieve project by id"""
dn = 'cn=%s,%s' % (pid,
@@ -115,19 +112,13 @@ class LdapDriver(object):
'(objectclass=novaUser)')
return [self.__to_user(attr) for attr in attrs]
- def get_key_pairs(self, uid):
- """Retrieve list of key pairs"""
- attrs = self.__find_objects(self.__uid_to_dn(uid),
- '(objectclass=novaKeyPair)')
- return [self.__to_key_pair(uid, attr) for attr in attrs]
-
def get_projects(self, uid=None):
"""Retrieve list of projects"""
- filter = '(objectclass=novaProject)'
+ pattern = '(objectclass=novaProject)'
if uid:
- filter = "(&%s(member=%s))" % (filter, self.__uid_to_dn(uid))
+ pattern = "(&%s(member=%s))" % (pattern, self.__uid_to_dn(uid))
attrs = self.__find_objects(FLAGS.ldap_project_subtree,
- filter)
+ pattern)
return [self.__to_project(attr) for attr in attrs]
def create_user(self, name, access_key, secret_key, is_admin):
@@ -150,21 +141,6 @@ class LdapDriver(object):
self.conn.add_s(self.__uid_to_dn(name), attr)
return self.__to_user(dict(attr))
- def create_key_pair(self, uid, key_name, public_key, fingerprint):
- """Create a key pair"""
- # TODO(vish): possibly refactor this to store keys in their own ou
- # and put dn reference in the user object
- attr = [
- ('objectclass', ['novaKeyPair']),
- ('cn', [key_name]),
- ('sshPublicKey', [public_key]),
- ('keyFingerprint', [fingerprint]),
- ]
- self.conn.add_s('cn=%s,%s' % (key_name,
- self.__uid_to_dn(uid)),
- attr)
- return self.__to_key_pair(uid, dict(attr))
-
def create_project(self, name, manager_uid,
description=None, member_uids=None):
"""Create a project"""
@@ -194,11 +170,28 @@ class LdapDriver(object):
('cn', [name]),
('description', [description]),
('projectManager', [manager_dn]),
- ('member', members)
- ]
+ ('member', members)]
self.conn.add_s('cn=%s,%s' % (name, FLAGS.ldap_project_subtree), attr)
return self.__to_project(dict(attr))
+ def modify_project(self, project_id, manager_uid=None, description=None):
+ """Modify an existing project"""
+ if not manager_uid and not description:
+ return
+ attr = []
+ if manager_uid:
+ if not self.__user_exists(manager_uid):
+ raise exception.NotFound("Project can't be modified because "
+ "manager %s doesn't exist" %
+ manager_uid)
+ manager_dn = self.__uid_to_dn(manager_uid)
+ attr.append((self.ldap.MOD_REPLACE, 'projectManager', manager_dn))
+ if description:
+ attr.append((self.ldap.MOD_REPLACE, 'description', description))
+ self.conn.modify_s('cn=%s,%s' % (project_id,
+ FLAGS.ldap_project_subtree),
+ attr)
+
def add_to_project(self, uid, project_id):
"""Add user to project"""
dn = 'cn=%s,%s' % (project_id, FLAGS.ldap_project_subtree)
@@ -262,18 +255,8 @@ class LdapDriver(object):
"""Delete a user"""
if not self.__user_exists(uid):
raise exception.NotFound("User %s doesn't exist" % uid)
- self.__delete_key_pairs(uid)
self.__remove_from_all(uid)
- self.conn.delete_s('uid=%s,%s' % (uid,
- FLAGS.ldap_user_subtree))
-
- def delete_key_pair(self, uid, key_name):
- """Delete a key pair"""
- if not self.__key_pair_exists(uid, key_name):
- raise exception.NotFound("Key Pair %s doesn't exist for user %s" %
- (key_name, uid))
- self.conn.delete_s('cn=%s,uid=%s,%s' % (key_name, uid,
- FLAGS.ldap_user_subtree))
+ self.conn.delete_s(self.__uid_to_dn(uid))
def delete_project(self, project_id):
"""Delete a project"""
@@ -281,15 +264,23 @@ class LdapDriver(object):
self.__delete_roles(project_dn)
self.__delete_group(project_dn)
+ def modify_user(self, uid, access_key=None, secret_key=None, admin=None):
+ """Modify an existing project"""
+ if not access_key and not secret_key and admin is None:
+ return
+ attr = []
+ if access_key:
+ attr.append((self.ldap.MOD_REPLACE, 'accessKey', access_key))
+ if secret_key:
+ attr.append((self.ldap.MOD_REPLACE, 'secretKey', secret_key))
+ if admin is not None:
+ attr.append((self.ldap.MOD_REPLACE, 'isAdmin', str(admin).upper()))
+ self.conn.modify_s(self.__uid_to_dn(uid), attr)
+
def __user_exists(self, uid):
"""Check if user exists"""
return self.get_user(uid) != None
- def __key_pair_exists(self, uid, key_name):
- """Check if key pair exists"""
- return self.get_user(uid) != None
- return self.get_key_pair(uid, key_name) != None
-
def __project_exists(self, project_id):
"""Check if project exists"""
return self.get_project(project_id) != None
@@ -310,7 +301,7 @@ class LdapDriver(object):
except self.ldap.NO_SUCH_OBJECT:
return []
# just return the DNs
- return [dn for dn, attributes in res]
+ return [dn for dn, _attributes in res]
def __find_objects(self, dn, query=None, scope=None):
"""Find objects by query"""
@@ -339,14 +330,8 @@ class LdapDriver(object):
"""Check if group exists"""
return self.__find_object(dn, '(objectclass=groupOfNames)') != None
- def __delete_key_pairs(self, uid):
- """Delete all key pairs for user"""
- keys = self.get_key_pairs(uid)
- if keys != None:
- for key in keys:
- self.delete_key_pair(uid, key['name'])
-
- def __role_to_dn(self, role, project_id=None):
+ @staticmethod
+ def __role_to_dn(role, project_id=None):
"""Convert role to corresponding dn"""
if project_id == None:
return FLAGS.__getitem__("ldap_%s" % role).value
@@ -356,7 +341,7 @@ class LdapDriver(object):
FLAGS.ldap_project_subtree)
def __create_group(self, group_dn, name, uid,
- description, member_uids = None):
+ description, member_uids=None):
"""Create a group"""
if self.__group_exists(group_dn):
raise exception.Duplicate("Group can't be created because "
@@ -375,8 +360,7 @@ class LdapDriver(object):
('objectclass', ['groupOfNames']),
('cn', [name]),
('description', [description]),
- ('member', members)
- ]
+ ('member', members)]
self.conn.add_s(group_dn, attr)
def __is_in_group(self, uid, group_dn):
@@ -402,9 +386,7 @@ class LdapDriver(object):
if self.__is_in_group(uid, group_dn):
raise exception.Duplicate("User %s is already a member of "
"the group %s" % (uid, group_dn))
- attr = [
- (self.ldap.MOD_ADD, 'member', self.__uid_to_dn(uid))
- ]
+ attr = [(self.ldap.MOD_ADD, 'member', self.__uid_to_dn(uid))]
self.conn.modify_s(group_dn, attr)
def __remove_from_group(self, uid, group_dn):
@@ -432,7 +414,7 @@ class LdapDriver(object):
self.conn.modify_s(group_dn, attr)
except self.ldap.OBJECT_CLASS_VIOLATION:
logging.debug("Attempted to remove the last member of a group. "
- "Deleting the group at %s instead." % group_dn )
+ "Deleting the group at %s instead.", group_dn)
self.__delete_group(group_dn)
def __remove_from_all(self, uid):
@@ -440,7 +422,6 @@ class LdapDriver(object):
if not self.__user_exists(uid):
raise exception.NotFound("User %s can't be removed from all "
"because the user doesn't exist" % (uid,))
- dn = self.__uid_to_dn(uid)
role_dns = self.__find_group_dns_with_member(
FLAGS.role_project_subtree, uid)
for role_dn in role_dns:
@@ -448,7 +429,7 @@ class LdapDriver(object):
project_dns = self.__find_group_dns_with_member(
FLAGS.ldap_project_subtree, uid)
for project_dn in project_dns:
- self.__safe_remove_from_group(uid, role_dn)
+ self.__safe_remove_from_group(uid, project_dn)
def __delete_group(self, group_dn):
"""Delete Group"""
@@ -461,7 +442,8 @@ class LdapDriver(object):
for role_dn in self.__find_role_dns(project_dn):
self.__delete_group(role_dn)
- def __to_user(self, attr):
+ @staticmethod
+ def __to_user(attr):
"""Convert ldap attributes to User object"""
if attr == None:
return None
@@ -470,20 +452,7 @@ class LdapDriver(object):
'name': attr['cn'][0],
'access': attr['accessKey'][0],
'secret': attr['secretKey'][0],
- 'admin': (attr['isAdmin'][0] == 'TRUE')
- }
-
- def __to_key_pair(self, owner, attr):
- """Convert ldap attributes to KeyPair object"""
- if attr == None:
- return None
- return {
- 'id': attr['cn'][0],
- 'name': attr['cn'][0],
- 'owner_id': owner,
- 'public_key': attr['sshPublicKey'][0],
- 'fingerprint': attr['keyFingerprint'][0],
- }
+ 'admin': (attr['isAdmin'][0] == 'TRUE')}
def __to_project(self, attr):
"""Convert ldap attributes to Project object"""
@@ -495,21 +464,22 @@ class LdapDriver(object):
'name': attr['cn'][0],
'project_manager_id': self.__dn_to_uid(attr['projectManager'][0]),
'description': attr.get('description', [None])[0],
- 'member_ids': [self.__dn_to_uid(x) for x in member_dns]
- }
+ 'member_ids': [self.__dn_to_uid(x) for x in member_dns]}
- def __dn_to_uid(self, dn):
+ @staticmethod
+ def __dn_to_uid(dn):
"""Convert user dn to uid"""
return dn.split(',')[0].split('=')[1]
- def __uid_to_dn(self, dn):
+ @staticmethod
+ def __uid_to_dn(dn):
"""Convert uid to dn"""
return 'uid=%s,%s' % (dn, FLAGS.ldap_user_subtree)
class FakeLdapDriver(LdapDriver):
"""Fake Ldap Auth driver"""
- def __init__(self):
+
+ def __init__(self): # pylint: disable-msg=W0231
__import__('nova.auth.fakeldap')
self.ldap = sys.modules['nova.auth.fakeldap']
-
diff --git a/nova/auth/manager.py b/nova/auth/manager.py
index 80ee78896..9c499c98d 100644
--- a/nova/auth/manager.py
+++ b/nova/auth/manager.py
@@ -23,17 +23,17 @@ Nova authentication management
import logging
import os
import shutil
-import string
+import string # pylint: disable-msg=W0402
import tempfile
import uuid
import zipfile
from nova import crypto
+from nova import db
from nova import exception
from nova import flags
from nova import utils
from nova.auth import signer
-from nova.network import vpn
FLAGS = flags.FLAGS
@@ -44,7 +44,7 @@ flags.DEFINE_list('allowed_roles',
# NOTE(vish): a user with one of these roles will be a superuser and
# have access to all api commands
flags.DEFINE_list('superuser_roles', ['cloudadmin'],
- 'Roles that ignore rbac checking completely')
+ 'Roles that ignore authorization checking completely')
# NOTE(vish): a user with one of these roles will have it for every
# project, even if he or she is not a member of the project
@@ -69,7 +69,7 @@ flags.DEFINE_string('credential_cert_subject',
'/C=US/ST=California/L=MountainView/O=AnsoLabs/'
'OU=NovaDev/CN=%s-%s',
'Subject for certificate for users')
-flags.DEFINE_string('auth_driver', 'nova.auth.ldapdriver.FakeLdapDriver',
+flags.DEFINE_string('auth_driver', 'nova.auth.dbdriver.DbDriver',
'Driver that auth manager uses')
@@ -128,24 +128,6 @@ class User(AuthBase):
def is_project_manager(self, project):
return AuthManager().is_project_manager(self, project)
- def generate_key_pair(self, name):
- return AuthManager().generate_key_pair(self.id, name)
-
- def create_key_pair(self, name, public_key, fingerprint):
- return AuthManager().create_key_pair(self.id,
- name,
- public_key,
- fingerprint)
-
- def get_key_pair(self, name):
- return AuthManager().get_key_pair(self.id, name)
-
- def delete_key_pair(self, name):
- return AuthManager().delete_key_pair(self.id, name)
-
- def get_key_pairs(self):
- return AuthManager().get_key_pairs(self.id)
-
def __repr__(self):
return "User('%s', '%s', '%s', '%s', %s)" % (self.id,
self.name,
@@ -154,29 +136,6 @@ class User(AuthBase):
self.admin)
-class KeyPair(AuthBase):
- """Represents an ssh key returned from the datastore
-
- Even though this object is named KeyPair, only the public key and
- fingerprint is stored. The user's private key is not saved.
- """
-
- def __init__(self, id, name, owner_id, public_key, fingerprint):
- AuthBase.__init__(self)
- self.id = id
- self.name = name
- self.owner_id = owner_id
- self.public_key = public_key
- self.fingerprint = fingerprint
-
- def __repr__(self):
- return "KeyPair('%s', '%s', '%s', '%s', '%s')" % (self.id,
- self.name,
- self.owner_id,
- self.public_key,
- self.fingerprint)
-
-
class Project(AuthBase):
"""Represents a Project returned from the datastore"""
@@ -194,12 +153,12 @@ class Project(AuthBase):
@property
def vpn_ip(self):
- ip, port = AuthManager().get_project_vpn_data(self)
+ ip, _port = AuthManager().get_project_vpn_data(self)
return ip
@property
def vpn_port(self):
- ip, port = AuthManager().get_project_vpn_data(self)
+ _ip, port = AuthManager().get_project_vpn_data(self)
return port
def has_manager(self, user):
@@ -221,11 +180,9 @@ class Project(AuthBase):
return AuthManager().get_credentials(user, self)
def __repr__(self):
- return "Project('%s', '%s', '%s', '%s', %s)" % (self.id,
- self.name,
- self.project_manager_id,
- self.description,
- self.member_ids)
+ return "Project('%s', '%s', '%s', '%s', %s)" % \
+ (self.id, self.name, self.project_manager_id, self.description,
+ self.member_ids)
class AuthManager(object):
@@ -254,6 +211,7 @@ class AuthManager(object):
__init__ is run every time AuthManager() is called, so we only
reset the driver if it is not set or a new driver is specified.
"""
+ self.network_manager = utils.import_object(FLAGS.network_manager)
if driver or not getattr(self, 'driver', None):
self.driver = utils.import_class(driver or FLAGS.auth_driver)
@@ -297,7 +255,7 @@ class AuthManager(object):
@return: User and project that the request represents.
"""
# TODO(vish): check for valid timestamp
- (access_key, sep, project_id) = access.partition(':')
+ (access_key, _sep, project_id) = access.partition(':')
logging.info('Looking up user: %r', access_key)
user = self.get_user_from_access_key(access_key)
@@ -308,7 +266,7 @@ class AuthManager(object):
# NOTE(vish): if we stop using project name as id we need better
# logic to find a default project for user
- if project_id is '':
+ if project_id == '':
project_id = user.name
project = self.get_project(project_id)
@@ -320,7 +278,8 @@ class AuthManager(object):
raise exception.NotFound('User %s is not a member of project %s' %
(user.id, project.id))
if check_type == 's3':
- expected_signature = signer.Signer(user.secret.encode()).s3_authorization(headers, verb, path)
+ sign = signer.Signer(user.secret.encode())
+ expected_signature = sign.s3_authorization(headers, verb, path)
logging.debug('user.secret: %s', user.secret)
logging.debug('expected_signature: %s', expected_signature)
logging.debug('signature: %s', signature)
@@ -345,7 +304,7 @@ class AuthManager(object):
return "%s:%s" % (user.access, Project.safe_id(project))
def is_superuser(self, user):
- """Checks for superuser status, allowing user to bypass rbac
+ """Checks for superuser status, allowing user to bypass authorization
@type user: User or uid
@param user: User to check.
@@ -465,7 +424,8 @@ class AuthManager(object):
with self.driver() as drv:
drv.remove_role(User.safe_id(user), role, Project.safe_id(project))
- def get_roles(self, project_roles=True):
+ @staticmethod
+ def get_roles(project_roles=True):
"""Get list of allowed roles"""
if project_roles:
return list(set(FLAGS.allowed_roles) - set(FLAGS.global_roles))
@@ -493,8 +453,8 @@ class AuthManager(object):
return []
return [Project(**project_dict) for project_dict in project_list]
- def create_project(self, name, manager_user,
- description=None, member_users=None):
+ def create_project(self, name, manager_user, description=None,
+ member_users=None, context=None):
"""Create a project
@type name: str
@@ -518,12 +478,33 @@ class AuthManager(object):
if member_users:
member_users = [User.safe_id(u) for u in member_users]
with self.driver() as drv:
- project_dict = drv.create_project(name,
- User.safe_id(manager_user),
- description,
- member_users)
+ project_dict = drv.create_project(name,
+ User.safe_id(manager_user),
+ description,
+ member_users)
if project_dict:
- return Project(**project_dict)
+ project = Project(**project_dict)
+ return project
+
+ def modify_project(self, project, manager_user=None, description=None):
+ """Modify a project
+
+ @type name: Project or project_id
+ @param project: The project to modify.
+
+ @type manager_user: User or uid
+ @param manager_user: This user will be the new project manager.
+
+ @type description: str
+ @param project: This will be the new description of the project.
+
+ """
+ if manager_user:
+ manager_user = User.safe_id(manager_user)
+ with self.driver() as drv:
+ drv.modify_project(Project.safe_id(project),
+ manager_user,
+ description)
def add_to_project(self, user, project):
"""Add user to project"""
@@ -549,7 +530,8 @@ class AuthManager(object):
return drv.remove_from_project(User.safe_id(user),
Project.safe_id(project))
- def get_project_vpn_data(self, project):
+ @staticmethod
+ def get_project_vpn_data(project, context=None):
"""Gets vpn ip and port for project
@type project: Project or project_id
@@ -559,15 +541,19 @@ class AuthManager(object):
@return: A tuple containing (ip, port) or None, None if vpn has
not been allocated for user.
"""
- network_data = vpn.NetworkData.lookup(Project.safe_id(project))
- if not network_data:
+
+ network_ref = db.project_get_network(context,
+ Project.safe_id(project))
+
+ if not network_ref['vpn_public_port']:
raise exception.NotFound('project network data has not been set')
- return (network_data.ip, network_data.port)
+ return (network_ref['vpn_public_address'],
+ network_ref['vpn_public_port'])
- def delete_project(self, project):
+ def delete_project(self, project, context=None):
"""Deletes a project"""
with self.driver() as drv:
- return drv.delete_project(Project.safe_id(project))
+ drv.delete_project(Project.safe_id(project))
def get_user(self, uid):
"""Retrieves a user by id"""
@@ -613,75 +599,29 @@ class AuthManager(object):
@rtype: User
@return: The new user.
"""
- if access == None: access = str(uuid.uuid4())
- if secret == None: secret = str(uuid.uuid4())
+ if access == None:
+ access = str(uuid.uuid4())
+ if secret == None:
+ secret = str(uuid.uuid4())
with self.driver() as drv:
user_dict = drv.create_user(name, access, secret, admin)
if user_dict:
return User(**user_dict)
def delete_user(self, user):
- """Deletes a user"""
- with self.driver() as drv:
- drv.delete_user(User.safe_id(user))
-
- def generate_key_pair(self, user, key_name):
- """Generates a key pair for a user
-
- Generates a public and private key, stores the public key using the
- key_name, and returns the private key and fingerprint.
+ """Deletes a user
- @type user: User or uid
- @param user: User for which to create key pair.
-
- @type key_name: str
- @param key_name: Name to use for the generated KeyPair.
-
- @rtype: tuple (private_key, fingerprint)
- @return: A tuple containing the private_key and fingerprint.
- """
- # NOTE(vish): generating key pair is slow so check for legal
- # creation before creating keypair
+ Additionally deletes all users key_pairs"""
uid = User.safe_id(user)
+ db.key_pair_destroy_all_by_user(None, uid)
with self.driver() as drv:
- if not drv.get_user(uid):
- raise exception.NotFound("User %s doesn't exist" % user)
- if drv.get_key_pair(uid, key_name):
- raise exception.Duplicate("The keypair %s already exists"
- % key_name)
- private_key, public_key, fingerprint = crypto.generate_key_pair()
- self.create_key_pair(uid, key_name, public_key, fingerprint)
- return private_key, fingerprint
-
- def create_key_pair(self, user, key_name, public_key, fingerprint):
- """Creates a key pair for user"""
- with self.driver() as drv:
- kp_dict = drv.create_key_pair(User.safe_id(user),
- key_name,
- public_key,
- fingerprint)
- if kp_dict:
- return KeyPair(**kp_dict)
-
- def get_key_pair(self, user, key_name):
- """Retrieves a key pair for user"""
- with self.driver() as drv:
- kp_dict = drv.get_key_pair(User.safe_id(user), key_name)
- if kp_dict:
- return KeyPair(**kp_dict)
+ drv.delete_user(uid)
- def get_key_pairs(self, user):
- """Retrieves all key pairs for user"""
- with self.driver() as drv:
- kp_list = drv.get_key_pairs(User.safe_id(user))
- if not kp_list:
- return []
- return [KeyPair(**kp_dict) for kp_dict in kp_list]
-
- def delete_key_pair(self, user, key_name):
- """Deletes a key pair for user"""
+ def modify_user(self, user, access_key=None, secret_key=None, admin=None):
+ """Modify credentials for a user"""
+ uid = User.safe_id(user)
with self.driver() as drv:
- drv.delete_key_pair(User.safe_id(user), key_name)
+ drv.modify_user(uid, access_key, secret_key, admin)
def get_credentials(self, user, project=None):
"""Get credential zip for user in project"""
@@ -700,15 +640,18 @@ class AuthManager(object):
zippy.writestr(FLAGS.credential_key_file, private_key)
zippy.writestr(FLAGS.credential_cert_file, signed_cert)
- network_data = vpn.NetworkData.lookup(pid)
- if network_data:
- configfile = open(FLAGS.vpn_client_template,"r")
+ try:
+ (vpn_ip, vpn_port) = self.get_project_vpn_data(project)
+ except exception.NotFound:
+ vpn_ip = None
+ if vpn_ip:
+ configfile = open(FLAGS.vpn_client_template, "r")
s = string.Template(configfile.read())
configfile.close()
config = s.substitute(keyfile=FLAGS.credential_key_file,
certfile=FLAGS.credential_cert_file,
- ip=network_data.ip,
- port=network_data.port)
+ ip=vpn_ip,
+ port=vpn_port)
zippy.writestr(FLAGS.credential_vpn_file, config)
else:
logging.warn("No vpn data for project %s" %
@@ -717,10 +660,10 @@ class AuthManager(object):
zippy.writestr(FLAGS.ca_file, crypto.fetch_ca(user.id))
zippy.close()
with open(zf, 'rb') as f:
- buffer = f.read()
+ read_buffer = f.read()
shutil.rmtree(tmpdir)
- return buffer
+ return read_buffer
def get_environment_rc(self, user, project=None):
"""Get credential zip for user in project"""
@@ -731,18 +674,18 @@ class AuthManager(object):
pid = Project.safe_id(project)
return self.__generate_rc(user.access, user.secret, pid)
- def __generate_rc(self, access, secret, pid):
+ @staticmethod
+ def __generate_rc(access, secret, pid):
"""Generate rc file for user"""
rc = open(FLAGS.credentials_template).read()
- rc = rc % { 'access': access,
- 'project': pid,
- 'secret': secret,
- 'ec2': FLAGS.ec2_url,
- 's3': 'http://%s:%s' % (FLAGS.s3_host, FLAGS.s3_port),
- 'nova': FLAGS.ca_file,
- 'cert': FLAGS.credential_cert_file,
- 'key': FLAGS.credential_key_file,
- }
+ rc = rc % {'access': access,
+ 'project': pid,
+ 'secret': secret,
+ 'ec2': FLAGS.ec2_url,
+ 's3': 'http://%s:%s' % (FLAGS.s3_host, FLAGS.s3_port),
+ 'nova': FLAGS.ca_file,
+ 'cert': FLAGS.credential_cert_file,
+ 'key': FLAGS.credential_key_file}
return rc
def _generate_x509_cert(self, uid, pid):
@@ -753,6 +696,7 @@ class AuthManager(object):
signed_cert = crypto.sign_csr(csr, pid)
return (private_key, signed_cert)
- def __cert_subject(self, uid):
+ @staticmethod
+ def __cert_subject(uid):
"""Helper to generate cert subject"""
return FLAGS.credential_cert_subject % (uid, utils.isotime())
diff --git a/nova/auth/rbac.py b/nova/auth/rbac.py
deleted file mode 100644
index 1446e4e27..000000000
--- a/nova/auth/rbac.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2010 United States Government as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-from nova import exception
-from nova.auth import manager
-
-
-def allow(*roles):
- def wrap(f):
- def wrapped_f(self, context, *args, **kwargs):
- if context.user.is_superuser():
- return f(self, context, *args, **kwargs)
- for role in roles:
- if __matches_role(context, role):
- return f(self, context, *args, **kwargs)
- raise exception.NotAuthorized()
- return wrapped_f
- return wrap
-
-
-def deny(*roles):
- def wrap(f):
- def wrapped_f(self, context, *args, **kwargs):
- if context.user.is_superuser():
- return f(self, context, *args, **kwargs)
- for role in roles:
- if __matches_role(context, role):
- raise exception.NotAuthorized()
- return f(self, context, *args, **kwargs)
- return wrapped_f
- return wrap
-
-
-def __matches_role(context, role):
- if role == 'all':
- return True
- if role == 'none':
- return False
- return context.project.has_role(context.user.id, role)
-
diff --git a/nova/auth/signer.py b/nova/auth/signer.py
index 8334806d2..f7d29f534 100644
--- a/nova/auth/signer.py
+++ b/nova/auth/signer.py
@@ -50,15 +50,15 @@ import logging
import urllib
# NOTE(vish): for new boto
-import boto
+import boto
# NOTE(vish): for old boto
-import boto.utils
+import boto.utils
from nova.exception import Error
class Signer(object):
- """ hacked up code from boto/connection.py """
+ """Hacked up code from boto/connection.py"""
def __init__(self, secret_key):
self.hmac = hmac.new(secret_key, digestmod=hashlib.sha1)
@@ -66,22 +66,27 @@ class Signer(object):
self.hmac_256 = hmac.new(secret_key, digestmod=hashlib.sha256)
def s3_authorization(self, headers, verb, path):
+ """Generate S3 authorization string."""
c_string = boto.utils.canonical_string(verb, path, headers)
- hmac = self.hmac.copy()
- hmac.update(c_string)
- b64_hmac = base64.encodestring(hmac.digest()).strip()
+ hmac_copy = self.hmac.copy()
+ hmac_copy.update(c_string)
+ b64_hmac = base64.encodestring(hmac_copy.digest()).strip()
return b64_hmac
def generate(self, params, verb, server_string, path):
+ """Generate auth string according to what SignatureVersion is given."""
if params['SignatureVersion'] == '0':
return self._calc_signature_0(params)
if params['SignatureVersion'] == '1':
return self._calc_signature_1(params)
if params['SignatureVersion'] == '2':
return self._calc_signature_2(params, verb, server_string, path)
- raise Error('Unknown Signature Version: %s' % self.SignatureVersion)
+ raise Error('Unknown Signature Version: %s' %
+ params['SignatureVersion'])
- def _get_utf8_value(self, value):
+ @staticmethod
+ def _get_utf8_value(value):
+ """Get the UTF8-encoded version of a value."""
if not isinstance(value, str) and not isinstance(value, unicode):
value = str(value)
if isinstance(value, unicode):
@@ -90,10 +95,11 @@ class Signer(object):
return value
def _calc_signature_0(self, params):
+ """Generate AWS signature version 0 string."""
s = params['Action'] + params['Timestamp']
self.hmac.update(s)
keys = params.keys()
- keys.sort(cmp = lambda x, y: cmp(x.lower(), y.lower()))
+ keys.sort(cmp=lambda x, y: cmp(x.lower(), y.lower()))
pairs = []
for key in keys:
val = self._get_utf8_value(params[key])
@@ -101,8 +107,9 @@ class Signer(object):
return base64.b64encode(self.hmac.digest())
def _calc_signature_1(self, params):
+ """Generate AWS signature version 1 string."""
keys = params.keys()
- keys.sort(cmp = lambda x, y: cmp(x.lower(), y.lower()))
+ keys.sort(cmp=lambda x, y: cmp(x.lower(), y.lower()))
pairs = []
for key in keys:
self.hmac.update(key)
@@ -112,30 +119,34 @@ class Signer(object):
return base64.b64encode(self.hmac.digest())
def _calc_signature_2(self, params, verb, server_string, path):
+ """Generate AWS signature version 2 string."""
logging.debug('using _calc_signature_2')
string_to_sign = '%s\n%s\n%s\n' % (verb, server_string, path)
if self.hmac_256:
- hmac = self.hmac_256
+ current_hmac = self.hmac_256
params['SignatureMethod'] = 'HmacSHA256'
else:
- hmac = self.hmac
+ current_hmac = self.hmac
params['SignatureMethod'] = 'HmacSHA1'
keys = params.keys()
keys.sort()
pairs = []
for key in keys:
val = self._get_utf8_value(params[key])
- pairs.append(urllib.quote(key, safe='') + '=' + urllib.quote(val, safe='-_~'))
+ val = urllib.quote(val, safe='-_~')
+ pairs.append(urllib.quote(key, safe='') + '=' + val)
qs = '&'.join(pairs)
- logging.debug('query string: %s' % qs)
+ logging.debug('query string: %s', qs)
string_to_sign += qs
- logging.debug('string_to_sign: %s' % string_to_sign)
- hmac.update(string_to_sign)
- b64 = base64.b64encode(hmac.digest())
- logging.debug('len(b64)=%d' % len(b64))
- logging.debug('base64 encoded digest: %s' % b64)
+ logging.debug('string_to_sign: %s', string_to_sign)
+ current_hmac.update(string_to_sign)
+ b64 = base64.b64encode(current_hmac.digest())
+ logging.debug('len(b64)=%d', len(b64))
+ logging.debug('base64 encoded digest: %s', b64)
return b64
if __name__ == '__main__':
- print Signer('foo').generate({"SignatureMethod": 'HmacSHA256', 'SignatureVersion': '2'}, "get", "server", "/foo")
+ print Signer('foo').generate({'SignatureMethod': 'HmacSHA256',
+ 'SignatureVersion': '2'},
+ 'get', 'server', '/foo')
diff --git a/nova/cloudpipe/api.py b/nova/cloudpipe/api.py
deleted file mode 100644
index 56aa89834..000000000
--- a/nova/cloudpipe/api.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2010 United States Government as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""
-Tornado REST API Request Handlers for CloudPipe
-"""
-
-import logging
-import urllib
-
-import tornado.web
-
-from nova import crypto
-from nova.auth import manager
-
-
-_log = logging.getLogger("api")
-_log.setLevel(logging.DEBUG)
-
-
-class CloudPipeRequestHandler(tornado.web.RequestHandler):
- def get(self, path):
- path = self.request.path
- _log.debug( "Cloudpipe path is %s" % path)
- if path.endswith("/getca/"):
- self.send_root_ca()
- self.finish()
-
- def get_project_id_from_ip(self, ip):
- cc = self.application.controllers['Cloud']
- instance = cc.get_instance_by_ip(ip)
- instance['project_id']
-
- def send_root_ca(self):
- _log.debug( "Getting root ca")
- project_id = self.get_project_id_from_ip(self.request.remote_ip)
- self.set_header("Content-Type", "text/plain")
- self.write(crypto.fetch_ca(project_id))
-
- def post(self, *args, **kwargs):
- project_id = self.get_project_id_from_ip(self.request.remote_ip)
- cert = self.get_argument('cert', '')
- self.write(crypto.sign_csr(urllib.unquote(cert), project_id))
- self.finish()
diff --git a/nova/cloudpipe/pipelib.py b/nova/cloudpipe/pipelib.py
index 2867bcb21..706a175d9 100644
--- a/nova/cloudpipe/pipelib.py
+++ b/nova/cloudpipe/pipelib.py
@@ -32,7 +32,9 @@ from nova import exception
from nova import flags
from nova import utils
from nova.auth import manager
-from nova.endpoint import api
+# TODO(eday): Eventually changes these to something not ec2-specific
+from nova.api.ec2 import cloud
+from nova.api.ec2 import context
FLAGS = flags.FLAGS
@@ -42,8 +44,8 @@ flags.DEFINE_string('boot_script_template',
class CloudPipe(object):
- def __init__(self, cloud_controller):
- self.controller = cloud_controller
+ def __init__(self):
+ self.controller = cloud.CloudController()
self.manager = manager.AuthManager()
def launch_vpn_instance(self, project_id):
@@ -58,9 +60,9 @@ class CloudPipe(object):
z.write(FLAGS.boot_script_template,'autorun.sh')
z.close()
- key_name = self.setup_keypair(project.project_manager_id, project_id)
+ key_name = self.setup_key_pair(project.project_manager_id, project_id)
zippy = open(zippath, "r")
- context = api.APIRequestContext(handler=None, user=project.project_manager, project=project)
+ context = context.APIRequestContext(user=project.project_manager, project=project)
reservation = self.controller.run_instances(context,
# run instances expects encoded userdata, it is decoded in the get_metadata_call
@@ -74,7 +76,7 @@ class CloudPipe(object):
security_groups=["vpn-secgroup"])
zippy.close()
- def setup_keypair(self, user_id, project_id):
+ def setup_key_pair(self, user_id, project_id):
key_name = '%s%s' % (project_id, FLAGS.vpn_key_suffix)
try:
private_key, fingerprint = self.manager.generate_key_pair(user_id, key_name)
diff --git a/nova/compute/instance_types.py b/nova/compute/instance_types.py
index 439be3c7d..0102bae54 100644
--- a/nova/compute/instance_types.py
+++ b/nova/compute/instance_types.py
@@ -21,10 +21,10 @@
The built-in instance properties.
"""
-INSTANCE_TYPES = {}
-INSTANCE_TYPES['m1.tiny'] = {'memory_mb': 512, 'vcpus': 1, 'local_gb': 0}
-INSTANCE_TYPES['m1.small'] = {'memory_mb': 1024, 'vcpus': 1, 'local_gb': 10}
-INSTANCE_TYPES['m1.medium'] = {'memory_mb': 2048, 'vcpus': 2, 'local_gb': 10}
-INSTANCE_TYPES['m1.large'] = {'memory_mb': 4096, 'vcpus': 4, 'local_gb': 10}
-INSTANCE_TYPES['m1.xlarge'] = {'memory_mb': 8192, 'vcpus': 4, 'local_gb': 10}
-INSTANCE_TYPES['c1.medium'] = {'memory_mb': 2048, 'vcpus': 4, 'local_gb': 10}
+INSTANCE_TYPES = {
+ 'm1.tiny': dict(memory_mb=512, vcpus=1, local_gb=0, flavorid=1),
+ 'm1.small': dict(memory_mb=1024, vcpus=1, local_gb=10, flavorid=2),
+ 'm1.medium': dict(memory_mb=2048, vcpus=2, local_gb=10, flavorid=3),
+ 'm1.large': dict(memory_mb=4096, vcpus=4, local_gb=10, flavorid=4),
+ 'm1.xlarge': dict(memory_mb=8192, vcpus=4, local_gb=10, flavorid=5),
+ 'c1.medium': dict(memory_mb=2048, vcpus=4, local_gb=10, flavorid=6)}
diff --git a/nova/compute/manager.py b/nova/compute/manager.py
new file mode 100644
index 000000000..94c95038f
--- /dev/null
+++ b/nova/compute/manager.py
@@ -0,0 +1,180 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Handles all code relating to instances (guest vms)
+"""
+
+import base64
+import datetime
+import logging
+import os
+
+from twisted.internet import defer
+
+from nova import exception
+from nova import flags
+from nova import manager
+from nova import utils
+from nova.compute import power_state
+
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string('instances_path', utils.abspath('../instances'),
+ 'where instances are stored on disk')
+flags.DEFINE_string('compute_driver', 'nova.virt.connection.get_connection',
+ 'Driver to use for volume creation')
+
+
+class ComputeManager(manager.Manager):
+ """
+ Manages the running instances.
+ """
+ def __init__(self, compute_driver=None, *args, **kwargs):
+ """Load configuration options and connect to the hypervisor."""
+ # TODO(vish): sync driver creation logic with the rest of the system
+ if not compute_driver:
+ compute_driver = FLAGS.compute_driver
+ self.driver = utils.import_object(compute_driver)
+ self.network_manager = utils.import_object(FLAGS.network_manager)
+ self.volume_manager = utils.import_object(FLAGS.volume_manager)
+ super(ComputeManager, self).__init__(*args, **kwargs)
+
+ def _update_state(self, context, instance_id):
+ """Update the state of an instance from the driver info"""
+ # FIXME(ja): include other fields from state?
+ instance_ref = self.db.instance_get(context, instance_id)
+ state = self.driver.get_info(instance_ref.name)['state']
+ self.db.instance_set_state(context, instance_id, state)
+
+ @defer.inlineCallbacks
+ @exception.wrap_exception
+ def refresh_security_group(self, context, security_group_id, **_kwargs):
+ yield self.driver.refresh_security_group(security_group_id)
+
+ @defer.inlineCallbacks
+ @exception.wrap_exception
+ def run_instance(self, context, instance_id, **_kwargs):
+ """Launch a new instance with specified options."""
+ instance_ref = self.db.instance_get(context, instance_id)
+ if instance_ref['name'] in self.driver.list_instances():
+ raise exception.Error("Instance has already been created")
+ logging.debug("instance %s: starting...", instance_id)
+ project_id = instance_ref['project_id']
+ self.network_manager.setup_compute_network(context, instance_id)
+ self.db.instance_update(context,
+ instance_id,
+ {'host': self.host})
+
+ # TODO(vish) check to make sure the availability zone matches
+ self.db.instance_set_state(context,
+ instance_id,
+ power_state.NOSTATE,
+ 'spawning')
+
+ try:
+ yield self.driver.spawn(instance_ref)
+ now = datetime.datetime.utcnow()
+ self.db.instance_update(context,
+ instance_id,
+ {'launched_at': now})
+ except Exception: # pylint: disable-msg=W0702
+ logging.exception("instance %s: Failed to spawn",
+ instance_ref['name'])
+ self.db.instance_set_state(context,
+ instance_id,
+ power_state.SHUTDOWN)
+
+ self._update_state(context, instance_id)
+
+ @defer.inlineCallbacks
+ @exception.wrap_exception
+ def terminate_instance(self, context, instance_id):
+ """Terminate an instance on this machine."""
+ logging.debug("instance %s: terminating", instance_id)
+
+ instance_ref = self.db.instance_get(context, instance_id)
+ if instance_ref['state'] == power_state.SHUTOFF:
+ self.db.instance_destroy(context, instance_id)
+ raise exception.Error('trying to destroy already destroyed'
+ ' instance: %s' % instance_id)
+
+ yield self.driver.destroy(instance_ref)
+
+ # TODO(ja): should we keep it in a terminated state for a bit?
+ self.db.instance_destroy(context, instance_id)
+
+ @defer.inlineCallbacks
+ @exception.wrap_exception
+ def reboot_instance(self, context, instance_id):
+ """Reboot an instance on this server."""
+ self._update_state(context, instance_id)
+ instance_ref = self.db.instance_get(context, instance_id)
+
+ if instance_ref['state'] != power_state.RUNNING:
+ raise exception.Error(
+ 'trying to reboot a non-running'
+ 'instance: %s (state: %s excepted: %s)' %
+ (instance_ref['internal_id'],
+ instance_ref['state'],
+ power_state.RUNNING))
+
+ logging.debug('instance %s: rebooting', instance_ref['name'])
+ self.db.instance_set_state(context,
+ instance_id,
+ power_state.NOSTATE,
+ 'rebooting')
+ yield self.driver.reboot(instance_ref)
+ self._update_state(context, instance_id)
+
+ @exception.wrap_exception
+ def get_console_output(self, context, instance_id):
+ """Send the console output for an instance."""
+ logging.debug("instance %s: getting console output", instance_id)
+ instance_ref = self.db.instance_get(context, instance_id)
+
+ return self.driver.get_console_output(instance_ref)
+
+ @defer.inlineCallbacks
+ @exception.wrap_exception
+ def attach_volume(self, context, instance_id, volume_id, mountpoint):
+ """Attach a volume to an instance."""
+ logging.debug("instance %s: attaching volume %s to %s", instance_id,
+ volume_id, mountpoint)
+ instance_ref = self.db.instance_get(context, instance_id)
+ dev_path = yield self.volume_manager.setup_compute_volume(context,
+ volume_id)
+ yield self.driver.attach_volume(instance_ref['ec2_id'],
+ dev_path,
+ mountpoint)
+ self.db.volume_attached(context, volume_id, instance_id, mountpoint)
+ defer.returnValue(True)
+
+ @defer.inlineCallbacks
+ @exception.wrap_exception
+ def detach_volume(self, context, instance_id, volume_id):
+ """Detach a volume from an instance."""
+ logging.debug("instance %s: detaching volume %s",
+ instance_id,
+ volume_id)
+ instance_ref = self.db.instance_get(context, instance_id)
+ volume_ref = self.db.volume_get(context, volume_id)
+ yield self.driver.detach_volume(instance_ref['ec2_id'],
+ volume_ref['mountpoint'])
+ self.db.volume_detached(context, volume_id)
+ defer.returnValue(True)
diff --git a/nova/compute/model.py b/nova/compute/model.py
deleted file mode 100644
index 84432b55f..000000000
--- a/nova/compute/model.py
+++ /dev/null
@@ -1,314 +0,0 @@
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2010 United States Government as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""
-Datastore Model objects for Compute Instances, with
-InstanceDirectory manager.
-
-# Create a new instance?
->>> InstDir = InstanceDirectory()
->>> inst = InstDir.new()
->>> inst.destroy()
-True
->>> inst = InstDir['i-123']
->>> inst['ip'] = "192.168.0.3"
->>> inst['project_id'] = "projectA"
->>> inst.save()
-True
-
->>> InstDir['i-123']
-<Instance:i-123>
->>> InstDir.all.next()
-<Instance:i-123>
-
->>> inst.destroy()
-True
-"""
-
-import datetime
-import uuid
-
-from nova import datastore
-from nova import exception
-from nova import flags
-from nova import utils
-
-
-FLAGS = flags.FLAGS
-
-
-# TODO(todd): Implement this at the class level for Instance
-class InstanceDirectory(object):
- """an api for interacting with the global state of instances"""
-
- def get(self, instance_id):
- """returns an instance object for a given id"""
- return Instance(instance_id)
-
- def __getitem__(self, item):
- return self.get(item)
-
- @datastore.absorb_connection_error
- def by_project(self, project):
- """returns a list of instance objects for a project"""
- for instance_id in datastore.Redis.instance().smembers('project:%s:instances' % project):
- yield Instance(instance_id)
-
- @datastore.absorb_connection_error
- def by_node(self, node):
- """returns a list of instances for a node"""
- for instance_id in datastore.Redis.instance().smembers('node:%s:instances' % node):
- yield Instance(instance_id)
-
- def by_ip(self, ip):
- """returns an instance object that is using the IP"""
- # NOTE(vish): The ip association should be just a single value, but
- # to maintain consistency it is using the standard
- # association and the ugly method for retrieving
- # the first item in the set below.
- result = datastore.Redis.instance().smembers('ip:%s:instances' % ip)
- if not result:
- return None
- return Instance(list(result)[0])
-
- def by_volume(self, volume_id):
- """returns the instance a volume is attached to"""
- pass
-
- @datastore.absorb_connection_error
- def exists(self, instance_id):
- return datastore.Redis.instance().sismember('instances', instance_id)
-
- @property
- @datastore.absorb_connection_error
- def all(self):
- """returns a list of all instances"""
- for instance_id in datastore.Redis.instance().smembers('instances'):
- yield Instance(instance_id)
-
- def new(self):
- """returns an empty Instance object, with ID"""
- instance_id = utils.generate_uid('i')
- return self.get(instance_id)
-
-
-class Instance(datastore.BasicModel):
- """Wrapper around stored properties of an instance"""
-
- def __init__(self, instance_id):
- """loads an instance from the datastore if exists"""
- # set instance data before super call since it uses default_state
- self.instance_id = instance_id
- super(Instance, self).__init__()
-
- def default_state(self):
- return {'state': 0,
- 'state_description': 'pending',
- 'instance_id': self.instance_id,
- 'node_name': 'unassigned',
- 'project_id': 'unassigned',
- 'user_id': 'unassigned',
- 'private_dns_name': 'unassigned'}
-
- @property
- def identifier(self):
- return self.instance_id
-
- @property
- def project(self):
- if self.state.get('project_id', None):
- return self.state['project_id']
- return self.state.get('owner_id', 'unassigned')
-
- @property
- def volumes(self):
- """returns a list of attached volumes"""
- pass
-
- @property
- def reservation(self):
- """Returns a reservation object"""
- pass
-
- def save(self):
- """Call into superclass to save object, then save associations"""
- # NOTE(todd): doesn't track migration between projects/nodes,
- # it just adds the first one
- is_new = self.is_new_record()
- node_set = (self.state['node_name'] != 'unassigned' and
- self.initial_state.get('node_name', 'unassigned')
- == 'unassigned')
- success = super(Instance, self).save()
- if success and is_new:
- self.associate_with("project", self.project)
- self.associate_with("ip", self.state['private_dns_name'])
- if success and node_set:
- self.associate_with("node", self.state['node_name'])
- return True
-
- def destroy(self):
- """Destroy associations, then destroy the object"""
- self.unassociate_with("project", self.project)
- self.unassociate_with("node", self.state['node_name'])
- self.unassociate_with("ip", self.state['private_dns_name'])
- return super(Instance, self).destroy()
-
-
-class Host(datastore.BasicModel):
- """A Host is the machine where a Daemon is running."""
-
- def __init__(self, hostname):
- """loads an instance from the datastore if exists"""
- # set instance data before super call since it uses default_state
- self.hostname = hostname
- super(Host, self).__init__()
-
- def default_state(self):
- return {"hostname": self.hostname}
-
- @property
- def identifier(self):
- return self.hostname
-
-
-class Daemon(datastore.BasicModel):
- """A Daemon is a job (compute, api, network, ...) that runs on a host."""
-
- def __init__(self, host_or_combined, binpath=None):
- """loads an instance from the datastore if exists"""
- # set instance data before super call since it uses default_state
- # since loading from datastore expects a combined key that
- # is equivilent to identifier, we need to expect that, while
- # maintaining meaningful semantics (2 arguments) when creating
- # from within other code like the bin/nova-* scripts
- if binpath:
- self.hostname = host_or_combined
- self.binary = binpath
- else:
- self.hostname, self.binary = host_or_combined.split(":")
- super(Daemon, self).__init__()
-
- def default_state(self):
- return {"hostname": self.hostname,
- "binary": self.binary,
- "updated_at": utils.isotime()
- }
-
- @property
- def identifier(self):
- return "%s:%s" % (self.hostname, self.binary)
-
- def save(self):
- """Call into superclass to save object, then save associations"""
- # NOTE(todd): this makes no attempt to destroy itsself,
- # so after termination a record w/ old timestmap remains
- success = super(Daemon, self).save()
- if success:
- self.associate_with("host", self.hostname)
- return True
-
- def destroy(self):
- """Destroy associations, then destroy the object"""
- self.unassociate_with("host", self.hostname)
- return super(Daemon, self).destroy()
-
- def heartbeat(self):
- self['updated_at'] = utils.isotime()
- return self.save()
-
- @classmethod
- def by_host(cls, hostname):
- for x in cls.associated_to("host", hostname):
- yield x
-
-
-class SessionToken(datastore.BasicModel):
- """This is a short-lived auth token that is passed through web requests"""
-
- def __init__(self, session_token):
- self.token = session_token
- self.default_ttl = FLAGS.auth_token_ttl
- super(SessionToken, self).__init__()
-
- @property
- def identifier(self):
- return self.token
-
- def default_state(self):
- now = datetime.datetime.utcnow()
- diff = datetime.timedelta(seconds=self.default_ttl)
- expires = now + diff
- return {'user': None, 'session_type': None, 'token': self.token,
- 'expiry': expires.strftime(utils.TIME_FORMAT)}
-
- def save(self):
- """Call into superclass to save object, then save associations"""
- if not self['user']:
- raise exception.Invalid("SessionToken requires a User association")
- success = super(SessionToken, self).save()
- if success:
- self.associate_with("user", self['user'])
- return True
-
- @classmethod
- def lookup(cls, key):
- token = super(SessionToken, cls).lookup(key)
- if token:
- expires_at = utils.parse_isotime(token['expiry'])
- if datetime.datetime.utcnow() >= expires_at:
- token.destroy()
- return None
- return token
-
- @classmethod
- def generate(cls, userid, session_type=None):
- """make a new token for the given user"""
- token = str(uuid.uuid4())
- while cls.lookup(token):
- token = str(uuid.uuid4())
- instance = cls(token)
- instance['user'] = userid
- instance['session_type'] = session_type
- instance.save()
- return instance
-
- def update_expiry(self, **kwargs):
- """updates the expirty attribute, but doesn't save"""
- if not kwargs:
- kwargs['seconds'] = self.default_ttl
- time = datetime.datetime.utcnow()
- diff = datetime.timedelta(**kwargs)
- expires = time + diff
- self['expiry'] = expires.strftime(utils.TIME_FORMAT)
-
- def is_expired(self):
- now = datetime.datetime.utcnow()
- expires = utils.parse_isotime(self['expiry'])
- return expires <= now
-
- def ttl(self):
- """number of seconds remaining before expiration"""
- now = datetime.datetime.utcnow()
- expires = utils.parse_isotime(self['expiry'])
- delta = expires - now
- return (delta.seconds + (delta.days * 24 * 3600))
-
-
-if __name__ == "__main__":
- import doctest
- doctest.testmod()
diff --git a/nova/compute/service.py b/nova/compute/service.py
deleted file mode 100644
index e59f3fb34..000000000
--- a/nova/compute/service.py
+++ /dev/null
@@ -1,367 +0,0 @@
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2010 United States Government as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""
-Compute Service:
-
- Runs on each compute host, managing the
- hypervisor using the virt module.
-
-"""
-
-import base64
-import json
-import logging
-import os
-import sys
-
-from twisted.internet import defer
-from twisted.internet import task
-
-from nova import exception
-from nova import flags
-from nova import process
-from nova import service
-from nova import utils
-from nova.compute import disk
-from nova.compute import model
-from nova.compute import power_state
-from nova.compute.instance_types import INSTANCE_TYPES
-from nova.network import service as network_service
-from nova.objectstore import image # for image_path flag
-from nova.virt import connection as virt_connection
-from nova.volume import service as volume_service
-
-
-FLAGS = flags.FLAGS
-flags.DEFINE_string('instances_path', utils.abspath('../instances'),
- 'where instances are stored on disk')
-
-
-class ComputeService(service.Service):
- """
- Manages the running instances.
- """
- def __init__(self):
- """ load configuration options for this node and connect to the hypervisor"""
- super(ComputeService, self).__init__()
- self._instances = {}
- self._conn = virt_connection.get_connection()
- self.instdir = model.InstanceDirectory()
- # TODO(joshua): This needs to ensure system state, specifically: modprobe aoe
-
- def noop(self):
- """ simple test of an AMQP message call """
- return defer.succeed('PONG')
-
- def get_instance(self, instance_id):
- # inst = self.instdir.get(instance_id)
- # return inst
- if self.instdir.exists(instance_id):
- return Instance.fromName(self._conn, instance_id)
- return None
-
- @exception.wrap_exception
- def adopt_instances(self):
- """ if there are instances already running, adopt them """
- return defer.succeed(0)
- instance_names = self._conn.list_instances()
- for name in instance_names:
- try:
- new_inst = Instance.fromName(self._conn, name)
- new_inst.update_state()
- except:
- pass
- return defer.succeed(len(self._instances))
-
- @exception.wrap_exception
- def describe_instances(self):
- retval = {}
- for inst in self.instdir.by_node(FLAGS.node_name):
- retval[inst['instance_id']] = (
- Instance.fromName(self._conn, inst['instance_id']))
- return retval
-
- @defer.inlineCallbacks
- def report_state(self, nodename, daemon):
- # TODO(termie): make this pattern be more elegant. -todd
- try:
- record = model.Daemon(nodename, daemon)
- record.heartbeat()
- if getattr(self, "model_disconnected", False):
- self.model_disconnected = False
- logging.error("Recovered model server connection!")
-
- except model.ConnectionError, ex:
- if not getattr(self, "model_disconnected", False):
- self.model_disconnected = True
- logging.exception("model server went away")
- yield
-
- @exception.wrap_exception
- def run_instance(self, instance_id, **_kwargs):
- """ launch a new instance with specified options """
- logging.debug("Starting instance %s..." % (instance_id))
- inst = self.instdir.get(instance_id)
- # TODO: Get the real security group of launch in here
- security_group = "default"
- # NOTE(vish): passing network type allows us to express the
- # network without making a call to network to find
- # out which type of network to setup
- network_service.setup_compute_network(
- inst.get('network_type', 'vlan'),
- inst['user_id'],
- inst['project_id'],
- security_group)
-
- inst['node_name'] = FLAGS.node_name
- inst.save()
- # TODO(vish) check to make sure the availability zone matches
- new_inst = Instance(self._conn, name=instance_id, data=inst)
- logging.info("Instances current state is %s", new_inst.state)
- if new_inst.is_running():
- raise exception.Error("Instance is already running")
- new_inst.spawn()
-
- @exception.wrap_exception
- def terminate_instance(self, instance_id):
- """ terminate an instance on this machine """
- logging.debug("Got told to terminate instance %s" % instance_id)
- instance = self.get_instance(instance_id)
- # inst = self.instdir.get(instance_id)
- if not instance:
- raise exception.Error(
- 'trying to terminate unknown instance: %s' % instance_id)
- d = instance.destroy()
- # d.addCallback(lambda x: inst.destroy())
- return d
-
- @exception.wrap_exception
- def reboot_instance(self, instance_id):
- """ reboot an instance on this server
- KVM doesn't support reboot, so we terminate and restart """
- instance = self.get_instance(instance_id)
- if not instance:
- raise exception.Error(
- 'trying to reboot unknown instance: %s' % instance_id)
- return instance.reboot()
-
- @defer.inlineCallbacks
- @exception.wrap_exception
- def get_console_output(self, instance_id):
- """ send the console output for an instance """
- logging.debug("Getting console output for %s" % (instance_id))
- inst = self.instdir.get(instance_id)
- instance = self.get_instance(instance_id)
- if not instance:
- raise exception.Error(
- 'trying to get console log for unknown: %s' % instance_id)
- rv = yield instance.console_output()
- # TODO(termie): this stuff belongs in the API layer, no need to
- # munge the data we send to ourselves
- output = {"InstanceId" : instance_id,
- "Timestamp" : "2",
- "output" : base64.b64encode(rv)}
- defer.returnValue(output)
-
- @defer.inlineCallbacks
- @exception.wrap_exception
- def attach_volume(self, instance_id = None,
- volume_id = None, mountpoint = None):
- volume = volume_service.get_volume(volume_id)
- yield self._init_aoe()
- yield process.simple_execute(
- "sudo virsh attach-disk %s /dev/etherd/%s %s" %
- (instance_id,
- volume['aoe_device'],
- mountpoint.rpartition('/dev/')[2]))
- volume.finish_attach()
- defer.returnValue(True)
-
- @defer.inlineCallbacks
- def _init_aoe(self):
- yield process.simple_execute("sudo aoe-discover")
- yield process.simple_execute("sudo aoe-stat")
-
- @defer.inlineCallbacks
- @exception.wrap_exception
- def detach_volume(self, instance_id, volume_id):
- """ detach a volume from an instance """
- # despite the documentation, virsh detach-disk just wants the device
- # name without the leading /dev/
- volume = volume_service.get_volume(volume_id)
- target = volume['mountpoint'].rpartition('/dev/')[2]
- yield process.simple_execute(
- "sudo virsh detach-disk %s %s " % (instance_id, target))
- volume.finish_detach()
- defer.returnValue(True)
-
-
-class Group(object):
- def __init__(self, group_id):
- self.group_id = group_id
-
-
-class ProductCode(object):
- def __init__(self, product_code):
- self.product_code = product_code
-
-
-class Instance(object):
-
- NOSTATE = 0x00
- RUNNING = 0x01
- BLOCKED = 0x02
- PAUSED = 0x03
- SHUTDOWN = 0x04
- SHUTOFF = 0x05
- CRASHED = 0x06
-
- def __init__(self, conn, name, data):
- """ spawn an instance with a given name """
- self._conn = conn
- # TODO(vish): this can be removed after data has been updated
- # data doesn't seem to have a working iterator so in doesn't work
- if data.get('owner_id', None) is not None:
- data['user_id'] = data['owner_id']
- data['project_id'] = data['owner_id']
- self.datamodel = data
-
- size = data.get('instance_type', FLAGS.default_instance_type)
- if size not in INSTANCE_TYPES:
- raise exception.Error('invalid instance type: %s' % size)
-
- self.datamodel.update(INSTANCE_TYPES[size])
-
- self.datamodel['name'] = name
- self.datamodel['instance_id'] = name
- self.datamodel['basepath'] = data.get(
- 'basepath', os.path.abspath(
- os.path.join(FLAGS.instances_path, self.name)))
- self.datamodel['memory_kb'] = int(self.datamodel['memory_mb']) * 1024
- self.datamodel.setdefault('image_id', FLAGS.default_image)
- self.datamodel.setdefault('kernel_id', FLAGS.default_kernel)
- self.datamodel.setdefault('ramdisk_id', FLAGS.default_ramdisk)
- self.datamodel.setdefault('project_id', self.datamodel['user_id'])
- self.datamodel.setdefault('bridge_name', None)
- #self.datamodel.setdefault('key_data', None)
- #self.datamodel.setdefault('key_name', None)
- #self.datamodel.setdefault('addressing_type', None)
-
- # TODO(joshua) - The ugly non-flat ones
- self.datamodel['groups'] = data.get('security_group', 'default')
- # TODO(joshua): Support product codes somehow
- self.datamodel.setdefault('product_codes', None)
-
- self.datamodel.save()
- logging.debug("Finished init of Instance with id of %s" % name)
-
- @classmethod
- def fromName(cls, conn, name):
- """ use the saved data for reloading the instance """
- instdir = model.InstanceDirectory()
- instance = instdir.get(name)
- return cls(conn=conn, name=name, data=instance)
-
- def set_state(self, state_code, state_description=None):
- self.datamodel['state'] = state_code
- if not state_description:
- state_description = power_state.name(state_code)
- self.datamodel['state_description'] = state_description
- self.datamodel.save()
-
- @property
- def state(self):
- # it is a string in datamodel
- return int(self.datamodel['state'])
-
- @property
- def name(self):
- return self.datamodel['name']
-
- def is_pending(self):
- return (self.state == power_state.NOSTATE or self.state == 'pending')
-
- def is_destroyed(self):
- return self.state == power_state.SHUTOFF
-
- def is_running(self):
- logging.debug("Instance state is: %s" % self.state)
- return (self.state == power_state.RUNNING or self.state == 'running')
-
- def describe(self):
- return self.datamodel
-
- def info(self):
- result = self._conn.get_info(self.name)
- result['node_name'] = FLAGS.node_name
- return result
-
- def update_state(self):
- self.datamodel.update(self.info())
- self.set_state(self.state)
- self.datamodel.save() # Extra, but harmless
-
- @defer.inlineCallbacks
- @exception.wrap_exception
- def destroy(self):
- if self.is_destroyed():
- self.datamodel.destroy()
- raise exception.Error('trying to destroy already destroyed'
- ' instance: %s' % self.name)
-
- self.set_state(power_state.NOSTATE, 'shutting_down')
- yield self._conn.destroy(self)
- self.datamodel.destroy()
-
- @defer.inlineCallbacks
- @exception.wrap_exception
- def reboot(self):
- if not self.is_running():
- raise exception.Error(
- 'trying to reboot a non-running'
- 'instance: %s (state: %s)' % (self.name, self.state))
-
- logging.debug('rebooting instance %s' % self.name)
- self.set_state(power_state.NOSTATE, 'rebooting')
- yield self._conn.reboot(self)
- self.update_state()
-
- @defer.inlineCallbacks
- @exception.wrap_exception
- def spawn(self):
- self.set_state(power_state.NOSTATE, 'spawning')
- logging.debug("Starting spawn in Instance")
- try:
- yield self._conn.spawn(self)
- except Exception, ex:
- logging.debug(ex)
- self.set_state(power_state.SHUTDOWN)
- self.update_state()
-
- @exception.wrap_exception
- def console_output(self):
- # FIXME: Abstract this for Xen
- if FLAGS.connection_type == 'libvirt':
- fname = os.path.abspath(
- os.path.join(self.datamodel['basepath'], 'console.log'))
- with open(fname, 'r') as f:
- console = f.read()
- else:
- console = 'FAKE CONSOLE OUTPUT'
- return defer.succeed(console)
diff --git a/nova/crypto.py b/nova/crypto.py
index b05548ea1..1c6fe57ad 100644
--- a/nova/crypto.py
+++ b/nova/crypto.py
@@ -18,7 +18,7 @@
"""
Wrappers around standard crypto, including root and intermediate CAs,
-SSH keypairs and x509 certificates.
+SSH key_pairs and x509 certificates.
"""
import base64
diff --git a/nova/datastore.py b/nova/datastore.py
index 5dc6ed107..8e2519429 100644
--- a/nova/datastore.py
+++ b/nova/datastore.py
@@ -26,10 +26,7 @@ before trying to run this.
import logging
import redis
-from nova import exception
from nova import flags
-from nova import utils
-
FLAGS = flags.FLAGS
flags.DEFINE_string('redis_host', '127.0.0.1',
@@ -54,209 +51,3 @@ class Redis(object):
return cls._instance
-class ConnectionError(exception.Error):
- pass
-
-
-def absorb_connection_error(fn):
- def _wrapper(*args, **kwargs):
- try:
- return fn(*args, **kwargs)
- except redis.exceptions.ConnectionError, ce:
- raise ConnectionError(str(ce))
- return _wrapper
-
-
-class BasicModel(object):
- """
- All Redis-backed data derives from this class.
-
- You MUST specify an identifier() property that returns a unique string
- per instance.
-
- You MUST have an initializer that takes a single argument that is a value
- returned by identifier() to load a new class with.
-
- You may want to specify a dictionary for default_state().
-
- You may also specify override_type at the class left to use a key other
- than __class__.__name__.
-
- You override save and destroy calls to automatically build and destroy
- associations.
- """
-
- override_type = None
-
- @absorb_connection_error
- def __init__(self):
- state = Redis.instance().hgetall(self.__redis_key)
- if state:
- self.initial_state = state
- self.state = dict(self.initial_state)
- else:
- self.initial_state = {}
- self.state = self.default_state()
-
-
- def default_state(self):
- """You probably want to define this in your subclass"""
- return {}
-
- @classmethod
- def _redis_name(cls):
- return cls.override_type or cls.__name__.lower()
-
- @classmethod
- def lookup(cls, identifier):
- rv = cls(identifier)
- if rv.is_new_record():
- return None
- else:
- return rv
-
- @classmethod
- @absorb_connection_error
- def all(cls):
- """yields all objects in the store"""
- redis_set = cls._redis_set_name(cls.__name__)
- for identifier in Redis.instance().smembers(redis_set):
- yield cls(identifier)
-
- @classmethod
- def associated_to(cls, foreign_type, foreign_id):
- for identifier in cls.associated_keys(foreign_type, foreign_id):
- yield cls(identifier)
-
- @classmethod
- @absorb_connection_error
- def associated_keys(cls, foreign_type, foreign_id):
- redis_set = cls._redis_association_name(foreign_type, foreign_id)
- return Redis.instance().smembers(redis_set) or []
-
- @classmethod
- def _redis_set_name(cls, kls_name):
- # stupidly pluralize (for compatiblity with previous codebase)
- return kls_name.lower() + "s"
-
- @classmethod
- def _redis_association_name(cls, foreign_type, foreign_id):
- return cls._redis_set_name("%s:%s:%s" %
- (foreign_type, foreign_id, cls._redis_name()))
-
- @property
- def identifier(self):
- """You DEFINITELY want to define this in your subclass"""
- raise NotImplementedError("Your subclass should define identifier")
-
- @property
- def __redis_key(self):
- return '%s:%s' % (self._redis_name(), self.identifier)
-
- def __repr__(self):
- return "<%s:%s>" % (self.__class__.__name__, self.identifier)
-
- def keys(self):
- return self.state.keys()
-
- def copy(self):
- copyDict = {}
- for item in self.keys():
- copyDict[item] = self[item]
- return copyDict
-
- def get(self, item, default):
- return self.state.get(item, default)
-
- def update(self, update_dict):
- return self.state.update(update_dict)
-
- def setdefault(self, item, default):
- return self.state.setdefault(item, default)
-
- def __contains__(self, item):
- return item in self.state
-
- def __getitem__(self, item):
- return self.state[item]
-
- def __setitem__(self, item, val):
- self.state[item] = val
- return self.state[item]
-
- def __delitem__(self, item):
- """We don't support this"""
- raise Exception("Silly monkey, models NEED all their properties.")
-
- def is_new_record(self):
- return self.initial_state == {}
-
- @absorb_connection_error
- def add_to_index(self):
- """Each insance of Foo has its id tracked int the set named Foos"""
- set_name = self.__class__._redis_set_name(self.__class__.__name__)
- Redis.instance().sadd(set_name, self.identifier)
-
- @absorb_connection_error
- def remove_from_index(self):
- """Remove id of this instance from the set tracking ids of this type"""
- set_name = self.__class__._redis_set_name(self.__class__.__name__)
- Redis.instance().srem(set_name, self.identifier)
-
- @absorb_connection_error
- def associate_with(self, foreign_type, foreign_id):
- """Add this class id into the set foreign_type:foreign_id:this_types"""
- # note the extra 's' on the end is for plurality
- # to match the old data without requiring a migration of any sort
- self.add_associated_model_to_its_set(foreign_type, foreign_id)
- redis_set = self.__class__._redis_association_name(foreign_type,
- foreign_id)
- Redis.instance().sadd(redis_set, self.identifier)
-
- @absorb_connection_error
- def unassociate_with(self, foreign_type, foreign_id):
- """Delete from foreign_type:foreign_id:this_types set"""
- redis_set = self.__class__._redis_association_name(foreign_type,
- foreign_id)
- Redis.instance().srem(redis_set, self.identifier)
-
- def add_associated_model_to_its_set(self, model_type, model_id):
- """
- When associating an X to a Y, save Y for newer timestamp, etc, and to
- make sure to save it if Y is a new record.
- If the model_type isn't found as a usable class, ignore it, this can
- happen when associating to things stored in LDAP (user, project, ...).
- """
- table = globals()
- klsname = model_type.capitalize()
- if table.has_key(klsname):
- model_class = table[klsname]
- model_inst = model_class(model_id)
- model_inst.save()
-
- @absorb_connection_error
- def save(self):
- """
- update the directory with the state from this model
- also add it to the index of items of the same type
- then set the initial_state = state so new changes are tracked
- """
- # TODO(ja): implement hmset in redis-py and use it
- # instead of multiple calls to hset
- if self.is_new_record():
- self["create_time"] = utils.isotime()
- for key, val in self.state.iteritems():
- Redis.instance().hset(self.__redis_key, key, val)
- self.add_to_index()
- self.initial_state = dict(self.state)
- return True
-
- @absorb_connection_error
- def destroy(self):
- """deletes all related records from datastore."""
- logging.info("Destroying datamodel for %s %s",
- self.__class__.__name__, self.identifier)
- Redis.instance().delete(self.__redis_key)
- self.remove_from_index()
- return True
-
diff --git a/nova/api/rackspace/images.py b/nova/db/__init__.py
index 986f11434..054b7ac94 100644
--- a/nova/api/rackspace/images.py
+++ b/nova/db/__init__.py
@@ -1,6 +1,8 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
-# Copyright 2010 OpenStack LLC.
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
@@ -14,5 +16,8 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
+"""
+DB abstraction for Nova
+"""
-class Controller(object): pass
+from nova.db.api import *
diff --git a/nova/db/api.py b/nova/db/api.py
new file mode 100644
index 000000000..6dbf3b809
--- /dev/null
+++ b/nova/db/api.py
@@ -0,0 +1,770 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+Defines interface for DB access
+"""
+
+from nova import exception
+from nova import flags
+from nova import utils
+
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string('db_backend', 'sqlalchemy',
+ 'The backend to use for db')
+
+
+IMPL = utils.LazyPluggable(FLAGS['db_backend'],
+ sqlalchemy='nova.db.sqlalchemy.api')
+
+
+class NoMoreAddresses(exception.Error):
+ """No more available addresses"""
+ pass
+
+
+class NoMoreBlades(exception.Error):
+ """No more available blades"""
+ pass
+
+
+class NoMoreNetworks(exception.Error):
+ """No more available networks"""
+ pass
+
+
+###################
+
+
+def service_destroy(context, instance_id):
+ """Destroy the service or raise if it does not exist."""
+ return IMPL.service_destroy(context, instance_id)
+
+
+def service_get(context, service_id):
+ """Get an service or raise if it does not exist."""
+ return IMPL.service_get(context, service_id)
+
+
+def service_get_all_by_topic(context, topic):
+ """Get all compute services for a given topic """
+ return IMPL.service_get_all_by_topic(context, topic)
+
+
+def service_get_all_compute_sorted(context):
+ """Get all compute services sorted by instance count
+
+ Returns a list of (Service, instance_count) tuples
+ """
+ return IMPL.service_get_all_compute_sorted(context)
+
+
+def service_get_all_network_sorted(context):
+ """Get all network services sorted by network count
+
+ Returns a list of (Service, network_count) tuples
+ """
+ return IMPL.service_get_all_network_sorted(context)
+
+
+def service_get_all_volume_sorted(context):
+ """Get all volume services sorted by volume count
+
+ Returns a list of (Service, volume_count) tuples
+ """
+ return IMPL.service_get_all_volume_sorted(context)
+
+
+def service_get_by_args(context, host, binary):
+ """Get the state of an service by node name and binary."""
+ return IMPL.service_get_by_args(context, host, binary)
+
+
+def service_create(context, values):
+ """Create a service from the values dictionary."""
+ return IMPL.service_create(context, values)
+
+
+def service_update(context, service_id, values):
+ """Set the given properties on an service and update it.
+
+ Raises NotFound if service does not exist.
+
+ """
+ return IMPL.service_update(context, service_id, values)
+
+
+###################
+
+
+def floating_ip_allocate_address(context, host, project_id):
+ """Allocate free floating ip and return the address.
+
+ Raises if one is not available.
+ """
+ return IMPL.floating_ip_allocate_address(context, host, project_id)
+
+
+def floating_ip_create(context, values):
+ """Create a floating ip from the values dictionary."""
+ return IMPL.floating_ip_create(context, values)
+
+
+def floating_ip_count_by_project(context, project_id):
+ """Count floating ips used by project."""
+ return IMPL.floating_ip_count_by_project(context, project_id)
+
+
+def floating_ip_deallocate(context, address):
+ """Deallocate an floating ip by address"""
+ return IMPL.floating_ip_deallocate(context, address)
+
+
+def floating_ip_destroy(context, address):
+ """Destroy the floating_ip or raise if it does not exist."""
+ return IMPL.floating_ip_destroy(context, address)
+
+
+def floating_ip_disassociate(context, address):
+ """Disassociate an floating ip from a fixed ip by address.
+
+ Returns the address of the existing fixed ip.
+ """
+ return IMPL.floating_ip_disassociate(context, address)
+
+
+def floating_ip_fixed_ip_associate(context, floating_address, fixed_address):
+ """Associate an floating ip to a fixed_ip by address."""
+ return IMPL.floating_ip_fixed_ip_associate(context,
+ floating_address,
+ fixed_address)
+
+
+def floating_ip_get_all(context):
+ """Get all floating ips."""
+ return IMPL.floating_ip_get_all(context)
+
+
+def floating_ip_get_all_by_host(context, host):
+ """Get all floating ips by host."""
+ return IMPL.floating_ip_get_all_by_host(context, host)
+
+
+def floating_ip_get_all_by_project(context, project_id):
+ """Get all floating ips by project."""
+ return IMPL.floating_ip_get_all_by_project(context, project_id)
+
+
+def floating_ip_get_by_address(context, address):
+ """Get a floating ip by address or raise if it doesn't exist."""
+ return IMPL.floating_ip_get_by_address(context, address)
+
+
+####################
+
+
+def fixed_ip_associate(context, address, instance_id):
+ """Associate fixed ip to instance.
+
+ Raises if fixed ip is not available.
+ """
+ return IMPL.fixed_ip_associate(context, address, instance_id)
+
+
+def fixed_ip_associate_pool(context, network_id, instance_id):
+ """Find free ip in network and associate it to instance.
+
+ Raises if one is not available.
+ """
+ return IMPL.fixed_ip_associate_pool(context, network_id, instance_id)
+
+
+def fixed_ip_create(context, values):
+ """Create a fixed ip from the values dictionary."""
+ return IMPL.fixed_ip_create(context, values)
+
+
+def fixed_ip_disassociate(context, address):
+ """Disassociate a fixed ip from an instance by address."""
+ return IMPL.fixed_ip_disassociate(context, address)
+
+
+def fixed_ip_disassociate_all_by_timeout(context, host, time):
+ """Disassociate old fixed ips from host"""
+ return IMPL.fixed_ip_disassociate_all_by_timeout(context, host, time)
+
+
+def fixed_ip_get_by_address(context, address):
+ """Get a fixed ip by address or raise if it does not exist."""
+ return IMPL.fixed_ip_get_by_address(context, address)
+
+
+def fixed_ip_get_instance(context, address):
+ """Get an instance for a fixed ip by address."""
+ return IMPL.fixed_ip_get_instance(context, address)
+
+
+def fixed_ip_get_network(context, address):
+ """Get a network for a fixed ip by address."""
+ return IMPL.fixed_ip_get_network(context, address)
+
+
+def fixed_ip_update(context, address, values):
+ """Create a fixed ip from the values dictionary."""
+ return IMPL.fixed_ip_update(context, address, values)
+
+
+####################
+
+
+def instance_create(context, values):
+ """Create an instance from the values dictionary."""
+ return IMPL.instance_create(context, values)
+
+
+def instance_data_get_for_project(context, project_id):
+ """Get (instance_count, core_count) for project."""
+ return IMPL.instance_data_get_for_project(context, project_id)
+
+
+def instance_destroy(context, instance_id):
+ """Destroy the instance or raise if it does not exist."""
+ return IMPL.instance_destroy(context, instance_id)
+
+
+def instance_get(context, instance_id):
+ """Get an instance or raise if it does not exist."""
+ return IMPL.instance_get(context, instance_id)
+
+
+def instance_get_all(context):
+ """Get all instances."""
+ return IMPL.instance_get_all(context)
+
+def instance_get_all_by_user(context, user_id):
+ """Get all instances."""
+ return IMPL.instance_get_all_by_user(context, user_id)
+
+def instance_get_all_by_project(context, project_id):
+ """Get all instance belonging to a project."""
+ return IMPL.instance_get_all_by_project(context, project_id)
+
+
+def instance_get_all_by_reservation(context, reservation_id):
+ """Get all instance belonging to a reservation."""
+ return IMPL.instance_get_all_by_reservation(context, reservation_id)
+
+
+def instance_get_fixed_address(context, instance_id):
+ """Get the fixed ip address of an instance."""
+ return IMPL.instance_get_fixed_address(context, instance_id)
+
+
+def instance_get_floating_address(context, instance_id):
+ """Get the first floating ip address of an instance."""
+ return IMPL.instance_get_floating_address(context, instance_id)
+
+
+def instance_get_by_internal_id(context, internal_id):
+ """Get an instance by ec2 id."""
+ return IMPL.instance_get_by_internal_id(context, internal_id)
+
+
+def instance_is_vpn(context, instance_id):
+ """True if instance is a vpn."""
+ return IMPL.instance_is_vpn(context, instance_id)
+
+
+def instance_set_state(context, instance_id, state, description=None):
+ """Set the state of an instance."""
+ return IMPL.instance_set_state(context, instance_id, state, description)
+
+
+def instance_update(context, instance_id, values):
+ """Set the given properties on an instance and update it.
+
+ Raises NotFound if instance does not exist.
+
+ """
+ return IMPL.instance_update(context, instance_id, values)
+
+
+def instance_add_security_group(context, instance_id, security_group_id):
+ """Associate the given security group with the given instance"""
+ return IMPL.instance_add_security_group(context, instance_id, security_group_id)
+
+
+###################
+
+
+def key_pair_create(context, values):
+ """Create a key_pair from the values dictionary."""
+ return IMPL.key_pair_create(context, values)
+
+
+def key_pair_destroy(context, user_id, name):
+ """Destroy the key_pair or raise if it does not exist."""
+ return IMPL.key_pair_destroy(context, user_id, name)
+
+
+def key_pair_destroy_all_by_user(context, user_id):
+ """Destroy all key_pairs by user."""
+ return IMPL.key_pair_destroy_all_by_user(context, user_id)
+
+
+def key_pair_get(context, user_id, name):
+ """Get a key_pair or raise if it does not exist."""
+ return IMPL.key_pair_get(context, user_id, name)
+
+
+def key_pair_get_all_by_user(context, user_id):
+ """Get all key_pairs by user."""
+ return IMPL.key_pair_get_all_by_user(context, user_id)
+
+
+####################
+
+
+def network_associate(context, project_id):
+ """Associate a free network to a project."""
+ return IMPL.network_associate(context, project_id)
+
+
+def network_count(context):
+ """Return the number of networks."""
+ return IMPL.network_count(context)
+
+
+def network_count_allocated_ips(context, network_id):
+ """Return the number of allocated non-reserved ips in the network."""
+ return IMPL.network_count_allocated_ips(context, network_id)
+
+
+def network_count_available_ips(context, network_id):
+ """Return the number of available ips in the network."""
+ return IMPL.network_count_available_ips(context, network_id)
+
+
+def network_count_reserved_ips(context, network_id):
+ """Return the number of reserved ips in the network."""
+ return IMPL.network_count_reserved_ips(context, network_id)
+
+
+def network_create_safe(context, values):
+ """Create a network from the values dict
+
+ The network is only returned if the create succeeds. If the create violates
+ constraints because the network already exists, no exception is raised."""
+ return IMPL.network_create_safe(context, values)
+
+
+def network_create_fixed_ips(context, network_id, num_vpn_clients):
+ """Create the ips for the network, reserving sepecified ips."""
+ return IMPL.network_create_fixed_ips(context, network_id, num_vpn_clients)
+
+
+def network_disassociate(context, network_id):
+ """Disassociate the network from project or raise if it does not exist."""
+ return IMPL.network_disassociate(context, network_id)
+
+
+def network_disassociate_all(context):
+ """Disassociate all networks from projects."""
+ return IMPL.network_disassociate_all(context)
+
+
+def network_get(context, network_id):
+ """Get an network or raise if it does not exist."""
+ return IMPL.network_get(context, network_id)
+
+
+# pylint: disable-msg=C0103
+def network_get_associated_fixed_ips(context, network_id):
+ """Get all network's ips that have been associated."""
+ return IMPL.network_get_associated_fixed_ips(context, network_id)
+
+
+def network_get_by_bridge(context, bridge):
+ """Get a network by bridge or raise if it does not exist."""
+ return IMPL.network_get_by_bridge(context, bridge)
+
+
+def network_get_by_instance(context, instance_id):
+ """Get a network by instance id or raise if it does not exist."""
+ return IMPL.network_get_by_instance(context, instance_id)
+
+
+def network_get_index(context, network_id):
+ """Get non-conflicting index for network"""
+ return IMPL.network_get_index(context, network_id)
+
+
+def network_get_vpn_ip(context, network_id):
+ """Get non-conflicting index for network"""
+ return IMPL.network_get_vpn_ip(context, network_id)
+
+
+def network_set_cidr(context, network_id, cidr):
+ """Set the Classless Inner Domain Routing for the network"""
+ return IMPL.network_set_cidr(context, network_id, cidr)
+
+
+def network_set_host(context, network_id, host_id):
+ """Safely set the host for network"""
+ return IMPL.network_set_host(context, network_id, host_id)
+
+
+def network_update(context, network_id, values):
+ """Set the given properties on an network and update it.
+
+ Raises NotFound if network does not exist.
+
+ """
+ return IMPL.network_update(context, network_id, values)
+
+
+###################
+
+
+def project_get_network(context, project_id):
+ """Return the network associated with the project.
+
+ Raises NotFound if no such network can be found.
+
+ """
+ return IMPL.project_get_network(context, project_id)
+
+
+###################
+
+
+def queue_get_for(context, topic, physical_node_id):
+ """Return a channel to send a message to a node with a topic."""
+ return IMPL.queue_get_for(context, topic, physical_node_id)
+
+
+###################
+
+
+def export_device_count(context):
+ """Return count of export devices."""
+ return IMPL.export_device_count(context)
+
+
+def export_device_create_safe(context, values):
+ """Create an export_device from the values dictionary.
+
+ The device is not returned. If the create violates the unique
+ constraints because the shelf_id and blade_id already exist,
+ no exception is raised."""
+ return IMPL.export_device_create_safe(context, values)
+
+
+###################
+
+
+def auth_destroy_token(context, token):
+ """Destroy an auth token"""
+ return IMPL.auth_destroy_token(context, token)
+
+def auth_get_token(context, token_hash):
+ """Retrieves a token given the hash representing it"""
+ return IMPL.auth_get_token(context, token_hash)
+
+def auth_create_token(context, token):
+ """Creates a new token"""
+ return IMPL.auth_create_token(context, token)
+
+
+###################
+
+
+def quota_create(context, values):
+ """Create a quota from the values dictionary."""
+ return IMPL.quota_create(context, values)
+
+
+def quota_get(context, project_id):
+ """Retrieve a quota or raise if it does not exist."""
+ return IMPL.quota_get(context, project_id)
+
+
+def quota_update(context, project_id, values):
+ """Update a quota from the values dictionary."""
+ return IMPL.quota_update(context, project_id, values)
+
+
+def quota_destroy(context, project_id):
+ """Destroy the quota or raise if it does not exist."""
+ return IMPL.quota_destroy(context, project_id)
+
+
+###################
+
+
+def volume_allocate_shelf_and_blade(context, volume_id):
+ """Atomically allocate a free shelf and blade from the pool."""
+ return IMPL.volume_allocate_shelf_and_blade(context, volume_id)
+
+
+def volume_attached(context, volume_id, instance_id, mountpoint):
+ """Ensure that a volume is set as attached."""
+ return IMPL.volume_attached(context, volume_id, instance_id, mountpoint)
+
+
+def volume_create(context, values):
+ """Create a volume from the values dictionary."""
+ return IMPL.volume_create(context, values)
+
+
+def volume_data_get_for_project(context, project_id):
+ """Get (volume_count, gigabytes) for project."""
+ return IMPL.volume_data_get_for_project(context, project_id)
+
+
+def volume_destroy(context, volume_id):
+ """Destroy the volume or raise if it does not exist."""
+ return IMPL.volume_destroy(context, volume_id)
+
+
+def volume_detached(context, volume_id):
+ """Ensure that a volume is set as detached."""
+ return IMPL.volume_detached(context, volume_id)
+
+
+def volume_get(context, volume_id):
+ """Get a volume or raise if it does not exist."""
+ return IMPL.volume_get(context, volume_id)
+
+
+def volume_get_all(context):
+ """Get all volumes."""
+ return IMPL.volume_get_all(context)
+
+
+def volume_get_instance(context, volume_id):
+ """Get the instance that a volume is attached to."""
+ return IMPL.volume_get_instance(context, volume_id)
+
+
+def volume_get_all_by_project(context, project_id):
+ """Get all volumes belonging to a project."""
+ return IMPL.volume_get_all_by_project(context, project_id)
+
+
+def volume_get_by_ec2_id(context, ec2_id):
+ """Get a volume by ec2 id."""
+ return IMPL.volume_get_by_ec2_id(context, ec2_id)
+
+
+def volume_get_shelf_and_blade(context, volume_id):
+ """Get the shelf and blade allocated to the volume."""
+ return IMPL.volume_get_shelf_and_blade(context, volume_id)
+
+
+def volume_update(context, volume_id, values):
+ """Set the given properties on an volume and update it.
+
+ Raises NotFound if volume does not exist.
+
+ """
+ return IMPL.volume_update(context, volume_id, values)
+
+
+####################
+
+
+def security_group_get_all(context):
+ """Get all security groups"""
+ return IMPL.security_group_get_all(context)
+
+
+def security_group_get(context, security_group_id):
+ """Get security group by its internal id"""
+ return IMPL.security_group_get(context, security_group_id)
+
+
+def security_group_get_by_name(context, project_id, group_name):
+ """Returns a security group with the specified name from a project"""
+ return IMPL.security_group_get_by_name(context, project_id, group_name)
+
+
+def security_group_get_by_project(context, project_id):
+ """Get all security groups belonging to a project"""
+ return IMPL.security_group_get_by_project(context, project_id)
+
+
+def security_group_get_by_instance(context, instance_id):
+ """Get security groups to which the instance is assigned"""
+ return IMPL.security_group_get_by_instance(context, instance_id)
+
+
+def security_group_exists(context, project_id, group_name):
+ """Indicates if a group name exists in a project"""
+ return IMPL.security_group_exists(context, project_id, group_name)
+
+
+def security_group_create(context, values):
+ """Create a new security group"""
+ return IMPL.security_group_create(context, values)
+
+
+def security_group_destroy(context, security_group_id):
+ """Deletes a security group"""
+ return IMPL.security_group_destroy(context, security_group_id)
+
+
+def security_group_destroy_all(context):
+ """Deletes a security group"""
+ return IMPL.security_group_destroy_all(context)
+
+
+####################
+
+
+def security_group_rule_create(context, values):
+ """Create a new security group"""
+ return IMPL.security_group_rule_create(context, values)
+
+
+def security_group_rule_get_by_security_group(context, security_group_id):
+ """Get all rules for a a given security group"""
+ return IMPL.security_group_rule_get_by_security_group(context, security_group_id)
+
+def security_group_rule_destroy(context, security_group_rule_id):
+ """Deletes a security group rule"""
+ return IMPL.security_group_rule_destroy(context, security_group_rule_id)
+
+
+###################
+
+
+def user_get(context, id):
+ """Get user by id"""
+ return IMPL.user_get(context, id)
+
+
+def user_get_by_uid(context, uid):
+ """Get user by uid"""
+ return IMPL.user_get_by_uid(context, uid)
+
+
+def user_get_by_access_key(context, access_key):
+ """Get user by access key"""
+ return IMPL.user_get_by_access_key(context, access_key)
+
+
+def user_create(context, values):
+ """Create a new user"""
+ return IMPL.user_create(context, values)
+
+
+def user_delete(context, id):
+ """Delete a user"""
+ return IMPL.user_delete(context, id)
+
+
+def user_get_all(context):
+ """Create a new user"""
+ return IMPL.user_get_all(context)
+
+
+def user_add_role(context, user_id, role):
+ """Add another global role for user"""
+ return IMPL.user_add_role(context, user_id, role)
+
+
+def user_remove_role(context, user_id, role):
+ """Remove global role from user"""
+ return IMPL.user_remove_role(context, user_id, role)
+
+
+def user_get_roles(context, user_id):
+ """Get global roles for user"""
+ return IMPL.user_get_roles(context, user_id)
+
+
+def user_add_project_role(context, user_id, project_id, role):
+ """Add project role for user"""
+ return IMPL.user_add_project_role(context, user_id, project_id, role)
+
+
+def user_remove_project_role(context, user_id, project_id, role):
+ """Remove project role from user"""
+ return IMPL.user_remove_project_role(context, user_id, project_id, role)
+
+
+def user_get_roles_for_project(context, user_id, project_id):
+ """Return list of roles a user holds on project"""
+ return IMPL.user_get_roles_for_project(context, user_id, project_id)
+
+
+def user_update(context, user_id, values):
+ """Update user"""
+ return IMPL.user_update(context, user_id, values)
+
+
+def project_get(context, id):
+ """Get project by id"""
+ return IMPL.project_get(context, id)
+
+
+def project_create(context, values):
+ """Create a new project"""
+ return IMPL.project_create(context, values)
+
+
+def project_add_member(context, project_id, user_id):
+ """Add user to project"""
+ return IMPL.project_add_member(context, project_id, user_id)
+
+
+def project_get_all(context):
+ """Get all projects"""
+ return IMPL.project_get_all(context)
+
+
+def project_get_by_user(context, user_id):
+ """Get all projects of which the given user is a member"""
+ return IMPL.project_get_by_user(context, user_id)
+
+
+def project_remove_member(context, project_id, user_id):
+ """Remove the given user from the given project"""
+ return IMPL.project_remove_member(context, project_id, user_id)
+
+
+def project_update(context, project_id, values):
+ """Update Remove the given user from the given project"""
+ return IMPL.project_update(context, project_id, values)
+
+
+def project_delete(context, project_id):
+ """Delete project"""
+ return IMPL.project_delete(context, project_id)
+
+
+###################
+
+
+def host_get_networks(context, host):
+ """Return all networks for which the given host is the designated
+ network host
+ """
+ return IMPL.host_get_networks(context, host)
+
diff --git a/nova/api/rackspace/flavors.py b/nova/db/sqlalchemy/__init__.py
index 986f11434..3288ebd20 100644
--- a/nova/api/rackspace/flavors.py
+++ b/nova/db/sqlalchemy/__init__.py
@@ -1,6 +1,7 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
-# Copyright 2010 OpenStack LLC.
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
@@ -15,4 +16,9 @@
# License for the specific language governing permissions and limitations
# under the License.
-class Controller(object): pass
+"""
+SQLAlchemy database backend
+"""
+from nova.db.sqlalchemy import models
+
+models.register_models()
diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py
new file mode 100644
index 000000000..f4a746cab
--- /dev/null
+++ b/nova/db/sqlalchemy/api.py
@@ -0,0 +1,1680 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+Implementation of SQLAlchemy backend
+"""
+
+import warnings
+
+from nova import db
+from nova import exception
+from nova import flags
+from nova import utils
+from nova.db.sqlalchemy import models
+from nova.db.sqlalchemy.session import get_session
+from sqlalchemy import or_
+from sqlalchemy.exc import IntegrityError
+from sqlalchemy.orm import joinedload
+from sqlalchemy.orm import joinedload_all
+from sqlalchemy.sql import exists, func
+from sqlalchemy.orm.exc import NoResultFound
+
+FLAGS = flags.FLAGS
+
+
+def is_admin_context(context):
+ """Indicates if the request context is an administrator."""
+ if not context:
+ warnings.warn('Use of empty request context is deprecated',
+ DeprecationWarning)
+ return True
+ return context.is_admin
+
+
+def is_user_context(context):
+ """Indicates if the request context is a normal user."""
+ if not context:
+ return False
+ if not context.user or not context.project:
+ return False
+ return True
+
+
+def authorize_project_context(context, project_id):
+ """Ensures that the request context has permission to access the
+ given project.
+ """
+ if is_user_context(context):
+ if not context.project:
+ raise exception.NotAuthorized()
+ elif context.project.id != project_id:
+ raise exception.NotAuthorized()
+
+
+def authorize_user_context(context, user_id):
+ """Ensures that the request context has permission to access the
+ given user.
+ """
+ if is_user_context(context):
+ if not context.user:
+ raise exception.NotAuthorized()
+ elif context.user.id != user_id:
+ raise exception.NotAuthorized()
+
+
+def can_read_deleted(context):
+ """Indicates if the context has access to deleted objects."""
+ if not context:
+ return False
+ return context.read_deleted
+
+
+def require_admin_context(f):
+ """Decorator used to indicate that the method requires an
+ administrator context.
+ """
+ def wrapper(*args, **kwargs):
+ if not is_admin_context(args[0]):
+ raise exception.NotAuthorized()
+ return f(*args, **kwargs)
+ return wrapper
+
+
+def require_context(f):
+ """Decorator used to indicate that the method requires either
+ an administrator or normal user context.
+ """
+ def wrapper(*args, **kwargs):
+ if not is_admin_context(args[0]) and not is_user_context(args[0]):
+ raise exception.NotAuthorized()
+ return f(*args, **kwargs)
+ return wrapper
+
+
+###################
+
+@require_admin_context
+def service_destroy(context, service_id):
+ session = get_session()
+ with session.begin():
+ service_ref = service_get(context, service_id, session=session)
+ service_ref.delete(session=session)
+
+
+@require_admin_context
+def service_get(context, service_id, session=None):
+ if not session:
+ session = get_session()
+
+ result = session.query(models.Service
+ ).filter_by(id=service_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+
+ if not result:
+ raise exception.NotFound('No service for id %s' % service_id)
+
+ return result
+
+
+@require_admin_context
+def service_get_all_by_topic(context, topic):
+ session = get_session()
+ return session.query(models.Service
+ ).filter_by(deleted=False
+ ).filter_by(disabled=False
+ ).filter_by(topic=topic
+ ).all()
+
+
+@require_admin_context
+def _service_get_all_topic_subquery(context, session, topic, subq, label):
+ sort_value = getattr(subq.c, label)
+ return session.query(models.Service, func.coalesce(sort_value, 0)
+ ).filter_by(topic=topic
+ ).filter_by(deleted=False
+ ).filter_by(disabled=False
+ ).outerjoin((subq, models.Service.host == subq.c.host)
+ ).order_by(sort_value
+ ).all()
+
+
+@require_admin_context
+def service_get_all_compute_sorted(context):
+ session = get_session()
+ with session.begin():
+ # NOTE(vish): The intended query is below
+ # SELECT services.*, COALESCE(inst_cores.instance_cores,
+ # 0)
+ # FROM services LEFT OUTER JOIN
+ # (SELECT host, SUM(instances.vcpus) AS instance_cores
+ # FROM instances GROUP BY host) AS inst_cores
+ # ON services.host = inst_cores.host
+ topic = 'compute'
+ label = 'instance_cores'
+ subq = session.query(models.Instance.host,
+ func.sum(models.Instance.vcpus).label(label)
+ ).filter_by(deleted=False
+ ).group_by(models.Instance.host
+ ).subquery()
+ return _service_get_all_topic_subquery(context,
+ session,
+ topic,
+ subq,
+ label)
+
+
+@require_admin_context
+def service_get_all_network_sorted(context):
+ session = get_session()
+ with session.begin():
+ topic = 'network'
+ label = 'network_count'
+ subq = session.query(models.Network.host,
+ func.count(models.Network.id).label(label)
+ ).filter_by(deleted=False
+ ).group_by(models.Network.host
+ ).subquery()
+ return _service_get_all_topic_subquery(context,
+ session,
+ topic,
+ subq,
+ label)
+
+
+@require_admin_context
+def service_get_all_volume_sorted(context):
+ session = get_session()
+ with session.begin():
+ topic = 'volume'
+ label = 'volume_gigabytes'
+ subq = session.query(models.Volume.host,
+ func.sum(models.Volume.size).label(label)
+ ).filter_by(deleted=False
+ ).group_by(models.Volume.host
+ ).subquery()
+ return _service_get_all_topic_subquery(context,
+ session,
+ topic,
+ subq,
+ label)
+
+
+@require_admin_context
+def service_get_by_args(context, host, binary):
+ session = get_session()
+ result = session.query(models.Service
+ ).filter_by(host=host
+ ).filter_by(binary=binary
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ if not result:
+ raise exception.NotFound('No service for %s, %s' % (host, binary))
+
+ return result
+
+
+@require_admin_context
+def service_create(context, values):
+ service_ref = models.Service()
+ for (key, value) in values.iteritems():
+ service_ref[key] = value
+ service_ref.save()
+ return service_ref
+
+
+@require_admin_context
+def service_update(context, service_id, values):
+ session = get_session()
+ with session.begin():
+ service_ref = service_get(context, service_id, session=session)
+ for (key, value) in values.iteritems():
+ service_ref[key] = value
+ service_ref.save(session=session)
+
+
+###################
+
+
+@require_context
+def floating_ip_allocate_address(context, host, project_id):
+ authorize_project_context(context, project_id)
+ session = get_session()
+ with session.begin():
+ floating_ip_ref = session.query(models.FloatingIp
+ ).filter_by(host=host
+ ).filter_by(fixed_ip_id=None
+ ).filter_by(project_id=None
+ ).filter_by(deleted=False
+ ).with_lockmode('update'
+ ).first()
+ # NOTE(vish): if with_lockmode isn't supported, as in sqlite,
+ # then this has concurrency issues
+ if not floating_ip_ref:
+ raise db.NoMoreAddresses()
+ floating_ip_ref['project_id'] = project_id
+ session.add(floating_ip_ref)
+ return floating_ip_ref['address']
+
+
+@require_context
+def floating_ip_create(context, values):
+ floating_ip_ref = models.FloatingIp()
+ for (key, value) in values.iteritems():
+ floating_ip_ref[key] = value
+ floating_ip_ref.save()
+ return floating_ip_ref['address']
+
+
+@require_context
+def floating_ip_count_by_project(context, project_id):
+ authorize_project_context(context, project_id)
+ session = get_session()
+ return session.query(models.FloatingIp
+ ).filter_by(project_id=project_id
+ ).filter_by(deleted=False
+ ).count()
+
+
+@require_context
+def floating_ip_fixed_ip_associate(context, floating_address, fixed_address):
+ session = get_session()
+ with session.begin():
+ # TODO(devcamcar): How to ensure floating_id belongs to user?
+ floating_ip_ref = floating_ip_get_by_address(context,
+ floating_address,
+ session=session)
+ fixed_ip_ref = fixed_ip_get_by_address(context,
+ fixed_address,
+ session=session)
+ floating_ip_ref.fixed_ip = fixed_ip_ref
+ floating_ip_ref.save(session=session)
+
+
+@require_context
+def floating_ip_deallocate(context, address):
+ session = get_session()
+ with session.begin():
+ # TODO(devcamcar): How to ensure floating id belongs to user?
+ floating_ip_ref = floating_ip_get_by_address(context,
+ address,
+ session=session)
+ floating_ip_ref['project_id'] = None
+ floating_ip_ref.save(session=session)
+
+
+@require_context
+def floating_ip_destroy(context, address):
+ session = get_session()
+ with session.begin():
+ # TODO(devcamcar): Ensure address belongs to user.
+ floating_ip_ref = get_floating_ip_by_address(context,
+ address,
+ session=session)
+ floating_ip_ref.delete(session=session)
+
+
+@require_context
+def floating_ip_disassociate(context, address):
+ session = get_session()
+ with session.begin():
+ # TODO(devcamcar): Ensure address belongs to user.
+ # Does get_floating_ip_by_address handle this?
+ floating_ip_ref = floating_ip_get_by_address(context,
+ address,
+ session=session)
+ fixed_ip_ref = floating_ip_ref.fixed_ip
+ if fixed_ip_ref:
+ fixed_ip_address = fixed_ip_ref['address']
+ else:
+ fixed_ip_address = None
+ floating_ip_ref.fixed_ip = None
+ floating_ip_ref.save(session=session)
+ return fixed_ip_address
+
+
+@require_admin_context
+def floating_ip_get_all(context):
+ session = get_session()
+ return session.query(models.FloatingIp
+ ).options(joinedload_all('fixed_ip.instance')
+ ).filter_by(deleted=False
+ ).all()
+
+
+@require_admin_context
+def floating_ip_get_all_by_host(context, host):
+ session = get_session()
+ return session.query(models.FloatingIp
+ ).options(joinedload_all('fixed_ip.instance')
+ ).filter_by(host=host
+ ).filter_by(deleted=False
+ ).all()
+
+
+@require_context
+def floating_ip_get_all_by_project(context, project_id):
+ authorize_project_context(context, project_id)
+ session = get_session()
+ return session.query(models.FloatingIp
+ ).options(joinedload_all('fixed_ip.instance')
+ ).filter_by(project_id=project_id
+ ).filter_by(deleted=False
+ ).all()
+
+
+@require_context
+def floating_ip_get_by_address(context, address, session=None):
+ # TODO(devcamcar): Ensure the address belongs to user.
+ if not session:
+ session = get_session()
+
+ result = session.query(models.FloatingIp
+ ).filter_by(address=address
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ if not result:
+ raise exception.NotFound('No fixed ip for address %s' % address)
+
+ return result
+
+
+###################
+
+
+@require_context
+def fixed_ip_associate(context, address, instance_id):
+ session = get_session()
+ with session.begin():
+ instance = instance_get(context, instance_id, session=session)
+ fixed_ip_ref = session.query(models.FixedIp
+ ).filter_by(address=address
+ ).filter_by(deleted=False
+ ).filter_by(instance=None
+ ).with_lockmode('update'
+ ).first()
+ # NOTE(vish): if with_lockmode isn't supported, as in sqlite,
+ # then this has concurrency issues
+ if not fixed_ip_ref:
+ raise db.NoMoreAddresses()
+ fixed_ip_ref.instance = instance
+ session.add(fixed_ip_ref)
+
+
+@require_admin_context
+def fixed_ip_associate_pool(context, network_id, instance_id):
+ session = get_session()
+ with session.begin():
+ network_or_none = or_(models.FixedIp.network_id == network_id,
+ models.FixedIp.network_id == None)
+ fixed_ip_ref = session.query(models.FixedIp
+ ).filter(network_or_none
+ ).filter_by(reserved=False
+ ).filter_by(deleted=False
+ ).filter_by(instance=None
+ ).with_lockmode('update'
+ ).first()
+ # NOTE(vish): if with_lockmode isn't supported, as in sqlite,
+ # then this has concurrency issues
+ if not fixed_ip_ref:
+ raise db.NoMoreAddresses()
+ if not fixed_ip_ref.network:
+ fixed_ip_ref.network = network_get(context,
+ network_id,
+ session=session)
+ fixed_ip_ref.instance = instance_get(context,
+ instance_id,
+ session=session)
+ session.add(fixed_ip_ref)
+ return fixed_ip_ref['address']
+
+
+@require_context
+def fixed_ip_create(_context, values):
+ fixed_ip_ref = models.FixedIp()
+ for (key, value) in values.iteritems():
+ fixed_ip_ref[key] = value
+ fixed_ip_ref.save()
+ return fixed_ip_ref['address']
+
+@require_context
+def fixed_ip_disassociate(context, address):
+ session = get_session()
+ with session.begin():
+ fixed_ip_ref = fixed_ip_get_by_address(context,
+ address,
+ session=session)
+ fixed_ip_ref.instance = None
+ fixed_ip_ref.save(session=session)
+
+@require_admin_context
+def fixed_ip_disassociate_all_by_timeout(_context, host, time):
+ session = get_session()
+ # NOTE(vish): The nested select is because sqlite doesn't support
+ # JOINs in UPDATEs.
+ result = session.execute('UPDATE fixed_ips SET instance_id = NULL, '
+ 'leased = 0 '
+ 'WHERE network_id IN (SELECT id FROM networks '
+ 'WHERE host = :host) '
+ 'AND updated_at < :time '
+ 'AND instance_id IS NOT NULL '
+ 'AND allocated = 0',
+ {'host': host,
+ 'time': time.isoformat()})
+ return result.rowcount
+
+
+@require_context
+def fixed_ip_get_by_address(context, address, session=None):
+ if not session:
+ session = get_session()
+ result = session.query(models.FixedIp
+ ).filter_by(address=address
+ ).filter_by(deleted=can_read_deleted(context)
+ ).options(joinedload('network')
+ ).options(joinedload('instance')
+ ).first()
+ if not result:
+ raise exception.NotFound('No floating ip for address %s' % address)
+
+ if is_user_context(context):
+ authorize_project_context(context, result.instance.project_id)
+
+ return result
+
+
+@require_context
+def fixed_ip_get_instance(context, address):
+ fixed_ip_ref = fixed_ip_get_by_address(context, address)
+ return fixed_ip_ref.instance
+
+
+@require_admin_context
+def fixed_ip_get_network(context, address):
+ fixed_ip_ref = fixed_ip_get_by_address(context, address)
+ return fixed_ip_ref.network
+
+
+@require_context
+def fixed_ip_update(context, address, values):
+ session = get_session()
+ with session.begin():
+ fixed_ip_ref = fixed_ip_get_by_address(context,
+ address,
+ session=session)
+ for (key, value) in values.iteritems():
+ fixed_ip_ref[key] = value
+ fixed_ip_ref.save(session=session)
+
+
+###################
+
+
+#TODO(gundlach): instance_create and volume_create are nearly identical
+#and should be refactored. I expect there are other copy-and-paste
+#functions between the two of them as well.
+@require_context
+def instance_create(context, values):
+ instance_ref = models.Instance()
+ for (key, value) in values.iteritems():
+ instance_ref[key] = value
+
+ session = get_session()
+ with session.begin():
+ while instance_ref.internal_id == None:
+ internal_id = utils.generate_uid(instance_ref.__prefix__)
+ if not instance_internal_id_exists(context, internal_id,
+ session=session):
+ instance_ref.internal_id = internal_id
+ instance_ref.save(session=session)
+ return instance_ref
+
+
+@require_admin_context
+def instance_data_get_for_project(context, project_id):
+ session = get_session()
+ result = session.query(func.count(models.Instance.id),
+ func.sum(models.Instance.vcpus)
+ ).filter_by(project_id=project_id
+ ).filter_by(deleted=False
+ ).first()
+ # NOTE(vish): convert None to 0
+ return (result[0] or 0, result[1] or 0)
+
+
+@require_context
+def instance_destroy(context, instance_id):
+ session = get_session()
+ with session.begin():
+ instance_ref = instance_get(context, instance_id, session=session)
+ instance_ref.delete(session=session)
+
+
+@require_context
+def instance_get(context, instance_id, session=None):
+ if not session:
+ session = get_session()
+ result = None
+
+ if is_admin_context(context):
+ result = session.query(models.Instance
+ ).options(joinedload('security_groups')
+ ).filter_by(id=instance_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Instance
+ ).options(joinedload('security_groups')
+ ).filter_by(project_id=context.project.id
+ ).filter_by(id=instance_id
+ ).filter_by(deleted=False
+ ).first()
+ if not result:
+ raise exception.NotFound('No instance for id %s' % instance_id)
+
+ return result
+
+
+@require_admin_context
+def instance_get_all(context):
+ session = get_session()
+ return session.query(models.Instance
+ ).options(joinedload_all('fixed_ip.floating_ips')
+ ).options(joinedload('security_groups')
+ ).filter_by(deleted=can_read_deleted(context)
+ ).all()
+
+
+@require_admin_context
+def instance_get_all_by_user(context, user_id):
+ session = get_session()
+ return session.query(models.Instance
+ ).options(joinedload_all('fixed_ip.floating_ips')
+ ).options(joinedload('security_groups')
+ ).filter_by(deleted=can_read_deleted(context)
+ ).filter_by(user_id=user_id
+ ).all()
+
+
+@require_context
+def instance_get_all_by_project(context, project_id):
+ authorize_project_context(context, project_id)
+
+ session = get_session()
+ return session.query(models.Instance
+ ).options(joinedload_all('fixed_ip.floating_ips')
+ ).options(joinedload('security_groups')
+ ).filter_by(project_id=project_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).all()
+
+
+@require_context
+def instance_get_all_by_reservation(context, reservation_id):
+ session = get_session()
+
+ if is_admin_context(context):
+ return session.query(models.Instance
+ ).options(joinedload_all('fixed_ip.floating_ips')
+ ).options(joinedload('security_groups')
+ ).filter_by(reservation_id=reservation_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).all()
+ elif is_user_context(context):
+ return session.query(models.Instance
+ ).options(joinedload_all('fixed_ip.floating_ips')
+ ).options(joinedload('security_groups')
+ ).filter_by(project_id=context.project.id
+ ).filter_by(reservation_id=reservation_id
+ ).filter_by(deleted=False
+ ).all()
+
+
+@require_context
+def instance_get_by_internal_id(context, internal_id):
+ session = get_session()
+
+ if is_admin_context(context):
+ result = session.query(models.Instance
+ ).options(joinedload('security_groups')
+ ).filter_by(internal_id=internal_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Instance
+ ).options(joinedload('security_groups')
+ ).filter_by(project_id=context.project.id
+ ).filter_by(internal_id=internal_id
+ ).filter_by(deleted=False
+ ).first()
+ if not result:
+ raise exception.NotFound('Instance %s not found' % (internal_id))
+
+ return result
+
+
+@require_context
+def instance_internal_id_exists(context, internal_id, session=None):
+ if not session:
+ session = get_session()
+ return session.query(
+ exists().where(models.Instance.internal_id==internal_id)
+ ).one()[0]
+
+
+@require_context
+def instance_get_fixed_address(context, instance_id):
+ session = get_session()
+ with session.begin():
+ instance_ref = instance_get(context, instance_id, session=session)
+ if not instance_ref.fixed_ip:
+ return None
+ return instance_ref.fixed_ip['address']
+
+
+@require_context
+def instance_get_floating_address(context, instance_id):
+ session = get_session()
+ with session.begin():
+ instance_ref = instance_get(context, instance_id, session=session)
+ if not instance_ref.fixed_ip:
+ return None
+ if not instance_ref.fixed_ip.floating_ips:
+ return None
+ # NOTE(vish): this just returns the first floating ip
+ return instance_ref.fixed_ip.floating_ips[0]['address']
+
+
+@require_admin_context
+def instance_is_vpn(context, instance_id):
+ # TODO(vish): Move this into image code somewhere
+ instance_ref = instance_get(context, instance_id)
+ return instance_ref['image_id'] == FLAGS.vpn_image_id
+
+
+@require_admin_context
+def instance_set_state(context, instance_id, state, description=None):
+ # TODO(devcamcar): Move this out of models and into driver
+ from nova.compute import power_state
+ if not description:
+ description = power_state.name(state)
+ db.instance_update(context,
+ instance_id,
+ {'state': state,
+ 'state_description': description})
+
+
+@require_context
+def instance_update(context, instance_id, values):
+ session = get_session()
+ with session.begin():
+ instance_ref = instance_get(context, instance_id, session=session)
+ for (key, value) in values.iteritems():
+ instance_ref[key] = value
+ instance_ref.save(session=session)
+
+
+def instance_add_security_group(context, instance_id, security_group_id):
+ """Associate the given security group with the given instance"""
+ session = get_session()
+ with session.begin():
+ instance_ref = instance_get(context, instance_id, session=session)
+ security_group_ref = security_group_get(context,
+ security_group_id,
+ session=session)
+ instance_ref.security_groups += [security_group_ref]
+ instance_ref.save(session=session)
+
+
+###################
+
+
+@require_context
+def key_pair_create(context, values):
+ key_pair_ref = models.KeyPair()
+ for (key, value) in values.iteritems():
+ key_pair_ref[key] = value
+ key_pair_ref.save()
+ return key_pair_ref
+
+
+@require_context
+def key_pair_destroy(context, user_id, name):
+ authorize_user_context(context, user_id)
+ session = get_session()
+ with session.begin():
+ key_pair_ref = key_pair_get(context, user_id, name, session=session)
+ key_pair_ref.delete(session=session)
+
+
+@require_context
+def key_pair_destroy_all_by_user(context, user_id):
+ authorize_user_context(context, user_id)
+ session = get_session()
+ with session.begin():
+ # TODO(vish): do we have to use sql here?
+ session.execute('update key_pairs set deleted=1 where user_id=:id',
+ {'id': user_id})
+
+
+@require_context
+def key_pair_get(context, user_id, name, session=None):
+ authorize_user_context(context, user_id)
+
+ if not session:
+ session = get_session()
+
+ result = session.query(models.KeyPair
+ ).filter_by(user_id=user_id
+ ).filter_by(name=name
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ if not result:
+ raise exception.NotFound('no keypair for user %s, name %s' %
+ (user_id, name))
+ return result
+
+
+@require_context
+def key_pair_get_all_by_user(context, user_id):
+ authorize_user_context(context, user_id)
+ session = get_session()
+ return session.query(models.KeyPair
+ ).filter_by(user_id=user_id
+ ).filter_by(deleted=False
+ ).all()
+
+
+###################
+
+
+@require_admin_context
+def network_associate(context, project_id):
+ session = get_session()
+ with session.begin():
+ network_ref = session.query(models.Network
+ ).filter_by(deleted=False
+ ).filter_by(project_id=None
+ ).with_lockmode('update'
+ ).first()
+ # NOTE(vish): if with_lockmode isn't supported, as in sqlite,
+ # then this has concurrency issues
+ if not network_ref:
+ raise db.NoMoreNetworks()
+ network_ref['project_id'] = project_id
+ session.add(network_ref)
+ return network_ref
+
+
+@require_admin_context
+def network_count(context):
+ session = get_session()
+ return session.query(models.Network
+ ).filter_by(deleted=can_read_deleted(context)
+ ).count()
+
+
+@require_admin_context
+def network_count_allocated_ips(context, network_id):
+ session = get_session()
+ return session.query(models.FixedIp
+ ).filter_by(network_id=network_id
+ ).filter_by(allocated=True
+ ).filter_by(deleted=False
+ ).count()
+
+
+@require_admin_context
+def network_count_available_ips(context, network_id):
+ session = get_session()
+ return session.query(models.FixedIp
+ ).filter_by(network_id=network_id
+ ).filter_by(allocated=False
+ ).filter_by(reserved=False
+ ).filter_by(deleted=False
+ ).count()
+
+
+@require_admin_context
+def network_count_reserved_ips(context, network_id):
+ session = get_session()
+ return session.query(models.FixedIp
+ ).filter_by(network_id=network_id
+ ).filter_by(reserved=True
+ ).filter_by(deleted=False
+ ).count()
+
+
+@require_admin_context
+def network_create_safe(context, values):
+ network_ref = models.Network()
+ for (key, value) in values.iteritems():
+ network_ref[key] = value
+ try:
+ network_ref.save()
+ return network_ref
+ except IntegrityError:
+ return None
+
+
+@require_admin_context
+def network_disassociate(context, network_id):
+ network_update(context, network_id, {'project_id': None})
+
+
+@require_admin_context
+def network_disassociate_all(context):
+ session = get_session()
+ session.execute('update networks set project_id=NULL')
+
+
+@require_context
+def network_get(context, network_id, session=None):
+ if not session:
+ session = get_session()
+ result = None
+
+ if is_admin_context(context):
+ result = session.query(models.Network
+ ).filter_by(id=network_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Network
+ ).filter_by(project_id=context.project.id
+ ).filter_by(id=network_id
+ ).filter_by(deleted=False
+ ).first()
+ if not result:
+ raise exception.NotFound('No network for id %s' % network_id)
+
+ return result
+
+
+# NOTE(vish): pylint complains because of the long method name, but
+# it fits with the names of the rest of the methods
+# pylint: disable-msg=C0103
+@require_admin_context
+def network_get_associated_fixed_ips(context, network_id):
+ session = get_session()
+ return session.query(models.FixedIp
+ ).options(joinedload_all('instance')
+ ).filter_by(network_id=network_id
+ ).filter(models.FixedIp.instance_id != None
+ ).filter_by(deleted=False
+ ).all()
+
+
+@require_admin_context
+def network_get_by_bridge(context, bridge):
+ session = get_session()
+ result = session.query(models.Network
+ ).filter_by(bridge=bridge
+ ).filter_by(deleted=False
+ ).first()
+
+ if not result:
+ raise exception.NotFound('No network for bridge %s' % bridge)
+ return result
+
+
+@require_admin_context
+def network_get_by_instance(_context, instance_id):
+ session = get_session()
+ rv = session.query(models.Network
+ ).filter_by(deleted=False
+ ).join(models.Network.fixed_ips
+ ).filter_by(instance_id=instance_id
+ ).filter_by(deleted=False
+ ).first()
+ if not rv:
+ raise exception.NotFound('No network for instance %s' % instance_id)
+ return rv
+
+
+@require_admin_context
+def network_set_host(context, network_id, host_id):
+ session = get_session()
+ with session.begin():
+ network_ref = session.query(models.Network
+ ).filter_by(id=network_id
+ ).filter_by(deleted=False
+ ).with_lockmode('update'
+ ).first()
+ if not network_ref:
+ raise exception.NotFound('No network for id %s' % network_id)
+
+ # NOTE(vish): if with_lockmode isn't supported, as in sqlite,
+ # then this has concurrency issues
+ if not network_ref['host']:
+ network_ref['host'] = host_id
+ session.add(network_ref)
+
+ return network_ref['host']
+
+
+@require_context
+def network_update(context, network_id, values):
+ session = get_session()
+ with session.begin():
+ network_ref = network_get(context, network_id, session=session)
+ for (key, value) in values.iteritems():
+ network_ref[key] = value
+ network_ref.save(session=session)
+
+
+###################
+
+
+@require_context
+def project_get_network(context, project_id):
+ session = get_session()
+ rv = session.query(models.Network
+ ).filter_by(project_id=project_id
+ ).filter_by(deleted=False
+ ).first()
+ if not rv:
+ try:
+ return network_associate(context, project_id)
+ except IntegrityError:
+ # NOTE(vish): We hit this if there is a race and two
+ # processes are attempting to allocate the
+ # network at the same time
+ rv = session.query(models.Network
+ ).filter_by(project_id=project_id
+ ).filter_by(deleted=False
+ ).first()
+ return rv
+
+
+###################
+
+
+def queue_get_for(_context, topic, physical_node_id):
+ # FIXME(ja): this should be servername?
+ return "%s.%s" % (topic, physical_node_id)
+
+
+###################
+
+
+@require_admin_context
+def export_device_count(context):
+ session = get_session()
+ return session.query(models.ExportDevice
+ ).filter_by(deleted=can_read_deleted(context)
+ ).count()
+
+
+@require_admin_context
+def export_device_create_safe(context, values):
+ export_device_ref = models.ExportDevice()
+ for (key, value) in values.iteritems():
+ export_device_ref[key] = value
+ try:
+ export_device_ref.save()
+ return export_device_ref
+ except IntegrityError:
+ return None
+
+
+###################
+
+
+def auth_destroy_token(_context, token):
+ session = get_session()
+ session.delete(token)
+
+def auth_get_token(_context, token_hash):
+ session = get_session()
+ tk = session.query(models.AuthToken
+ ).filter_by(token_hash=token_hash
+ ).first()
+ if not tk:
+ raise exception.NotFound('Token %s does not exist' % token_hash)
+ return tk
+
+def auth_create_token(_context, token):
+ tk = models.AuthToken()
+ for k,v in token.iteritems():
+ tk[k] = v
+ tk.save()
+ return tk
+
+
+###################
+
+
+@require_admin_context
+def quota_get(context, project_id, session=None):
+ if not session:
+ session = get_session()
+
+ result = session.query(models.Quota
+ ).filter_by(project_id=project_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ if not result:
+ raise exception.NotFound('No quota for project_id %s' % project_id)
+
+ return result
+
+
+@require_admin_context
+def quota_create(context, values):
+ quota_ref = models.Quota()
+ for (key, value) in values.iteritems():
+ quota_ref[key] = value
+ quota_ref.save()
+ return quota_ref
+
+
+@require_admin_context
+def quota_update(context, project_id, values):
+ session = get_session()
+ with session.begin():
+ quota_ref = quota_get(context, project_id, session=session)
+ for (key, value) in values.iteritems():
+ quota_ref[key] = value
+ quota_ref.save(session=session)
+
+
+@require_admin_context
+def quota_destroy(context, project_id):
+ session = get_session()
+ with session.begin():
+ quota_ref = quota_get(context, project_id, session=session)
+ quota_ref.delete(session=session)
+
+
+###################
+
+
+@require_admin_context
+def volume_allocate_shelf_and_blade(context, volume_id):
+ session = get_session()
+ with session.begin():
+ export_device = session.query(models.ExportDevice
+ ).filter_by(volume=None
+ ).filter_by(deleted=False
+ ).with_lockmode('update'
+ ).first()
+ # NOTE(vish): if with_lockmode isn't supported, as in sqlite,
+ # then this has concurrency issues
+ if not export_device:
+ raise db.NoMoreBlades()
+ export_device.volume_id = volume_id
+ session.add(export_device)
+ return (export_device.shelf_id, export_device.blade_id)
+
+
+@require_admin_context
+def volume_attached(context, volume_id, instance_id, mountpoint):
+ session = get_session()
+ with session.begin():
+ volume_ref = volume_get(context, volume_id, session=session)
+ volume_ref['status'] = 'in-use'
+ volume_ref['mountpoint'] = mountpoint
+ volume_ref['attach_status'] = 'attached'
+ volume_ref.instance = instance_get(context, instance_id, session=session)
+ volume_ref.save(session=session)
+
+
+@require_context
+def volume_create(context, values):
+ volume_ref = models.Volume()
+ for (key, value) in values.iteritems():
+ volume_ref[key] = value
+
+ session = get_session()
+ with session.begin():
+ while volume_ref.ec2_id == None:
+ ec2_id = utils.generate_uid(volume_ref.__prefix__)
+ if not volume_ec2_id_exists(context, ec2_id, session=session):
+ volume_ref.ec2_id = ec2_id
+ volume_ref.save(session=session)
+ return volume_ref
+
+
+@require_admin_context
+def volume_data_get_for_project(context, project_id):
+ session = get_session()
+ result = session.query(func.count(models.Volume.id),
+ func.sum(models.Volume.size)
+ ).filter_by(project_id=project_id
+ ).filter_by(deleted=False
+ ).first()
+ # NOTE(vish): convert None to 0
+ return (result[0] or 0, result[1] or 0)
+
+
+@require_admin_context
+def volume_destroy(context, volume_id):
+ session = get_session()
+ with session.begin():
+ # TODO(vish): do we have to use sql here?
+ session.execute('update volumes set deleted=1 where id=:id',
+ {'id': volume_id})
+ session.execute('update export_devices set volume_id=NULL '
+ 'where volume_id=:id',
+ {'id': volume_id})
+
+
+@require_admin_context
+def volume_detached(context, volume_id):
+ session = get_session()
+ with session.begin():
+ volume_ref = volume_get(context, volume_id, session=session)
+ volume_ref['status'] = 'available'
+ volume_ref['mountpoint'] = None
+ volume_ref['attach_status'] = 'detached'
+ volume_ref.instance = None
+ volume_ref.save(session=session)
+
+
+@require_context
+def volume_get(context, volume_id, session=None):
+ if not session:
+ session = get_session()
+ result = None
+
+ if is_admin_context(context):
+ result = session.query(models.Volume
+ ).filter_by(id=volume_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Volume
+ ).filter_by(project_id=context.project.id
+ ).filter_by(id=volume_id
+ ).filter_by(deleted=False
+ ).first()
+ if not result:
+ raise exception.NotFound('No volume for id %s' % volume_id)
+
+ return result
+
+
+@require_admin_context
+def volume_get_all(context):
+ session = get_session()
+ return session.query(models.Volume
+ ).filter_by(deleted=can_read_deleted(context)
+ ).all()
+
+@require_context
+def volume_get_all_by_project(context, project_id):
+ authorize_project_context(context, project_id)
+
+ session = get_session()
+ return session.query(models.Volume
+ ).filter_by(project_id=project_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).all()
+
+
+@require_context
+def volume_get_by_ec2_id(context, ec2_id):
+ session = get_session()
+ result = None
+
+ if is_admin_context(context):
+ result = session.query(models.Volume
+ ).filter_by(ec2_id=ec2_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Volume
+ ).filter_by(project_id=context.project.id
+ ).filter_by(ec2_id=ec2_id
+ ).filter_by(deleted=False
+ ).first()
+ else:
+ raise exception.NotAuthorized()
+
+ if not result:
+ raise exception.NotFound('Volume %s not found' % ec2_id)
+
+ return result
+
+
+@require_context
+def volume_ec2_id_exists(context, ec2_id, session=None):
+ if not session:
+ session = get_session()
+
+ return session.query(exists(
+ ).where(models.Volume.id==ec2_id)
+ ).one()[0]
+
+
+@require_admin_context
+def volume_get_instance(context, volume_id):
+ session = get_session()
+ result = session.query(models.Volume
+ ).filter_by(id=volume_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).options(joinedload('instance')
+ ).first()
+ if not result:
+ raise exception.NotFound('Volume %s not found' % ec2_id)
+
+ return result.instance
+
+
+@require_admin_context
+def volume_get_shelf_and_blade(context, volume_id):
+ session = get_session()
+ result = session.query(models.ExportDevice
+ ).filter_by(volume_id=volume_id
+ ).first()
+ if not result:
+ raise exception.NotFound('No export device found for volume %s' %
+ volume_id)
+
+ return (result.shelf_id, result.blade_id)
+
+
+@require_context
+def volume_update(context, volume_id, values):
+ session = get_session()
+ with session.begin():
+ volume_ref = volume_get(context, volume_id, session=session)
+ for (key, value) in values.iteritems():
+ volume_ref[key] = value
+ volume_ref.save(session=session)
+
+
+###################
+
+
+@require_context
+def security_group_get_all(context):
+ session = get_session()
+ return session.query(models.SecurityGroup
+ ).filter_by(deleted=can_read_deleted(context)
+ ).options(joinedload_all('rules')
+ ).all()
+
+
+@require_context
+def security_group_get(context, security_group_id, session=None):
+ if not session:
+ session = get_session()
+ if is_admin_context(context):
+ result = session.query(models.SecurityGroup
+ ).filter_by(deleted=can_read_deleted(context),
+ ).filter_by(id=security_group_id
+ ).options(joinedload_all('rules')
+ ).first()
+ else:
+ result = session.query(models.SecurityGroup
+ ).filter_by(deleted=False
+ ).filter_by(id=security_group_id
+ ).filter_by(project_id=context.project_id
+ ).options(joinedload_all('rules')
+ ).first()
+ if not result:
+ raise exception.NotFound("No secuity group with id %s" %
+ security_group_id)
+ return result
+
+
+@require_context
+def security_group_get_by_name(context, project_id, group_name):
+ session = get_session()
+ result = session.query(models.SecurityGroup
+ ).filter_by(project_id=project_id
+ ).filter_by(name=group_name
+ ).filter_by(deleted=False
+ ).options(joinedload_all('rules')
+ ).options(joinedload_all('instances')
+ ).first()
+ if not result:
+ raise exception.NotFound(
+ 'No security group named %s for project: %s' \
+ % (group_name, project_id))
+ return result
+
+
+@require_context
+def security_group_get_by_project(context, project_id):
+ session = get_session()
+ return session.query(models.SecurityGroup
+ ).filter_by(project_id=project_id
+ ).filter_by(deleted=False
+ ).options(joinedload_all('rules')
+ ).all()
+
+
+@require_context
+def security_group_get_by_instance(context, instance_id):
+ session = get_session()
+ return session.query(models.SecurityGroup
+ ).filter_by(deleted=False
+ ).options(joinedload_all('rules')
+ ).join(models.SecurityGroup.instances
+ ).filter_by(id=instance_id
+ ).filter_by(deleted=False
+ ).all()
+
+
+@require_context
+def security_group_exists(context, project_id, group_name):
+ try:
+ group = security_group_get_by_name(context, project_id, group_name)
+ return group != None
+ except exception.NotFound:
+ return False
+
+
+@require_context
+def security_group_create(context, values):
+ security_group_ref = models.SecurityGroup()
+ # FIXME(devcamcar): Unless I do this, rules fails with lazy load exception
+ # once save() is called. This will get cleaned up in next orm pass.
+ security_group_ref.rules
+ for (key, value) in values.iteritems():
+ security_group_ref[key] = value
+ security_group_ref.save()
+ return security_group_ref
+
+
+@require_context
+def security_group_destroy(context, security_group_id):
+ session = get_session()
+ with session.begin():
+ # TODO(vish): do we have to use sql here?
+ session.execute('update security_groups set deleted=1 where id=:id',
+ {'id': security_group_id})
+ session.execute('update security_group_rules set deleted=1 '
+ 'where group_id=:id',
+ {'id': security_group_id})
+
+@require_context
+def security_group_destroy_all(context, session=None):
+ if not session:
+ session = get_session()
+ with session.begin():
+ # TODO(vish): do we have to use sql here?
+ session.execute('update security_groups set deleted=1')
+ session.execute('update security_group_rules set deleted=1')
+
+
+###################
+
+
+@require_context
+def security_group_rule_get(context, security_group_rule_id, session=None):
+ if not session:
+ session = get_session()
+ if is_admin_context(context):
+ result = session.query(models.SecurityGroupIngressRule
+ ).filter_by(deleted=can_read_deleted(context)
+ ).filter_by(id=security_group_rule_id
+ ).first()
+ else:
+ # TODO(vish): Join to group and check for project_id
+ result = session.query(models.SecurityGroupIngressRule
+ ).filter_by(deleted=False
+ ).filter_by(id=security_group_rule_id
+ ).first()
+ if not result:
+ raise exception.NotFound("No secuity group rule with id %s" %
+ security_group_rule_id)
+ return result
+
+
+@require_context
+def security_group_rule_create(context, values):
+ security_group_rule_ref = models.SecurityGroupIngressRule()
+ for (key, value) in values.iteritems():
+ security_group_rule_ref[key] = value
+ security_group_rule_ref.save()
+ return security_group_rule_ref
+
+@require_context
+def security_group_rule_destroy(context, security_group_rule_id):
+ session = get_session()
+ with session.begin():
+ security_group_rule = security_group_rule_get(context,
+ security_group_rule_id,
+ session=session)
+ security_group_rule.delete(session=session)
+
+
+###################
+
+@require_admin_context
+def user_get(context, id, session=None):
+ if not session:
+ session = get_session()
+
+ result = session.query(models.User
+ ).filter_by(id=id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+
+ if not result:
+ raise exception.NotFound('No user for id %s' % id)
+
+ return result
+
+
+@require_admin_context
+def user_get_by_access_key(context, access_key, session=None):
+ if not session:
+ session = get_session()
+
+ result = session.query(models.User
+ ).filter_by(access_key=access_key
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+
+ if not result:
+ raise exception.NotFound('No user for access key %s' % access_key)
+
+ return result
+
+
+@require_admin_context
+def user_create(_context, values):
+ user_ref = models.User()
+ for (key, value) in values.iteritems():
+ user_ref[key] = value
+ user_ref.save()
+ return user_ref
+
+
+@require_admin_context
+def user_delete(context, id):
+ session = get_session()
+ with session.begin():
+ session.execute('delete from user_project_association where user_id=:id',
+ {'id': id})
+ session.execute('delete from user_role_association where user_id=:id',
+ {'id': id})
+ session.execute('delete from user_project_role_association where user_id=:id',
+ {'id': id})
+ user_ref = user_get(context, id, session=session)
+ session.delete(user_ref)
+
+
+def user_get_all(context):
+ session = get_session()
+ return session.query(models.User
+ ).filter_by(deleted=can_read_deleted(context)
+ ).all()
+
+
+def project_create(_context, values):
+ project_ref = models.Project()
+ for (key, value) in values.iteritems():
+ project_ref[key] = value
+ project_ref.save()
+ return project_ref
+
+
+def project_add_member(context, project_id, user_id):
+ session = get_session()
+ with session.begin():
+ project_ref = project_get(context, project_id, session=session)
+ user_ref = user_get(context, user_id, session=session)
+
+ project_ref.members += [user_ref]
+ project_ref.save(session=session)
+
+
+def project_get(context, id, session=None):
+ if not session:
+ session = get_session()
+
+ result = session.query(models.Project
+ ).filter_by(deleted=False
+ ).filter_by(id=id
+ ).options(joinedload_all('members')
+ ).first()
+
+ if not result:
+ raise exception.NotFound("No project with id %s" % id)
+
+ return result
+
+
+def project_get_all(context):
+ session = get_session()
+ return session.query(models.Project
+ ).filter_by(deleted=can_read_deleted(context)
+ ).options(joinedload_all('members')
+ ).all()
+
+
+def project_get_by_user(context, user_id):
+ session = get_session()
+ user = session.query(models.User
+ ).filter_by(deleted=can_read_deleted(context)
+ ).options(joinedload_all('projects')
+ ).first()
+ return user.projects
+
+
+def project_remove_member(context, project_id, user_id):
+ session = get_session()
+ project = project_get(context, project_id, session=session)
+ user = user_get(context, user_id, session=session)
+
+ if user in project.members:
+ project.members.remove(user)
+ project.save(session=session)
+
+
+def user_update(context, user_id, values):
+ session = get_session()
+ with session.begin():
+ user_ref = user_get(context, user_id, session=session)
+ for (key, value) in values.iteritems():
+ user_ref[key] = value
+ user_ref.save(session=session)
+
+
+def project_update(context, project_id, values):
+ session = get_session()
+ with session.begin():
+ project_ref = project_get(context, project_id, session=session)
+ for (key, value) in values.iteritems():
+ project_ref[key] = value
+ project_ref.save(session=session)
+
+
+def project_delete(context, id):
+ session = get_session()
+ with session.begin():
+ session.execute('delete from user_project_association where project_id=:id',
+ {'id': id})
+ session.execute('delete from user_project_role_association where project_id=:id',
+ {'id': id})
+ project_ref = project_get(context, id, session=session)
+ session.delete(project_ref)
+
+
+def user_get_roles(context, user_id):
+ session = get_session()
+ with session.begin():
+ user_ref = user_get(context, user_id, session=session)
+ return [role.role for role in user_ref['roles']]
+
+
+def user_get_roles_for_project(context, user_id, project_id):
+ session = get_session()
+ with session.begin():
+ res = session.query(models.UserProjectRoleAssociation
+ ).filter_by(user_id=user_id
+ ).filter_by(project_id=project_id
+ ).all()
+ return [association.role for association in res]
+
+def user_remove_project_role(context, user_id, project_id, role):
+ session = get_session()
+ with session.begin():
+ session.execute('delete from user_project_role_association where ' + \
+ 'user_id=:user_id and project_id=:project_id and ' + \
+ 'role=:role', { 'user_id' : user_id,
+ 'project_id' : project_id,
+ 'role' : role })
+
+
+def user_remove_role(context, user_id, role):
+ session = get_session()
+ with session.begin():
+ res = session.query(models.UserRoleAssociation
+ ).filter_by(user_id=user_id
+ ).filter_by(role=role
+ ).all()
+ for role in res:
+ session.delete(role)
+
+
+def user_add_role(context, user_id, role):
+ session = get_session()
+ with session.begin():
+ user_ref = user_get(context, user_id, session=session)
+ models.UserRoleAssociation(user=user_ref, role=role).save(session=session)
+
+
+def user_add_project_role(context, user_id, project_id, role):
+ session = get_session()
+ with session.begin():
+ user_ref = user_get(context, user_id, session=session)
+ project_ref = project_get(context, project_id, session=session)
+ models.UserProjectRoleAssociation(user_id=user_ref['id'],
+ project_id=project_ref['id'],
+ role=role).save(session=session)
+
+
+###################
+
+
+
+@require_admin_context
+def host_get_networks(context, host):
+ session = get_session()
+ with session.begin():
+ return session.query(models.Network
+ ).filter_by(deleted=False
+ ).filter_by(host=host
+ ).all()
diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py
new file mode 100644
index 000000000..a63bca2b0
--- /dev/null
+++ b/nova/db/sqlalchemy/models.py
@@ -0,0 +1,513 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+SQLAlchemy models for nova data
+"""
+
+import sys
+import datetime
+
+# TODO(vish): clean up these imports
+from sqlalchemy.orm import relationship, backref, exc, object_mapper
+from sqlalchemy import Column, Integer, String, schema
+from sqlalchemy import ForeignKey, DateTime, Boolean, Text
+from sqlalchemy.exc import IntegrityError
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.schema import ForeignKeyConstraint
+
+from nova.db.sqlalchemy.session import get_session
+
+from nova import auth
+from nova import exception
+from nova import flags
+
+FLAGS = flags.FLAGS
+
+BASE = declarative_base()
+
+
+class NovaBase(object):
+ """Base class for Nova Models"""
+ __table_args__ = {'mysql_engine': 'InnoDB'}
+ __table_initialized__ = False
+ __prefix__ = 'none'
+ created_at = Column(DateTime, default=datetime.datetime.utcnow)
+ updated_at = Column(DateTime, onupdate=datetime.datetime.utcnow)
+ deleted_at = Column(DateTime)
+ deleted = Column(Boolean, default=False)
+
+ @property
+ def str_id(self):
+ """Get string id of object (generally prefix + '-' + id)"""
+ return "%s-%s" % (self.__prefix__, self.id)
+
+ def save(self, session=None):
+ """Save this object"""
+ if not session:
+ session = get_session()
+ session.add(self)
+ try:
+ session.flush()
+ except IntegrityError, e:
+ if str(e).endswith('is not unique'):
+ raise exception.Duplicate(str(e))
+ else:
+ raise
+
+ def delete(self, session=None):
+ """Delete this object"""
+ self.deleted = True
+ self.deleted_at = datetime.datetime.utcnow()
+ self.save(session=session)
+
+ def __setitem__(self, key, value):
+ setattr(self, key, value)
+
+ def __getitem__(self, key):
+ return getattr(self, key)
+
+ def __iter__(self):
+ self._i = iter(object_mapper(self).columns)
+ return self
+
+ def next(self):
+ n = self._i.next().name
+ return n, getattr(self, n)
+
+# TODO(vish): Store images in the database instead of file system
+#class Image(BASE, NovaBase):
+# """Represents an image in the datastore"""
+# __tablename__ = 'images'
+# __prefix__ = 'ami'
+# id = Column(Integer, primary_key=True)
+# ec2_id = Column(String(12), unique=True)
+# user_id = Column(String(255))
+# project_id = Column(String(255))
+# image_type = Column(String(255))
+# public = Column(Boolean, default=False)
+# state = Column(String(255))
+# location = Column(String(255))
+# arch = Column(String(255))
+# default_kernel_id = Column(String(255))
+# default_ramdisk_id = Column(String(255))
+#
+# @validates('image_type')
+# def validate_image_type(self, key, image_type):
+# assert(image_type in ['machine', 'kernel', 'ramdisk', 'raw'])
+#
+# @validates('state')
+# def validate_state(self, key, state):
+# assert(state in ['available', 'pending', 'disabled'])
+#
+# @validates('default_kernel_id')
+# def validate_kernel_id(self, key, val):
+# if val != 'machine':
+# assert(val is None)
+#
+# @validates('default_ramdisk_id')
+# def validate_ramdisk_id(self, key, val):
+# if val != 'machine':
+# assert(val is None)
+#
+#
+# TODO(vish): To make this into its own table, we need a good place to
+# create the host entries. In config somwhere? Or the first
+# time any object sets host? This only becomes particularly
+# important if we need to store per-host data.
+#class Host(BASE, NovaBase):
+# """Represents a host where services are running"""
+# __tablename__ = 'hosts'
+# id = Column(String(255), primary_key=True)
+#
+#
+class Service(BASE, NovaBase):
+ """Represents a running service on a host"""
+ __tablename__ = 'services'
+ id = Column(Integer, primary_key=True)
+ host = Column(String(255)) # , ForeignKey('hosts.id'))
+ binary = Column(String(255))
+ topic = Column(String(255))
+ report_count = Column(Integer, nullable=False, default=0)
+ disabled = Column(Boolean, default=False)
+
+
+class Instance(BASE, NovaBase):
+ """Represents a guest vm"""
+ __tablename__ = 'instances'
+ __prefix__ = 'i'
+ id = Column(Integer, primary_key=True)
+ internal_id = Column(Integer, unique=True)
+
+ admin_pass = Column(String(255))
+
+ user_id = Column(String(255))
+ project_id = Column(String(255))
+
+ @property
+ def user(self):
+ return auth.manager.AuthManager().get_user(self.user_id)
+
+ @property
+ def project(self):
+ return auth.manager.AuthManager().get_project(self.project_id)
+
+ @property
+ def name(self):
+ return "instance-%d" % self.internal_id
+
+ image_id = Column(String(255))
+ kernel_id = Column(String(255))
+ ramdisk_id = Column(String(255))
+
+ server_name = Column(String(255))
+
+# image_id = Column(Integer, ForeignKey('images.id'), nullable=True)
+# kernel_id = Column(Integer, ForeignKey('images.id'), nullable=True)
+# ramdisk_id = Column(Integer, ForeignKey('images.id'), nullable=True)
+# ramdisk = relationship(Ramdisk, backref=backref('instances', order_by=id))
+# kernel = relationship(Kernel, backref=backref('instances', order_by=id))
+# project = relationship(Project, backref=backref('instances', order_by=id))
+
+ launch_index = Column(Integer)
+ key_name = Column(String(255))
+ key_data = Column(Text)
+
+ state = Column(Integer)
+ state_description = Column(String(255))
+
+ memory_mb = Column(Integer)
+ vcpus = Column(Integer)
+ local_gb = Column(Integer)
+
+ hostname = Column(String(255))
+ host = Column(String(255)) # , ForeignKey('hosts.id'))
+
+ instance_type = Column(String(255))
+
+ user_data = Column(Text)
+
+ reservation_id = Column(String(255))
+ mac_address = Column(String(255))
+
+ scheduled_at = Column(DateTime)
+ launched_at = Column(DateTime)
+ terminated_at = Column(DateTime)
+
+ display_name = Column(String(255))
+ display_description = Column(String(255))
+
+ # TODO(vish): see Ewan's email about state improvements, probably
+ # should be in a driver base class or some such
+ # vmstate_state = running, halted, suspended, paused
+ # power_state = what we have
+ # task_state = transitory and may trigger power state transition
+
+ #@validates('state')
+ #def validate_state(self, key, state):
+ # assert(state in ['nostate', 'running', 'blocked', 'paused',
+ # 'shutdown', 'shutoff', 'crashed'])
+
+
+class Volume(BASE, NovaBase):
+ """Represents a block storage device that can be attached to a vm"""
+ __tablename__ = 'volumes'
+ __prefix__ = 'vol'
+ id = Column(Integer, primary_key=True)
+ ec2_id = Column(String(12), unique=True)
+
+ user_id = Column(String(255))
+ project_id = Column(String(255))
+
+ host = Column(String(255)) # , ForeignKey('hosts.id'))
+ size = Column(Integer)
+ availability_zone = Column(String(255)) # TODO(vish): foreign key?
+ instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True)
+ instance = relationship(Instance,
+ backref=backref('volumes'),
+ foreign_keys=instance_id,
+ primaryjoin='and_(Volume.instance_id==Instance.id,'
+ 'Volume.deleted==False)')
+ mountpoint = Column(String(255))
+ attach_time = Column(String(255)) # TODO(vish): datetime
+ status = Column(String(255)) # TODO(vish): enum?
+ attach_status = Column(String(255)) # TODO(vish): enum
+
+ scheduled_at = Column(DateTime)
+ launched_at = Column(DateTime)
+ terminated_at = Column(DateTime)
+
+ display_name = Column(String(255))
+ display_description = Column(String(255))
+
+
+class Quota(BASE, NovaBase):
+ """Represents quota overrides for a project"""
+ __tablename__ = 'quotas'
+ id = Column(Integer, primary_key=True)
+
+ project_id = Column(String(255))
+
+ instances = Column(Integer)
+ cores = Column(Integer)
+ volumes = Column(Integer)
+ gigabytes = Column(Integer)
+ floating_ips = Column(Integer)
+
+ @property
+ def str_id(self):
+ return self.project_id
+
+
+class ExportDevice(BASE, NovaBase):
+ """Represates a shelf and blade that a volume can be exported on"""
+ __tablename__ = 'export_devices'
+ __table_args__ = (schema.UniqueConstraint("shelf_id", "blade_id"), {'mysql_engine': 'InnoDB'})
+ id = Column(Integer, primary_key=True)
+ shelf_id = Column(Integer)
+ blade_id = Column(Integer)
+ volume_id = Column(Integer, ForeignKey('volumes.id'), nullable=True)
+ volume = relationship(Volume,
+ backref=backref('export_device', uselist=False),
+ foreign_keys=volume_id,
+ primaryjoin='and_(ExportDevice.volume_id==Volume.id,'
+ 'ExportDevice.deleted==False)')
+
+
+class SecurityGroupInstanceAssociation(BASE, NovaBase):
+ __tablename__ = 'security_group_instance_association'
+ id = Column(Integer, primary_key=True)
+ security_group_id = Column(Integer, ForeignKey('security_groups.id'))
+ instance_id = Column(Integer, ForeignKey('instances.id'))
+
+
+class SecurityGroup(BASE, NovaBase):
+ """Represents a security group"""
+ __tablename__ = 'security_groups'
+ id = Column(Integer, primary_key=True)
+
+ name = Column(String(255))
+ description = Column(String(255))
+ user_id = Column(String(255))
+ project_id = Column(String(255))
+
+ instances = relationship(Instance,
+ secondary="security_group_instance_association",
+ primaryjoin="and_(SecurityGroup.id == SecurityGroupInstanceAssociation.security_group_id,"
+ "SecurityGroup.deleted == False)",
+ secondaryjoin="and_(SecurityGroupInstanceAssociation.instance_id == Instance.id,"
+ "Instance.deleted == False)",
+ backref='security_groups')
+
+ @property
+ def user(self):
+ return auth.manager.AuthManager().get_user(self.user_id)
+
+ @property
+ def project(self):
+ return auth.manager.AuthManager().get_project(self.project_id)
+
+
+class SecurityGroupIngressRule(BASE, NovaBase):
+ """Represents a rule in a security group"""
+ __tablename__ = 'security_group_rules'
+ id = Column(Integer, primary_key=True)
+
+ parent_group_id = Column(Integer, ForeignKey('security_groups.id'))
+ parent_group = relationship("SecurityGroup", backref="rules",
+ foreign_keys=parent_group_id,
+ primaryjoin="and_(SecurityGroupIngressRule.parent_group_id == SecurityGroup.id,"
+ "SecurityGroupIngressRule.deleted == False)")
+
+ protocol = Column(String(5)) # "tcp", "udp", or "icmp"
+ from_port = Column(Integer)
+ to_port = Column(Integer)
+ cidr = Column(String(255))
+
+ # Note: This is not the parent SecurityGroup. It's SecurityGroup we're
+ # granting access for.
+ group_id = Column(Integer, ForeignKey('security_groups.id'))
+
+
+class KeyPair(BASE, NovaBase):
+ """Represents a public key pair for ssh"""
+ __tablename__ = 'key_pairs'
+ id = Column(Integer, primary_key=True)
+
+ name = Column(String(255))
+
+ user_id = Column(String(255))
+
+ fingerprint = Column(String(255))
+ public_key = Column(Text)
+
+ @property
+ def str_id(self):
+ return '%s.%s' % (self.user_id, self.name)
+
+
+class Network(BASE, NovaBase):
+ """Represents a network"""
+ __tablename__ = 'networks'
+ __table_args__ = (schema.UniqueConstraint("vpn_public_address",
+ "vpn_public_port"),
+ {'mysql_engine': 'InnoDB'})
+ id = Column(Integer, primary_key=True)
+
+ injected = Column(Boolean, default=False)
+ cidr = Column(String(255), unique=True)
+ netmask = Column(String(255))
+ bridge = Column(String(255))
+ gateway = Column(String(255))
+ broadcast = Column(String(255))
+ dns = Column(String(255))
+
+ vlan = Column(Integer)
+ vpn_public_address = Column(String(255))
+ vpn_public_port = Column(Integer)
+ vpn_private_address = Column(String(255))
+ dhcp_start = Column(String(255))
+
+ # NOTE(vish): The unique constraint below helps avoid a race condition
+ # when associating a network, but it also means that we
+ # can't associate two networks with one project.
+ project_id = Column(String(255), unique=True)
+ host = Column(String(255)) # , ForeignKey('hosts.id'))
+
+
+class AuthToken(BASE, NovaBase):
+ """Represents an authorization token for all API transactions. Fields
+ are a string representing the actual token and a user id for mapping
+ to the actual user"""
+ __tablename__ = 'auth_tokens'
+ token_hash = Column(String(255), primary_key=True)
+ user_id = Column(Integer)
+ server_manageent_url = Column(String(255))
+ storage_url = Column(String(255))
+ cdn_management_url = Column(String(255))
+
+
+# TODO(vish): can these both come from the same baseclass?
+class FixedIp(BASE, NovaBase):
+ """Represents a fixed ip for an instance"""
+ __tablename__ = 'fixed_ips'
+ id = Column(Integer, primary_key=True)
+ address = Column(String(255))
+ network_id = Column(Integer, ForeignKey('networks.id'), nullable=True)
+ network = relationship(Network, backref=backref('fixed_ips'))
+ instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True)
+ instance = relationship(Instance,
+ backref=backref('fixed_ip', uselist=False),
+ foreign_keys=instance_id,
+ primaryjoin='and_(FixedIp.instance_id==Instance.id,'
+ 'FixedIp.deleted==False)')
+ allocated = Column(Boolean, default=False)
+ leased = Column(Boolean, default=False)
+ reserved = Column(Boolean, default=False)
+
+ @property
+ def str_id(self):
+ return self.address
+
+
+class User(BASE, NovaBase):
+ """Represents a user"""
+ __tablename__ = 'users'
+ id = Column(String(255), primary_key=True)
+
+ name = Column(String(255))
+ access_key = Column(String(255))
+ secret_key = Column(String(255))
+
+ is_admin = Column(Boolean)
+
+
+class Project(BASE, NovaBase):
+ """Represents a project"""
+ __tablename__ = 'projects'
+ id = Column(String(255), primary_key=True)
+ name = Column(String(255))
+ description = Column(String(255))
+
+ project_manager = Column(String(255), ForeignKey(User.id))
+
+ members = relationship(User,
+ secondary='user_project_association',
+ backref='projects')
+
+
+class UserProjectRoleAssociation(BASE, NovaBase):
+ __tablename__ = 'user_project_role_association'
+ user_id = Column(String(255), primary_key=True)
+ user = relationship(User,
+ primaryjoin=user_id==User.id,
+ foreign_keys=[User.id],
+ uselist=False)
+
+ project_id = Column(String(255), primary_key=True)
+ project = relationship(Project,
+ primaryjoin=project_id==Project.id,
+ foreign_keys=[Project.id],
+ uselist=False)
+
+ role = Column(String(255), primary_key=True)
+ ForeignKeyConstraint(['user_id',
+ 'project_id'],
+ ['user_project_association.user_id',
+ 'user_project_association.project_id'])
+
+
+class UserRoleAssociation(BASE, NovaBase):
+ __tablename__ = 'user_role_association'
+ user_id = Column(String(255), ForeignKey('users.id'), primary_key=True)
+ user = relationship(User, backref='roles')
+ role = Column(String(255), primary_key=True)
+
+
+class UserProjectAssociation(BASE, NovaBase):
+ __tablename__ = 'user_project_association'
+ user_id = Column(String(255), ForeignKey(User.id), primary_key=True)
+ project_id = Column(String(255), ForeignKey(Project.id), primary_key=True)
+
+
+
+class FloatingIp(BASE, NovaBase):
+ """Represents a floating ip that dynamically forwards to a fixed ip"""
+ __tablename__ = 'floating_ips'
+ id = Column(Integer, primary_key=True)
+ address = Column(String(255))
+ fixed_ip_id = Column(Integer, ForeignKey('fixed_ips.id'), nullable=True)
+ fixed_ip = relationship(FixedIp,
+ backref=backref('floating_ips'),
+ foreign_keys=fixed_ip_id,
+ primaryjoin='and_(FloatingIp.fixed_ip_id==FixedIp.id,'
+ 'FloatingIp.deleted==False)')
+ project_id = Column(String(255))
+ host = Column(String(255)) # , ForeignKey('hosts.id'))
+
+
+def register_models():
+ """Register Models and create metadata"""
+ from sqlalchemy import create_engine
+ models = (Service, Instance, Volume, ExportDevice, FixedIp,
+ FloatingIp, Network, SecurityGroup,
+ SecurityGroupIngressRule, SecurityGroupInstanceAssociation,
+ AuthToken, User, Project) # , Image, Host
+ engine = create_engine(FLAGS.sql_connection, echo=False)
+ for model in models:
+ model.metadata.create_all(engine)
diff --git a/nova/db/sqlalchemy/session.py b/nova/db/sqlalchemy/session.py
new file mode 100644
index 000000000..826754f6a
--- /dev/null
+++ b/nova/db/sqlalchemy/session.py
@@ -0,0 +1,43 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+Session Handling for SQLAlchemy backend
+"""
+
+from sqlalchemy import create_engine
+from sqlalchemy.orm import sessionmaker
+
+from nova import flags
+
+FLAGS = flags.FLAGS
+
+_ENGINE = None
+_MAKER = None
+
+def get_session(autocommit=True, expire_on_commit=False):
+ """Helper method to grab session"""
+ global _ENGINE
+ global _MAKER
+ if not _MAKER:
+ if not _ENGINE:
+ _ENGINE = create_engine(FLAGS.sql_connection, echo=False)
+ _MAKER = (sessionmaker(bind=_ENGINE,
+ autocommit=autocommit,
+ expire_on_commit=expire_on_commit))
+ session = _MAKER()
+ return session
diff --git a/nova/endpoint/api.py b/nova/endpoint/api.py
deleted file mode 100755
index 40be00bb7..000000000
--- a/nova/endpoint/api.py
+++ /dev/null
@@ -1,344 +0,0 @@
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2010 United States Government as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""
-Tornado REST API Request Handlers for Nova functions
-Most calls are proxied into the responsible controller.
-"""
-
-import logging
-import multiprocessing
-import random
-import re
-import urllib
-# TODO(termie): replace minidom with etree
-from xml.dom import minidom
-
-import tornado.web
-from twisted.internet import defer
-
-from nova import crypto
-from nova import exception
-from nova import flags
-from nova import utils
-from nova.auth import manager
-import nova.cloudpipe.api
-from nova.endpoint import cloud
-
-
-FLAGS = flags.FLAGS
-flags.DEFINE_integer('cc_port', 8773, 'cloud controller port')
-
-
-_log = logging.getLogger("api")
-_log.setLevel(logging.DEBUG)
-
-
-_c2u = re.compile('(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))')
-
-
-def _camelcase_to_underscore(str):
- return _c2u.sub(r'_\1', str).lower().strip('_')
-
-
-def _underscore_to_camelcase(str):
- return ''.join([x[:1].upper() + x[1:] for x in str.split('_')])
-
-
-def _underscore_to_xmlcase(str):
- res = _underscore_to_camelcase(str)
- return res[:1].lower() + res[1:]
-
-
-class APIRequestContext(object):
- def __init__(self, handler, user, project):
- self.handler = handler
- self.user = user
- self.project = project
- self.request_id = ''.join(
- [random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-')
- for x in xrange(20)]
- )
-
-
-class APIRequest(object):
- def __init__(self, controller, action):
- self.controller = controller
- self.action = action
-
- def send(self, context, **kwargs):
-
- try:
- method = getattr(self.controller,
- _camelcase_to_underscore(self.action))
- except AttributeError:
- _error = ('Unsupported API request: controller = %s,'
- 'action = %s') % (self.controller, self.action)
- _log.warning(_error)
- # TODO: Raise custom exception, trap in apiserver,
- # and reraise as 400 error.
- raise Exception(_error)
-
- args = {}
- for key, value in kwargs.items():
- parts = key.split(".")
- key = _camelcase_to_underscore(parts[0])
- if len(parts) > 1:
- d = args.get(key, {})
- d[parts[1]] = value[0]
- value = d
- else:
- value = value[0]
- args[key] = value
-
- for key in args.keys():
- if isinstance(args[key], dict):
- if args[key] != {} and args[key].keys()[0].isdigit():
- s = args[key].items()
- s.sort()
- args[key] = [v for k, v in s]
-
- d = defer.maybeDeferred(method, context, **args)
- d.addCallback(self._render_response, context.request_id)
- return d
-
- def _render_response(self, response_data, request_id):
- xml = minidom.Document()
-
- response_el = xml.createElement(self.action + 'Response')
- response_el.setAttribute('xmlns',
- 'http://ec2.amazonaws.com/doc/2009-11-30/')
- request_id_el = xml.createElement('requestId')
- request_id_el.appendChild(xml.createTextNode(request_id))
- response_el.appendChild(request_id_el)
- if(response_data == True):
- self._render_dict(xml, response_el, {'return': 'true'})
- else:
- self._render_dict(xml, response_el, response_data)
-
- xml.appendChild(response_el)
-
- response = xml.toxml()
- xml.unlink()
- _log.debug(response)
- return response
-
- def _render_dict(self, xml, el, data):
- try:
- for key in data.keys():
- val = data[key]
- el.appendChild(self._render_data(xml, key, val))
- except:
- _log.debug(data)
- raise
-
- def _render_data(self, xml, el_name, data):
- el_name = _underscore_to_xmlcase(el_name)
- data_el = xml.createElement(el_name)
-
- if isinstance(data, list):
- for item in data:
- data_el.appendChild(self._render_data(xml, 'item', item))
- elif isinstance(data, dict):
- self._render_dict(xml, data_el, data)
- elif hasattr(data, '__dict__'):
- self._render_dict(xml, data_el, data.__dict__)
- elif isinstance(data, bool):
- data_el.appendChild(xml.createTextNode(str(data).lower()))
- elif data != None:
- data_el.appendChild(xml.createTextNode(str(data)))
-
- return data_el
-
-
-class RootRequestHandler(tornado.web.RequestHandler):
- def get(self):
- # available api versions
- versions = [
- '1.0',
- '2007-01-19',
- '2007-03-01',
- '2007-08-29',
- '2007-10-10',
- '2007-12-15',
- '2008-02-01',
- '2008-09-01',
- '2009-04-04',
- ]
- for version in versions:
- self.write('%s\n' % version)
- self.finish()
-
-
-class MetadataRequestHandler(tornado.web.RequestHandler):
- def print_data(self, data):
- if isinstance(data, dict):
- output = ''
- for key in data:
- if key == '_name':
- continue
- output += key
- if isinstance(data[key], dict):
- if '_name' in data[key]:
- output += '=' + str(data[key]['_name'])
- else:
- output += '/'
- output += '\n'
- self.write(output[:-1]) # cut off last \n
- elif isinstance(data, list):
- self.write('\n'.join(data))
- else:
- self.write(str(data))
-
- def lookup(self, path, data):
- items = path.split('/')
- for item in items:
- if item:
- if not isinstance(data, dict):
- return data
- if not item in data:
- return None
- data = data[item]
- return data
-
- def get(self, path):
- cc = self.application.controllers['Cloud']
- meta_data = cc.get_metadata(self.request.remote_ip)
- if meta_data is None:
- _log.error('Failed to get metadata for ip: %s' %
- self.request.remote_ip)
- raise tornado.web.HTTPError(404)
- data = self.lookup(path, meta_data)
- if data is None:
- raise tornado.web.HTTPError(404)
- self.print_data(data)
- self.finish()
-
-
-class APIRequestHandler(tornado.web.RequestHandler):
- def get(self, controller_name):
- self.execute(controller_name)
-
- @tornado.web.asynchronous
- def execute(self, controller_name):
- # Obtain the appropriate controller for this request.
- try:
- controller = self.application.controllers[controller_name]
- except KeyError:
- self._error('unhandled', 'no controller named %s' % controller_name)
- return
-
- args = self.request.arguments
-
- # Read request signature.
- try:
- signature = args.pop('Signature')[0]
- except:
- raise tornado.web.HTTPError(400)
-
- # Make a copy of args for authentication and signature verification.
- auth_params = {}
- for key, value in args.items():
- auth_params[key] = value[0]
-
- # Get requested action and remove authentication args for final request.
- try:
- action = args.pop('Action')[0]
- access = args.pop('AWSAccessKeyId')[0]
- args.pop('SignatureMethod')
- args.pop('SignatureVersion')
- args.pop('Version')
- args.pop('Timestamp')
- except:
- raise tornado.web.HTTPError(400)
-
- # Authenticate the request.
- try:
- (user, project) = manager.AuthManager().authenticate(
- access,
- signature,
- auth_params,
- self.request.method,
- self.request.host,
- self.request.path
- )
-
- except exception.Error, ex:
- logging.debug("Authentication Failure: %s" % ex)
- raise tornado.web.HTTPError(403)
-
- _log.debug('action: %s' % action)
-
- for key, value in args.items():
- _log.debug('arg: %s\t\tval: %s' % (key, value))
-
- request = APIRequest(controller, action)
- context = APIRequestContext(self, user, project)
- d = request.send(context, **args)
- # d.addCallback(utils.debug)
-
- # TODO: Wrap response in AWS XML format
- d.addCallbacks(self._write_callback, self._error_callback)
-
- def _write_callback(self, data):
- self.set_header('Content-Type', 'text/xml')
- self.write(data)
- self.finish()
-
- def _error_callback(self, failure):
- try:
- failure.raiseException()
- except exception.ApiError as ex:
- self._error(type(ex).__name__ + "." + ex.code, ex.message)
- # TODO(vish): do something more useful with unknown exceptions
- except Exception as ex:
- self._error(type(ex).__name__, str(ex))
- raise
-
- def post(self, controller_name):
- self.execute(controller_name)
-
- def _error(self, code, message):
- self._status_code = 400
- self.set_header('Content-Type', 'text/xml')
- self.write('<?xml version="1.0"?>\n')
- self.write('<Response><Errors><Error><Code>%s</Code>'
- '<Message>%s</Message></Error></Errors>'
- '<RequestID>?</RequestID></Response>' % (code, message))
- self.finish()
-
-
-class APIServerApplication(tornado.web.Application):
- def __init__(self, controllers):
- tornado.web.Application.__init__(self, [
- (r'/', RootRequestHandler),
- (r'/cloudpipe/(.*)', nova.cloudpipe.api.CloudPipeRequestHandler),
- (r'/cloudpipe', nova.cloudpipe.api.CloudPipeRequestHandler),
- (r'/services/([A-Za-z0-9]+)/', APIRequestHandler),
- (r'/latest/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- (r'/2009-04-04/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- (r'/2008-09-01/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- (r'/2008-02-01/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- (r'/2007-12-15/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- (r'/2007-10-10/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- (r'/2007-08-29/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- (r'/2007-03-01/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- (r'/2007-01-19/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- (r'/1.0/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- ], pool=multiprocessing.Pool(4))
- self.controllers = controllers
diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py
deleted file mode 100644
index a28b888f3..000000000
--- a/nova/endpoint/cloud.py
+++ /dev/null
@@ -1,745 +0,0 @@
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2010 United States Government as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""
-Cloud Controller: Implementation of EC2 REST API calls, which are
-dispatched to other nodes via AMQP RPC. State is via distributed
-datastore.
-"""
-
-import base64
-import logging
-import os
-import time
-
-from twisted.internet import defer
-
-from nova import datastore
-from nova import exception
-from nova import flags
-from nova import rpc
-from nova import utils
-from nova.auth import rbac
-from nova.auth import manager
-from nova.compute import model
-from nova.compute.instance_types import INSTANCE_TYPES
-from nova.endpoint import images
-from nova.network import service as network_service
-from nova.network import model as network_model
-from nova.volume import service
-
-
-FLAGS = flags.FLAGS
-flags.DEFINE_string('cloud_topic', 'cloud', 'the topic clouds listen on')
-
-
-def _gen_key(user_id, key_name):
- """ Tuck this into AuthManager """
- try:
- mgr = manager.AuthManager()
- private_key, fingerprint = mgr.generate_key_pair(user_id, key_name)
- except Exception as ex:
- return {'exception': ex}
- return {'private_key': private_key, 'fingerprint': fingerprint}
-
-
-class CloudController(object):
- """ CloudController provides the critical dispatch between
- inbound API calls through the endpoint and messages
- sent to the other nodes.
-"""
- def __init__(self):
- self.instdir = model.InstanceDirectory()
- self.setup()
-
- @property
- def instances(self):
- """ All instances in the system, as dicts """
- return self.instdir.all
-
- @property
- def volumes(self):
- """ returns a list of all volumes """
- for volume_id in datastore.Redis.instance().smembers("volumes"):
- volume = service.get_volume(volume_id)
- yield volume
-
- def __str__(self):
- return 'CloudController'
-
- def setup(self):
- """ Ensure the keychains and folders exist. """
- # Create keys folder, if it doesn't exist
- if not os.path.exists(FLAGS.keys_path):
- os.makedirs(FLAGS.keys_path)
- # Gen root CA, if we don't have one
- root_ca_path = os.path.join(FLAGS.ca_path, FLAGS.ca_file)
- if not os.path.exists(root_ca_path):
- start = os.getcwd()
- os.chdir(FLAGS.ca_path)
- utils.runthis("Generating root CA: %s", "sh genrootca.sh")
- os.chdir(start)
- # TODO: Do this with M2Crypto instead
-
- def get_instance_by_ip(self, ip):
- return self.instdir.by_ip(ip)
-
- def _get_mpi_data(self, project_id):
- result = {}
- for instance in self.instdir.all:
- if instance['project_id'] == project_id:
- line = '%s slots=%d' % (instance['private_dns_name'],
- INSTANCE_TYPES[instance['instance_type']]['vcpus'])
- if instance['key_name'] in result:
- result[instance['key_name']].append(line)
- else:
- result[instance['key_name']] = [line]
- return result
-
- def get_metadata(self, ipaddress):
- i = self.get_instance_by_ip(ipaddress)
- if i is None:
- return None
- mpi = self._get_mpi_data(i['project_id'])
- if i['key_name']:
- keys = {
- '0': {
- '_name': i['key_name'],
- 'openssh-key': i['key_data']
- }
- }
- else:
- keys = ''
-
- address_record = network_model.FixedIp(i['private_dns_name'])
- if address_record:
- hostname = address_record['hostname']
- else:
- hostname = 'ip-%s' % i['private_dns_name'].replace('.', '-')
- data = {
- 'user-data': base64.b64decode(i['user_data']),
- 'meta-data': {
- 'ami-id': i['image_id'],
- 'ami-launch-index': i['ami_launch_index'],
- 'ami-manifest-path': 'FIXME', # image property
- 'block-device-mapping': { # TODO: replace with real data
- 'ami': 'sda1',
- 'ephemeral0': 'sda2',
- 'root': '/dev/sda1',
- 'swap': 'sda3'
- },
- 'hostname': hostname,
- 'instance-action': 'none',
- 'instance-id': i['instance_id'],
- 'instance-type': i.get('instance_type', ''),
- 'local-hostname': hostname,
- 'local-ipv4': i['private_dns_name'], # TODO: switch to IP
- 'kernel-id': i.get('kernel_id', ''),
- 'placement': {
- 'availaibility-zone': i.get('availability_zone', 'nova'),
- },
- 'public-hostname': hostname,
- 'public-ipv4': i.get('dns_name', ''), # TODO: switch to IP
- 'public-keys': keys,
- 'ramdisk-id': i.get('ramdisk_id', ''),
- 'reservation-id': i['reservation_id'],
- 'security-groups': i.get('groups', ''),
- 'mpi': mpi
- }
- }
- if False: # TODO: store ancestor ids
- data['ancestor-ami-ids'] = []
- if i.get('product_codes', None):
- data['product-codes'] = i['product_codes']
- return data
-
- @rbac.allow('all')
- def describe_availability_zones(self, context, **kwargs):
- return {'availabilityZoneInfo': [{'zoneName': 'nova',
- 'zoneState': 'available'}]}
-
- @rbac.allow('all')
- def describe_regions(self, context, region_name=None, **kwargs):
- # TODO(vish): region_name is an array. Support filtering
- return {'regionInfo': [{'regionName': 'nova',
- 'regionUrl': FLAGS.ec2_url}]}
-
- @rbac.allow('all')
- def describe_snapshots(self,
- context,
- snapshot_id=None,
- owner=None,
- restorable_by=None,
- **kwargs):
- return {'snapshotSet': [{'snapshotId': 'fixme',
- 'volumeId': 'fixme',
- 'status': 'fixme',
- 'startTime': 'fixme',
- 'progress': 'fixme',
- 'ownerId': 'fixme',
- 'volumeSize': 0,
- 'description': 'fixme'}]}
-
- @rbac.allow('all')
- def describe_key_pairs(self, context, key_name=None, **kwargs):
- key_pairs = context.user.get_key_pairs()
- if not key_name is None:
- key_pairs = [x for x in key_pairs if x.name in key_name]
-
- result = []
- for key_pair in key_pairs:
- # filter out the vpn keys
- suffix = FLAGS.vpn_key_suffix
- if context.user.is_admin() or not key_pair.name.endswith(suffix):
- result.append({
- 'keyName': key_pair.name,
- 'keyFingerprint': key_pair.fingerprint,
- })
-
- return {'keypairsSet': result}
-
- @rbac.allow('all')
- def create_key_pair(self, context, key_name, **kwargs):
- dcall = defer.Deferred()
- pool = context.handler.application.settings.get('pool')
- def _complete(kwargs):
- if 'exception' in kwargs:
- dcall.errback(kwargs['exception'])
- return
- dcall.callback({'keyName': key_name,
- 'keyFingerprint': kwargs['fingerprint'],
- 'keyMaterial': kwargs['private_key']})
- pool.apply_async(_gen_key, [context.user.id, key_name],
- callback=_complete)
- return dcall
-
- @rbac.allow('all')
- def delete_key_pair(self, context, key_name, **kwargs):
- context.user.delete_key_pair(key_name)
- # aws returns true even if the key doens't exist
- return True
-
- @rbac.allow('all')
- def describe_security_groups(self, context, group_names, **kwargs):
- groups = {'securityGroupSet': []}
-
- # Stubbed for now to unblock other things.
- return groups
-
- @rbac.allow('netadmin')
- def create_security_group(self, context, group_name, **kwargs):
- return True
-
- @rbac.allow('netadmin')
- def delete_security_group(self, context, group_name, **kwargs):
- return True
-
- @rbac.allow('projectmanager', 'sysadmin')
- def get_console_output(self, context, instance_id, **kwargs):
- # instance_id is passed in as a list of instances
- instance = self._get_instance(context, instance_id[0])
- return rpc.call('%s.%s' % (FLAGS.compute_topic, instance['node_name']),
- {"method": "get_console_output",
- "args": {"instance_id": instance_id[0]}})
-
- def _get_user_id(self, context):
- if context and context.user:
- return context.user.id
- else:
- return None
-
- @rbac.allow('projectmanager', 'sysadmin')
- def describe_volumes(self, context, **kwargs):
- volumes = []
- for volume in self.volumes:
- if context.user.is_admin() or volume['project_id'] == context.project.id:
- v = self.format_volume(context, volume)
- volumes.append(v)
- return defer.succeed({'volumeSet': volumes})
-
- def format_volume(self, context, volume):
- v = {}
- v['volumeId'] = volume['volume_id']
- v['status'] = volume['status']
- v['size'] = volume['size']
- v['availabilityZone'] = volume['availability_zone']
- v['createTime'] = volume['create_time']
- if context.user.is_admin():
- v['status'] = '%s (%s, %s, %s, %s)' % (
- volume.get('status', None),
- volume.get('user_id', None),
- volume.get('node_name', None),
- volume.get('instance_id', ''),
- volume.get('mountpoint', ''))
- if volume['attach_status'] == 'attached':
- v['attachmentSet'] = [{'attachTime': volume['attach_time'],
- 'deleteOnTermination': volume['delete_on_termination'],
- 'device': volume['mountpoint'],
- 'instanceId': volume['instance_id'],
- 'status': 'attached',
- 'volume_id': volume['volume_id']}]
- else:
- v['attachmentSet'] = [{}]
- return v
-
- @rbac.allow('projectmanager', 'sysadmin')
- @defer.inlineCallbacks
- def create_volume(self, context, size, **kwargs):
- # TODO(vish): refactor this to create the volume object here and tell service to create it
- result = yield rpc.call(FLAGS.volume_topic, {"method": "create_volume",
- "args": {"size": size,
- "user_id": context.user.id,
- "project_id": context.project.id}})
- # NOTE(vish): rpc returned value is in the result key in the dictionary
- volume = self._get_volume(context, result)
- defer.returnValue({'volumeSet': [self.format_volume(context, volume)]})
-
- def _get_address(self, context, public_ip):
- # FIXME(vish) this should move into network.py
- address = network_model.ElasticIp.lookup(public_ip)
- if address and (context.user.is_admin() or address['project_id'] == context.project.id):
- return address
- raise exception.NotFound("Address at ip %s not found" % public_ip)
-
- def _get_image(self, context, image_id):
- """passes in context because
- objectstore does its own authorization"""
- result = images.list(context, [image_id])
- if not result:
- raise exception.NotFound('Image %s could not be found' % image_id)
- image = result[0]
- return image
-
- def _get_instance(self, context, instance_id):
- for instance in self.instdir.all:
- if instance['instance_id'] == instance_id:
- if context.user.is_admin() or instance['project_id'] == context.project.id:
- return instance
- raise exception.NotFound('Instance %s could not be found' % instance_id)
-
- def _get_volume(self, context, volume_id):
- volume = service.get_volume(volume_id)
- if context.user.is_admin() or volume['project_id'] == context.project.id:
- return volume
- raise exception.NotFound('Volume %s could not be found' % volume_id)
-
- @rbac.allow('projectmanager', 'sysadmin')
- def attach_volume(self, context, volume_id, instance_id, device, **kwargs):
- volume = self._get_volume(context, volume_id)
- if volume['status'] == "attached":
- raise exception.ApiError("Volume is already attached")
- # TODO(vish): looping through all volumes is slow. We should probably maintain an index
- for vol in self.volumes:
- if vol['instance_id'] == instance_id and vol['mountpoint'] == device:
- raise exception.ApiError("Volume %s is already attached to %s" % (vol['volume_id'], vol['mountpoint']))
- volume.start_attach(instance_id, device)
- instance = self._get_instance(context, instance_id)
- compute_node = instance['node_name']
- rpc.cast('%s.%s' % (FLAGS.compute_topic, compute_node),
- {"method": "attach_volume",
- "args": {"volume_id": volume_id,
- "instance_id": instance_id,
- "mountpoint": device}})
- return defer.succeed({'attachTime': volume['attach_time'],
- 'device': volume['mountpoint'],
- 'instanceId': instance_id,
- 'requestId': context.request_id,
- 'status': volume['attach_status'],
- 'volumeId': volume_id})
-
- @rbac.allow('projectmanager', 'sysadmin')
- def detach_volume(self, context, volume_id, **kwargs):
- volume = self._get_volume(context, volume_id)
- instance_id = volume.get('instance_id', None)
- if not instance_id:
- raise exception.Error("Volume isn't attached to anything!")
- if volume['status'] == "available":
- raise exception.Error("Volume is already detached")
- try:
- volume.start_detach()
- instance = self._get_instance(context, instance_id)
- rpc.cast('%s.%s' % (FLAGS.compute_topic, instance['node_name']),
- {"method": "detach_volume",
- "args": {"instance_id": instance_id,
- "volume_id": volume_id}})
- except exception.NotFound:
- # If the instance doesn't exist anymore,
- # then we need to call detach blind
- volume.finish_detach()
- return defer.succeed({'attachTime': volume['attach_time'],
- 'device': volume['mountpoint'],
- 'instanceId': instance_id,
- 'requestId': context.request_id,
- 'status': volume['attach_status'],
- 'volumeId': volume_id})
-
- def _convert_to_set(self, lst, label):
- if lst == None or lst == []:
- return None
- if not isinstance(lst, list):
- lst = [lst]
- return [{label: x} for x in lst]
-
- @rbac.allow('all')
- def describe_instances(self, context, **kwargs):
- return defer.succeed(self._format_describe_instances(context))
-
- def _format_describe_instances(self, context):
- return { 'reservationSet': self._format_instances(context) }
-
- def _format_run_instances(self, context, reservation_id):
- i = self._format_instances(context, reservation_id)
- assert len(i) == 1
- return i[0]
-
- def _format_instances(self, context, reservation_id = None):
- reservations = {}
- if context.user.is_admin():
- instgenerator = self.instdir.all
- else:
- instgenerator = self.instdir.by_project(context.project.id)
- for instance in instgenerator:
- res_id = instance.get('reservation_id', 'Unknown')
- if reservation_id != None and reservation_id != res_id:
- continue
- if not context.user.is_admin():
- if instance['image_id'] == FLAGS.vpn_image_id:
- continue
- i = {}
- i['instance_id'] = instance.get('instance_id', None)
- i['image_id'] = instance.get('image_id', None)
- i['instance_state'] = {
- 'code': instance.get('state', 0),
- 'name': instance.get('state_description', 'pending')
- }
- i['public_dns_name'] = network_model.get_public_ip_for_instance(
- i['instance_id'])
- i['private_dns_name'] = instance.get('private_dns_name', None)
- if not i['public_dns_name']:
- i['public_dns_name'] = i['private_dns_name']
- i['dns_name'] = instance.get('dns_name', None)
- i['key_name'] = instance.get('key_name', None)
- if context.user.is_admin():
- i['key_name'] = '%s (%s, %s)' % (i['key_name'],
- instance.get('project_id', None),
- instance.get('node_name', ''))
- i['product_codes_set'] = self._convert_to_set(
- instance.get('product_codes', None), 'product_code')
- i['instance_type'] = instance.get('instance_type', None)
- i['launch_time'] = instance.get('launch_time', None)
- i['ami_launch_index'] = instance.get('ami_launch_index',
- None)
- if not reservations.has_key(res_id):
- r = {}
- r['reservation_id'] = res_id
- r['owner_id'] = instance.get('project_id', None)
- r['group_set'] = self._convert_to_set(
- instance.get('groups', None), 'group_id')
- r['instances_set'] = []
- reservations[res_id] = r
- reservations[res_id]['instances_set'].append(i)
-
- return list(reservations.values())
-
- @rbac.allow('all')
- def describe_addresses(self, context, **kwargs):
- return self.format_addresses(context)
-
- def format_addresses(self, context):
- addresses = []
- for address in network_model.ElasticIp.all():
- # TODO(vish): implement a by_project iterator for addresses
- if (context.user.is_admin() or
- address['project_id'] == context.project.id):
- address_rv = {
- 'public_ip': address['address'],
- 'instance_id': address.get('instance_id', 'free')
- }
- if context.user.is_admin():
- address_rv['instance_id'] = "%s (%s, %s)" % (
- address['instance_id'],
- address['user_id'],
- address['project_id'],
- )
- addresses.append(address_rv)
- return {'addressesSet': addresses}
-
- @rbac.allow('netadmin')
- @defer.inlineCallbacks
- def allocate_address(self, context, **kwargs):
- network_topic = yield self._get_network_topic(context)
- public_ip = yield rpc.call(network_topic,
- {"method": "allocate_elastic_ip",
- "args": {"user_id": context.user.id,
- "project_id": context.project.id}})
- defer.returnValue({'addressSet': [{'publicIp': public_ip}]})
-
- @rbac.allow('netadmin')
- @defer.inlineCallbacks
- def release_address(self, context, public_ip, **kwargs):
- # NOTE(vish): Should we make sure this works?
- network_topic = yield self._get_network_topic(context)
- rpc.cast(network_topic,
- {"method": "deallocate_elastic_ip",
- "args": {"elastic_ip": public_ip}})
- defer.returnValue({'releaseResponse': ["Address released."]})
-
- @rbac.allow('netadmin')
- @defer.inlineCallbacks
- def associate_address(self, context, instance_id, public_ip, **kwargs):
- instance = self._get_instance(context, instance_id)
- address = self._get_address(context, public_ip)
- network_topic = yield self._get_network_topic(context)
- rpc.cast(network_topic,
- {"method": "associate_elastic_ip",
- "args": {"elastic_ip": address['address'],
- "fixed_ip": instance['private_dns_name'],
- "instance_id": instance['instance_id']}})
- defer.returnValue({'associateResponse': ["Address associated."]})
-
- @rbac.allow('netadmin')
- @defer.inlineCallbacks
- def disassociate_address(self, context, public_ip, **kwargs):
- address = self._get_address(context, public_ip)
- network_topic = yield self._get_network_topic(context)
- rpc.cast(network_topic,
- {"method": "disassociate_elastic_ip",
- "args": {"elastic_ip": address['address']}})
- defer.returnValue({'disassociateResponse': ["Address disassociated."]})
-
- @defer.inlineCallbacks
- def _get_network_topic(self, context):
- """Retrieves the network host for a project"""
- host = network_service.get_host_for_project(context.project.id)
- if not host:
- host = yield rpc.call(FLAGS.network_topic,
- {"method": "set_network_host",
- "args": {"user_id": context.user.id,
- "project_id": context.project.id}})
- defer.returnValue('%s.%s' %(FLAGS.network_topic, host))
-
- @rbac.allow('projectmanager', 'sysadmin')
- @defer.inlineCallbacks
- def run_instances(self, context, **kwargs):
- # make sure user can access the image
- # vpn image is private so it doesn't show up on lists
- if kwargs['image_id'] != FLAGS.vpn_image_id:
- image = self._get_image(context, kwargs['image_id'])
-
- # FIXME(ja): if image is cloudpipe, this breaks
-
- # get defaults from imagestore
- image_id = image['imageId']
- kernel_id = image.get('kernelId', FLAGS.default_kernel)
- ramdisk_id = image.get('ramdiskId', FLAGS.default_ramdisk)
-
- # API parameters overrides of defaults
- kernel_id = kwargs.get('kernel_id', kernel_id)
- ramdisk_id = kwargs.get('ramdisk_id', ramdisk_id)
-
- if kernel_id == str(FLAGS.null_kernel):
- kernel_id = None
- ramdisk_id = None
-
- # make sure we have access to kernel and ramdisk
- if kernel_id:
- self._get_image(context, kernel_id)
- if ramdisk_id:
- self._get_image(context, ramdisk_id)
-
- logging.debug("Going to run instances...")
- reservation_id = utils.generate_uid('r')
- launch_time = time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime())
- key_data = None
- if kwargs.has_key('key_name'):
- key_pair = context.user.get_key_pair(kwargs['key_name'])
- if not key_pair:
- raise exception.ApiError('Key Pair %s not found' %
- kwargs['key_name'])
- key_data = key_pair.public_key
- network_topic = yield self._get_network_topic(context)
- # TODO: Get the real security group of launch in here
- security_group = "default"
- for num in range(int(kwargs['max_count'])):
- is_vpn = False
- if image_id == FLAGS.vpn_image_id:
- is_vpn = True
- inst = self.instdir.new()
- allocate_data = yield rpc.call(network_topic,
- {"method": "allocate_fixed_ip",
- "args": {"user_id": context.user.id,
- "project_id": context.project.id,
- "security_group": security_group,
- "is_vpn": is_vpn,
- "hostname": inst.instance_id}})
- inst['image_id'] = image_id
- inst['kernel_id'] = kernel_id or ''
- inst['ramdisk_id'] = ramdisk_id or ''
- inst['user_data'] = kwargs.get('user_data', '')
- inst['instance_type'] = kwargs.get('instance_type', 'm1.small')
- inst['reservation_id'] = reservation_id
- inst['launch_time'] = launch_time
- inst['key_data'] = key_data or ''
- inst['key_name'] = kwargs.get('key_name', '')
- inst['user_id'] = context.user.id
- inst['project_id'] = context.project.id
- inst['ami_launch_index'] = num
- inst['security_group'] = security_group
- inst['hostname'] = inst.instance_id
- for (key, value) in allocate_data.iteritems():
- inst[key] = value
-
- inst.save()
- rpc.cast(FLAGS.compute_topic,
- {"method": "run_instance",
- "args": {"instance_id": inst.instance_id}})
- logging.debug("Casting to node for %s's instance with IP of %s" %
- (context.user.name, inst['private_dns_name']))
- # TODO: Make Network figure out the network name from ip.
- defer.returnValue(self._format_run_instances(context, reservation_id))
-
- @rbac.allow('projectmanager', 'sysadmin')
- @defer.inlineCallbacks
- def terminate_instances(self, context, instance_id, **kwargs):
- logging.debug("Going to start terminating instances")
- network_topic = yield self._get_network_topic(context)
- for i in instance_id:
- logging.debug("Going to try and terminate %s" % i)
- try:
- instance = self._get_instance(context, i)
- except exception.NotFound:
- logging.warning("Instance %s was not found during terminate"
- % i)
- continue
- elastic_ip = network_model.get_public_ip_for_instance(i)
- if elastic_ip:
- logging.debug("Disassociating address %s" % elastic_ip)
- # NOTE(vish): Right now we don't really care if the ip is
- # disassociated. We may need to worry about
- # checking this later. Perhaps in the scheduler?
- rpc.cast(network_topic,
- {"method": "disassociate_elastic_ip",
- "args": {"elastic_ip": elastic_ip}})
-
- fixed_ip = instance.get('private_dns_name', None)
- if fixed_ip:
- logging.debug("Deallocating address %s" % fixed_ip)
- # NOTE(vish): Right now we don't really care if the ip is
- # actually removed. We may need to worry about
- # checking this later. Perhaps in the scheduler?
- rpc.cast(network_topic,
- {"method": "deallocate_fixed_ip",
- "args": {"fixed_ip": fixed_ip}})
-
- if instance.get('node_name', 'unassigned') != 'unassigned':
- # NOTE(joshua?): It's also internal default
- rpc.cast('%s.%s' % (FLAGS.compute_topic, instance['node_name']),
- {"method": "terminate_instance",
- "args": {"instance_id": i}})
- else:
- instance.destroy()
- defer.returnValue(True)
-
- @rbac.allow('projectmanager', 'sysadmin')
- def reboot_instances(self, context, instance_id, **kwargs):
- """instance_id is a list of instance ids"""
- for i in instance_id:
- instance = self._get_instance(context, i)
- rpc.cast('%s.%s' % (FLAGS.compute_topic, instance['node_name']),
- {"method": "reboot_instance",
- "args": {"instance_id": i}})
- return defer.succeed(True)
-
- @rbac.allow('projectmanager', 'sysadmin')
- def delete_volume(self, context, volume_id, **kwargs):
- # TODO: return error if not authorized
- volume = self._get_volume(context, volume_id)
- volume_node = volume['node_name']
- rpc.cast('%s.%s' % (FLAGS.volume_topic, volume_node),
- {"method": "delete_volume",
- "args": {"volume_id": volume_id}})
- return defer.succeed(True)
-
- @rbac.allow('all')
- def describe_images(self, context, image_id=None, **kwargs):
- # The objectstore does its own authorization for describe
- imageSet = images.list(context, image_id)
- return defer.succeed({'imagesSet': imageSet})
-
- @rbac.allow('projectmanager', 'sysadmin')
- def deregister_image(self, context, image_id, **kwargs):
- # FIXME: should the objectstore be doing these authorization checks?
- images.deregister(context, image_id)
- return defer.succeed({'imageId': image_id})
-
- @rbac.allow('projectmanager', 'sysadmin')
- def register_image(self, context, image_location=None, **kwargs):
- # FIXME: should the objectstore be doing these authorization checks?
- if image_location is None and kwargs.has_key('name'):
- image_location = kwargs['name']
- image_id = images.register(context, image_location)
- logging.debug("Registered %s as %s" % (image_location, image_id))
-
- return defer.succeed({'imageId': image_id})
-
- @rbac.allow('all')
- def describe_image_attribute(self, context, image_id, attribute, **kwargs):
- if attribute != 'launchPermission':
- raise exception.ApiError('attribute not supported: %s' % attribute)
- try:
- image = images.list(context, image_id)[0]
- except IndexError:
- raise exception.ApiError('invalid id: %s' % image_id)
- result = {'image_id': image_id, 'launchPermission': []}
- if image['isPublic']:
- result['launchPermission'].append({'group': 'all'})
- return defer.succeed(result)
-
- @rbac.allow('projectmanager', 'sysadmin')
- def modify_image_attribute(self, context, image_id, attribute, operation_type, **kwargs):
- # TODO(devcamcar): Support users and groups other than 'all'.
- if attribute != 'launchPermission':
- raise exception.ApiError('attribute not supported: %s' % attribute)
- if not 'user_group' in kwargs:
- raise exception.ApiError('user or group not specified')
- if len(kwargs['user_group']) != 1 and kwargs['user_group'][0] != 'all':
- raise exception.ApiError('only group "all" is supported')
- if not operation_type in ['add', 'remove']:
- raise exception.ApiError('operation_type must be add or remove')
- result = images.modify(context, image_id, operation_type)
- return defer.succeed(result)
-
- def update_state(self, topic, value):
- """ accepts status reports from the queue and consolidates them """
- # TODO(jmc): if an instance has disappeared from
- # the node, call instance_death
- if topic == "instances":
- return defer.succeed(True)
- aggregate_state = getattr(self, topic)
- node_name = value.keys()[0]
- items = value[node_name]
-
- logging.debug("Updating %s state for %s" % (topic, node_name))
-
- for item_id in items.keys():
- if (aggregate_state.has_key('pending') and
- aggregate_state['pending'].has_key(item_id)):
- del aggregate_state['pending'][item_id]
- aggregate_state[node_name] = items
-
- return defer.succeed(True)
diff --git a/nova/exception.py b/nova/exception.py
index 29bcb17f8..f157fab2d 100644
--- a/nova/exception.py
+++ b/nova/exception.py
@@ -26,6 +26,18 @@ import sys
import traceback
+class ProcessExecutionError(IOError):
+ def __init__(self, stdout=None, stderr=None, exit_code=None, cmd=None,
+ description=None):
+ if description is None:
+ description = "Unexpected error while running command."
+ if exit_code is None:
+ exit_code = '-'
+ message = "%s\nCommand: %s\nExit code: %s\nStdout: %r\nStderr: %r" % (
+ description, cmd, exit_code, stdout, stderr)
+ IOError.__init__(self, message)
+
+
class Error(Exception):
def __init__(self, message=None):
super(Error, self).__init__(message)
@@ -57,6 +69,9 @@ class NotEmpty(Error):
class Invalid(Error):
pass
+class InvalidInputException(Error):
+ pass
+
def wrap_exception(f):
def _wrap(*args, **kw):
diff --git a/nova/fakerabbit.py b/nova/fakerabbit.py
index 068025249..df5e61e6e 100644
--- a/nova/fakerabbit.py
+++ b/nova/fakerabbit.py
@@ -22,6 +22,7 @@ import logging
import Queue as queue
from carrot.backends import base
+from eventlet import greenthread
class Message(base.BaseMessage):
@@ -38,6 +39,7 @@ class Exchange(object):
def publish(self, message, routing_key=None):
logging.debug('(%s) publish (key: %s) %s',
self.name, routing_key, message)
+ routing_key = routing_key.split('.')[0]
if routing_key in self._routes:
for f in self._routes[routing_key]:
logging.debug('Publishing to route %s', f)
@@ -94,6 +96,18 @@ class Backend(object):
self._exchanges[exchange].bind(self._queues[queue].push,
routing_key)
+ def declare_consumer(self, queue, callback, *args, **kwargs):
+ self.current_queue = queue
+ self.current_callback = callback
+
+ def consume(self, *args, **kwargs):
+ while True:
+ item = self.get(self.current_queue)
+ if item:
+ self.current_callback(item)
+ raise StopIteration()
+ greenthread.sleep(0)
+
def get(self, queue, no_ack=False):
if not queue in self._queues or not self._queues[queue].size():
return None
@@ -102,6 +116,7 @@ class Backend(object):
message = Message(backend=self, body=message_data,
content_type=content_type,
content_encoding=content_encoding)
+ message.result = True
logging.debug('Getting from %s: %s', queue, message)
return message
diff --git a/nova/flags.py b/nova/flags.py
index 0815a338c..2b96a15f7 100644
--- a/nova/flags.py
+++ b/nova/flags.py
@@ -22,6 +22,7 @@ where they're used.
"""
import getopt
+import os
import socket
import sys
@@ -34,7 +35,7 @@ class FlagValues(gflags.FlagValues):
Unknown flags will be ignored when parsing the command line, but the
command line will be kept so that it can be replayed if new flags are
defined after the initial parsing.
-
+
"""
def __init__(self):
@@ -50,7 +51,7 @@ class FlagValues(gflags.FlagValues):
# leftover args at the end
sneaky_unparsed_args = {"value": None}
original_argv = list(argv)
-
+
if self.IsGnuGetOpt():
orig_getopt = getattr(getopt, 'gnu_getopt')
orig_name = 'gnu_getopt'
@@ -74,14 +75,14 @@ class FlagValues(gflags.FlagValues):
unparsed_args = sneaky_unparsed_args['value']
if unparsed_args:
if self.IsGnuGetOpt():
- args = argv[:1] + unparsed
+ args = argv[:1] + unparsed_args
else:
args = argv[:1] + original_argv[-len(unparsed_args):]
else:
args = argv[:1]
finally:
setattr(getopt, orig_name, orig_getopt)
-
+
# Store the arguments for later, we'll need them for new flags
# added at runtime
self.__dict__['__stored_argv'] = original_argv
@@ -92,7 +93,7 @@ class FlagValues(gflags.FlagValues):
def SetDirty(self, name):
"""Mark a flag as dirty so that accessing it will case a reparse."""
self.__dict__['__dirty'].append(name)
-
+
def IsDirty(self, name):
return name in self.__dict__['__dirty']
@@ -113,12 +114,12 @@ class FlagValues(gflags.FlagValues):
for k in self.__dict__['__dirty']:
setattr(self, k, getattr(new_flags, k))
self.ClearDirty()
-
+
def __setitem__(self, name, flag):
gflags.FlagValues.__setitem__(self, name, flag)
if self.WasAlreadyParsed():
self.SetDirty(name)
-
+
def __getitem__(self, name):
if self.IsDirty(name):
self.ParseNewFlags()
@@ -141,6 +142,7 @@ def _wrapper(func):
return _wrapped
+DEFINE = _wrapper(gflags.DEFINE)
DEFINE_string = _wrapper(gflags.DEFINE_string)
DEFINE_integer = _wrapper(gflags.DEFINE_integer)
DEFINE_bool = _wrapper(gflags.DEFINE_bool)
@@ -165,11 +167,14 @@ def DECLARE(name, module_string, flag_values=FLAGS):
# Define any app-specific flags in their own files, docs at:
# http://code.google.com/p/python-gflags/source/browse/trunk/gflags.py#39
+DEFINE_list('region_list',
+ [],
+ 'list of region=url pairs separated by commas')
DEFINE_string('connection_type', 'libvirt', 'libvirt, xenapi or fake')
DEFINE_integer('s3_port', 3333, 's3 port')
DEFINE_string('s3_host', '127.0.0.1', 's3 host')
-#DEFINE_string('cloud_topic', 'cloud', 'the topic clouds listen on')
DEFINE_string('compute_topic', 'compute', 'the topic compute nodes listen on')
+DEFINE_string('scheduler_topic', 'scheduler', 'the topic scheduler nodes listen on')
DEFINE_string('volume_topic', 'volume', 'the topic volume nodes listen on')
DEFINE_string('network_topic', 'network', 'the topic network nodes listen on')
@@ -183,6 +188,8 @@ DEFINE_string('rabbit_userid', 'guest', 'rabbit userid')
DEFINE_string('rabbit_password', 'guest', 'rabbit password')
DEFINE_string('rabbit_virtual_host', '/', 'rabbit virtual host')
DEFINE_string('control_exchange', 'nova', 'the main exchange to connect to')
+DEFINE_string('cc_host', '127.0.0.1', 'ip of api server')
+DEFINE_integer('cc_port', 8773, 'cloud controller port')
DEFINE_string('ec2_url', 'http://127.0.0.1:8773/services/Cloud',
'Url to ec2 api server')
@@ -204,9 +211,26 @@ DEFINE_string('vpn_key_suffix',
DEFINE_integer('auth_token_ttl', 3600, 'Seconds for auth tokens to linger')
+DEFINE_string('sql_connection',
+ 'sqlite:///%s/nova.sqlite' % os.path.abspath("./"),
+ 'connection string for sql database')
+
+DEFINE_string('compute_manager', 'nova.compute.manager.ComputeManager',
+ 'Manager for compute')
+DEFINE_string('network_manager', 'nova.network.manager.VlanManager',
+ 'Manager for network')
+DEFINE_string('volume_manager', 'nova.volume.manager.AOEManager',
+ 'Manager for volume')
+DEFINE_string('scheduler_manager', 'nova.scheduler.manager.SchedulerManager',
+ 'Manager for scheduler')
+
+# The service to use for image search and retrieval
+DEFINE_string('image_service', 'nova.image.service.LocalImageService',
+ 'The service to use for retrieving and searching for images.')
+
+DEFINE_string('host', socket.gethostname(),
+ 'name of this node')
+
# UNUSED
DEFINE_string('node_availability_zone', 'nova',
'availability zone of this node')
-DEFINE_string('node_name', socket.gethostname(),
- 'name of this node')
-
diff --git a/nova/endpoint/__init__.py b/nova/image/__init__.py
index e69de29bb..e69de29bb 100644
--- a/nova/endpoint/__init__.py
+++ b/nova/image/__init__.py
diff --git a/nova/image/service.py b/nova/image/service.py
new file mode 100644
index 000000000..5276e1312
--- /dev/null
+++ b/nova/image/service.py
@@ -0,0 +1,285 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import cPickle as pickle
+import httplib
+import json
+import logging
+import os.path
+import random
+import string
+import urlparse
+
+import webob.exc
+
+from nova import utils
+from nova import flags
+from nova import exception
+
+
+FLAGS = flags.FLAGS
+
+
+flags.DEFINE_string('glance_teller_address', 'http://127.0.0.1',
+ 'IP address or URL where Glance\'s Teller service resides')
+flags.DEFINE_string('glance_teller_port', '9191',
+ 'Port for Glance\'s Teller service')
+flags.DEFINE_string('glance_parallax_address', 'http://127.0.0.1',
+ 'IP address or URL where Glance\'s Parallax service resides')
+flags.DEFINE_string('glance_parallax_port', '9292',
+ 'Port for Glance\'s Parallax service')
+
+
+class BaseImageService(object):
+
+ """Base class for providing image search and retrieval services"""
+
+ def index(self):
+ """
+ Return a dict from opaque image id to image data.
+ """
+ raise NotImplementedError
+
+ def show(self, id):
+ """
+ Returns a dict containing image data for the given opaque image id.
+
+ :raises NotFound if the image does not exist
+ """
+ raise NotImplementedError
+
+ def create(self, data):
+ """
+ Store the image data and return the new image id.
+
+ :raises AlreadyExists if the image already exist.
+
+ """
+ raise NotImplementedError
+
+ def update(self, image_id, data):
+ """Replace the contents of the given image with the new data.
+
+ :raises NotFound if the image does not exist.
+
+ """
+ raise NotImplementedError
+
+ def delete(self, image_id):
+ """
+ Delete the given image.
+
+ :raises NotFound if the image does not exist.
+
+ """
+ raise NotImplementedError
+
+
+class TellerClient(object):
+
+ def __init__(self):
+ self.address = FLAGS.glance_teller_address
+ self.port = FLAGS.glance_teller_port
+ url = urlparse.urlparse(self.address)
+ self.netloc = url.netloc
+ self.connection_type = {'http': httplib.HTTPConnection,
+ 'https': httplib.HTTPSConnection}[url.scheme]
+
+
+class ParallaxClient(object):
+
+ def __init__(self):
+ self.address = FLAGS.glance_parallax_address
+ self.port = FLAGS.glance_parallax_port
+ url = urlparse.urlparse(self.address)
+ self.netloc = url.netloc
+ self.connection_type = {'http': httplib.HTTPConnection,
+ 'https': httplib.HTTPSConnection}[url.scheme]
+
+ def get_images(self):
+ """
+ Returns a list of image data mappings from Parallax
+ """
+ try:
+ c = self.connection_type(self.netloc, self.port)
+ c.request("GET", "images")
+ res = c.getresponse()
+ if res.status == 200:
+ # Parallax returns a JSONified dict(images=image_list)
+ data = json.loads(res.read())['images']
+ return data
+ else:
+ logging.warn("Parallax returned HTTP error %d from "
+ "request for /images", res.status_int)
+ return []
+ finally:
+ c.close()
+
+ def get_image_metadata(self, image_id):
+ """
+ Returns a mapping of image metadata from Parallax
+ """
+ try:
+ c = self.connection_type(self.netloc, self.port)
+ c.request("GET", "images/%s" % image_id)
+ res = c.getresponse()
+ if res.status == 200:
+ # Parallax returns a JSONified dict(image=image_info)
+ data = json.loads(res.read())['image']
+ return data
+ else:
+ # TODO(jaypipes): log the error?
+ return None
+ finally:
+ c.close()
+
+ def add_image_metadata(self, image_metadata):
+ """
+ Tells parallax about an image's metadata
+ """
+ pass
+
+ def update_image_metadata(self, image_id, image_metadata):
+ """
+ Updates Parallax's information about an image
+ """
+ pass
+
+ def delete_image_metadata(self, image_id):
+ """
+ Deletes Parallax's information about an image
+ """
+ pass
+
+
+class GlanceImageService(BaseImageService):
+
+ """Provides storage and retrieval of disk image objects within Glance."""
+
+ def __init__(self):
+ self.teller = TellerClient()
+ self.parallax = ParallaxClient()
+
+ def index(self):
+ """
+ Calls out to Parallax for a list of images available
+ """
+ images = self.parallax.get_images()
+ return images
+
+ def show(self, id):
+ """
+ Returns a dict containing image data for the given opaque image id.
+ """
+ image = self.parallax.get_image_metadata(id)
+ if image:
+ return image
+ raise exception.NotFound
+
+ def create(self, data):
+ """
+ Store the image data and return the new image id.
+
+ :raises AlreadyExists if the image already exist.
+
+ """
+ return self.parallax.add_image_metadata(data)
+
+ def update(self, image_id, data):
+ """Replace the contents of the given image with the new data.
+
+ :raises NotFound if the image does not exist.
+
+ """
+ self.parallax.update_image_metadata(image_id, data)
+
+ def delete(self, image_id):
+ """
+ Delete the given image.
+
+ :raises NotFound if the image does not exist.
+
+ """
+ self.parallax.delete_image_metadata(image_id)
+
+ def delete_all(self):
+ """
+ Clears out all images
+ """
+ pass
+
+
+class LocalImageService(BaseImageService):
+
+ """Image service storing images to local disk.
+
+ It assumes that image_ids are integers."""
+
+ def __init__(self):
+ self._path = "/tmp/nova/images"
+ try:
+ os.makedirs(self._path)
+ except OSError: # exists
+ pass
+
+ def _path_to(self, image_id):
+ return os.path.join(self._path, str(image_id))
+
+ def _ids(self):
+ """The list of all image ids."""
+ return [int(i) for i in os.listdir(self._path)]
+
+ def index(self):
+ return [ self.show(id) for id in self._ids() ]
+
+ def show(self, id):
+ try:
+ return pickle.load(open(self._path_to(id)))
+ except IOError:
+ raise exception.NotFound
+
+ def create(self, data):
+ """
+ Store the image data and return the new image id.
+ """
+ id = random.randint(0, 2**32-1)
+ data['id'] = id
+ self.update(id, data)
+ return id
+
+ def update(self, image_id, data):
+ """Replace the contents of the given image with the new data."""
+ try:
+ pickle.dump(data, open(self._path_to(image_id), 'w'))
+ except IOError:
+ raise exception.NotFound
+
+ def delete(self, image_id):
+ """
+ Delete the given image. Raises OSError if the image does not exist.
+ """
+ try:
+ os.unlink(self._path_to(image_id))
+ except IOError:
+ raise exception.NotFound
+
+ def delete_all(self):
+ """
+ Clears out all images in local directory
+ """
+ for id in self._ids():
+ os.unlink(self._path_to(id))
diff --git a/nova/manager.py b/nova/manager.py
new file mode 100644
index 000000000..56ba7d3f6
--- /dev/null
+++ b/nova/manager.py
@@ -0,0 +1,52 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+Base class for managers of different parts of the system
+"""
+
+from nova import utils
+from nova import flags
+
+from twisted.internet import defer
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string('db_driver', 'nova.db.api',
+ 'driver to use for volume creation')
+
+
+class Manager(object):
+ """DB driver is injected in the init method"""
+ def __init__(self, host=None, db_driver=None):
+ if not host:
+ host = FLAGS.host
+ self.host = host
+ if not db_driver:
+ db_driver = FLAGS.db_driver
+ self.db = utils.import_object(db_driver) # pylint: disable-msg=C0103
+
+ @defer.inlineCallbacks
+ def periodic_tasks(self, context=None):
+ """Tasks to be run at a periodic interval"""
+ yield
+
+ def init_host(self):
+ """Do any initialization that needs to be run if this is a standalone service.
+
+ Child classes should override this method.
+ """
+ pass
diff --git a/nova/network/linux_net.py b/nova/network/linux_net.py
index 9e5aabd97..c0be0e8cc 100644
--- a/nova/network/linux_net.py
+++ b/nova/network/linux_net.py
@@ -23,111 +23,148 @@ import signal
# TODO(ja): does the definition of network_path belong here?
+from nova import db
from nova import flags
from nova import utils
+def _bin_file(script):
+ """Return the absolute path to scipt in the bin directory"""
+ return os.path.abspath(os.path.join(__file__, "../../../bin", script))
+
+
FLAGS = flags.FLAGS
flags.DEFINE_string('dhcpbridge_flagfile',
'/etc/nova/nova-dhcpbridge.conf',
'location of flagfile for dhcpbridge')
-
-def execute(cmd, addl_env=None):
- """Wrapper around utils.execute for fake_network"""
- if FLAGS.fake_network:
- logging.debug("FAKE NET: %s", cmd)
- return "fake", 0
+flags.DEFINE_string('networks_path', utils.abspath('../networks'),
+ 'Location to keep network config files')
+flags.DEFINE_string('public_interface', 'vlan1',
+ 'Interface for public IP addresses')
+flags.DEFINE_string('bridge_dev', 'eth0',
+ 'network device for bridges')
+flags.DEFINE_string('dhcpbridge', _bin_file('nova-dhcpbridge'),
+ 'location of nova-dhcpbridge')
+flags.DEFINE_string('routing_source_ip', '127.0.0.1',
+ 'Public IP of network host')
+flags.DEFINE_bool('use_nova_chains', False,
+ 'use the nova_ routing chains instead of default')
+
+DEFAULT_PORTS = [("tcp", 80), ("tcp", 22), ("udp", 1194), ("tcp", 443)]
+
+def init_host():
+ """Basic networking setup goes here"""
+ # NOTE(devcamcar): Cloud public DNAT entries, CloudPipe port
+ # forwarding entries and a default DNAT entry.
+ _confirm_rule("PREROUTING", "-t nat -s 0.0.0.0/0 "
+ "-d 169.254.169.254/32 -p tcp -m tcp --dport 80 -j DNAT "
+ "--to-destination %s:%s" % (FLAGS.cc_host, FLAGS.cc_port))
+
+ # NOTE(devcamcar): Cloud public SNAT entries and the default
+ # SNAT rule for outbound traffic.
+ _confirm_rule("POSTROUTING", "-t nat -s %s "
+ "-j SNAT --to-source %s"
+ % (FLAGS.fixed_range, FLAGS.routing_source_ip))
+
+ _confirm_rule("POSTROUTING", "-t nat -s %s -j MASQUERADE" %
+ FLAGS.fixed_range)
+ _confirm_rule("POSTROUTING", "-t nat -s %(range)s -d %(range)s -j ACCEPT" %
+ {'range': FLAGS.fixed_range})
+
+def bind_floating_ip(floating_ip):
+ """Bind ip to public interface"""
+ _execute("sudo ip addr add %s dev %s" % (floating_ip,
+ FLAGS.public_interface))
+
+
+def unbind_floating_ip(floating_ip):
+ """Unbind a public ip from public interface"""
+ _execute("sudo ip addr del %s dev %s" % (floating_ip,
+ FLAGS.public_interface))
+
+
+def ensure_vlan_forward(public_ip, port, private_ip):
+ """Sets up forwarding rules for vlan"""
+ _confirm_rule("FORWARD", "-d %s -p udp --dport 1194 -j ACCEPT" %
+ private_ip)
+ _confirm_rule("PREROUTING",
+ "-t nat -d %s -p udp --dport %s -j DNAT --to %s:1194"
+ % (public_ip, port, private_ip))
+
+
+def ensure_floating_forward(floating_ip, fixed_ip):
+ """Ensure floating ip forwarding rule"""
+ _confirm_rule("PREROUTING", "-t nat -d %s -j DNAT --to %s"
+ % (floating_ip, fixed_ip))
+ _confirm_rule("POSTROUTING", "-t nat -s %s -j SNAT --to %s"
+ % (fixed_ip, floating_ip))
+ # TODO(joshua): Get these from the secgroup datastore entries
+ _confirm_rule("FORWARD", "-d %s -p icmp -j ACCEPT"
+ % (fixed_ip))
+ for (protocol, port) in DEFAULT_PORTS:
+ _confirm_rule("FORWARD","-d %s -p %s --dport %s -j ACCEPT"
+ % (fixed_ip, protocol, port))
+
+
+def remove_floating_forward(floating_ip, fixed_ip):
+ """Remove forwarding for floating ip"""
+ _remove_rule("PREROUTING", "-t nat -d %s -j DNAT --to %s"
+ % (floating_ip, fixed_ip))
+ _remove_rule("POSTROUTING", "-t nat -s %s -j SNAT --to %s"
+ % (fixed_ip, floating_ip))
+ _remove_rule("FORWARD", "-d %s -p icmp -j ACCEPT"
+ % (fixed_ip))
+ for (protocol, port) in DEFAULT_PORTS:
+ _remove_rule("FORWARD", "-d %s -p %s --dport %s -j ACCEPT"
+ % (fixed_ip, protocol, port))
+
+
+def ensure_vlan_bridge(vlan_num, bridge, net_attrs=None):
+ """Create a vlan and bridge unless they already exist"""
+ interface = ensure_vlan(vlan_num)
+ ensure_bridge(bridge, interface, net_attrs)
+
+
+def ensure_vlan(vlan_num):
+ """Create a vlan unless it already exists"""
+ interface = "vlan%s" % vlan_num
+ if not _device_exists(interface):
+ logging.debug("Starting VLAN inteface %s", interface)
+ _execute("sudo vconfig set_name_type VLAN_PLUS_VID_NO_PAD")
+ _execute("sudo vconfig add %s %s" % (FLAGS.bridge_dev, vlan_num))
+ _execute("sudo ifconfig %s up" % interface)
+ return interface
+
+
+def ensure_bridge(bridge, interface, net_attrs=None):
+ """Create a bridge unless it already exists"""
+ if not _device_exists(bridge):
+ logging.debug("Starting Bridge inteface for %s", interface)
+ _execute("sudo brctl addbr %s" % bridge)
+ _execute("sudo brctl setfd %s 0" % bridge)
+ # _execute("sudo brctl setageing %s 10" % bridge)
+ _execute("sudo brctl stp %s off" % bridge)
+ _execute("sudo brctl addif %s %s" % (bridge, interface))
+ if net_attrs:
+ _execute("sudo ifconfig %s %s broadcast %s netmask %s up" % \
+ (bridge,
+ net_attrs['gateway'],
+ net_attrs['broadcast'],
+ net_attrs['netmask']))
else:
- return utils.execute(cmd, addl_env=addl_env)
-
-
-def runthis(desc, cmd):
- """Wrapper around utils.runthis for fake_network"""
- if FLAGS.fake_network:
- return execute(cmd)
- else:
- return utils.runthis(desc, cmd)
-
-
-def device_exists(device):
- """Check if ethernet device exists"""
- (_out, err) = execute("ifconfig %s" % device)
- return not err
-
-
-def confirm_rule(cmd):
- """Delete and re-add iptables rule"""
- execute("sudo iptables --delete %s" % (cmd))
- execute("sudo iptables -I %s" % (cmd))
-
-
-def remove_rule(cmd):
- """Remove iptables rule"""
- execute("sudo iptables --delete %s" % (cmd))
-
-
-def bind_public_ip(public_ip, interface):
- """Bind ip to an interface"""
- runthis("Binding IP to interface: %s",
- "sudo ip addr add %s dev %s" % (public_ip, interface))
-
-
-def unbind_public_ip(public_ip, interface):
- """Unbind a public ip from an interface"""
- runthis("Binding IP to interface: %s",
- "sudo ip addr del %s dev %s" % (public_ip, interface))
-
-
-def vlan_create(net):
- """Create a vlan on on a bridge device unless vlan already exists"""
- if not device_exists("vlan%s" % net['vlan']):
- logging.debug("Starting VLAN inteface for %s network", (net['vlan']))
- execute("sudo vconfig set_name_type VLAN_PLUS_VID_NO_PAD")
- execute("sudo vconfig add %s %s" % (FLAGS.bridge_dev, net['vlan']))
- execute("sudo ifconfig vlan%s up" % (net['vlan']))
-
-
-def bridge_create(net):
- """Create a bridge on a vlan unless it already exists"""
- if not device_exists(net['bridge_name']):
- logging.debug("Starting Bridge inteface for %s network", (net['vlan']))
- execute("sudo brctl addbr %s" % (net['bridge_name']))
- execute("sudo brctl setfd %s 0" % (net.bridge_name))
- # execute("sudo brctl setageing %s 10" % (net.bridge_name))
- execute("sudo brctl stp %s off" % (net['bridge_name']))
- execute("sudo brctl addif %s vlan%s" % (net['bridge_name'],
- net['vlan']))
- if net.bridge_gets_ip:
- execute("sudo ifconfig %s %s broadcast %s netmask %s up" % \
- (net['bridge_name'], net.gateway, net.broadcast, net.netmask))
- confirm_rule("FORWARD --in-interface %s -j ACCEPT" %
- (net['bridge_name']))
- else:
- execute("sudo ifconfig %s up" % net['bridge_name'])
-
-
-def _dnsmasq_cmd(net):
- """Builds dnsmasq command"""
- cmd = ['sudo -E dnsmasq',
- ' --strict-order',
- ' --bind-interfaces',
- ' --conf-file=',
- ' --pid-file=%s' % dhcp_file(net['vlan'], 'pid'),
- ' --listen-address=%s' % net.dhcp_listen_address,
- ' --except-interface=lo',
- ' --dhcp-range=%s,static,120s' % net.dhcp_range_start,
- ' --dhcp-hostsfile=%s' % dhcp_file(net['vlan'], 'conf'),
- ' --dhcp-script=%s' % bin_file('nova-dhcpbridge'),
- ' --leasefile-ro']
- return ''.join(cmd)
+ _execute("sudo ifconfig %s up" % bridge)
+ _confirm_rule("FORWARD", "--in-interface %s -j ACCEPT" % bridge)
+ _confirm_rule("FORWARD", "--out-interface %s -j ACCEPT" % bridge)
-def host_dhcp(address):
- """Return a host string for an address object"""
- return "%s,%s.novalocal,%s" % (address['mac'],
- address['hostname'],
- address.address)
+def get_dhcp_hosts(context, network_id):
+ """Get a string containing a network's hosts config in dnsmasq format"""
+ hosts = []
+ for fixed_ip_ref in db.network_get_associated_fixed_ips(context,
+ network_id):
+ hosts.append(_host_dhcp(fixed_ip_ref))
+ return '\n'.join(hosts)
# TODO(ja): if the system has restarted or pid numbers have wrapped
@@ -135,65 +172,124 @@ def host_dhcp(address):
# dnsmasq. As well, sending a HUP only reloads the hostfile,
# so any configuration options (like dchp-range, vlan, ...)
# aren't reloaded
-def start_dnsmasq(network):
+def update_dhcp(context, network_id):
"""(Re)starts a dnsmasq server for a given network
if a dnsmasq instance is already running then send a HUP
signal causing it to reload, otherwise spawn a new instance
"""
- with open(dhcp_file(network['vlan'], 'conf'), 'w') as f:
- for address in network.assigned_objs:
- f.write("%s\n" % host_dhcp(address))
+ network_ref = db.network_get(context, network_id)
+
+ conffile = _dhcp_file(network_ref['bridge'], 'conf')
+ with open(conffile, 'w') as f:
+ f.write(get_dhcp_hosts(context, network_id))
+
+ # Make sure dnsmasq can actually read it (it setuid()s to "nobody")
+ os.chmod(conffile, 0644)
- pid = dnsmasq_pid_for(network)
+ pid = _dnsmasq_pid_for(network_ref['bridge'])
# if dnsmasq is already running, then tell it to reload
if pid:
# TODO(ja): use "/proc/%d/cmdline" % (pid) to determine if pid refers
# correct dnsmasq process
try:
- os.kill(pid, signal.SIGHUP)
+ _execute('sudo kill -HUP %d' % pid)
return
except Exception as exc: # pylint: disable-msg=W0703
logging.debug("Hupping dnsmasq threw %s", exc)
# FLAGFILE and DNSMASQ_INTERFACE in env
env = {'FLAGFILE': FLAGS.dhcpbridge_flagfile,
- 'DNSMASQ_INTERFACE': network['bridge_name']}
- execute(_dnsmasq_cmd(network), addl_env=env)
+ 'DNSMASQ_INTERFACE': network_ref['bridge']}
+ command = _dnsmasq_cmd(network_ref)
+ _execute(command, addl_env=env)
+
+
+def _host_dhcp(fixed_ip_ref):
+ """Return a host string for an address"""
+ instance_ref = fixed_ip_ref['instance']
+ return "%s,%s.novalocal,%s" % (instance_ref['mac_address'],
+ instance_ref['hostname'],
+ fixed_ip_ref['address'])
+
+
+def _execute(cmd, *args, **kwargs):
+ """Wrapper around utils._execute for fake_network"""
+ if FLAGS.fake_network:
+ logging.debug("FAKE NET: %s", cmd)
+ return "fake", 0
+ else:
+ return utils.execute(cmd, *args, **kwargs)
+
+
+def _device_exists(device):
+ """Check if ethernet device exists"""
+ (_out, err) = _execute("ifconfig %s" % device, check_exit_code=False)
+ return not err
-def stop_dnsmasq(network):
+def _confirm_rule(chain, cmd):
+ """Delete and re-add iptables rule"""
+ if FLAGS.use_nova_chains:
+ chain = "nova_%s" % chain.lower()
+ _execute("sudo iptables --delete %s %s" % (chain, cmd), check_exit_code=False)
+ _execute("sudo iptables -I %s %s" % (chain, cmd))
+
+
+def _remove_rule(chain, cmd):
+ """Remove iptables rule"""
+ if FLAGS.use_nova_chains:
+ chain = "%S" % chain.lower()
+ _execute("sudo iptables --delete %s %s" % (chain, cmd))
+
+
+def _dnsmasq_cmd(net):
+ """Builds dnsmasq command"""
+ cmd = ['sudo -E dnsmasq',
+ ' --strict-order',
+ ' --bind-interfaces',
+ ' --conf-file=',
+ ' --pid-file=%s' % _dhcp_file(net['bridge'], 'pid'),
+ ' --listen-address=%s' % net['gateway'],
+ ' --except-interface=lo',
+ ' --dhcp-range=%s,static,120s' % net['dhcp_start'],
+ ' --dhcp-hostsfile=%s' % _dhcp_file(net['bridge'], 'conf'),
+ ' --dhcp-script=%s' % FLAGS.dhcpbridge,
+ ' --leasefile-ro']
+ return ''.join(cmd)
+
+
+def _stop_dnsmasq(network):
"""Stops the dnsmasq instance for a given network"""
- pid = dnsmasq_pid_for(network)
+ pid = _dnsmasq_pid_for(network)
if pid:
try:
- os.kill(pid, signal.SIGTERM)
+ _execute('sudo kill -TERM %d' % pid)
except Exception as exc: # pylint: disable-msg=W0703
logging.debug("Killing dnsmasq threw %s", exc)
-def dhcp_file(vlan, kind):
- """Return path to a pid, leases or conf file for a vlan"""
+def _dhcp_file(bridge, kind):
+ """Return path to a pid, leases or conf file for a bridge"""
- return os.path.abspath("%s/nova-%s.%s" % (FLAGS.networks_path, vlan, kind))
-
-
-def bin_file(script):
- """Return the absolute path to scipt in the bin directory"""
- return os.path.abspath(os.path.join(__file__, "../../../bin", script))
+ if not os.path.exists(FLAGS.networks_path):
+ os.makedirs(FLAGS.networks_path)
+ return os.path.abspath("%s/nova-%s.%s" % (FLAGS.networks_path,
+ bridge,
+ kind))
-def dnsmasq_pid_for(network):
- """Returns he pid for prior dnsmasq instance for a vlan
+def _dnsmasq_pid_for(bridge):
+ """Returns the pid for prior dnsmasq instance for a bridge
Returns None if no pid file exists
If machine has rebooted pid might be incorrect (caller should check)
"""
- pid_file = dhcp_file(network['vlan'], 'pid')
+ pid_file = _dhcp_file(bridge, 'pid')
if os.path.exists(pid_file):
with open(pid_file, 'r') as f:
diff --git a/nova/network/manager.py b/nova/network/manager.py
new file mode 100644
index 000000000..2ea1c1aa0
--- /dev/null
+++ b/nova/network/manager.py
@@ -0,0 +1,428 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Network Hosts are responsible for allocating ips and setting up network
+"""
+
+import datetime
+import logging
+import math
+
+import IPy
+from twisted.internet import defer
+
+from nova import db
+from nova import exception
+from nova import flags
+from nova import manager
+from nova import utils
+
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string('flat_network_bridge', 'br100',
+ 'Bridge for simple network instances')
+flags.DEFINE_string('flat_network_dns', '8.8.4.4',
+ 'Dns for simple network')
+flags.DEFINE_string('flat_network_dhcp_start', '192.168.0.2',
+ 'Dhcp start for FlatDhcp')
+flags.DEFINE_integer('vlan_start', 100, 'First VLAN for private networks')
+flags.DEFINE_integer('num_networks', 1000, 'Number of networks to support')
+flags.DEFINE_string('vpn_ip', utils.get_my_ip(),
+ 'Public IP for the cloudpipe VPN servers')
+flags.DEFINE_integer('vpn_start', 1000, 'First Vpn port for private networks')
+flags.DEFINE_integer('network_size', 256,
+ 'Number of addresses in each private subnet')
+flags.DEFINE_string('floating_range', '4.4.4.0/24', 'Floating IP address block')
+flags.DEFINE_string('fixed_range', '10.0.0.0/8', 'Fixed IP address block')
+flags.DEFINE_integer('cnt_vpn_clients', 5,
+ 'Number of addresses reserved for vpn clients')
+flags.DEFINE_string('network_driver', 'nova.network.linux_net',
+ 'Driver to use for network creation')
+flags.DEFINE_bool('update_dhcp_on_disassociate', False,
+ 'Whether to update dhcp when fixed_ip is disassociated')
+flags.DEFINE_integer('fixed_ip_disassociate_timeout', 600,
+ 'Seconds after which a deallocated ip is disassociated')
+
+
+class AddressAlreadyAllocated(exception.Error):
+ """Address was already allocated"""
+ pass
+
+
+class NetworkManager(manager.Manager):
+ """Implements common network manager functionality
+
+ This class must be subclassed.
+ """
+ def __init__(self, network_driver=None, *args, **kwargs):
+ if not network_driver:
+ network_driver = FLAGS.network_driver
+ self.driver = utils.import_object(network_driver)
+ super(NetworkManager, self).__init__(*args, **kwargs)
+
+ def init_host(self):
+ # Set up networking for the projects for which we're already
+ # the designated network host.
+ for network in self.db.host_get_networks(None, self.host):
+ self._on_set_network_host(None, network['id'])
+
+ def set_network_host(self, context, network_id):
+ """Safely sets the host of the network"""
+ logging.debug("setting network host")
+ host = self.db.network_set_host(None,
+ network_id,
+ self.host)
+ self._on_set_network_host(context, network_id)
+ return host
+
+ def allocate_fixed_ip(self, context, instance_id, *args, **kwargs):
+ """Gets a fixed ip from the pool"""
+ raise NotImplementedError()
+
+ def deallocate_fixed_ip(self, context, address, *args, **kwargs):
+ """Returns a fixed ip to the pool"""
+ raise NotImplementedError()
+
+ def setup_fixed_ip(self, context, address):
+ """Sets up rules for fixed ip"""
+ raise NotImplementedError()
+
+ def _on_set_network_host(self, context, network_id):
+ """Called when this host becomes the host for a network"""
+ raise NotImplementedError()
+
+ def setup_compute_network(self, context, instance_id):
+ """Sets up matching network for compute hosts"""
+ raise NotImplementedError()
+
+ def allocate_floating_ip(self, context, project_id):
+ """Gets an floating ip from the pool"""
+ # TODO(vish): add floating ips through manage command
+ return self.db.floating_ip_allocate_address(context,
+ self.host,
+ project_id)
+
+ def associate_floating_ip(self, context, floating_address, fixed_address):
+ """Associates an floating ip to a fixed ip"""
+ self.db.floating_ip_fixed_ip_associate(context,
+ floating_address,
+ fixed_address)
+ self.driver.bind_floating_ip(floating_address)
+ self.driver.ensure_floating_forward(floating_address, fixed_address)
+
+ def disassociate_floating_ip(self, context, floating_address):
+ """Disassociates a floating ip"""
+ fixed_address = self.db.floating_ip_disassociate(context,
+ floating_address)
+ self.driver.unbind_floating_ip(floating_address)
+ self.driver.remove_floating_forward(floating_address, fixed_address)
+
+ def deallocate_floating_ip(self, context, floating_address):
+ """Returns an floating ip to the pool"""
+ self.db.floating_ip_deallocate(context, floating_address)
+
+ def lease_fixed_ip(self, context, mac, address):
+ """Called by dhcp-bridge when ip is leased"""
+ logging.debug("Leasing IP %s", address)
+ fixed_ip_ref = self.db.fixed_ip_get_by_address(context, address)
+ instance_ref = fixed_ip_ref['instance']
+ if not instance_ref:
+ raise exception.Error("IP %s leased that isn't associated" %
+ address)
+ if instance_ref['mac_address'] != mac:
+ raise exception.Error("IP %s leased to bad mac %s vs %s" %
+ (address, instance_ref['mac_address'], mac))
+ self.db.fixed_ip_update(context,
+ fixed_ip_ref['address'],
+ {'leased': True})
+ if not fixed_ip_ref['allocated']:
+ logging.warn("IP %s leased that was already deallocated", address)
+
+ def release_fixed_ip(self, context, mac, address):
+ """Called by dhcp-bridge when ip is released"""
+ logging.debug("Releasing IP %s", address)
+ fixed_ip_ref = self.db.fixed_ip_get_by_address(context, address)
+ instance_ref = fixed_ip_ref['instance']
+ if not instance_ref:
+ raise exception.Error("IP %s released that isn't associated" %
+ address)
+ if instance_ref['mac_address'] != mac:
+ raise exception.Error("IP %s released from bad mac %s vs %s" %
+ (address, instance_ref['mac_address'], mac))
+ if not fixed_ip_ref['leased']:
+ logging.warn("IP %s released that was not leased", address)
+ self.db.fixed_ip_update(context,
+ fixed_ip_ref['str_id'],
+ {'leased': False})
+ if not fixed_ip_ref['allocated']:
+ self.db.fixed_ip_disassociate(context, address)
+ # NOTE(vish): dhcp server isn't updated until next setup, this
+ # means there will stale entries in the conf file
+ # the code below will update the file if necessary
+ if FLAGS.update_dhcp_on_disassociate:
+ network_ref = self.db.fixed_ip_get_network(context, address)
+ self.driver.update_dhcp(context, network_ref['id'])
+
+ def get_network(self, context):
+ """Get the network for the current context"""
+ raise NotImplementedError()
+
+ def create_networks(self, context, num_networks, network_size,
+ *args, **kwargs):
+ """Create networks based on parameters"""
+ raise NotImplementedError()
+
+ @property
+ def _bottom_reserved_ips(self): # pylint: disable-msg=R0201
+ """Number of reserved ips at the bottom of the range"""
+ return 2 # network, gateway
+
+ @property
+ def _top_reserved_ips(self): # pylint: disable-msg=R0201
+ """Number of reserved ips at the top of the range"""
+ return 1 # broadcast
+
+ def _create_fixed_ips(self, context, network_id):
+ """Create all fixed ips for network"""
+ network_ref = self.db.network_get(context, network_id)
+ # NOTE(vish): Should these be properties of the network as opposed
+ # to properties of the manager class?
+ bottom_reserved = self._bottom_reserved_ips
+ top_reserved = self._top_reserved_ips
+ project_net = IPy.IP(network_ref['cidr'])
+ num_ips = len(project_net)
+ for index in range(num_ips):
+ address = str(project_net[index])
+ if index < bottom_reserved or num_ips - index < top_reserved:
+ reserved = True
+ else:
+ reserved = False
+ self.db.fixed_ip_create(context, {'network_id': network_id,
+ 'address': address,
+ 'reserved': reserved})
+
+
+class FlatManager(NetworkManager):
+ """Basic network where no vlans are used"""
+
+ def allocate_fixed_ip(self, context, instance_id, *args, **kwargs):
+ """Gets a fixed ip from the pool"""
+ # TODO(vish): when this is called by compute, we can associate compute
+ # with a network, or a cluster of computes with a network
+ # and use that network here with a method like
+ # network_get_by_compute_host
+ network_ref = self.db.network_get_by_bridge(None,
+ FLAGS.flat_network_bridge)
+ address = self.db.fixed_ip_associate_pool(None,
+ network_ref['id'],
+ instance_id)
+ self.db.fixed_ip_update(context, address, {'allocated': True})
+ return address
+
+ def deallocate_fixed_ip(self, context, address, *args, **kwargs):
+ """Returns a fixed ip to the pool"""
+ self.db.fixed_ip_update(context, address, {'allocated': False})
+ self.db.fixed_ip_disassociate(None, address)
+
+ def setup_compute_network(self, context, instance_id):
+ """Network is created manually"""
+ pass
+
+ def setup_fixed_ip(self, context, address):
+ """Currently no setup"""
+ pass
+
+ def create_networks(self, context, cidr, num_networks, network_size,
+ *args, **kwargs):
+ """Create networks based on parameters"""
+ fixed_net = IPy.IP(cidr)
+ for index in range(num_networks):
+ start = index * network_size
+ significant_bits = 32 - int(math.log(network_size, 2))
+ cidr = "%s/%s" % (fixed_net[start], significant_bits)
+ project_net = IPy.IP(cidr)
+ net = {}
+ net['cidr'] = cidr
+ net['netmask'] = str(project_net.netmask())
+ net['gateway'] = str(project_net[1])
+ net['broadcast'] = str(project_net.broadcast())
+ net['dhcp_start'] = str(project_net[2])
+ network_ref = self.db.network_create_safe(context, net)
+ if network_ref:
+ self._create_fixed_ips(context, network_ref['id'])
+
+ def get_network(self, context):
+ """Get the network for the current context"""
+ # NOTE(vish): To support mutilple network hosts, This could randomly
+ # select from multiple networks instead of just
+ # returning the one. It could also potentially be done
+ # in the scheduler.
+ return self.db.network_get_by_bridge(context,
+ FLAGS.flat_network_bridge)
+
+ def _on_set_network_host(self, context, network_id):
+ """Called when this host becomes the host for a network"""
+ net = {}
+ net['injected'] = True
+ net['bridge'] = FLAGS.flat_network_bridge
+ net['dns'] = FLAGS.flat_network_dns
+ self.db.network_update(context, network_id, net)
+
+
+
+class FlatDHCPManager(NetworkManager):
+ """Flat networking with dhcp"""
+
+ def setup_fixed_ip(self, context, address):
+ """Setup dhcp for this network"""
+ network_ref = db.fixed_ip_get_by_address(context, address)
+ self.driver.update_dhcp(context, network_ref['id'])
+
+ def deallocate_fixed_ip(self, context, address, *args, **kwargs):
+ """Returns a fixed ip to the pool"""
+ self.db.fixed_ip_update(context, address, {'allocated': False})
+
+ def _on_set_network_host(self, context, network_id):
+ """Called when this host becomes the host for a project"""
+ super(FlatDHCPManager, self)._on_set_network_host(context, network_id)
+ network_ref = self.db.network_get(context, network_id)
+ self.db.network_update(context,
+ network_id,
+ {'dhcp_start': FLAGS.flat_network_dhcp_start})
+ self.driver.ensure_bridge(network_ref['bridge'],
+ FLAGS.bridge_dev,
+ network_ref)
+
+
+class VlanManager(NetworkManager):
+ """Vlan network with dhcp"""
+
+ @defer.inlineCallbacks
+ def periodic_tasks(self, context=None):
+ """Tasks to be run at a periodic interval"""
+ yield super(VlanManager, self).periodic_tasks(context)
+ now = datetime.datetime.utcnow()
+ timeout = FLAGS.fixed_ip_disassociate_timeout
+ time = now - datetime.timedelta(seconds=timeout)
+ num = self.db.fixed_ip_disassociate_all_by_timeout(context,
+ self.host,
+ time)
+ if num:
+ logging.debug("Dissassociated %s stale fixed ip(s)", num)
+
+ def init_host(self):
+ """Do any initialization that needs to be run if this is a
+ standalone service.
+ """
+ super(VlanManager, self).init_host()
+ self.driver.init_host()
+
+ def allocate_fixed_ip(self, context, instance_id, *args, **kwargs):
+ """Gets a fixed ip from the pool"""
+ # TODO(vish): This should probably be getting project_id from
+ # the instance, but it is another trip to the db.
+ # Perhaps this method should take an instance_ref.
+ network_ref = self.db.project_get_network(context, context.project.id)
+ if kwargs.get('vpn', None):
+ address = network_ref['vpn_private_address']
+ self.db.fixed_ip_associate(None, address, instance_id)
+ else:
+ address = self.db.fixed_ip_associate_pool(None,
+ network_ref['id'],
+ instance_id)
+ self.db.fixed_ip_update(context, address, {'allocated': True})
+ return address
+
+ def deallocate_fixed_ip(self, context, address, *args, **kwargs):
+ """Returns a fixed ip to the pool"""
+ self.db.fixed_ip_update(context, address, {'allocated': False})
+
+ def setup_fixed_ip(self, context, address):
+ """Sets forwarding rules and dhcp for fixed ip"""
+ fixed_ip_ref = self.db.fixed_ip_get_by_address(context, address)
+ network_ref = self.db.fixed_ip_get_network(context, address)
+ if self.db.instance_is_vpn(context, fixed_ip_ref['instance_id']):
+ self.driver.ensure_vlan_forward(network_ref['vpn_public_address'],
+ network_ref['vpn_public_port'],
+ network_ref['vpn_private_address'])
+ self.driver.update_dhcp(context, network_ref['id'])
+
+ def setup_compute_network(self, context, instance_id):
+ """Sets up matching network for compute hosts"""
+ network_ref = db.network_get_by_instance(context, instance_id)
+ self.driver.ensure_vlan_bridge(network_ref['vlan'],
+ network_ref['bridge'])
+
+ def restart_nets(self):
+ """Ensure the network for each user is enabled"""
+ # TODO(vish): Implement this
+ pass
+
+ def create_networks(self, context, cidr, num_networks, network_size,
+ vlan_start, vpn_start):
+ """Create networks based on parameters"""
+ fixed_net = IPy.IP(cidr)
+ for index in range(num_networks):
+ vlan = vlan_start + index
+ start = index * network_size
+ significant_bits = 32 - int(math.log(network_size, 2))
+ cidr = "%s/%s" % (fixed_net[start], significant_bits)
+ project_net = IPy.IP(cidr)
+ net = {}
+ net['cidr'] = cidr
+ net['netmask'] = str(project_net.netmask())
+ net['gateway'] = str(project_net[1])
+ net['broadcast'] = str(project_net.broadcast())
+ net['vpn_private_address'] = str(project_net[2])
+ net['dhcp_start'] = str(project_net[3])
+ net['vlan'] = vlan
+ net['bridge'] = 'br%s' % vlan
+ # NOTE(vish): This makes ports unique accross the cloud, a more
+ # robust solution would be to make them unique per ip
+ net['vpn_public_port'] = vpn_start + index
+ network_ref = self.db.network_create_safe(context, net)
+ if network_ref:
+ self._create_fixed_ips(context, network_ref['id'])
+
+ def get_network(self, context):
+ """Get the network for the current context"""
+ return self.db.project_get_network(None, context.project.id)
+
+ def _on_set_network_host(self, context, network_id):
+ """Called when this host becomes the host for a network"""
+ network_ref = self.db.network_get(context, network_id)
+ net = {}
+ net['vpn_public_address'] = FLAGS.vpn_ip
+ db.network_update(context, network_id, net)
+ self.driver.ensure_vlan_bridge(network_ref['vlan'],
+ network_ref['bridge'],
+ network_ref)
+ self.driver.update_dhcp(context, network_id)
+
+ @property
+ def _bottom_reserved_ips(self):
+ """Number of reserved ips at the bottom of the range"""
+ return super(VlanManager, self)._bottom_reserved_ips + 1 # vpn server
+
+ @property
+ def _top_reserved_ips(self):
+ """Number of reserved ips at the top of the range"""
+ parent_reserved = super(VlanManager, self)._top_reserved_ips
+ return parent_reserved + FLAGS.cnt_vpn_clients
+
diff --git a/nova/network/model.py b/nova/network/model.py
deleted file mode 100644
index 6e4fcc47e..000000000
--- a/nova/network/model.py
+++ /dev/null
@@ -1,634 +0,0 @@
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2010 United States Government as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""
-Model Classes for network control, including VLANs, DHCP, and IP allocation.
-"""
-
-import logging
-import os
-import time
-
-import IPy
-from nova import datastore
-from nova import exception as nova_exception
-from nova import flags
-from nova import utils
-from nova.auth import manager
-from nova.network import exception
-from nova.network import linux_net
-
-
-FLAGS = flags.FLAGS
-flags.DEFINE_string('networks_path', utils.abspath('../networks'),
- 'Location to keep network config files')
-flags.DEFINE_integer('public_vlan', 1, 'VLAN for public IP addresses')
-flags.DEFINE_string('public_interface', 'vlan1',
- 'Interface for public IP addresses')
-flags.DEFINE_string('bridge_dev', 'eth1',
- 'network device for bridges')
-flags.DEFINE_integer('vlan_start', 100, 'First VLAN for private networks')
-flags.DEFINE_integer('vlan_end', 4093, 'Last VLAN for private networks')
-flags.DEFINE_integer('network_size', 256,
- 'Number of addresses in each private subnet')
-flags.DEFINE_string('public_range', '4.4.4.0/24', 'Public IP address block')
-flags.DEFINE_string('private_range', '10.0.0.0/8', 'Private IP address block')
-flags.DEFINE_integer('cnt_vpn_clients', 5,
- 'Number of addresses reserved for vpn clients')
-flags.DEFINE_integer('cloudpipe_start_port', 12000,
- 'Starting port for mapped CloudPipe external ports')
-
-
-logging.getLogger().setLevel(logging.DEBUG)
-
-
-class Vlan(datastore.BasicModel):
- """Tracks vlans assigned to project it the datastore"""
- def __init__(self, project, vlan): # pylint: disable-msg=W0231
- """
- Since we don't want to try and find a vlan by its identifier,
- but by a project id, we don't call super-init.
- """
- self.project_id = project
- self.vlan_id = vlan
-
- @property
- def identifier(self):
- """Datastore identifier"""
- return "%s:%s" % (self.project_id, self.vlan_id)
-
- @classmethod
- def create(cls, project, vlan):
- """Create a Vlan object"""
- instance = cls(project, vlan)
- instance.save()
- return instance
-
- @classmethod
- @datastore.absorb_connection_error
- def lookup(cls, project):
- """Returns object by project if it exists in datastore or None"""
- set_name = cls._redis_set_name(cls.__name__)
- vlan = datastore.Redis.instance().hget(set_name, project)
- if vlan:
- return cls(project, vlan)
- else:
- return None
-
- @classmethod
- @datastore.absorb_connection_error
- def dict_by_project(cls):
- """A hash of project:vlan"""
- set_name = cls._redis_set_name(cls.__name__)
- return datastore.Redis.instance().hgetall(set_name) or {}
-
- @classmethod
- @datastore.absorb_connection_error
- def dict_by_vlan(cls):
- """A hash of vlan:project"""
- set_name = cls._redis_set_name(cls.__name__)
- retvals = {}
- hashset = datastore.Redis.instance().hgetall(set_name) or {}
- for (key, val) in hashset.iteritems():
- retvals[val] = key
- return retvals
-
- @classmethod
- @datastore.absorb_connection_error
- def all(cls):
- set_name = cls._redis_set_name(cls.__name__)
- elements = datastore.Redis.instance().hgetall(set_name)
- for project in elements:
- yield cls(project, elements[project])
-
- @datastore.absorb_connection_error
- def save(self):
- """
- Vlan saves state into a giant hash named "vlans", with keys of
- project_id and value of vlan number. Therefore, we skip the
- default way of saving into "vlan:ID" and adding to a set of "vlans".
- """
- set_name = self._redis_set_name(self.__class__.__name__)
- datastore.Redis.instance().hset(set_name,
- self.project_id,
- self.vlan_id)
-
- @datastore.absorb_connection_error
- def destroy(self):
- """Removes the object from the datastore"""
- set_name = self._redis_set_name(self.__class__.__name__)
- datastore.Redis.instance().hdel(set_name, self.project_id)
-
- def subnet(self):
- """Returns a string containing the subnet"""
- vlan = int(self.vlan_id)
- network = IPy.IP(FLAGS.private_range)
- start = (vlan - FLAGS.vlan_start) * FLAGS.network_size
- # minus one for the gateway.
- return "%s-%s" % (network[start],
- network[start + FLAGS.network_size - 1])
-
-
-class FixedIp(datastore.BasicModel):
- """Represents a fixed ip in the datastore"""
-
- def __init__(self, address):
- self.address = address
- super(FixedIp, self).__init__()
-
- @property
- def identifier(self):
- return self.address
-
- # NOTE(vish): address states allocated, leased, deallocated
- def default_state(self):
- return {'address': self.address,
- 'state': 'none'}
-
- @classmethod
- # pylint: disable-msg=R0913
- def create(cls, user_id, project_id, address, mac, hostname, network_id):
- """Creates an FixedIp object"""
- addr = cls(address)
- addr['user_id'] = user_id
- addr['project_id'] = project_id
- addr['mac'] = mac
- if hostname is None:
- hostname = "ip-%s" % address.replace('.', '-')
- addr['hostname'] = hostname
- addr['network_id'] = network_id
- addr['state'] = 'allocated'
- addr.save()
- return addr
-
- def save(self):
- is_new = self.is_new_record()
- success = super(FixedIp, self).save()
- if success and is_new:
- self.associate_with("network", self['network_id'])
-
- def destroy(self):
- self.unassociate_with("network", self['network_id'])
- super(FixedIp, self).destroy()
-
-
-class ElasticIp(FixedIp):
- """Represents an elastic ip in the datastore"""
- override_type = "address"
-
- def default_state(self):
- return {'address': self.address,
- 'instance_id': 'available',
- 'private_ip': 'available'}
-
-
-# CLEANUP:
-# TODO(ja): does vlanpool "keeper" need to know the min/max -
-# shouldn't FLAGS always win?
-class BaseNetwork(datastore.BasicModel):
- """Implements basic logic for allocating ips in a network"""
- override_type = 'network'
- address_class = FixedIp
-
- @property
- def identifier(self):
- """Datastore identifier"""
- return self.network_id
-
- def default_state(self):
- """Default values for new objects"""
- return {'network_id': self.network_id, 'network_str': self.network_str}
-
- @classmethod
- # pylint: disable-msg=R0913
- def create(cls, user_id, project_id, security_group, vlan, network_str):
- """Create a BaseNetwork object"""
- network_id = "%s:%s" % (project_id, security_group)
- net = cls(network_id, network_str)
- net['user_id'] = user_id
- net['project_id'] = project_id
- net["vlan"] = vlan
- net["bridge_name"] = "br%s" % vlan
- net.save()
- return net
-
- def __init__(self, network_id, network_str=None):
- self.network_id = network_id
- self.network_str = network_str
- super(BaseNetwork, self).__init__()
- self.save()
-
- @property
- def network(self):
- """Returns a string representing the network"""
- return IPy.IP(self['network_str'])
-
- @property
- def netmask(self):
- """Returns the netmask of this network"""
- return self.network.netmask()
-
- @property
- def gateway(self):
- """Returns the network gateway address"""
- return self.network[1]
-
- @property
- def broadcast(self):
- """Returns the network broadcast address"""
- return self.network.broadcast()
-
- @property
- def bridge_name(self):
- """Returns the bridge associated with this network"""
- return "br%s" % (self["vlan"])
-
- @property
- def user(self):
- """Returns the user associated with this network"""
- return manager.AuthManager().get_user(self['user_id'])
-
- @property
- def project(self):
- """Returns the project associated with this network"""
- return manager.AuthManager().get_project(self['project_id'])
-
- # pylint: disable-msg=R0913
- def _add_host(self, user_id, project_id, ip_address, mac, hostname):
- """Add a host to the datastore"""
- self.address_class.create(user_id, project_id, ip_address,
- mac, hostname, self.identifier)
-
- def _rem_host(self, ip_address):
- """Remove a host from the datastore"""
- self.address_class(ip_address).destroy()
-
- @property
- def assigned(self):
- """Returns a list of all assigned addresses"""
- return self.address_class.associated_keys('network', self.identifier)
-
- @property
- def assigned_objs(self):
- """Returns a list of all assigned addresses as objects"""
- return self.address_class.associated_to('network', self.identifier)
-
- def get_address(self, ip_address):
- """Returns a specific ip as an object"""
- if ip_address in self.assigned:
- return self.address_class(ip_address)
- return None
-
- @property
- def available(self):
- """Returns a list of all available addresses in the network"""
- for idx in range(self.num_bottom_reserved_ips,
- len(self.network) - self.num_top_reserved_ips):
- address = str(self.network[idx])
- if not address in self.assigned:
- yield address
-
- @property
- def num_bottom_reserved_ips(self):
- """Returns number of ips reserved at the bottom of the range"""
- return 2 # Network, Gateway
-
- @property
- def num_top_reserved_ips(self):
- """Returns number of ips reserved at the top of the range"""
- return 1 # Broadcast
-
- def allocate_ip(self, user_id, project_id, mac, hostname=None):
- """Allocates an ip to a mac address"""
- for address in self.available:
- logging.debug("Allocating IP %s to %s", address, project_id)
- self._add_host(user_id, project_id, address, mac, hostname)
- self.express(address=address)
- return address
- raise exception.NoMoreAddresses("Project %s with network %s" %
- (project_id, str(self.network)))
-
- def lease_ip(self, ip_str):
- """Called when DHCP lease is activated"""
- if not ip_str in self.assigned:
- raise exception.AddressNotAllocated()
- address = self.get_address(ip_str)
- if address:
- logging.debug("Leasing allocated IP %s", ip_str)
- address['state'] = 'leased'
- address.save()
-
- def release_ip(self, ip_str):
- """Called when DHCP lease expires
-
- Removes the ip from the assigned list"""
- if not ip_str in self.assigned:
- raise exception.AddressNotAllocated()
- logging.debug("Releasing IP %s", ip_str)
- self._rem_host(ip_str)
- self.deexpress(address=ip_str)
-
- def deallocate_ip(self, ip_str):
- """Deallocates an allocated ip"""
- if not ip_str in self.assigned:
- raise exception.AddressNotAllocated()
- address = self.get_address(ip_str)
- if address:
- if address['state'] != 'leased':
- # NOTE(vish): address hasn't been leased, so release it
- self.release_ip(ip_str)
- else:
- logging.debug("Deallocating allocated IP %s", ip_str)
- address['state'] == 'deallocated'
- address.save()
-
- def express(self, address=None):
- """Set up network. Implemented in subclasses"""
- pass
-
- def deexpress(self, address=None):
- """Tear down network. Implemented in subclasses"""
- pass
-
-
-class BridgedNetwork(BaseNetwork):
- """
- Virtual Network that can express itself to create a vlan and
- a bridge (with or without an IP address/netmask/gateway)
-
- properties:
- bridge_name - string (example value: br42)
- vlan - integer (example value: 42)
- bridge_dev - string (example: eth0)
- bridge_gets_ip - boolean used during bridge creation
-
- if bridge_gets_ip then network address for bridge uses the properties:
- gateway
- broadcast
- netmask
- """
-
- bridge_gets_ip = False
- override_type = 'network'
-
- @classmethod
- def get_network_for_project(cls,
- user_id,
- project_id,
- security_group='default'):
- """Returns network for a given project"""
- vlan = get_vlan_for_project(project_id)
- network_str = vlan.subnet()
- return cls.create(user_id, project_id, security_group, vlan.vlan_id,
- network_str)
-
- def __init__(self, *args, **kwargs):
- super(BridgedNetwork, self).__init__(*args, **kwargs)
- self['bridge_dev'] = FLAGS.bridge_dev
- self.save()
-
- def express(self, address=None):
- super(BridgedNetwork, self).express(address=address)
- linux_net.vlan_create(self)
- linux_net.bridge_create(self)
-
-
-class DHCPNetwork(BridgedNetwork):
- """Network supporting DHCP"""
- bridge_gets_ip = True
- override_type = 'network'
-
- def __init__(self, *args, **kwargs):
- super(DHCPNetwork, self).__init__(*args, **kwargs)
- if not(os.path.exists(FLAGS.networks_path)):
- os.makedirs(FLAGS.networks_path)
-
- @property
- def num_bottom_reserved_ips(self):
- # For cloudpipe
- return super(DHCPNetwork, self).num_bottom_reserved_ips + 1
-
- @property
- def num_top_reserved_ips(self):
- return super(DHCPNetwork, self).num_top_reserved_ips + \
- FLAGS.cnt_vpn_clients
-
- @property
- def dhcp_listen_address(self):
- """Address where dhcp server should listen"""
- return self.gateway
-
- @property
- def dhcp_range_start(self):
- """Starting address dhcp server should use"""
- return self.network[self.num_bottom_reserved_ips]
-
- def express(self, address=None):
- super(DHCPNetwork, self).express(address=address)
- if len(self.assigned) > 0:
- logging.debug("Starting dnsmasq server for network with vlan %s",
- self['vlan'])
- linux_net.start_dnsmasq(self)
- else:
- logging.debug("Not launching dnsmasq: no hosts.")
- self.express_vpn()
-
- def allocate_vpn_ip(self, user_id, project_id, mac, hostname=None):
- """Allocates the reserved ip to a vpn instance"""
- address = str(self.network[2])
- self._add_host(user_id, project_id, address, mac, hostname)
- self.express(address=address)
- return address
-
- def express_vpn(self):
- """Sets up routing rules for vpn"""
- private_ip = str(self.network[2])
- linux_net.confirm_rule("FORWARD -d %s -p udp --dport 1194 -j ACCEPT"
- % (private_ip, ))
- linux_net.confirm_rule(
- "PREROUTING -t nat -d %s -p udp --dport %s -j DNAT --to %s:1194"
- % (self.project.vpn_ip, self.project.vpn_port, private_ip))
-
- def deexpress(self, address=None):
- # if this is the last address, stop dns
- super(DHCPNetwork, self).deexpress(address=address)
- if len(self.assigned) == 0:
- linux_net.stop_dnsmasq(self)
- else:
- linux_net.start_dnsmasq(self)
-
-DEFAULT_PORTS = [("tcp", 80), ("tcp", 22), ("udp", 1194), ("tcp", 443)]
-
-
-class PublicNetworkController(BaseNetwork):
- """Handles elastic ips"""
- override_type = 'network'
- address_class = ElasticIp
-
- def __init__(self, *args, **kwargs):
- network_id = "public:default"
- super(PublicNetworkController, self).__init__(network_id,
- FLAGS.public_range, *args, **kwargs)
- self['user_id'] = "public"
- self['project_id'] = "public"
- self["create_time"] = time.strftime('%Y-%m-%dT%H:%M:%SZ',
- time.gmtime())
- self["vlan"] = FLAGS.public_vlan
- self.save()
- self.express()
-
- def deallocate_ip(self, ip_str):
- # NOTE(vish): cleanup is now done on release by the parent class
- self.release_ip(ip_str)
-
- def associate_address(self, public_ip, private_ip, instance_id):
- """Associates a public ip to a private ip and instance id"""
- if not public_ip in self.assigned:
- raise exception.AddressNotAllocated()
- # TODO(josh): Keep an index going both ways
- for addr in self.assigned_objs:
- if addr.get('private_ip', None) == private_ip:
- raise exception.AddressAlreadyAssociated()
- addr = self.get_address(public_ip)
- if addr.get('private_ip', 'available') != 'available':
- raise exception.AddressAlreadyAssociated()
- addr['private_ip'] = private_ip
- addr['instance_id'] = instance_id
- addr.save()
- self.express(address=public_ip)
-
- def disassociate_address(self, public_ip):
- """Disassociates a public ip with its private ip"""
- if not public_ip in self.assigned:
- raise exception.AddressNotAllocated()
- addr = self.get_address(public_ip)
- if addr.get('private_ip', 'available') == 'available':
- raise exception.AddressNotAssociated()
- self.deexpress(address=public_ip)
- addr['private_ip'] = 'available'
- addr['instance_id'] = 'available'
- addr.save()
-
- def express(self, address=None):
- if address:
- if not address in self.assigned:
- raise exception.AddressNotAllocated()
- addresses = [self.get_address(address)]
- else:
- addresses = self.assigned_objs
- for addr in addresses:
- if addr.get('private_ip', 'available') == 'available':
- continue
- public_ip = addr['address']
- private_ip = addr['private_ip']
- linux_net.bind_public_ip(public_ip, FLAGS.public_interface)
- linux_net.confirm_rule("PREROUTING -t nat -d %s -j DNAT --to %s"
- % (public_ip, private_ip))
- linux_net.confirm_rule("POSTROUTING -t nat -s %s -j SNAT --to %s"
- % (private_ip, public_ip))
- # TODO(joshua): Get these from the secgroup datastore entries
- linux_net.confirm_rule("FORWARD -d %s -p icmp -j ACCEPT"
- % (private_ip))
- for (protocol, port) in DEFAULT_PORTS:
- linux_net.confirm_rule(
- "FORWARD -d %s -p %s --dport %s -j ACCEPT"
- % (private_ip, protocol, port))
-
- def deexpress(self, address=None):
- addr = self.get_address(address)
- private_ip = addr['private_ip']
- linux_net.unbind_public_ip(address, FLAGS.public_interface)
- linux_net.remove_rule("PREROUTING -t nat -d %s -j DNAT --to %s"
- % (address, private_ip))
- linux_net.remove_rule("POSTROUTING -t nat -s %s -j SNAT --to %s"
- % (private_ip, address))
- linux_net.remove_rule("FORWARD -d %s -p icmp -j ACCEPT"
- % (private_ip))
- for (protocol, port) in DEFAULT_PORTS:
- linux_net.remove_rule("FORWARD -d %s -p %s --dport %s -j ACCEPT"
- % (private_ip, protocol, port))
-
-
-# FIXME(todd): does this present a race condition, or is there some
-# piece of architecture that mitigates it (only one queue
-# listener per net)?
-def get_vlan_for_project(project_id):
- """Allocate vlan IDs to individual users"""
- vlan = Vlan.lookup(project_id)
- if vlan:
- return vlan
- known_vlans = Vlan.dict_by_vlan()
- for vnum in range(FLAGS.vlan_start, FLAGS.vlan_end):
- vstr = str(vnum)
- if not vstr in known_vlans:
- return Vlan.create(project_id, vnum)
- old_project_id = known_vlans[vstr]
- if not manager.AuthManager().get_project(old_project_id):
- vlan = Vlan.lookup(old_project_id)
- if vlan:
- # NOTE(todd): This doesn't check for vlan id match, because
- # it seems to be assumed that vlan<=>project is
- # always a 1:1 mapping. It could be made way
- # sexier if it didn't fight against the way
- # BasicModel worked and used associate_with
- # to build connections to projects.
- # NOTE(josh): This is here because we want to make sure we
- # don't orphan any VLANs. It is basically
- # garbage collection for after projects abandoned
- # their reference.
- vlan.destroy()
- vlan.project_id = project_id
- vlan.save()
- return vlan
- else:
- return Vlan.create(project_id, vnum)
- raise exception.AddressNotAllocated("Out of VLANs")
-
-
-def get_project_network(project_id, security_group='default'):
- """Gets a project's private network, allocating one if needed"""
- project = manager.AuthManager().get_project(project_id)
- if not project:
- raise nova_exception.NotFound("Project %s doesn't exist." % project_id)
- manager_id = project.project_manager_id
- return DHCPNetwork.get_network_for_project(manager_id,
- project.id,
- security_group)
-
-
-def get_network_by_address(address):
- """Gets the network for a given private ip"""
- address_record = FixedIp.lookup(address)
- if not address_record:
- raise exception.AddressNotAllocated()
- return get_project_network(address_record['project_id'])
-
-
-def get_network_by_interface(iface, security_group='default'):
- """Gets the network for a given interface"""
- vlan = iface.rpartition("br")[2]
- project_id = Vlan.dict_by_vlan().get(vlan)
- return get_project_network(project_id, security_group)
-
-
-def get_public_ip_for_instance(instance_id):
- """Gets the public ip for a given instance"""
- # FIXME(josh): this should be a lookup - iteration won't scale
- for address_record in ElasticIp.all():
- if address_record.get('instance_id', 'available') == instance_id:
- return address_record['address']
diff --git a/nova/network/service.py b/nova/network/service.py
deleted file mode 100644
index d3aa1c46f..000000000
--- a/nova/network/service.py
+++ /dev/null
@@ -1,257 +0,0 @@
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2010 United States Government as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""
-Network Hosts are responsible for allocating ips and setting up network
-"""
-
-from nova import datastore
-from nova import exception
-from nova import flags
-from nova import service
-from nova import utils
-from nova.auth import manager
-from nova.network import exception
-from nova.network import model
-from nova.network import vpn
-
-
-FLAGS = flags.FLAGS
-flags.DEFINE_string('network_type',
- 'flat',
- 'Service Class for Networking')
-flags.DEFINE_string('flat_network_bridge', 'br100',
- 'Bridge for simple network instances')
-flags.DEFINE_list('flat_network_ips',
- ['192.168.0.2', '192.168.0.3', '192.168.0.4'],
- 'Available ips for simple network')
-flags.DEFINE_string('flat_network_network', '192.168.0.0',
- 'Network for simple network')
-flags.DEFINE_string('flat_network_netmask', '255.255.255.0',
- 'Netmask for simple network')
-flags.DEFINE_string('flat_network_gateway', '192.168.0.1',
- 'Broadcast for simple network')
-flags.DEFINE_string('flat_network_broadcast', '192.168.0.255',
- 'Broadcast for simple network')
-flags.DEFINE_string('flat_network_dns', '8.8.4.4',
- 'Dns for simple network')
-
-
-def type_to_class(network_type):
- """Convert a network_type string into an actual Python class"""
- if network_type == 'flat':
- return FlatNetworkService
- elif network_type == 'vlan':
- return VlanNetworkService
- raise exception.NotFound("Couldn't find %s network type" % network_type)
-
-
-def setup_compute_network(network_type, user_id, project_id, security_group):
- """Sets up the network on a compute host"""
- srv = type_to_class(network_type)
- srv.setup_compute_network(network_type,
- user_id,
- project_id,
- security_group)
-
-
-def get_host_for_project(project_id):
- """Get host allocated to project from datastore"""
- redis = datastore.Redis.instance()
- return redis.get(_host_key(project_id))
-
-
-def _host_key(project_id):
- """Returns redis host key for network"""
- return "networkhost:%s" % project_id
-
-
-class BaseNetworkService(service.Service):
- """Implements common network service functionality
-
- This class must be subclassed.
- """
- def __init__(self, *args, **kwargs):
- self.network = model.PublicNetworkController()
- super(BaseNetworkService, self).__init__(*args, **kwargs)
-
- def set_network_host(self, user_id, project_id, *args, **kwargs):
- """Safely sets the host of the projects network"""
- redis = datastore.Redis.instance()
- key = _host_key(project_id)
- if redis.setnx(key, FLAGS.node_name):
- self._on_set_network_host(user_id, project_id,
- security_group='default',
- *args, **kwargs)
- return FLAGS.node_name
- else:
- return redis.get(key)
-
- def allocate_fixed_ip(self, user_id, project_id,
- security_group='default',
- *args, **kwargs):
- """Subclass implements getting fixed ip from the pool"""
- raise NotImplementedError()
-
- def deallocate_fixed_ip(self, fixed_ip, *args, **kwargs):
- """Subclass implements return of ip to the pool"""
- raise NotImplementedError()
-
- def _on_set_network_host(self, user_id, project_id,
- *args, **kwargs):
- """Called when this host becomes the host for a project"""
- pass
-
- @classmethod
- def setup_compute_network(cls, user_id, project_id, security_group,
- *args, **kwargs):
- """Sets up matching network for compute hosts"""
- raise NotImplementedError()
-
- def allocate_elastic_ip(self, user_id, project_id):
- """Gets a elastic ip from the pool"""
- # NOTE(vish): Replicating earlier decision to use 'public' as
- # mac address name, although this should probably
- # be done inside of the PublicNetworkController
- return self.network.allocate_ip(user_id, project_id, 'public')
-
- def associate_elastic_ip(self, elastic_ip, fixed_ip, instance_id):
- """Associates an elastic ip to a fixed ip"""
- self.network.associate_address(elastic_ip, fixed_ip, instance_id)
-
- def disassociate_elastic_ip(self, elastic_ip):
- """Disassociates a elastic ip"""
- self.network.disassociate_address(elastic_ip)
-
- def deallocate_elastic_ip(self, elastic_ip):
- """Returns a elastic ip to the pool"""
- self.network.deallocate_ip(elastic_ip)
-
-
-class FlatNetworkService(BaseNetworkService):
- """Basic network where no vlans are used"""
-
- @classmethod
- def setup_compute_network(cls, user_id, project_id, security_group,
- *args, **kwargs):
- """Network is created manually"""
- pass
-
- def allocate_fixed_ip(self,
- user_id,
- project_id,
- security_group='default',
- *args, **kwargs):
- """Gets a fixed ip from the pool
-
- Flat network just grabs the next available ip from the pool
- """
- # NOTE(vish): Some automation could be done here. For example,
- # creating the flat_network_bridge and setting up
- # a gateway. This is all done manually atm.
- redis = datastore.Redis.instance()
- if not redis.exists('ips') and not len(redis.keys('instances:*')):
- for fixed_ip in FLAGS.flat_network_ips:
- redis.sadd('ips', fixed_ip)
- fixed_ip = redis.spop('ips')
- if not fixed_ip:
- raise exception.NoMoreAddresses()
- # TODO(vish): some sort of dns handling for hostname should
- # probably be done here.
- return {'inject_network': True,
- 'network_type': FLAGS.network_type,
- 'mac_address': utils.generate_mac(),
- 'private_dns_name': str(fixed_ip),
- 'bridge_name': FLAGS.flat_network_bridge,
- 'network_network': FLAGS.flat_network_network,
- 'network_netmask': FLAGS.flat_network_netmask,
- 'network_gateway': FLAGS.flat_network_gateway,
- 'network_broadcast': FLAGS.flat_network_broadcast,
- 'network_dns': FLAGS.flat_network_dns}
-
- def deallocate_fixed_ip(self, fixed_ip, *args, **kwargs):
- """Returns an ip to the pool"""
- datastore.Redis.instance().sadd('ips', fixed_ip)
-
-
-class VlanNetworkService(BaseNetworkService):
- """Vlan network with dhcp"""
- # NOTE(vish): A lot of the interactions with network/model.py can be
- # simplified and improved. Also there it may be useful
- # to support vlans separately from dhcp, instead of having
- # both of them together in this class.
- # pylint: disable-msg=W0221
- def allocate_fixed_ip(self,
- user_id,
- project_id,
- security_group='default',
- is_vpn=False,
- hostname=None,
- *args, **kwargs):
- """Gets a fixed ip from the pool"""
- mac = utils.generate_mac()
- net = model.get_project_network(project_id)
- if is_vpn:
- fixed_ip = net.allocate_vpn_ip(user_id,
- project_id,
- mac,
- hostname)
- else:
- fixed_ip = net.allocate_ip(user_id,
- project_id,
- mac,
- hostname)
- return {'network_type': FLAGS.network_type,
- 'bridge_name': net['bridge_name'],
- 'mac_address': mac,
- 'private_dns_name': fixed_ip}
-
- def deallocate_fixed_ip(self, fixed_ip,
- *args, **kwargs):
- """Returns an ip to the pool"""
- return model.get_network_by_address(fixed_ip).deallocate_ip(fixed_ip)
-
- def lease_ip(self, fixed_ip):
- """Called by bridge when ip is leased"""
- return model.get_network_by_address(fixed_ip).lease_ip(fixed_ip)
-
- def release_ip(self, fixed_ip):
- """Called by bridge when ip is released"""
- return model.get_network_by_address(fixed_ip).release_ip(fixed_ip)
-
- def restart_nets(self):
- """Ensure the network for each user is enabled"""
- for project in manager.AuthManager().get_projects():
- model.get_project_network(project.id).express()
-
- def _on_set_network_host(self, user_id, project_id,
- *args, **kwargs):
- """Called when this host becomes the host for a project"""
- vpn.NetworkData.create(project_id)
-
- @classmethod
- def setup_compute_network(cls, user_id, project_id, security_group,
- *args, **kwargs):
- """Sets up matching network for compute hosts"""
- # NOTE(vish): Use BridgedNetwork instead of DHCPNetwork because
- # we don't want to run dnsmasq on the client machines
- net = model.BridgedNetwork.get_network_for_project(
- user_id,
- project_id,
- security_group)
- net.express()
diff --git a/nova/network/vpn.py b/nova/network/vpn.py
deleted file mode 100644
index 85366ed89..000000000
--- a/nova/network/vpn.py
+++ /dev/null
@@ -1,126 +0,0 @@
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2010 United States Government as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""Network Data for projects"""
-
-from nova import datastore
-from nova import exception
-from nova import flags
-from nova import utils
-
-
-FLAGS = flags.FLAGS
-flags.DEFINE_string('vpn_ip', utils.get_my_ip(),
- 'Public IP for the cloudpipe VPN servers')
-flags.DEFINE_integer('vpn_start_port', 1000,
- 'Start port for the cloudpipe VPN servers')
-flags.DEFINE_integer('vpn_end_port', 2000,
- 'End port for the cloudpipe VPN servers')
-
-
-class NoMorePorts(exception.Error):
- """No ports available to allocate for the given ip"""
- pass
-
-
-class NetworkData(datastore.BasicModel):
- """Manages network host, and vpn ip and port for projects"""
- def __init__(self, project_id):
- self.project_id = project_id
- super(NetworkData, self).__init__()
-
- @property
- def identifier(self):
- """Identifier used for key in redis"""
- return self.project_id
-
- @classmethod
- def create(cls, project_id):
- """Creates a vpn for project
-
- This method finds a free ip and port and stores the associated
- values in the datastore.
- """
- # TODO(vish): will we ever need multiiple ips per host?
- port = cls.find_free_port_for_ip(FLAGS.vpn_ip)
- network_data = cls(project_id)
- # save ip for project
- network_data['host'] = FLAGS.node_name
- network_data['project'] = project_id
- network_data['ip'] = FLAGS.vpn_ip
- network_data['port'] = port
- network_data.save()
- return network_data
-
- @classmethod
- def find_free_port_for_ip(cls, vpn_ip):
- """Finds a free port for a given ip from the redis set"""
- # TODO(vish): these redis commands should be generalized and
- # placed into a base class. Conceptually, it is
- # similar to an association, but we are just
- # storing a set of values instead of keys that
- # should be turned into objects.
- cls._ensure_set_exists(vpn_ip)
-
- port = datastore.Redis.instance().spop(cls._redis_ports_key(vpn_ip))
- if not port:
- raise NoMorePorts()
- return port
-
- @classmethod
- def _redis_ports_key(cls, vpn_ip):
- """Key that ports are stored under in redis"""
- return 'ip:%s:ports' % vpn_ip
-
- @classmethod
- def _ensure_set_exists(cls, vpn_ip):
- """Creates the set of ports for the ip if it doesn't already exist"""
- # TODO(vish): these ports should be allocated through an admin
- # command instead of a flag
- redis = datastore.Redis.instance()
- if (not redis.exists(cls._redis_ports_key(vpn_ip)) and
- not redis.exists(cls._redis_association_name('ip', vpn_ip))):
- for i in range(FLAGS.vpn_start_port, FLAGS.vpn_end_port + 1):
- redis.sadd(cls._redis_ports_key(vpn_ip), i)
-
- @classmethod
- def num_ports_for_ip(cls, vpn_ip):
- """Calculates the number of free ports for a given ip"""
- cls._ensure_set_exists(vpn_ip)
- return datastore.Redis.instance().scard('ip:%s:ports' % vpn_ip)
-
- @property
- def ip(self): # pylint: disable-msg=C0103
- """The ip assigned to the project"""
- return self['ip']
-
- @property
- def port(self):
- """The port assigned to the project"""
- return int(self['port'])
-
- def save(self):
- """Saves the association to the given ip"""
- self.associate_with('ip', self.ip)
- super(NetworkData, self).save()
-
- def destroy(self):
- """Cleans up datastore and adds port back to pool"""
- self.unassociate_with('ip', self.ip)
- datastore.Redis.instance().sadd('ip:%s:ports' % self.ip, self.port)
- super(NetworkData, self).destroy()
diff --git a/nova/objectstore/__init__.py b/nova/objectstore/__init__.py
index b8890ac03..ecad9be7c 100644
--- a/nova/objectstore/__init__.py
+++ b/nova/objectstore/__init__.py
@@ -22,7 +22,7 @@
.. automodule:: nova.objectstore
:platform: Unix
- :synopsis: Currently a trivial file-based system, getting extended w/ mongo.
+ :synopsis: Currently a trivial file-based system, getting extended w/ swift.
.. moduleauthor:: Jesse Andrews <jesse@ansolabs.com>
.. moduleauthor:: Devin Carlen <devin.carlen@gmail.com>
.. moduleauthor:: Vishvananda Ishaya <vishvananda@yahoo.com>
diff --git a/nova/objectstore/handler.py b/nova/objectstore/handler.py
index 035e342ca..dfee64aca 100644
--- a/nova/objectstore/handler.py
+++ b/nova/objectstore/handler.py
@@ -55,7 +55,7 @@ from twisted.web import static
from nova import exception
from nova import flags
from nova.auth import manager
-from nova.endpoint import api
+from nova.api.ec2 import context
from nova.objectstore import bucket
from nova.objectstore import image
@@ -64,6 +64,7 @@ FLAGS = flags.FLAGS
def render_xml(request, value):
+ """Writes value as XML string to request"""
assert isinstance(value, dict) and len(value) == 1
request.setHeader("Content-Type", "application/xml; charset=UTF-8")
@@ -77,12 +78,14 @@ def render_xml(request, value):
def finish(request, content=None):
+ """Finalizer method for request"""
if content:
request.write(content)
request.finish()
def _render_parts(value, write_cb):
+ """Helper method to render different Python objects to XML"""
if isinstance(value, basestring):
write_cb(escape.xhtml_escape(value))
elif isinstance(value, int) or isinstance(value, long):
@@ -102,37 +105,48 @@ def _render_parts(value, write_cb):
def get_argument(request, key, default_value):
+ """Returns the request's value at key, or default_value
+ if not found
+ """
if key in request.args:
return request.args[key][0]
return default_value
def get_context(request):
+ """Returns the supplied request's context object"""
try:
# Authorization Header format: 'AWS <access>:<secret>'
authorization_header = request.getHeader('Authorization')
if not authorization_header:
raise exception.NotAuthorized
- access, sep, secret = authorization_header.split(' ')[1].rpartition(':')
- (user, project) = manager.AuthManager().authenticate(access,
- secret,
- {},
- request.method,
- request.getRequestHostname(),
- request.uri,
- headers=request.getAllHeaders(),
- check_type='s3')
- return api.APIRequestContext(None, user, project)
+ auth_header_value = authorization_header.split(' ')[1]
+ access, _ignored, secret = auth_header_value.rpartition(':')
+ am = manager.AuthManager()
+ (user, project) = am.authenticate(access,
+ secret,
+ {},
+ request.method,
+ request.getRequestHostname(),
+ request.uri,
+ headers=request.getAllHeaders(),
+ check_type='s3')
+ return context.APIRequestContext(user, project)
except exception.Error as ex:
- logging.debug("Authentication Failure: %s" % ex)
+ logging.debug("Authentication Failure: %s", ex)
raise exception.NotAuthorized
-
class ErrorHandlingResource(resource.Resource):
- """Maps exceptions to 404 / 401 codes. Won't work for exceptions thrown after NOT_DONE_YET is returned."""
- # TODO(unassigned) (calling-all-twisted-experts): This needs to be plugged in to the right place in twisted...
- # This doesn't look like it's the right place (consider exceptions in getChild; or after NOT_DONE_YET is returned
+ """Maps exceptions to 404 / 401 codes. Won't work for
+ exceptions thrown after NOT_DONE_YET is returned.
+ """
+ # TODO(unassigned) (calling-all-twisted-experts): This needs to be
+ # plugged in to the right place in twisted...
+ # This doesn't look like it's the right place
+ # (consider exceptions in getChild; or after
+ # NOT_DONE_YET is returned
def render(self, request):
+ """Renders the response as XML"""
try:
return resource.Resource.render(self, request)
except exception.NotFound:
@@ -145,7 +159,11 @@ class ErrorHandlingResource(resource.Resource):
class S3(ErrorHandlingResource):
"""Implementation of an S3-like storage server based on local files."""
- def getChild(self, name, request):
+ def __init__(self):
+ ErrorHandlingResource.__init__(self)
+
+ def getChild(self, name, request): # pylint: disable-msg=C0103
+ """Returns either the image or bucket resource"""
request.context = get_context(request)
if name == '':
return self
@@ -154,9 +172,11 @@ class S3(ErrorHandlingResource):
else:
return BucketResource(name)
- def render_GET(self, request):
+ def render_GET(self, request): # pylint: disable-msg=R0201
+ """Renders the GET request for a list of buckets as XML"""
logging.debug('List of buckets requested')
- buckets = [b for b in bucket.Bucket.all() if b.is_authorized(request.context)]
+ buckets = [b for b in bucket.Bucket.all() \
+ if b.is_authorized(request.context)]
render_xml(request, {"ListAllMyBucketsResult": {
"Buckets": {"Bucket": [b.metadata for b in buckets]},
@@ -165,22 +185,27 @@ class S3(ErrorHandlingResource):
class BucketResource(ErrorHandlingResource):
+ """A web resource containing an S3-like bucket"""
def __init__(self, name):
resource.Resource.__init__(self)
self.name = name
def getChild(self, name, request):
+ """Returns the bucket resource itself, or the object resource
+ the bucket contains if a name is supplied
+ """
if name == '':
return self
else:
return ObjectResource(bucket.Bucket(self.name), name)
def render_GET(self, request):
- logging.debug("List keys for bucket %s" % (self.name))
+ "Returns the keys for the bucket resource"""
+ logging.debug("List keys for bucket %s", self.name)
try:
bucket_object = bucket.Bucket(self.name)
- except exception.NotFound, e:
+ except exception.NotFound:
return error.NoResource(message="No such bucket").render(request)
if not bucket_object.is_authorized(request.context):
@@ -191,19 +216,26 @@ class BucketResource(ErrorHandlingResource):
max_keys = int(get_argument(request, "max-keys", 1000))
terse = int(get_argument(request, "terse", 0))
- results = bucket_object.list_keys(prefix=prefix, marker=marker, max_keys=max_keys, terse=terse)
+ results = bucket_object.list_keys(prefix=prefix,
+ marker=marker,
+ max_keys=max_keys,
+ terse=terse)
render_xml(request, {"ListBucketResult": results})
return server.NOT_DONE_YET
def render_PUT(self, request):
- logging.debug("Creating bucket %s" % (self.name))
- logging.debug("calling bucket.Bucket.create(%r, %r)" % (self.name, request.context))
+ "Creates the bucket resource"""
+ logging.debug("Creating bucket %s", self.name)
+ logging.debug("calling bucket.Bucket.create(%r, %r)",
+ self.name,
+ request.context)
bucket.Bucket.create(self.name, request.context)
request.finish()
return server.NOT_DONE_YET
def render_DELETE(self, request):
- logging.debug("Deleting bucket %s" % (self.name))
+ """Deletes the bucket resource"""
+ logging.debug("Deleting bucket %s", self.name)
bucket_object = bucket.Bucket(self.name)
if not bucket_object.is_authorized(request.context):
@@ -215,25 +247,37 @@ class BucketResource(ErrorHandlingResource):
class ObjectResource(ErrorHandlingResource):
+ """The resource returned from a bucket"""
def __init__(self, bucket, name):
resource.Resource.__init__(self)
self.bucket = bucket
self.name = name
def render_GET(self, request):
- logging.debug("Getting object: %s / %s" % (self.bucket.name, self.name))
+ """Returns the object
+
+ Raises NotAuthorized if user in request context is not
+ authorized to delete the object.
+ """
+ logging.debug("Getting object: %s / %s", self.bucket.name, self.name)
if not self.bucket.is_authorized(request.context):
raise exception.NotAuthorized
obj = self.bucket[urllib.unquote(self.name)]
request.setHeader("Content-Type", "application/unknown")
- request.setHeader("Last-Modified", datetime.datetime.utcfromtimestamp(obj.mtime))
+ request.setHeader("Last-Modified",
+ datetime.datetime.utcfromtimestamp(obj.mtime))
request.setHeader("Etag", '"' + obj.md5 + '"')
return static.File(obj.path).render_GET(request)
def render_PUT(self, request):
- logging.debug("Putting object: %s / %s" % (self.bucket.name, self.name))
+ """Modifies/inserts the object and returns a result code
+
+ Raises NotAuthorized if user in request context is not
+ authorized to delete the object.
+ """
+ logging.debug("Putting object: %s / %s", self.bucket.name, self.name)
if not self.bucket.is_authorized(request.context):
raise exception.NotAuthorized
@@ -246,7 +290,15 @@ class ObjectResource(ErrorHandlingResource):
return server.NOT_DONE_YET
def render_DELETE(self, request):
- logging.debug("Deleting object: %s / %s" % (self.bucket.name, self.name))
+ """Deletes the object and returns a result code
+
+ Raises NotAuthorized if user in request context is not
+ authorized to delete the object.
+ """
+
+ logging.debug("Deleting object: %s / %s",
+ self.bucket.name,
+ self.name)
if not self.bucket.is_authorized(request.context):
raise exception.NotAuthorized
@@ -257,6 +309,7 @@ class ObjectResource(ErrorHandlingResource):
class ImageResource(ErrorHandlingResource):
+ """A web resource representing a single image"""
isLeaf = True
def __init__(self, name):
@@ -264,17 +317,21 @@ class ImageResource(ErrorHandlingResource):
self.img = image.Image(name)
def render_GET(self, request):
- return static.File(self.img.image_path, defaultType='application/octet-stream').render_GET(request)
-
+ """Returns the image file"""
+ return static.File(self.img.image_path,
+ defaultType='application/octet-stream'
+ ).render_GET(request)
class ImagesResource(resource.Resource):
- def getChild(self, name, request):
+ """A web resource representing a list of images"""
+ def getChild(self, name, _request):
+ """Returns itself or an ImageResource if no name given"""
if name == '':
return self
else:
return ImageResource(name)
- def render_GET(self, request):
+ def render_GET(self, request): # pylint: disable-msg=R0201
""" returns a json listing of all images
that a user has permissions to see """
@@ -295,13 +352,15 @@ class ImagesResource(resource.Resource):
m[u'imageType'] = m['type']
elif 'imageType' in m:
m[u'type'] = m['imageType']
+ if 'displayName' not in m:
+ m[u'displayName'] = u''
return m
request.write(json.dumps([decorate(i.metadata) for i in images]))
request.finish()
return server.NOT_DONE_YET
- def render_PUT(self, request):
+ def render_PUT(self, request): # pylint: disable-msg=R0201
""" create a new registered image """
image_id = get_argument(request, 'image_id', u'')
@@ -313,7 +372,6 @@ class ImagesResource(resource.Resource):
raise exception.NotAuthorized
bucket_object = bucket.Bucket(image_location.split("/")[0])
- manifest = image_location[len(image_location.split('/')[0])+1:]
if not bucket_object.is_authorized(request.context):
raise exception.NotAuthorized
@@ -323,23 +381,32 @@ class ImagesResource(resource.Resource):
p.start()
return ''
- def render_POST(self, request):
- """ update image attributes: public/private """
+ def render_POST(self, request): # pylint: disable-msg=R0201
+ """Update image attributes: public/private"""
+ # image_id required for all requests
image_id = get_argument(request, 'image_id', u'')
- operation = get_argument(request, 'operation', u'')
-
image_object = image.Image(image_id)
-
if not image_object.is_authorized(request.context):
+ logging.debug("not authorized for render_POST in images")
raise exception.NotAuthorized
- image_object.set_public(operation=='add')
-
+ operation = get_argument(request, 'operation', u'')
+ if operation:
+ # operation implies publicity toggle
+ logging.debug("handling publicity toggle")
+ image_object.set_public(operation=='add')
+ else:
+ # other attributes imply update
+ logging.debug("update user fields")
+ clean_args = {}
+ for arg in request.args.keys():
+ clean_args[arg] = request.args[arg][0]
+ image_object.update_user_editable_fields(clean_args)
return ''
- def render_DELETE(self, request):
- """ delete a registered image """
+ def render_DELETE(self, request): # pylint: disable-msg=R0201
+ """Delete a registered image"""
image_id = get_argument(request, "image_id", u"")
image_object = image.Image(image_id)
@@ -353,14 +420,19 @@ class ImagesResource(resource.Resource):
def get_site():
+ """Support for WSGI-like interfaces"""
root = S3()
site = server.Site(root)
return site
def get_application():
+ """Support WSGI-like interfaces"""
factory = get_site()
application = service.Application("objectstore")
+ # Disabled because of lack of proper introspection in Twisted
+ # or possibly different versions of twisted?
+ # pylint: disable-msg=E1101
objectStoreService = internet.TCPServer(FLAGS.s3_port, factory)
objectStoreService.setServiceParent(application)
return application
diff --git a/nova/objectstore/image.py b/nova/objectstore/image.py
index f3c02a425..c01b041bb 100644
--- a/nova/objectstore/image.py
+++ b/nova/objectstore/image.py
@@ -82,6 +82,16 @@ class Image(object):
with open(os.path.join(self.path, 'info.json'), 'w') as f:
json.dump(md, f)
+ def update_user_editable_fields(self, args):
+ """args is from the request parameters, so requires extra cleaning"""
+ fields = {'display_name': 'displayName', 'description': 'description'}
+ info = self.metadata
+ for field in fields.keys():
+ if field in args:
+ info[fields[field]] = args[field]
+ with open(os.path.join(self.path, 'info.json'), 'w') as f:
+ json.dump(info, f)
+
@staticmethod
def all():
images = []
@@ -181,14 +191,14 @@ class Image(object):
if kernel_id == 'true':
image_type = 'kernel'
except:
- pass
+ kernel_id = None
try:
ramdisk_id = manifest.find("machine_configuration/ramdisk_id").text
if ramdisk_id == 'true':
image_type = 'ramdisk'
except:
- pass
+ ramdisk_id = None
info = {
'imageId': image_id,
@@ -199,6 +209,12 @@ class Image(object):
'imageType' : image_type
}
+ if kernel_id:
+ info['kernelId'] = kernel_id
+
+ if ramdisk_id:
+ info['ramdiskId'] = ramdisk_id
+
def write_state(state):
info['imageState'] = state
with open(os.path.join(image_path, 'info.json'), "w") as f:
diff --git a/nova/process.py b/nova/process.py
index 425d9f162..13cb90e82 100644
--- a/nova/process.py
+++ b/nova/process.py
@@ -18,9 +18,10 @@
# under the License.
"""
-Process pool, still buggy right now.
+Process pool using twisted threading
"""
+import logging
import StringIO
from twisted.internet import defer
@@ -29,30 +30,14 @@ from twisted.internet import protocol
from twisted.internet import reactor
from nova import flags
+from nova.exception import ProcessExecutionError
FLAGS = flags.FLAGS
flags.DEFINE_integer('process_pool_size', 4,
'Number of processes to use in the process pool')
-
-# NOTE(termie): this is copied from twisted.internet.utils but since
-# they don't export it I've copied and modified
-class UnexpectedErrorOutput(IOError):
- """
- Standard error data was received where it was not expected. This is a
- subclass of L{IOError} to preserve backward compatibility with the previous
- error behavior of L{getProcessOutput}.
-
- @ivar processEnded: A L{Deferred} which will fire when the process which
- produced the data on stderr has ended (exited and all file descriptors
- closed).
- """
- def __init__(self, stdout=None, stderr=None):
- IOError.__init__(self, "got stdout: %r\nstderr: %r" % (stdout, stderr))
-
-
-# This is based on _BackRelay from twister.internal.utils, but modified to
-# capture both stdout and stderr, without odd stderr handling, and also to
+# This is based on _BackRelay from twister.internal.utils, but modified to
+# capture both stdout and stderr, without odd stderr handling, and also to
# handle stdin
class BackRelayWithInput(protocol.ProcessProtocol):
"""
@@ -62,22 +47,23 @@ class BackRelayWithInput(protocol.ProcessProtocol):
@ivar deferred: A L{Deferred} which will be called back with all of stdout
and all of stderr as well (as a tuple). C{terminate_on_stderr} is true
and any bytes are received over stderr, this will fire with an
- L{_UnexpectedErrorOutput} instance and the attribute will be set to
+ L{_ProcessExecutionError} instance and the attribute will be set to
C{None}.
- @ivar onProcessEnded: If C{terminate_on_stderr} is false and bytes are
- received over stderr, this attribute will refer to a L{Deferred} which
- will be called back when the process ends. This C{Deferred} is also
- associated with the L{_UnexpectedErrorOutput} which C{deferred} fires
- with earlier in this case so that users can determine when the process
+ @ivar onProcessEnded: If C{terminate_on_stderr} is false and bytes are
+ received over stderr, this attribute will refer to a L{Deferred} which
+ will be called back when the process ends. This C{Deferred} is also
+ associated with the L{_ProcessExecutionError} which C{deferred} fires
+ with earlier in this case so that users can determine when the process
has actually ended, in addition to knowing when bytes have been received
via stderr.
"""
- def __init__(self, deferred, started_deferred=None,
- terminate_on_stderr=False, check_exit_code=True,
- process_input=None):
+ def __init__(self, deferred, cmd, started_deferred=None,
+ terminate_on_stderr=False, check_exit_code=True,
+ process_input=None):
self.deferred = deferred
+ self.cmd = cmd
self.stdout = StringIO.StringIO()
self.stderr = StringIO.StringIO()
self.started_deferred = started_deferred
@@ -85,14 +71,18 @@ class BackRelayWithInput(protocol.ProcessProtocol):
self.check_exit_code = check_exit_code
self.process_input = process_input
self.on_process_ended = None
-
+
+ def _build_execution_error(self, exit_code=None):
+ return ProcessExecutionError(cmd=self.cmd,
+ exit_code=exit_code,
+ stdout=self.stdout.getvalue(),
+ stderr=self.stderr.getvalue())
+
def errReceived(self, text):
self.stderr.write(text)
if self.terminate_on_stderr and (self.deferred is not None):
self.on_process_ended = defer.Deferred()
- self.deferred.errback(UnexpectedErrorOutput(
- stdout=self.stdout.getvalue(),
- stderr=self.stderr.getvalue()))
+ self.deferred.errback(self._build_execution_error())
self.deferred = None
self.transport.loseConnection()
@@ -102,15 +92,19 @@ class BackRelayWithInput(protocol.ProcessProtocol):
def processEnded(self, reason):
if self.deferred is not None:
stdout, stderr = self.stdout.getvalue(), self.stderr.getvalue()
- try:
- if self.check_exit_code:
- reason.trap(error.ProcessDone)
- self.deferred.callback((stdout, stderr))
- except:
- # NOTE(justinsb): This logic is a little suspicious to me...
- # If the callback throws an exception, then errback will be
- # called also. However, this is what the unit tests test for...
- self.deferred.errback(UnexpectedErrorOutput(stdout, stderr))
+ exit_code = reason.value.exitCode
+ if self.check_exit_code and exit_code <> 0:
+ self.deferred.errback(self._build_execution_error(exit_code))
+ else:
+ try:
+ if self.check_exit_code:
+ reason.trap(error.ProcessDone)
+ self.deferred.callback((stdout, stderr))
+ except:
+ # NOTE(justinsb): This logic is a little suspicious to me...
+ # If the callback throws an exception, then errback will be
+ # called also. However, this is what the unit tests test for...
+ self.deferred.errback(self._build_execution_error(exit_code))
elif self.on_process_ended is not None:
self.on_process_ended.errback(reason)
@@ -119,11 +113,11 @@ class BackRelayWithInput(protocol.ProcessProtocol):
if self.started_deferred:
self.started_deferred.callback(self)
if self.process_input:
- self.transport.write(self.process_input)
+ self.transport.write(str(self.process_input))
self.transport.closeStdin()
-def get_process_output(executable, args=None, env=None, path=None,
- process_reactor=None, check_exit_code=True,
+def get_process_output(executable, args=None, env=None, path=None,
+ process_reactor=None, check_exit_code=True,
process_input=None, started_deferred=None,
terminate_on_stderr=False):
if process_reactor is None:
@@ -131,10 +125,15 @@ def get_process_output(executable, args=None, env=None, path=None,
args = args and args or ()
env = env and env and {}
deferred = defer.Deferred()
+ cmd = executable
+ if args:
+ cmd = " ".join([cmd] + args)
+ logging.debug("Running cmd: %s", cmd)
process_handler = BackRelayWithInput(
- deferred,
- started_deferred=started_deferred,
- check_exit_code=check_exit_code,
+ deferred,
+ cmd,
+ started_deferred=started_deferred,
+ check_exit_code=check_exit_code,
process_input=process_input,
terminate_on_stderr=terminate_on_stderr)
# NOTE(vish): commands come in as unicode, but self.executes needs
@@ -142,8 +141,8 @@ def get_process_output(executable, args=None, env=None, path=None,
executable = str(executable)
if not args is None:
args = [str(x) for x in args]
- process_reactor.spawnProcess( process_handler, executable,
- (executable,)+tuple(args), env, path)
+ process_reactor.spawnProcess(process_handler, executable,
+ (executable,)+tuple(args), env, path)
return deferred
diff --git a/nova/quota.py b/nova/quota.py
new file mode 100644
index 000000000..edbb83111
--- /dev/null
+++ b/nova/quota.py
@@ -0,0 +1,92 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+Quotas for instances, volumes, and floating ips
+"""
+
+from nova import db
+from nova import exception
+from nova import flags
+from nova.compute import instance_types
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_integer('quota_instances', 10,
+ 'number of instances allowed per project')
+flags.DEFINE_integer('quota_cores', 20,
+ 'number of instance cores allowed per project')
+flags.DEFINE_integer('quota_volumes', 10,
+ 'number of volumes allowed per project')
+flags.DEFINE_integer('quota_gigabytes', 1000,
+ 'number of volume gigabytes allowed per project')
+flags.DEFINE_integer('quota_floating_ips', 10,
+ 'number of floating ips allowed per project')
+
+def get_quota(context, project_id):
+ rval = {'instances': FLAGS.quota_instances,
+ 'cores': FLAGS.quota_cores,
+ 'volumes': FLAGS.quota_volumes,
+ 'gigabytes': FLAGS.quota_gigabytes,
+ 'floating_ips': FLAGS.quota_floating_ips}
+ try:
+ quota = db.quota_get(context, project_id)
+ for key in rval.keys():
+ if quota[key] is not None:
+ rval[key] = quota[key]
+ except exception.NotFound:
+ pass
+ return rval
+
+def allowed_instances(context, num_instances, instance_type):
+ """Check quota and return min(num_instances, allowed_instances)"""
+ project_id = context.project.id
+ used_instances, used_cores = db.instance_data_get_for_project(context,
+ project_id)
+ quota = get_quota(context, project_id)
+ allowed_instances = quota['instances'] - used_instances
+ allowed_cores = quota['cores'] - used_cores
+ type_cores = instance_types.INSTANCE_TYPES[instance_type]['vcpus']
+ num_cores = num_instances * type_cores
+ allowed_instances = min(allowed_instances,
+ int(allowed_cores // type_cores))
+ return min(num_instances, allowed_instances)
+
+
+def allowed_volumes(context, num_volumes, size):
+ """Check quota and return min(num_volumes, allowed_volumes)"""
+ project_id = context.project.id
+ used_volumes, used_gigabytes = db.volume_data_get_for_project(context,
+ project_id)
+ quota = get_quota(context, project_id)
+ allowed_volumes = quota['volumes'] - used_volumes
+ allowed_gigabytes = quota['gigabytes'] - used_gigabytes
+ size = int(size)
+ num_gigabytes = num_volumes * size
+ allowed_volumes = min(allowed_volumes,
+ int(allowed_gigabytes // size))
+ return min(num_volumes, allowed_volumes)
+
+
+def allowed_floating_ips(context, num_floating_ips):
+ """Check quota and return min(num_floating_ips, allowed_floating_ips)"""
+ project_id = context.project.id
+ used_floating_ips = db.floating_ip_count_by_project(context, project_id)
+ quota = get_quota(context, project_id)
+ allowed_floating_ips = quota['floating_ips'] - used_floating_ips
+ return min(num_floating_ips, allowed_floating_ips)
+
diff --git a/nova/rpc.py b/nova/rpc.py
index 84a9b5590..447ad3b93 100644
--- a/nova/rpc.py
+++ b/nova/rpc.py
@@ -28,6 +28,7 @@ import uuid
from carrot import connection as carrot_connection
from carrot import messaging
+from eventlet import greenthread
from twisted.internet import defer
from twisted.internet import task
@@ -46,9 +47,9 @@ LOG.setLevel(logging.DEBUG)
class Connection(carrot_connection.BrokerConnection):
"""Connection instance object"""
@classmethod
- def instance(cls):
+ def instance(cls, new=False):
"""Returns the instance"""
- if not hasattr(cls, '_instance'):
+ if new or not hasattr(cls, '_instance'):
params = dict(hostname=FLAGS.rabbit_host,
port=FLAGS.rabbit_port,
userid=FLAGS.rabbit_userid,
@@ -60,7 +61,10 @@ class Connection(carrot_connection.BrokerConnection):
# NOTE(vish): magic is fun!
# pylint: disable-msg=W0142
- cls._instance = cls(**params)
+ if new:
+ return cls(**params)
+ else:
+ cls._instance = cls(**params)
return cls._instance
@classmethod
@@ -81,21 +85,6 @@ class Consumer(messaging.Consumer):
self.failed_connection = False
super(Consumer, self).__init__(*args, **kwargs)
- # TODO(termie): it would be nice to give these some way of automatically
- # cleaning up after themselves
- def attach_to_tornado(self, io_inst=None):
- """Attach a callback to tornado that fires 10 times a second"""
- from tornado import ioloop
- if io_inst is None:
- io_inst = ioloop.IOLoop.instance()
-
- injected = ioloop.PeriodicCallback(
- lambda: self.fetch(enable_callbacks=True), 100, io_loop=io_inst)
- injected.start()
- return injected
-
- attachToTornado = attach_to_tornado
-
def fetch(self, no_ack=None, auto_ack=None, enable_callbacks=False):
"""Wraps the parent fetch with some logic for failed connections"""
# TODO(vish): the logic for failed connections and logging should be
@@ -119,10 +108,19 @@ class Consumer(messaging.Consumer):
logging.exception("Failed to fetch message from queue")
self.failed_connection = True
+ def attach_to_eventlet(self):
+ """Only needed for unit tests!"""
+ def fetch_repeatedly():
+ while True:
+ self.fetch(enable_callbacks=True)
+ greenthread.sleep(0.1)
+ greenthread.spawn(fetch_repeatedly)
+
def attach_to_twisted(self):
"""Attach a callback to twisted that fires 10 times a second"""
loop = task.LoopingCall(self.fetch, enable_callbacks=True)
loop.start(interval=0.1)
+ return loop
class Publisher(messaging.Publisher):
@@ -265,6 +263,41 @@ def call(topic, msg):
msg.update({'_msg_id': msg_id})
LOG.debug("MSG_ID is %s" % (msg_id))
+ class WaitMessage(object):
+
+ def __call__(self, data, message):
+ """Acks message and sets result."""
+ message.ack()
+ if data['failure']:
+ self.result = RemoteError(*data['failure'])
+ else:
+ self.result = data['result']
+
+ wait_msg = WaitMessage()
+ conn = Connection.instance(True)
+ consumer = DirectConsumer(connection=conn, msg_id=msg_id)
+ consumer.register_callback(wait_msg)
+
+ conn = Connection.instance()
+ publisher = TopicPublisher(connection=conn, topic=topic)
+ publisher.send(msg)
+ publisher.close()
+
+ try:
+ consumer.wait(limit=1)
+ except StopIteration:
+ pass
+ consumer.close()
+ return wait_msg.result
+
+
+def call_twisted(topic, msg):
+ """Sends a message on a topic and wait for a response"""
+ LOG.debug("Making asynchronous call...")
+ msg_id = uuid.uuid4().hex
+ msg.update({'_msg_id': msg_id})
+ LOG.debug("MSG_ID is %s" % (msg_id))
+
conn = Connection.instance()
d = defer.Deferred()
consumer = DirectConsumer(connection=conn, msg_id=msg_id)
@@ -278,7 +311,7 @@ def call(topic, msg):
return d.callback(data['result'])
consumer.register_callback(deferred_receive)
- injected = consumer.attach_to_tornado()
+ injected = consumer.attach_to_twisted()
# clean up after the injected listened and return x
d.addCallback(lambda x: injected.stop() and x or x)
diff --git a/nova/scheduler/__init__.py b/nova/scheduler/__init__.py
new file mode 100644
index 000000000..8359a7aeb
--- /dev/null
+++ b/nova/scheduler/__init__.py
@@ -0,0 +1,25 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright (c) 2010 Openstack, LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+:mod:`nova.scheduler` -- Scheduler Nodes
+=====================================================
+
+.. automodule:: nova.scheduler
+ :platform: Unix
+ :synopsis: Module that picks a compute node to run a VM instance.
+.. moduleauthor:: Chris Behrens <cbehrens@codestud.com>
+"""
diff --git a/nova/network/exception.py b/nova/scheduler/chance.py
index 2a3f5ec14..7fd09b053 100644
--- a/nova/network/exception.py
+++ b/nova/scheduler/chance.py
@@ -1,5 +1,6 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
+# Copyright (c) 2010 Openstack, LLC.
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
@@ -17,32 +18,21 @@
# under the License.
"""
-Exceptions for network errors.
+Chance (Random) Scheduler implementation
"""
-from nova import exception
+import random
+from nova.scheduler import driver
-class NoMoreAddresses(exception.Error):
- """No More Addresses are available in the network"""
- pass
+class ChanceScheduler(driver.Scheduler):
+ """Implements Scheduler as a random node selector."""
-class AddressNotAllocated(exception.Error):
- """The specified address has not been allocated"""
- pass
+ def schedule(self, context, topic, *_args, **_kwargs):
+ """Picks a host that is up at random."""
-
-class AddressAlreadyAssociated(exception.Error):
- """The specified address has already been associated"""
- pass
-
-
-class AddressNotAssociated(exception.Error):
- """The specified address is not associated"""
- pass
-
-
-class NotValidNetworkSize(exception.Error):
- """The network size is not valid"""
- pass
+ hosts = self.hosts_up(context, topic)
+ if not hosts:
+ raise driver.NoValidHost("No hosts found")
+ return hosts[int(random.random() * len(hosts))]
diff --git a/nova/scheduler/driver.py b/nova/scheduler/driver.py
new file mode 100644
index 000000000..c89d25a47
--- /dev/null
+++ b/nova/scheduler/driver.py
@@ -0,0 +1,59 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright (c) 2010 Openstack, LLC.
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Scheduler base class that all Schedulers should inherit from
+"""
+
+import datetime
+
+from nova import db
+from nova import exception
+from nova import flags
+
+FLAGS = flags.FLAGS
+flags.DEFINE_integer('service_down_time', 60,
+ 'maximum time since last checkin for up service')
+
+class NoValidHost(exception.Error):
+ """There is no valid host for the command."""
+ pass
+
+class Scheduler(object):
+ """The base class that all Scheduler clases should inherit from."""
+
+ @staticmethod
+ def service_is_up(service):
+ """Check whether a service is up based on last heartbeat."""
+ last_heartbeat = service['updated_at'] or service['created_at']
+ # Timestamps in DB are UTC.
+ elapsed = datetime.datetime.utcnow() - last_heartbeat
+ return elapsed < datetime.timedelta(seconds=FLAGS.service_down_time)
+
+ def hosts_up(self, context, topic):
+ """Return the list of hosts that have a running service for topic."""
+
+ services = db.service_get_all_by_topic(context, topic)
+ return [service.host
+ for service in services
+ if self.service_is_up(service)]
+
+ def schedule(self, context, topic, *_args, **_kwargs):
+ """Must override at least this method for scheduler to work."""
+ raise NotImplementedError("Must implement a fallback schedule")
diff --git a/nova/scheduler/manager.py b/nova/scheduler/manager.py
new file mode 100644
index 000000000..0ad7ca86b
--- /dev/null
+++ b/nova/scheduler/manager.py
@@ -0,0 +1,66 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright (c) 2010 Openstack, LLC.
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Scheduler Service
+"""
+
+import logging
+import functools
+
+from nova import db
+from nova import flags
+from nova import manager
+from nova import rpc
+from nova import utils
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string('scheduler_driver',
+ 'nova.scheduler.chance.ChanceScheduler',
+ 'Driver to use for the scheduler')
+
+
+class SchedulerManager(manager.Manager):
+ """Chooses a host to run instances on."""
+ def __init__(self, scheduler_driver=None, *args, **kwargs):
+ if not scheduler_driver:
+ scheduler_driver = FLAGS.scheduler_driver
+ self.driver = utils.import_object(scheduler_driver)
+ super(SchedulerManager, self).__init__(*args, **kwargs)
+
+ def __getattr__(self, key):
+ """Converts all method calls to use the schedule method"""
+ return functools.partial(self._schedule, key)
+
+ def _schedule(self, method, context, topic, *args, **kwargs):
+ """Tries to call schedule_* method on the driver to retrieve host.
+
+ Falls back to schedule(context, topic) if method doesn't exist.
+ """
+ driver_method = 'schedule_%s' % method
+ try:
+ host = getattr(self.driver, driver_method)(context, *args, **kwargs)
+ except AttributeError:
+ host = self.driver.schedule(context, topic, *args, **kwargs)
+
+ kwargs.update({"context": None})
+ rpc.cast(db.queue_get_for(context, topic, host),
+ {"method": method,
+ "args": kwargs})
+ logging.debug("Casting to %s %s for %s", topic, host, method)
diff --git a/nova/scheduler/simple.py b/nova/scheduler/simple.py
new file mode 100644
index 000000000..fdaff74d8
--- /dev/null
+++ b/nova/scheduler/simple.py
@@ -0,0 +1,90 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright (c) 2010 Openstack, LLC.
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Simple Scheduler
+"""
+
+import datetime
+
+from nova import db
+from nova import flags
+from nova.scheduler import driver
+from nova.scheduler import chance
+
+FLAGS = flags.FLAGS
+flags.DEFINE_integer("max_cores", 16,
+ "maximum number of instance cores to allow per host")
+flags.DEFINE_integer("max_gigabytes", 10000,
+ "maximum number of volume gigabytes to allow per host")
+flags.DEFINE_integer("max_networks", 1000,
+ "maximum number of networks to allow per host")
+
+class SimpleScheduler(chance.ChanceScheduler):
+ """Implements Naive Scheduler that tries to find least loaded host."""
+
+ def schedule_run_instance(self, context, instance_id, *_args, **_kwargs):
+ """Picks a host that is up and has the fewest running instances."""
+ instance_ref = db.instance_get(context, instance_id)
+ results = db.service_get_all_compute_sorted(context)
+ for result in results:
+ (service, instance_cores) = result
+ if instance_cores + instance_ref['vcpus'] > FLAGS.max_cores:
+ raise driver.NoValidHost("All hosts have too many cores")
+ if self.service_is_up(service):
+ # NOTE(vish): this probably belongs in the manager, if we
+ # can generalize this somehow
+ now = datetime.datetime.utcnow()
+ db.instance_update(context,
+ instance_id,
+ {'host': service['host'],
+ 'scheduled_at': now})
+ return service['host']
+ raise driver.NoValidHost("No hosts found")
+
+ def schedule_create_volume(self, context, volume_id, *_args, **_kwargs):
+ """Picks a host that is up and has the fewest volumes."""
+ volume_ref = db.volume_get(context, volume_id)
+ results = db.service_get_all_volume_sorted(context)
+ for result in results:
+ (service, volume_gigabytes) = result
+ if volume_gigabytes + volume_ref['size'] > FLAGS.max_gigabytes:
+ raise driver.NoValidHost("All hosts have too many gigabytes")
+ if self.service_is_up(service):
+ # NOTE(vish): this probably belongs in the manager, if we
+ # can generalize this somehow
+ now = datetime.datetime.utcnow()
+ db.volume_update(context,
+ volume_id,
+ {'host': service['host'],
+ 'scheduled_at': now})
+ return service['host']
+ raise driver.NoValidHost("No hosts found")
+
+ def schedule_set_network_host(self, context, *_args, **_kwargs):
+ """Picks a host that is up and has the fewest networks."""
+
+ results = db.service_get_all_network_sorted(context)
+ for result in results:
+ (service, instance_count) = result
+ if instance_count >= FLAGS.max_networks:
+ raise driver.NoValidHost("All hosts have too many networks")
+ if self.service_is_up(service):
+ return service['host']
+ raise driver.NoValidHost("No hosts found")
diff --git a/nova/server.py b/nova/server.py
index 96550f078..c58a15041 100644
--- a/nova/server.py
+++ b/nova/server.py
@@ -44,6 +44,8 @@ flags.DEFINE_bool('use_syslog', True, 'output to syslog when daemonizing')
flags.DEFINE_string('logfile', None, 'log file to output to')
flags.DEFINE_string('pidfile', None, 'pid file to output to')
flags.DEFINE_string('working_directory', './', 'working directory...')
+flags.DEFINE_integer('uid', os.getuid(), 'uid under which to run')
+flags.DEFINE_integer('gid', os.getgid(), 'gid under which to run')
def stop(pidfile):
@@ -58,7 +60,7 @@ def stop(pidfile):
sys.stderr.write(message % pidfile)
return # not an error in a restart
- # Try killing the daemon process
+ # Try killing the daemon process
try:
while 1:
os.kill(pid, signal.SIGTERM)
@@ -104,6 +106,7 @@ def serve(name, main):
def daemonize(args, name, main):
"""Does the work of daemonizing the process"""
logging.getLogger('amqplib').setLevel(logging.WARN)
+ files_to_keep = []
if FLAGS.daemonize:
logger = logging.getLogger()
formatter = logging.Formatter(
@@ -112,12 +115,14 @@ def daemonize(args, name, main):
syslog = logging.handlers.SysLogHandler(address='/dev/log')
syslog.setFormatter(formatter)
logger.addHandler(syslog)
+ files_to_keep.append(syslog.socket)
else:
if not FLAGS.logfile:
FLAGS.logfile = '%s.log' % name
logfile = logging.FileHandler(FLAGS.logfile)
logfile.setFormatter(formatter)
logger.addHandler(logfile)
+ files_to_keep.append(logfile.stream)
stdin, stdout, stderr = None, None, None
else:
stdin, stdout, stderr = sys.stdin, sys.stdout, sys.stderr
@@ -135,6 +140,9 @@ def daemonize(args, name, main):
threaded=False),
stdin=stdin,
stdout=stdout,
- stderr=stderr
+ stderr=stderr,
+ uid=FLAGS.uid,
+ gid=FLAGS.gid,
+ files_preserve=files_to_keep
):
main(args)
diff --git a/nova/service.py b/nova/service.py
index 96281bc6b..115e0ff32 100644
--- a/nova/service.py
+++ b/nova/service.py
@@ -28,75 +28,155 @@ from twisted.internet import defer
from twisted.internet import task
from twisted.application import service
-from nova import datastore
+from nova import db
+from nova import exception
from nova import flags
from nova import rpc
-from nova.compute import model
+from nova import utils
FLAGS = flags.FLAGS
-
flags.DEFINE_integer('report_interval', 10,
- 'seconds between nodes reporting state to cloud',
+ 'seconds between nodes reporting state to datastore',
+ lower_bound=1)
+
+flags.DEFINE_integer('periodic_interval', 60,
+ 'seconds between running periodic tasks',
lower_bound=1)
+
class Service(object, service.Service):
- """Base class for workers that run on hosts"""
+ """Base class for workers that run on hosts."""
+
+ def __init__(self, host, binary, topic, manager, *args, **kwargs):
+ self.host = host
+ self.binary = binary
+ self.topic = topic
+ self.manager_class_name = manager
+ super(Service, self).__init__(*args, **kwargs)
+ self.saved_args, self.saved_kwargs = args, kwargs
+
+
+ def startService(self): # pylint: disable-msg C0103
+ manager_class = utils.import_class(self.manager_class_name)
+ self.manager = manager_class(host=self.host, *self.saved_args,
+ **self.saved_kwargs)
+ self.manager.init_host()
+ self.model_disconnected = False
+ try:
+ service_ref = db.service_get_by_args(None,
+ self.host,
+ self.binary)
+ self.service_id = service_ref['id']
+ except exception.NotFound:
+ self._create_service_ref()
+
+
+ def _create_service_ref(self):
+ service_ref = db.service_create(None, {'host': self.host,
+ 'binary': self.binary,
+ 'topic': self.topic,
+ 'report_count': 0})
+ self.service_id = service_ref['id']
+
+ def __getattr__(self, key):
+ try:
+ return super(Service, self).__getattr__(key)
+ except AttributeError:
+ return getattr(self.manager, key)
@classmethod
def create(cls,
- report_interval=None, # defaults to flag
- bin_name=None, # defaults to basename of executable
- topic=None): # defaults to basename - "nova-" part
- """Instantiates class and passes back application object"""
+ host=None,
+ binary=None,
+ topic=None,
+ manager=None,
+ report_interval=None,
+ periodic_interval=None):
+ """Instantiates class and passes back application object.
+
+ Args:
+ host, defaults to FLAGS.host
+ binary, defaults to basename of executable
+ topic, defaults to bin_name - "nova-" part
+ manager, defaults to FLAGS.<topic>_manager
+ report_interval, defaults to FLAGS.report_interval
+ periodic_interval, defaults to FLAGS.periodic_interval
+ """
+ if not host:
+ host = FLAGS.host
+ if not binary:
+ binary = os.path.basename(inspect.stack()[-1][1])
+ if not topic:
+ topic = binary.rpartition("nova-")[2]
+ if not manager:
+ manager = FLAGS.get('%s_manager' % topic, None)
if not report_interval:
- # NOTE(vish): set here because if it is set to flag in the
- # parameter list, it wrongly uses the default
report_interval = FLAGS.report_interval
- # NOTE(vish): magic to automatically determine bin_name and topic
- if not bin_name:
- bin_name = os.path.basename(inspect.stack()[-1][1])
- if not topic:
- topic = bin_name.rpartition("nova-")[2]
- logging.warn("Starting %s node" % topic)
- node_instance = cls()
-
+ if not periodic_interval:
+ periodic_interval = FLAGS.periodic_interval
+ logging.warn("Starting %s node", topic)
+ service_obj = cls(host, binary, topic, manager)
conn = rpc.Connection.instance()
consumer_all = rpc.AdapterConsumer(
connection=conn,
- topic='%s' % topic,
- proxy=node_instance)
-
+ topic=topic,
+ proxy=service_obj)
consumer_node = rpc.AdapterConsumer(
connection=conn,
- topic='%s.%s' % (topic, FLAGS.node_name),
- proxy=node_instance)
-
- pulse = task.LoopingCall(node_instance.report_state,
- FLAGS.node_name,
- bin_name)
- pulse.start(interval=report_interval, now=False)
+ topic='%s.%s' % (topic, host),
+ proxy=service_obj)
consumer_all.attach_to_twisted()
consumer_node.attach_to_twisted()
+ pulse = task.LoopingCall(service_obj.report_state)
+ pulse.start(interval=report_interval, now=False)
+
+ pulse = task.LoopingCall(service_obj.periodic_tasks)
+ pulse.start(interval=periodic_interval, now=False)
+
# This is the parent service that twistd will be looking for when it
- # parses this file, return it so that we can get it into globals below
- application = service.Application(bin_name)
- node_instance.setServiceParent(application)
+ # parses this file, return it so that we can get it into globals.
+ application = service.Application(binary)
+ service_obj.setServiceParent(application)
return application
+ def kill(self, context=None):
+ """Destroy the service object in the datastore"""
+ try:
+ db.service_destroy(context, self.service_id)
+ except exception.NotFound:
+ logging.warn("Service killed that has no database entry")
+
+ @defer.inlineCallbacks
+ def periodic_tasks(self, context=None):
+ """Tasks to be run at a periodic interval"""
+ yield self.manager.periodic_tasks(context)
+
@defer.inlineCallbacks
- def report_state(self, nodename, daemon):
- # TODO(termie): make this pattern be more elegant. -todd
+ def report_state(self, context=None):
+ """Update the state of this service in the datastore."""
try:
- record = model.Daemon(nodename, daemon)
- record.heartbeat()
+ try:
+ service_ref = db.service_get(context, self.service_id)
+ except exception.NotFound:
+ logging.debug("The service database object disappeared, "
+ "Recreating it.")
+ self._create_service_ref()
+ service_ref = db.service_get(context, self.service_id)
+
+ db.service_update(context,
+ self.service_id,
+ {'report_count': service_ref['report_count'] + 1})
+
+ # TODO(termie): make this pattern be more elegant.
if getattr(self, "model_disconnected", False):
self.model_disconnected = False
logging.error("Recovered model server connection!")
- except datastore.ConnectionError, ex:
+ # TODO(vish): this should probably only catch connection errors
+ except Exception: # pylint: disable-msg=W0702
if not getattr(self, "model_disconnected", False):
self.model_disconnected = True
logging.exception("model server went away")
diff --git a/nova/test.py b/nova/test.py
index c392c8a84..f6485377d 100644
--- a/nova/test.py
+++ b/nova/test.py
@@ -24,6 +24,7 @@ and some black magic for inline callbacks.
import sys
import time
+import datetime
import mox
import stubout
@@ -31,8 +32,11 @@ from tornado import ioloop
from twisted.internet import defer
from twisted.trial import unittest
+from nova import db
from nova import fakerabbit
from nova import flags
+from nova import rpc
+from nova.network import manager as network_manager
FLAGS = flags.FLAGS
@@ -56,24 +60,47 @@ class TrialTestCase(unittest.TestCase):
def setUp(self): # pylint: disable-msg=C0103
"""Run before each test method to initialize test environment"""
super(TrialTestCase, self).setUp()
+ # NOTE(vish): We need a better method for creating fixtures for tests
+ # now that we have some required db setup for the system
+ # to work properly.
+ self.start = datetime.datetime.utcnow()
+ if db.network_count(None) != 5:
+ network_manager.VlanManager().create_networks(None,
+ FLAGS.fixed_range,
+ 5, 16,
+ FLAGS.vlan_start,
+ FLAGS.vpn_start)
# emulate some of the mox stuff, we can't use the metaclass
# because it screws with our generators
self.mox = mox.Mox()
self.stubs = stubout.StubOutForTesting()
self.flag_overrides = {}
+ self.injected = []
+ self._monkeyPatchAttach()
def tearDown(self): # pylint: disable-msg=C0103
"""Runs after each test method to finalize/tear down test environment"""
- super(TrialTestCase, self).tearDown()
self.reset_flags()
self.mox.UnsetStubs()
self.stubs.UnsetAll()
self.stubs.SmartUnsetAll()
self.mox.VerifyAll()
+ # NOTE(vish): Clean up any ips associated during the test.
+ db.fixed_ip_disassociate_all_by_timeout(None, FLAGS.host, self.start)
+ db.network_disassociate_all(None)
+ rpc.Consumer.attach_to_twisted = self.originalAttach
+ for x in self.injected:
+ try:
+ x.stop()
+ except AssertionError:
+ pass
if FLAGS.fake_rabbit:
fakerabbit.reset_all()
+ db.security_group_destroy_all(None)
+
+ super(TrialTestCase, self).tearDown()
def flags(self, **kw):
"""Override flag variables for a test"""
@@ -90,16 +117,51 @@ class TrialTestCase(unittest.TestCase):
for k, v in self.flag_overrides.iteritems():
setattr(FLAGS, k, v)
+ def run(self, result=None):
+ test_method = getattr(self, self._testMethodName)
+ setattr(self,
+ self._testMethodName,
+ self._maybeInlineCallbacks(test_method, result))
+ rv = super(TrialTestCase, self).run(result)
+ setattr(self, self._testMethodName, test_method)
+ return rv
+
+ def _maybeInlineCallbacks(self, func, result):
+ def _wrapped():
+ g = func()
+ if isinstance(g, defer.Deferred):
+ return g
+ if not hasattr(g, 'send'):
+ return defer.succeed(g)
+
+ inlined = defer.inlineCallbacks(func)
+ d = inlined()
+ return d
+ _wrapped.func_name = func.func_name
+ return _wrapped
+
+ def _monkeyPatchAttach(self):
+ self.originalAttach = rpc.Consumer.attach_to_twisted
+ def _wrapped(innerSelf):
+ rv = self.originalAttach(innerSelf)
+ self.injected.append(rv)
+ return rv
+
+ _wrapped.func_name = self.originalAttach.func_name
+ rpc.Consumer.attach_to_twisted = _wrapped
+
class BaseTestCase(TrialTestCase):
# TODO(jaypipes): Can this be moved into the TrialTestCase class?
- """Base test case class for all unit tests."""
+ """Base test case class for all unit tests.
+
+ DEPRECATED: This is being removed once Tornado is gone, use TrialTestCase.
+ """
def setUp(self): # pylint: disable-msg=C0103
"""Run before each test method to initialize test environment"""
super(BaseTestCase, self).setUp()
# TODO(termie): we could possibly keep a more global registry of
# the injected listeners... this is fine for now though
- self.injected = []
self.ioloop = ioloop.IOLoop.instance()
self._waiting = None
@@ -109,8 +171,6 @@ class BaseTestCase(TrialTestCase):
def tearDown(self):# pylint: disable-msg=C0103
"""Runs after each test method to finalize/tear down test environment"""
super(BaseTestCase, self).tearDown()
- for x in self.injected:
- x.stop()
if FLAGS.fake_rabbit:
fakerabbit.reset_all()
diff --git a/nova/tests/access_unittest.py b/nova/tests/access_unittest.py
index fa0a090a0..4b40ffd0a 100644
--- a/nova/tests/access_unittest.py
+++ b/nova/tests/access_unittest.py
@@ -18,23 +18,22 @@
import unittest
import logging
+import webob
from nova import exception
from nova import flags
from nova import test
+from nova.api import ec2
from nova.auth import manager
-from nova.auth import rbac
FLAGS = flags.FLAGS
class Context(object):
pass
-class AccessTestCase(test.BaseTestCase):
+class AccessTestCase(test.TrialTestCase):
def setUp(self):
super(AccessTestCase, self).setUp()
- FLAGS.connection_type = 'fake'
- FLAGS.fake_storage = True
um = manager.AuthManager()
# Make test users
try:
@@ -74,9 +73,17 @@ class AccessTestCase(test.BaseTestCase):
try:
self.project.add_role(self.testsys, 'sysadmin')
except: pass
- self.context = Context()
- self.context.project = self.project
#user is set in each test
+ def noopWSGIApp(environ, start_response):
+ start_response('200 OK', [])
+ return ['']
+ self.mw = ec2.Authorizer(noopWSGIApp)
+ self.mw.action_roles = {'str': {
+ '_allow_all': ['all'],
+ '_allow_none': [],
+ '_allow_project_manager': ['projectmanager'],
+ '_allow_sys_and_net': ['sysadmin', 'netadmin'],
+ '_allow_sysadmin': ['sysadmin']}}
def tearDown(self):
um = manager.AuthManager()
@@ -89,76 +96,46 @@ class AccessTestCase(test.BaseTestCase):
um.delete_user('testsys')
super(AccessTestCase, self).tearDown()
+ def response_status(self, user, methodName):
+ context = Context()
+ context.project = self.project
+ context.user = user
+ environ = {'ec2.context' : context,
+ 'ec2.controller': 'some string',
+ 'ec2.action': methodName}
+ req = webob.Request.blank('/', environ)
+ resp = req.get_response(self.mw)
+ return resp.status_int
+
+ def shouldAllow(self, user, methodName):
+ self.assertEqual(200, self.response_status(user, methodName))
+
+ def shouldDeny(self, user, methodName):
+ self.assertEqual(401, self.response_status(user, methodName))
+
def test_001_allow_all(self):
- self.context.user = self.testadmin
- self.assertTrue(self._allow_all(self.context))
- self.context.user = self.testpmsys
- self.assertTrue(self._allow_all(self.context))
- self.context.user = self.testnet
- self.assertTrue(self._allow_all(self.context))
- self.context.user = self.testsys
- self.assertTrue(self._allow_all(self.context))
+ users = [self.testadmin, self.testpmsys, self.testnet, self.testsys]
+ for user in users:
+ self.shouldAllow(user, '_allow_all')
def test_002_allow_none(self):
- self.context.user = self.testadmin
- self.assertTrue(self._allow_none(self.context))
- self.context.user = self.testpmsys
- self.assertRaises(exception.NotAuthorized, self._allow_none, self.context)
- self.context.user = self.testnet
- self.assertRaises(exception.NotAuthorized, self._allow_none, self.context)
- self.context.user = self.testsys
- self.assertRaises(exception.NotAuthorized, self._allow_none, self.context)
+ self.shouldAllow(self.testadmin, '_allow_none')
+ users = [self.testpmsys, self.testnet, self.testsys]
+ for user in users:
+ self.shouldDeny(user, '_allow_none')
def test_003_allow_project_manager(self):
- self.context.user = self.testadmin
- self.assertTrue(self._allow_project_manager(self.context))
- self.context.user = self.testpmsys
- self.assertTrue(self._allow_project_manager(self.context))
- self.context.user = self.testnet
- self.assertRaises(exception.NotAuthorized, self._allow_project_manager, self.context)
- self.context.user = self.testsys
- self.assertRaises(exception.NotAuthorized, self._allow_project_manager, self.context)
+ for user in [self.testadmin, self.testpmsys]:
+ self.shouldAllow(user, '_allow_project_manager')
+ for user in [self.testnet, self.testsys]:
+ self.shouldDeny(user, '_allow_project_manager')
def test_004_allow_sys_and_net(self):
- self.context.user = self.testadmin
- self.assertTrue(self._allow_sys_and_net(self.context))
- self.context.user = self.testpmsys # doesn't have the per project sysadmin
- self.assertRaises(exception.NotAuthorized, self._allow_sys_and_net, self.context)
- self.context.user = self.testnet
- self.assertTrue(self._allow_sys_and_net(self.context))
- self.context.user = self.testsys
- self.assertTrue(self._allow_sys_and_net(self.context))
-
- def test_005_allow_sys_no_pm(self):
- self.context.user = self.testadmin
- self.assertTrue(self._allow_sys_no_pm(self.context))
- self.context.user = self.testpmsys
- self.assertRaises(exception.NotAuthorized, self._allow_sys_no_pm, self.context)
- self.context.user = self.testnet
- self.assertRaises(exception.NotAuthorized, self._allow_sys_no_pm, self.context)
- self.context.user = self.testsys
- self.assertTrue(self._allow_sys_no_pm(self.context))
-
- @rbac.allow('all')
- def _allow_all(self, context):
- return True
-
- @rbac.allow('none')
- def _allow_none(self, context):
- return True
-
- @rbac.allow('projectmanager')
- def _allow_project_manager(self, context):
- return True
-
- @rbac.allow('sysadmin', 'netadmin')
- def _allow_sys_and_net(self, context):
- return True
-
- @rbac.allow('sysadmin')
- @rbac.deny('projectmanager')
- def _allow_sys_no_pm(self, context):
- return True
+ for user in [self.testadmin, self.testnet, self.testsys]:
+ self.shouldAllow(user, '_allow_sys_and_net')
+ # denied because it doesn't have the per project sysadmin
+ for user in [self.testpmsys]:
+ self.shouldDeny(user, '_allow_sys_and_net')
if __name__ == "__main__":
# TODO: Implement use_fake as an option
diff --git a/nova/tests/api/__init__.py b/nova/tests/api/__init__.py
new file mode 100644
index 000000000..f051e2390
--- /dev/null
+++ b/nova/tests/api/__init__.py
@@ -0,0 +1,83 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Test for the root WSGI middleware for all API controllers.
+"""
+
+import unittest
+
+import stubout
+import webob
+import webob.dec
+
+import nova.exception
+from nova import api
+from nova.tests.api.fakes import APIStub
+
+
+class Test(unittest.TestCase):
+
+ def setUp(self): # pylint: disable-msg=C0103
+ self.stubs = stubout.StubOutForTesting()
+
+ def tearDown(self): # pylint: disable-msg=C0103
+ self.stubs.UnsetAll()
+
+ def _request(self, url, subdomain, **kwargs):
+ environ_keys = {'HTTP_HOST': '%s.example.com' % subdomain}
+ environ_keys.update(kwargs)
+ req = webob.Request.blank(url, environ_keys)
+ return req.get_response(api.API())
+
+ def test_openstack(self):
+ self.stubs.Set(api.openstack, 'API', APIStub)
+ result = self._request('/v1.0/cloud', 'api')
+ self.assertEqual(result.body, "/cloud")
+
+ def test_ec2(self):
+ self.stubs.Set(api.ec2, 'API', APIStub)
+ result = self._request('/services/cloud', 'ec2')
+ self.assertEqual(result.body, "/cloud")
+
+ def test_not_found(self):
+ self.stubs.Set(api.ec2, 'API', APIStub)
+ self.stubs.Set(api.openstack, 'API', APIStub)
+ result = self._request('/test/cloud', 'ec2')
+ self.assertNotEqual(result.body, "/cloud")
+
+ def test_query_api_versions(self):
+ result = self._request('/', 'api')
+ self.assertTrue('CURRENT' in result.body)
+
+ def test_metadata(self):
+ def go(url):
+ result = self._request(url, 'ec2',
+ REMOTE_ADDR='128.192.151.2')
+ # Each should get to the ORM layer and fail to find the IP
+ self.assertRaises(nova.exception.NotFound, go, '/latest/')
+ self.assertRaises(nova.exception.NotFound, go, '/2009-04-04/')
+ self.assertRaises(nova.exception.NotFound, go, '/1.0/')
+
+ def test_ec2_root(self):
+ result = self._request('/', 'ec2')
+ self.assertTrue('2007-12-15\n' in result.body)
+
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/nova/tests/api/fakes.py b/nova/tests/api/fakes.py
new file mode 100644
index 000000000..d0a2cc027
--- /dev/null
+++ b/nova/tests/api/fakes.py
@@ -0,0 +1,8 @@
+import webob.dec
+from nova import wsgi
+
+class APIStub(object):
+ """Class to verify request and mark it was called."""
+ @webob.dec.wsgify
+ def __call__(self, req):
+ return req.path_info
diff --git a/nova/tests/api/openstack/__init__.py b/nova/tests/api/openstack/__init__.py
new file mode 100644
index 000000000..b534897f5
--- /dev/null
+++ b/nova/tests/api/openstack/__init__.py
@@ -0,0 +1,108 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+from nova.api.openstack import limited
+from nova.api.openstack import RateLimitingMiddleware
+from nova.tests.api.fakes import APIStub
+from webob import Request
+
+
+class RateLimitingMiddlewareTest(unittest.TestCase):
+
+ def test_get_action_name(self):
+ middleware = RateLimitingMiddleware(APIStub())
+ def verify(method, url, action_name):
+ req = Request.blank(url)
+ req.method = method
+ action = middleware.get_action_name(req)
+ self.assertEqual(action, action_name)
+ verify('PUT', '/servers/4', 'PUT')
+ verify('DELETE', '/servers/4', 'DELETE')
+ verify('POST', '/images/4', 'POST')
+ verify('POST', '/servers/4', 'POST servers')
+ verify('GET', '/foo?a=4&changes-since=never&b=5', 'GET changes-since')
+ verify('GET', '/foo?a=4&monkeys-since=never&b=5', None)
+ verify('GET', '/servers/4', None)
+ verify('HEAD', '/servers/4', None)
+
+ def exhaust(self, middleware, method, url, username, times):
+ req = Request.blank(url, dict(REQUEST_METHOD=method),
+ headers={'X-Auth-User': username})
+ for i in range(times):
+ resp = req.get_response(middleware)
+ self.assertEqual(resp.status_int, 200)
+ resp = req.get_response(middleware)
+ self.assertEqual(resp.status_int, 413)
+ self.assertTrue('Retry-After' in resp.headers)
+
+ def test_single_action(self):
+ middleware = RateLimitingMiddleware(APIStub())
+ self.exhaust(middleware, 'DELETE', '/servers/4', 'usr1', 100)
+ self.exhaust(middleware, 'DELETE', '/servers/4', 'usr2', 100)
+
+ def test_POST_servers_action_implies_POST_action(self):
+ middleware = RateLimitingMiddleware(APIStub())
+ self.exhaust(middleware, 'POST', '/servers/4', 'usr1', 10)
+ self.exhaust(middleware, 'POST', '/images/4', 'usr2', 10)
+ self.assertTrue(set(middleware.limiter._levels) ==
+ set(['usr1:POST', 'usr1:POST servers', 'usr2:POST']))
+
+ def test_POST_servers_action_correctly_ratelimited(self):
+ middleware = RateLimitingMiddleware(APIStub())
+ # Use up all of our "POST" allowance for the minute, 5 times
+ for i in range(5):
+ self.exhaust(middleware, 'POST', '/servers/4', 'usr1', 10)
+ # Reset the 'POST' action counter.
+ del middleware.limiter._levels['usr1:POST']
+ # All 50 daily "POST servers" actions should be all used up
+ self.exhaust(middleware, 'POST', '/servers/4', 'usr1', 0)
+
+ def test_proxy_ctor_works(self):
+ middleware = RateLimitingMiddleware(APIStub())
+ self.assertEqual(middleware.limiter.__class__.__name__, "Limiter")
+ middleware = RateLimitingMiddleware(APIStub(), service_host='foobar')
+ self.assertEqual(middleware.limiter.__class__.__name__, "WSGIAppProxy")
+
+
+class LimiterTest(unittest.TestCase):
+
+ def test_limiter(self):
+ items = range(2000)
+ req = Request.blank('/')
+ self.assertEqual(limited(items, req), items[ :1000])
+ req = Request.blank('/?offset=0')
+ self.assertEqual(limited(items, req), items[ :1000])
+ req = Request.blank('/?offset=3')
+ self.assertEqual(limited(items, req), items[3:1003])
+ req = Request.blank('/?offset=2005')
+ self.assertEqual(limited(items, req), [])
+ req = Request.blank('/?limit=10')
+ self.assertEqual(limited(items, req), items[ :10])
+ req = Request.blank('/?limit=0')
+ self.assertEqual(limited(items, req), items[ :1000])
+ req = Request.blank('/?limit=3000')
+ self.assertEqual(limited(items, req), items[ :1000])
+ req = Request.blank('/?offset=1&limit=3')
+ self.assertEqual(limited(items, req), items[1:4])
+ req = Request.blank('/?offset=3&limit=0')
+ self.assertEqual(limited(items, req), items[3:1003])
+ req = Request.blank('/?offset=3&limit=1500')
+ self.assertEqual(limited(items, req), items[3:1003])
+ req = Request.blank('/?offset=3000&limit=10')
+ self.assertEqual(limited(items, req), [])
diff --git a/nova/tests/api/openstack/fakes.py b/nova/tests/api/openstack/fakes.py
new file mode 100644
index 000000000..71da2fd21
--- /dev/null
+++ b/nova/tests/api/openstack/fakes.py
@@ -0,0 +1,210 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import json
+import random
+import string
+
+import webob
+import webob.dec
+
+from nova import auth
+from nova import utils
+from nova import flags
+from nova import exception as exc
+import nova.api.openstack.auth
+from nova.image import service
+from nova.wsgi import Router
+
+
+FLAGS = flags.FLAGS
+
+
+class Context(object):
+ pass
+
+
+class FakeRouter(Router):
+ def __init__(self):
+ pass
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ res = webob.Response()
+ res.status = '200'
+ res.headers['X-Test-Success'] = 'True'
+ return res
+
+
+def fake_auth_init(self):
+ self.db = FakeAuthDatabase()
+ self.context = Context()
+ self.auth = FakeAuthManager()
+ self.host = 'foo'
+
+
+@webob.dec.wsgify
+def fake_wsgi(self, req):
+ req.environ['nova.context'] = dict(user=dict(id=1))
+ if req.body:
+ req.environ['inst_dict'] = json.loads(req.body)
+ return self.application
+
+
+def stub_out_key_pair_funcs(stubs):
+ def key_pair(context, user_id):
+ return [dict(name='key', public_key='public_key')]
+ stubs.Set(nova.db.api, 'key_pair_get_all_by_user',
+ key_pair)
+
+
+def stub_out_image_service(stubs):
+ def fake_image_show(meh, id):
+ return dict(kernelId=1, ramdiskId=1)
+
+ stubs.Set(nova.image.service.LocalImageService, 'show', fake_image_show)
+
+def stub_out_auth(stubs):
+ def fake_auth_init(self, app):
+ self.application = app
+
+ stubs.Set(nova.api.openstack.AuthMiddleware,
+ '__init__', fake_auth_init)
+ stubs.Set(nova.api.openstack.AuthMiddleware,
+ '__call__', fake_wsgi)
+
+
+def stub_out_rate_limiting(stubs):
+ def fake_rate_init(self, app):
+ super(nova.api.openstack.RateLimitingMiddleware, self).__init__(app)
+ self.application = app
+
+ stubs.Set(nova.api.openstack.RateLimitingMiddleware,
+ '__init__', fake_rate_init)
+
+ stubs.Set(nova.api.openstack.RateLimitingMiddleware,
+ '__call__', fake_wsgi)
+
+
+def stub_out_networking(stubs):
+ def get_my_ip():
+ return '127.0.0.1'
+ stubs.Set(nova.utils, 'get_my_ip', get_my_ip)
+ FLAGS.FAKE_subdomain = 'api'
+
+
+def stub_out_glance(stubs):
+
+ class FakeParallaxClient:
+
+ def __init__(self):
+ self.fixtures = {}
+
+ def fake_get_images(self):
+ return self.fixtures
+
+ def fake_get_image_metadata(self, image_id):
+ for k, f in self.fixtures.iteritems():
+ if k == image_id:
+ return f
+ return None
+
+ def fake_add_image_metadata(self, image_data):
+ id = ''.join(random.choice(string.letters) for _ in range(20))
+ image_data['id'] = id
+ self.fixtures[id] = image_data
+ return id
+
+ def fake_update_image_metadata(self, image_id, image_data):
+
+ if image_id not in self.fixtures.keys():
+ raise exc.NotFound
+
+ self.fixtures[image_id].update(image_data)
+
+ def fake_delete_image_metadata(self, image_id):
+
+ if image_id not in self.fixtures.keys():
+ raise exc.NotFound
+
+ del self.fixtures[image_id]
+
+ def fake_delete_all(self):
+ self.fixtures = {}
+
+ fake_parallax_client = FakeParallaxClient()
+ stubs.Set(nova.image.service.ParallaxClient, 'get_images',
+ fake_parallax_client.fake_get_images)
+ stubs.Set(nova.image.service.ParallaxClient, 'get_image_metadata',
+ fake_parallax_client.fake_get_image_metadata)
+ stubs.Set(nova.image.service.ParallaxClient, 'add_image_metadata',
+ fake_parallax_client.fake_add_image_metadata)
+ stubs.Set(nova.image.service.ParallaxClient, 'update_image_metadata',
+ fake_parallax_client.fake_update_image_metadata)
+ stubs.Set(nova.image.service.ParallaxClient, 'delete_image_metadata',
+ fake_parallax_client.fake_delete_image_metadata)
+ stubs.Set(nova.image.service.GlanceImageService, 'delete_all',
+ fake_parallax_client.fake_delete_all)
+
+class FakeToken(object):
+ def __init__(self, **kwargs):
+ for k,v in kwargs.iteritems():
+ setattr(self, k, v)
+
+class FakeAuthDatabase(object):
+ data = {}
+
+ @staticmethod
+ def auth_get_token(context, token_hash):
+ return FakeAuthDatabase.data.get(token_hash, None)
+
+ @staticmethod
+ def auth_create_token(context, token):
+ fake_token = FakeToken(created_at=datetime.datetime.now(), **token)
+ FakeAuthDatabase.data[fake_token.token_hash] = fake_token
+ return fake_token
+
+ @staticmethod
+ def auth_destroy_token(context, token):
+ if token.token_hash in FakeAuthDatabase.data:
+ del FakeAuthDatabase.data['token_hash']
+
+
+class FakeAuthManager(object):
+ auth_data = {}
+
+ def add_user(self, key, user):
+ FakeAuthManager.auth_data[key] = user
+
+ def get_user(self, uid):
+ for k, v in FakeAuthManager.auth_data.iteritems():
+ if v.id == uid:
+ return v
+ return None
+
+ def get_user_from_access_key(self, key):
+ return FakeAuthManager.auth_data.get(key, None)
+
+
+class FakeRateLimiter(object):
+ def __init__(self, application):
+ self.application = application
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ return self.application
diff --git a/nova/tests/api/openstack/test_auth.py b/nova/tests/api/openstack/test_auth.py
new file mode 100644
index 000000000..bbfb0fcea
--- /dev/null
+++ b/nova/tests/api/openstack/test_auth.py
@@ -0,0 +1,110 @@
+import datetime
+import unittest
+
+import stubout
+import webob
+import webob.dec
+
+import nova.api
+import nova.api.openstack.auth
+import nova.auth.manager
+from nova import auth
+from nova.tests.api.openstack import fakes
+
+class Test(unittest.TestCase):
+ def setUp(self):
+ self.stubs = stubout.StubOutForTesting()
+ self.stubs.Set(nova.api.openstack.auth.BasicApiAuthManager,
+ '__init__', fakes.fake_auth_init)
+ fakes.FakeAuthManager.auth_data = {}
+ fakes.FakeAuthDatabase.data = {}
+ fakes.stub_out_rate_limiting(self.stubs)
+ fakes.stub_out_networking(self.stubs)
+
+ def tearDown(self):
+ self.stubs.UnsetAll()
+ fakes.fake_data_store = {}
+
+ def test_authorize_user(self):
+ f = fakes.FakeAuthManager()
+ f.add_user('derp', nova.auth.manager.User(1, 'herp', None, None, None))
+
+ req = webob.Request.blank('/v1.0/')
+ req.headers['X-Auth-User'] = 'herp'
+ req.headers['X-Auth-Key'] = 'derp'
+ result = req.get_response(nova.api.API())
+ self.assertEqual(result.status, '204 No Content')
+ self.assertEqual(len(result.headers['X-Auth-Token']), 40)
+ self.assertEqual(result.headers['X-CDN-Management-Url'],
+ "")
+ self.assertEqual(result.headers['X-Storage-Url'], "")
+
+ def test_authorize_token(self):
+ f = fakes.FakeAuthManager()
+ f.add_user('derp', nova.auth.manager.User(1, 'herp', None, None, None))
+
+ req = webob.Request.blank('/v1.0/')
+ req.headers['X-Auth-User'] = 'herp'
+ req.headers['X-Auth-Key'] = 'derp'
+ result = req.get_response(nova.api.API())
+ self.assertEqual(result.status, '204 No Content')
+ self.assertEqual(len(result.headers['X-Auth-Token']), 40)
+ self.assertEqual(result.headers['X-Server-Management-Url'],
+ "https://foo/v1.0/")
+ self.assertEqual(result.headers['X-CDN-Management-Url'],
+ "")
+ self.assertEqual(result.headers['X-Storage-Url'], "")
+
+ token = result.headers['X-Auth-Token']
+ self.stubs.Set(nova.api.openstack, 'APIRouter',
+ fakes.FakeRouter)
+ req = webob.Request.blank('/v1.0/fake')
+ req.headers['X-Auth-Token'] = token
+ result = req.get_response(nova.api.API())
+ self.assertEqual(result.status, '200 OK')
+ self.assertEqual(result.headers['X-Test-Success'], 'True')
+
+ def test_token_expiry(self):
+ self.destroy_called = False
+ token_hash = 'bacon'
+
+ def destroy_token_mock(meh, context, token):
+ self.destroy_called = True
+
+ def bad_token(meh, context, token_hash):
+ return fakes.FakeToken(
+ token_hash=token_hash,
+ created_at=datetime.datetime(1990, 1, 1))
+
+ self.stubs.Set(fakes.FakeAuthDatabase, 'auth_destroy_token',
+ destroy_token_mock)
+
+ self.stubs.Set(fakes.FakeAuthDatabase, 'auth_get_token',
+ bad_token)
+
+ req = webob.Request.blank('/v1.0/')
+ req.headers['X-Auth-Token'] = 'bacon'
+ result = req.get_response(nova.api.API())
+ self.assertEqual(result.status, '401 Unauthorized')
+ self.assertEqual(self.destroy_called, True)
+
+ def test_bad_user(self):
+ req = webob.Request.blank('/v1.0/')
+ req.headers['X-Auth-User'] = 'herp'
+ req.headers['X-Auth-Key'] = 'derp'
+ result = req.get_response(nova.api.API())
+ self.assertEqual(result.status, '401 Unauthorized')
+
+ def test_no_user(self):
+ req = webob.Request.blank('/v1.0/')
+ result = req.get_response(nova.api.API())
+ self.assertEqual(result.status, '401 Unauthorized')
+
+ def test_bad_token(self):
+ req = webob.Request.blank('/v1.0/')
+ req.headers['X-Auth-Token'] = 'baconbaconbacon'
+ result = req.get_response(nova.api.API())
+ self.assertEqual(result.status, '401 Unauthorized')
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/nova/tests/api/openstack/test_faults.py b/nova/tests/api/openstack/test_faults.py
new file mode 100644
index 000000000..70a811469
--- /dev/null
+++ b/nova/tests/api/openstack/test_faults.py
@@ -0,0 +1,40 @@
+import unittest
+import webob
+import webob.dec
+import webob.exc
+
+from nova.api.openstack import faults
+
+class TestFaults(unittest.TestCase):
+
+ def test_fault_parts(self):
+ req = webob.Request.blank('/.xml')
+ f = faults.Fault(webob.exc.HTTPBadRequest(explanation='scram'))
+ resp = req.get_response(f)
+
+ first_two_words = resp.body.strip().split()[:2]
+ self.assertEqual(first_two_words, ['<badRequest', 'code="400">'])
+ body_without_spaces = ''.join(resp.body.split())
+ self.assertTrue('<message>scram</message>' in body_without_spaces)
+
+ def test_retry_header(self):
+ req = webob.Request.blank('/.xml')
+ exc = webob.exc.HTTPRequestEntityTooLarge(explanation='sorry',
+ headers={'Retry-After': 4})
+ f = faults.Fault(exc)
+ resp = req.get_response(f)
+ first_two_words = resp.body.strip().split()[:2]
+ self.assertEqual(first_two_words, ['<overLimit', 'code="413">'])
+ body_sans_spaces = ''.join(resp.body.split())
+ self.assertTrue('<message>sorry</message>' in body_sans_spaces)
+ self.assertTrue('<retryAfter>4</retryAfter>' in body_sans_spaces)
+ self.assertEqual(resp.headers['Retry-After'], 4)
+
+ def test_raise(self):
+ @webob.dec.wsgify
+ def raiser(req):
+ raise faults.Fault(webob.exc.HTTPNotFound(explanation='whut?'))
+ req = webob.Request.blank('/.xml')
+ resp = req.get_response(raiser)
+ self.assertEqual(resp.status_int, 404)
+ self.assertTrue('whut?' in resp.body)
diff --git a/nova/tests/api/openstack/test_flavors.py b/nova/tests/api/openstack/test_flavors.py
new file mode 100644
index 000000000..8dd4d1f29
--- /dev/null
+++ b/nova/tests/api/openstack/test_flavors.py
@@ -0,0 +1,48 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+import stubout
+import webob
+
+import nova.api
+from nova.api.openstack import flavors
+from nova.tests.api.openstack import fakes
+
+
+class FlavorsTest(unittest.TestCase):
+ def setUp(self):
+ self.stubs = stubout.StubOutForTesting()
+ fakes.FakeAuthManager.auth_data = {}
+ fakes.FakeAuthDatabase.data = {}
+ fakes.stub_out_networking(self.stubs)
+ fakes.stub_out_rate_limiting(self.stubs)
+ fakes.stub_out_auth(self.stubs)
+
+ def tearDown(self):
+ self.stubs.UnsetAll()
+
+ def test_get_flavor_list(self):
+ req = webob.Request.blank('/v1.0/flavors')
+ res = req.get_response(nova.api.API())
+
+ def test_get_flavor_by_id(self):
+ pass
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/nova/tests/api/openstack/test_images.py b/nova/tests/api/openstack/test_images.py
new file mode 100644
index 000000000..505fea3e2
--- /dev/null
+++ b/nova/tests/api/openstack/test_images.py
@@ -0,0 +1,141 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import logging
+import unittest
+
+import stubout
+
+from nova import exception
+from nova import utils
+from nova.api.openstack import images
+from nova.tests.api.openstack import fakes
+
+
+class BaseImageServiceTests():
+
+ """Tasks to test for all image services"""
+
+ def test_create(self):
+
+ fixture = {'name': 'test image',
+ 'updated': None,
+ 'created': None,
+ 'status': None,
+ 'serverId': None,
+ 'progress': None}
+
+ num_images = len(self.service.index())
+
+ id = self.service.create(fixture)
+
+ self.assertNotEquals(None, id)
+ self.assertEquals(num_images + 1, len(self.service.index()))
+
+ def test_create_and_show_non_existing_image(self):
+
+ fixture = {'name': 'test image',
+ 'updated': None,
+ 'created': None,
+ 'status': None,
+ 'serverId': None,
+ 'progress': None}
+
+ num_images = len(self.service.index())
+
+ id = self.service.create(fixture)
+
+ self.assertNotEquals(None, id)
+
+ self.assertRaises(exception.NotFound,
+ self.service.show,
+ 'bad image id')
+
+ def test_update(self):
+
+ fixture = {'name': 'test image',
+ 'updated': None,
+ 'created': None,
+ 'status': None,
+ 'serverId': None,
+ 'progress': None}
+
+ id = self.service.create(fixture)
+
+ fixture['status'] = 'in progress'
+
+ self.service.update(id, fixture)
+ new_image_data = self.service.show(id)
+ self.assertEquals('in progress', new_image_data['status'])
+
+ def test_delete(self):
+
+ fixtures = [
+ {'name': 'test image 1',
+ 'updated': None,
+ 'created': None,
+ 'status': None,
+ 'serverId': None,
+ 'progress': None},
+ {'name': 'test image 2',
+ 'updated': None,
+ 'created': None,
+ 'status': None,
+ 'serverId': None,
+ 'progress': None}]
+
+ ids = []
+ for fixture in fixtures:
+ new_id = self.service.create(fixture)
+ ids.append(new_id)
+
+ num_images = len(self.service.index())
+ self.assertEquals(2, num_images)
+
+ self.service.delete(ids[0])
+
+ num_images = len(self.service.index())
+ self.assertEquals(1, num_images)
+
+
+class LocalImageServiceTest(unittest.TestCase,
+ BaseImageServiceTests):
+
+ """Tests the local image service"""
+
+ def setUp(self):
+ self.stubs = stubout.StubOutForTesting()
+ self.service = utils.import_object('nova.image.service.LocalImageService')
+
+ def tearDown(self):
+ self.service.delete_all()
+ self.stubs.UnsetAll()
+
+
+class GlanceImageServiceTest(unittest.TestCase,
+ BaseImageServiceTests):
+
+ """Tests the local image service"""
+
+ def setUp(self):
+ self.stubs = stubout.StubOutForTesting()
+ fakes.stub_out_glance(self.stubs)
+ self.service = utils.import_object('nova.image.service.GlanceImageService')
+
+ def tearDown(self):
+ self.service.delete_all()
+ self.stubs.UnsetAll()
diff --git a/nova/tests/api/openstack/test_ratelimiting.py b/nova/tests/api/openstack/test_ratelimiting.py
new file mode 100644
index 000000000..ad9e67454
--- /dev/null
+++ b/nova/tests/api/openstack/test_ratelimiting.py
@@ -0,0 +1,237 @@
+import httplib
+import StringIO
+import time
+import unittest
+import webob
+
+import nova.api.openstack.ratelimiting as ratelimiting
+
+class LimiterTest(unittest.TestCase):
+
+ def setUp(self):
+ self.limits = {
+ 'a': (5, ratelimiting.PER_SECOND),
+ 'b': (5, ratelimiting.PER_MINUTE),
+ 'c': (5, ratelimiting.PER_HOUR),
+ 'd': (1, ratelimiting.PER_SECOND),
+ 'e': (100, ratelimiting.PER_SECOND)}
+ self.rl = ratelimiting.Limiter(self.limits)
+
+ def exhaust(self, action, times_until_exhausted, **kwargs):
+ for i in range(times_until_exhausted):
+ when = self.rl.perform(action, **kwargs)
+ self.assertEqual(when, None)
+ num, period = self.limits[action]
+ delay = period * 1.0 / num
+ # Verify that we are now thoroughly delayed
+ for i in range(10):
+ when = self.rl.perform(action, **kwargs)
+ self.assertAlmostEqual(when, delay, 2)
+
+ def test_second(self):
+ self.exhaust('a', 5)
+ time.sleep(0.2)
+ self.exhaust('a', 1)
+ time.sleep(1)
+ self.exhaust('a', 5)
+
+ def test_minute(self):
+ self.exhaust('b', 5)
+
+ def test_one_per_period(self):
+ def allow_once_and_deny_once():
+ when = self.rl.perform('d')
+ self.assertEqual(when, None)
+ when = self.rl.perform('d')
+ self.assertAlmostEqual(when, 1, 2)
+ return when
+ time.sleep(allow_once_and_deny_once())
+ time.sleep(allow_once_and_deny_once())
+ allow_once_and_deny_once()
+
+ def test_we_can_go_indefinitely_if_we_spread_out_requests(self):
+ for i in range(200):
+ when = self.rl.perform('e')
+ self.assertEqual(when, None)
+ time.sleep(0.01)
+
+ def test_users_get_separate_buckets(self):
+ self.exhaust('c', 5, username='alice')
+ self.exhaust('c', 5, username='bob')
+ self.exhaust('c', 5, username='chuck')
+ self.exhaust('c', 0, username='chuck')
+ self.exhaust('c', 0, username='bob')
+ self.exhaust('c', 0, username='alice')
+
+
+class FakeLimiter(object):
+ """Fake Limiter class that you can tell how to behave."""
+ def __init__(self, test):
+ self._action = self._username = self._delay = None
+ self.test = test
+ def mock(self, action, username, delay):
+ self._action = action
+ self._username = username
+ self._delay = delay
+ def perform(self, action, username):
+ self.test.assertEqual(action, self._action)
+ self.test.assertEqual(username, self._username)
+ return self._delay
+
+
+class WSGIAppTest(unittest.TestCase):
+
+ def setUp(self):
+ self.limiter = FakeLimiter(self)
+ self.app = ratelimiting.WSGIApp(self.limiter)
+
+ def test_invalid_methods(self):
+ requests = []
+ for method in ['GET', 'PUT', 'DELETE']:
+ req = webob.Request.blank('/limits/michael/breakdance',
+ dict(REQUEST_METHOD=method))
+ requests.append(req)
+ for req in requests:
+ self.assertEqual(req.get_response(self.app).status_int, 405)
+
+ def test_invalid_urls(self):
+ requests = []
+ for prefix in ['limit', '', 'limiter2', 'limiter/limits', 'limiter/1']:
+ req = webob.Request.blank('/%s/michael/breakdance' % prefix,
+ dict(REQUEST_METHOD='POST'))
+ requests.append(req)
+ for req in requests:
+ self.assertEqual(req.get_response(self.app).status_int, 404)
+
+ def verify(self, url, username, action, delay=None):
+ """Make sure that POSTing to the given url causes the given username
+ to perform the given action. Make the internal rate limiter return
+ delay and make sure that the WSGI app returns the correct response.
+ """
+ req = webob.Request.blank(url, dict(REQUEST_METHOD='POST'))
+ self.limiter.mock(action, username, delay)
+ resp = req.get_response(self.app)
+ if not delay:
+ self.assertEqual(resp.status_int, 200)
+ else:
+ self.assertEqual(resp.status_int, 403)
+ self.assertEqual(resp.headers['X-Wait-Seconds'], "%.2f" % delay)
+
+ def test_good_urls(self):
+ self.verify('/limiter/michael/hoot', 'michael', 'hoot')
+
+ def test_escaping(self):
+ self.verify('/limiter/michael/jump%20up', 'michael', 'jump up')
+
+ def test_response_to_delays(self):
+ self.verify('/limiter/michael/hoot', 'michael', 'hoot', 1)
+ self.verify('/limiter/michael/hoot', 'michael', 'hoot', 1.56)
+ self.verify('/limiter/michael/hoot', 'michael', 'hoot', 1000)
+
+
+class FakeHttplibSocket(object):
+ """a fake socket implementation for httplib.HTTPResponse, trivial"""
+
+ def __init__(self, response_string):
+ self._buffer = StringIO.StringIO(response_string)
+
+ def makefile(self, _mode, _other):
+ """Returns the socket's internal buffer"""
+ return self._buffer
+
+
+class FakeHttplibConnection(object):
+ """A fake httplib.HTTPConnection
+
+ Requests made via this connection actually get translated and routed into
+ our WSGI app, we then wait for the response and turn it back into
+ an httplib.HTTPResponse.
+ """
+ def __init__(self, app, host, is_secure=False):
+ self.app = app
+ self.host = host
+
+ def request(self, method, path, data='', headers={}):
+ req = webob.Request.blank(path)
+ req.method = method
+ req.body = data
+ req.headers = headers
+ req.host = self.host
+ # Call the WSGI app, get the HTTP response
+ resp = str(req.get_response(self.app))
+ # For some reason, the response doesn't have "HTTP/1.0 " prepended; I
+ # guess that's a function the web server usually provides.
+ resp = "HTTP/1.0 %s" % resp
+ sock = FakeHttplibSocket(resp)
+ self.http_response = httplib.HTTPResponse(sock)
+ self.http_response.begin()
+
+ def getresponse(self):
+ return self.http_response
+
+
+def wire_HTTPConnection_to_WSGI(host, app):
+ """Monkeypatches HTTPConnection so that if you try to connect to host, you
+ are instead routed straight to the given WSGI app.
+
+ After calling this method, when any code calls
+
+ httplib.HTTPConnection(host)
+
+ the connection object will be a fake. Its requests will be sent directly
+ to the given WSGI app rather than through a socket.
+
+ Code connecting to hosts other than host will not be affected.
+
+ This method may be called multiple times to map different hosts to
+ different apps.
+ """
+ class HTTPConnectionDecorator(object):
+ """Wraps the real HTTPConnection class so that when you instantiate
+ the class you might instead get a fake instance."""
+ def __init__(self, wrapped):
+ self.wrapped = wrapped
+ def __call__(self, connection_host, *args, **kwargs):
+ if connection_host == host:
+ return FakeHttplibConnection(app, host)
+ else:
+ return self.wrapped(connection_host, *args, **kwargs)
+ httplib.HTTPConnection = HTTPConnectionDecorator(httplib.HTTPConnection)
+
+
+class WSGIAppProxyTest(unittest.TestCase):
+
+ def setUp(self):
+ """Our WSGIAppProxy is going to call across an HTTPConnection to a
+ WSGIApp running a limiter. The proxy will send input, and the proxy
+ should receive that same input, pass it to the limiter who gives a
+ result, and send the expected result back.
+
+ The HTTPConnection isn't real -- it's monkeypatched to point straight
+ at the WSGIApp. And the limiter isn't real -- it's a fake that
+ behaves the way we tell it to.
+ """
+ self.limiter = FakeLimiter(self)
+ app = ratelimiting.WSGIApp(self.limiter)
+ wire_HTTPConnection_to_WSGI('100.100.100.100:80', app)
+ self.proxy = ratelimiting.WSGIAppProxy('100.100.100.100:80')
+
+ def test_200(self):
+ self.limiter.mock('conquer', 'caesar', None)
+ when = self.proxy.perform('conquer', 'caesar')
+ self.assertEqual(when, None)
+
+ def test_403(self):
+ self.limiter.mock('grumble', 'proletariat', 1.5)
+ when = self.proxy.perform('grumble', 'proletariat')
+ self.assertEqual(when, 1.5)
+
+ def test_failure(self):
+ def shouldRaise():
+ self.limiter.mock('murder', 'brutus', None)
+ self.proxy.perform('stab', 'brutus')
+ self.assertRaises(AssertionError, shouldRaise)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/nova/tests/api/openstack/test_servers.py b/nova/tests/api/openstack/test_servers.py
new file mode 100644
index 000000000..d1ee533b6
--- /dev/null
+++ b/nova/tests/api/openstack/test_servers.py
@@ -0,0 +1,249 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import json
+import unittest
+
+import stubout
+import webob
+
+from nova import db
+from nova import flags
+import nova.api.openstack
+from nova.api.openstack import servers
+import nova.db.api
+from nova.db.sqlalchemy.models import Instance
+import nova.rpc
+from nova.tests.api.openstack import fakes
+
+
+FLAGS = flags.FLAGS
+
+FLAGS.verbose = True
+
+def return_server(context, id):
+ return stub_instance(id)
+
+
+def return_servers(context, user_id=1):
+ return [stub_instance(i, user_id) for i in xrange(5)]
+
+
+def stub_instance(id, user_id=1):
+ return Instance(
+ id=id, state=0, image_id=10, server_name='server%s'%id,
+ user_id=user_id
+ )
+
+
+class ServersTest(unittest.TestCase):
+ def setUp(self):
+ self.stubs = stubout.StubOutForTesting()
+ fakes.FakeAuthManager.auth_data = {}
+ fakes.FakeAuthDatabase.data = {}
+ fakes.stub_out_networking(self.stubs)
+ fakes.stub_out_rate_limiting(self.stubs)
+ fakes.stub_out_auth(self.stubs)
+ fakes.stub_out_key_pair_funcs(self.stubs)
+ fakes.stub_out_image_service(self.stubs)
+ self.stubs.Set(nova.db.api, 'instance_get_all', return_servers)
+ self.stubs.Set(nova.db.api, 'instance_get_by_internal_id', return_server)
+ self.stubs.Set(nova.db.api, 'instance_get_all_by_user',
+ return_servers)
+
+ def tearDown(self):
+ self.stubs.UnsetAll()
+
+ def test_get_server_by_id(self):
+ req = webob.Request.blank('/v1.0/servers/1')
+ res = req.get_response(nova.api.API())
+ res_dict = json.loads(res.body)
+ self.assertEqual(res_dict['server']['id'], 1)
+ self.assertEqual(res_dict['server']['name'], 'server1')
+
+ def test_get_server_list(self):
+ req = webob.Request.blank('/v1.0/servers')
+ res = req.get_response(nova.api.API())
+ res_dict = json.loads(res.body)
+
+ i = 0
+ for s in res_dict['servers']:
+ self.assertEqual(s['id'], i)
+ self.assertEqual(s['name'], 'server%d'%i)
+ self.assertEqual(s.get('imageId', None), None)
+ i += 1
+
+ def test_create_instance(self):
+ def server_update(context, id, params):
+ pass
+
+ def instance_create(context, inst):
+ class Foo(object):
+ internal_id = 1
+ return Foo()
+
+ def fake_method(*args, **kwargs):
+ pass
+
+ def project_get_network(context, user_id):
+ return dict(id='1', host='localhost')
+
+ def queue_get_for(context, *args):
+ return 'network_topic'
+
+ self.stubs.Set(nova.db.api, 'project_get_network', project_get_network)
+ self.stubs.Set(nova.db.api, 'instance_create', instance_create)
+ self.stubs.Set(nova.rpc, 'cast', fake_method)
+ self.stubs.Set(nova.rpc, 'call', fake_method)
+ self.stubs.Set(nova.db.api, 'instance_update',
+ server_update)
+ self.stubs.Set(nova.db.api, 'queue_get_for', queue_get_for)
+ self.stubs.Set(nova.network.manager.VlanManager, 'allocate_fixed_ip',
+ fake_method)
+
+ body = dict(server=dict(
+ name='server_test', imageId=2, flavorId=2, metadata={},
+ personality = {}
+ ))
+ req = webob.Request.blank('/v1.0/servers')
+ req.method = 'POST'
+ req.body = json.dumps(body)
+
+ res = req.get_response(nova.api.API())
+
+ self.assertEqual(res.status_int, 200)
+
+ def test_update_no_body(self):
+ req = webob.Request.blank('/v1.0/servers/1')
+ req.method = 'PUT'
+ res = req.get_response(nova.api.API())
+ self.assertEqual(res.status_int, 422)
+
+ def test_update_bad_params(self):
+ """ Confirm that update is filtering params """
+ inst_dict = dict(cat='leopard', name='server_test', adminPass='bacon')
+ self.body = json.dumps(dict(server=inst_dict))
+
+ def server_update(context, id, params):
+ self.update_called = True
+ filtered_dict = dict(name='server_test', admin_pass='bacon')
+ self.assertEqual(params, filtered_dict)
+
+ self.stubs.Set(nova.db.api, 'instance_update',
+ server_update)
+
+ req = webob.Request.blank('/v1.0/servers/1')
+ req.method = 'PUT'
+ req.body = self.body
+ req.get_response(nova.api.API())
+
+ def test_update_server(self):
+ inst_dict = dict(name='server_test', adminPass='bacon')
+ self.body = json.dumps(dict(server=inst_dict))
+
+ def server_update(context, id, params):
+ filtered_dict = dict(name='server_test', admin_pass='bacon')
+ self.assertEqual(params, filtered_dict)
+
+ self.stubs.Set(nova.db.api, 'instance_update',
+ server_update)
+
+ req = webob.Request.blank('/v1.0/servers/1')
+ req.method = 'PUT'
+ req.body = self.body
+ req.get_response(nova.api.API())
+
+ def test_create_backup_schedules(self):
+ req = webob.Request.blank('/v1.0/servers/1/backup_schedules')
+ req.method = 'POST'
+ res = req.get_response(nova.api.API())
+ self.assertEqual(res.status, '404 Not Found')
+
+ def test_delete_backup_schedules(self):
+ req = webob.Request.blank('/v1.0/servers/1/backup_schedules')
+ req.method = 'DELETE'
+ res = req.get_response(nova.api.API())
+ self.assertEqual(res.status, '404 Not Found')
+
+ def test_get_server_backup_schedules(self):
+ req = webob.Request.blank('/v1.0/servers/1/backup_schedules')
+ res = req.get_response(nova.api.API())
+ self.assertEqual(res.status, '404 Not Found')
+
+ def test_get_all_server_details(self):
+ req = webob.Request.blank('/v1.0/servers/detail')
+ res = req.get_response(nova.api.API())
+ res_dict = json.loads(res.body)
+
+ i = 0
+ for s in res_dict['servers']:
+ self.assertEqual(s['id'], i)
+ self.assertEqual(s['name'], 'server%d'%i)
+ self.assertEqual(s['imageId'], 10)
+ i += 1
+
+ def test_server_reboot(self):
+ body = dict(server=dict(
+ name='server_test', imageId=2, flavorId=2, metadata={},
+ personality = {}
+ ))
+ req = webob.Request.blank('/v1.0/servers/1/action')
+ req.method = 'POST'
+ req.content_type= 'application/json'
+ req.body = json.dumps(body)
+ res = req.get_response(nova.api.API())
+
+ def test_server_rebuild(self):
+ body = dict(server=dict(
+ name='server_test', imageId=2, flavorId=2, metadata={},
+ personality = {}
+ ))
+ req = webob.Request.blank('/v1.0/servers/1/action')
+ req.method = 'POST'
+ req.content_type= 'application/json'
+ req.body = json.dumps(body)
+ res = req.get_response(nova.api.API())
+
+ def test_server_resize(self):
+ body = dict(server=dict(
+ name='server_test', imageId=2, flavorId=2, metadata={},
+ personality = {}
+ ))
+ req = webob.Request.blank('/v1.0/servers/1/action')
+ req.method = 'POST'
+ req.content_type= 'application/json'
+ req.body = json.dumps(body)
+ res = req.get_response(nova.api.API())
+
+ def test_delete_server_instance(self):
+ req = webob.Request.blank('/v1.0/servers/1')
+ req.method = 'DELETE'
+
+ self.server_delete_called = False
+ def instance_destroy_mock(context, id):
+ self.server_delete_called = True
+
+ self.stubs.Set(nova.db.api, 'instance_destroy',
+ instance_destroy_mock)
+
+ res = req.get_response(nova.api.API())
+ self.assertEqual(res.status, '202 Accepted')
+ self.assertEqual(self.server_delete_called, True)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/nova/tests/api/openstack/test_sharedipgroups.py b/nova/tests/api/openstack/test_sharedipgroups.py
new file mode 100644
index 000000000..d199951d8
--- /dev/null
+++ b/nova/tests/api/openstack/test_sharedipgroups.py
@@ -0,0 +1,39 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+import stubout
+
+from nova.api.openstack import sharedipgroups
+
+
+class SharedIpGroupsTest(unittest.TestCase):
+ def setUp(self):
+ self.stubs = stubout.StubOutForTesting()
+
+ def tearDown(self):
+ self.stubs.UnsetAll()
+
+ def test_get_shared_ip_groups(self):
+ pass
+
+ def test_create_shared_ip_group(self):
+ pass
+
+ def test_delete_shared_ip_group(self):
+ pass
diff --git a/nova/tests/api/test_wsgi.py b/nova/tests/api/test_wsgi.py
new file mode 100644
index 000000000..9425b01d0
--- /dev/null
+++ b/nova/tests/api/test_wsgi.py
@@ -0,0 +1,147 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# Copyright 2010 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Test WSGI basics and provide some helper functions for other WSGI tests.
+"""
+
+import unittest
+
+import routes
+import webob
+
+from nova import wsgi
+
+
+class Test(unittest.TestCase):
+
+ def test_debug(self):
+
+ class Application(wsgi.Application):
+ """Dummy application to test debug."""
+
+ def __call__(self, environ, start_response):
+ start_response("200", [("X-Test", "checking")])
+ return ['Test result']
+
+ application = wsgi.Debug(Application())
+ result = webob.Request.blank('/').get_response(application)
+ self.assertEqual(result.body, "Test result")
+
+ def test_router(self):
+
+ class Application(wsgi.Application):
+ """Test application to call from router."""
+
+ def __call__(self, environ, start_response):
+ start_response("200", [])
+ return ['Router result']
+
+ class Router(wsgi.Router):
+ """Test router."""
+
+ def __init__(self):
+ mapper = routes.Mapper()
+ mapper.connect("/test", controller=Application())
+ super(Router, self).__init__(mapper)
+
+ result = webob.Request.blank('/test').get_response(Router())
+ self.assertEqual(result.body, "Router result")
+ result = webob.Request.blank('/bad').get_response(Router())
+ self.assertNotEqual(result.body, "Router result")
+
+ def test_controller(self):
+
+ class Controller(wsgi.Controller):
+ """Test controller to call from router."""
+ test = self
+
+ def show(self, req, id): # pylint: disable-msg=W0622,C0103
+ """Default action called for requests with an ID."""
+ self.test.assertEqual(req.path_info, '/tests/123')
+ self.test.assertEqual(id, '123')
+ return id
+
+ class Router(wsgi.Router):
+ """Test router."""
+
+ def __init__(self):
+ mapper = routes.Mapper()
+ mapper.resource("test", "tests", controller=Controller())
+ super(Router, self).__init__(mapper)
+
+ result = webob.Request.blank('/tests/123').get_response(Router())
+ self.assertEqual(result.body, "123")
+ result = webob.Request.blank('/test/123').get_response(Router())
+ self.assertNotEqual(result.body, "123")
+
+
+class SerializerTest(unittest.TestCase):
+
+ def match(self, url, accept, expect):
+ input_dict = dict(servers=dict(a=(2,3)))
+ expected_xml = '<servers><a>(2,3)</a></servers>'
+ expected_json = '{"servers":{"a":[2,3]}}'
+ req = webob.Request.blank(url, headers=dict(Accept=accept))
+ result = wsgi.Serializer(req.environ).to_content_type(input_dict)
+ result = result.replace('\n', '').replace(' ', '')
+ if expect == 'xml':
+ self.assertEqual(result, expected_xml)
+ elif expect == 'json':
+ self.assertEqual(result, expected_json)
+ else:
+ raise "Bad expect value"
+
+ def test_basic(self):
+ self.match('/servers/4.json', None, expect='json')
+ self.match('/servers/4', 'application/json', expect='json')
+ self.match('/servers/4', 'application/xml', expect='xml')
+ self.match('/servers/4.xml', None, expect='xml')
+
+ def test_defaults_to_json(self):
+ self.match('/servers/4', None, expect='json')
+ self.match('/servers/4', 'text/html', expect='json')
+
+ def test_suffix_takes_precedence_over_accept_header(self):
+ self.match('/servers/4.xml', 'application/json', expect='xml')
+ self.match('/servers/4.xml.', 'application/json', expect='json')
+
+ def test_deserialize(self):
+ xml = """
+ <a a1="1" a2="2">
+ <bs><b>1</b><b>2</b><b>3</b><b><c c1="1"/></b></bs>
+ <d><e>1</e></d>
+ <f>1</f>
+ </a>
+ """.strip()
+ as_dict = dict(a={
+ 'a1': '1',
+ 'a2': '2',
+ 'bs': ['1', '2', '3', {'c': dict(c1='1')}],
+ 'd': {'e': '1'},
+ 'f': '1'})
+ metadata = {'application/xml': dict(plurals={'bs': 'b', 'ts': 't'})}
+ serializer = wsgi.Serializer({}, metadata)
+ self.assertEqual(serializer.deserialize(xml), as_dict)
+
+ def test_deserialize_empty_xml(self):
+ xml = """<a></a>"""
+ as_dict = {"a": {}}
+ serializer = wsgi.Serializer({})
+ self.assertEqual(serializer.deserialize(xml), as_dict)
diff --git a/nova/tests/api_unittest.py b/nova/tests/api_unittest.py
index 9d072866c..7ab27e000 100644
--- a/nova/tests/api_unittest.py
+++ b/nova/tests/api_unittest.py
@@ -16,167 +16,102 @@
# License for the specific language governing permissions and limitations
# under the License.
+"""Unit tests for the API endpoint"""
+
import boto
from boto.ec2 import regioninfo
import httplib
import random
import StringIO
-from tornado import httpserver
-from twisted.internet import defer
+import webob
from nova import flags
from nova import test
+from nova import api
+from nova.api.ec2 import cloud
from nova.auth import manager
-from nova.endpoint import api
-from nova.endpoint import cloud
FLAGS = flags.FLAGS
-
-
-# NOTE(termie): These are a bunch of helper methods and classes to short
-# circuit boto calls and feed them into our tornado handlers,
-# it's pretty damn circuitous so apologies if you have to fix
-# a bug in it
-def boto_to_tornado(method, path, headers, data, host, connection=None):
- """ translate boto requests into tornado requests
-
- connection should be a FakeTornadoHttpConnection instance
- """
- try:
- headers = httpserver.HTTPHeaders()
- except AttributeError:
- from tornado import httputil
- headers = httputil.HTTPHeaders()
- for k, v in headers.iteritems():
- headers[k] = v
-
- req = httpserver.HTTPRequest(method=method,
- uri=path,
- headers=headers,
- body=data,
- host=host,
- remote_ip='127.0.0.1',
- connection=connection)
- return req
-
-
-def raw_to_httpresponse(s):
- """ translate a raw tornado http response into an httplib.HTTPResponse """
- sock = FakeHttplibSocket(s)
- resp = httplib.HTTPResponse(sock)
- resp.begin()
- return resp
+FLAGS.FAKE_subdomain = 'ec2'
class FakeHttplibSocket(object):
- """ a fake socket implementation for httplib.HTTPResponse, trivial """
- def __init__(self, s):
- self.fp = StringIO.StringIO(s)
-
- def makefile(self, mode, other):
- return self.fp
-
-
-class FakeTornadoStream(object):
- """ a fake stream to satisfy tornado's assumptions, trivial """
- def set_close_callback(self, f):
- pass
-
-
-class FakeTornadoConnection(object):
- """ a fake connection object for tornado to pass to its handlers
-
- web requests are expected to write to this as they get data and call
- finish when they are done with the request, we buffer the writes and
- kick off a callback when it is done so that we can feed the result back
- into boto.
- """
- def __init__(self, d):
- self.d = d
- self._buffer = StringIO.StringIO()
-
- def write(self, chunk):
- self._buffer.write(chunk)
-
- def finish(self):
- s = self._buffer.getvalue()
- self.d.callback(s)
+ """a fake socket implementation for httplib.HTTPResponse, trivial"""
+ def __init__(self, response_string):
+ self._buffer = StringIO.StringIO(response_string)
- xheaders = None
-
- @property
- def stream(self):
- return FakeTornadoStream()
+ def makefile(self, _mode, _other):
+ """Returns the socket's internal buffer"""
+ return self._buffer
class FakeHttplibConnection(object):
- """ a fake httplib.HTTPConnection for boto to use
+ """A fake httplib.HTTPConnection for boto to use
requests made via this connection actually get translated and routed into
- our tornado app, we then wait for the response and turn it back into
+ our WSGI app, we then wait for the response and turn it back into
the httplib.HTTPResponse that boto expects.
"""
def __init__(self, app, host, is_secure=False):
self.app = app
self.host = host
- self.deferred = defer.Deferred()
def request(self, method, path, data, headers):
- req = boto_to_tornado
- conn = FakeTornadoConnection(self.deferred)
- request = boto_to_tornado(connection=conn,
- method=method,
- path=path,
- headers=headers,
- data=data,
- host=self.host)
- handler = self.app(request)
- self.deferred.addCallback(raw_to_httpresponse)
+ req = webob.Request.blank(path)
+ req.method = method
+ req.body = data
+ req.headers = headers
+ req.headers['Accept'] = 'text/html'
+ req.host = self.host
+ # Call the WSGI app, get the HTTP response
+ resp = str(req.get_response(self.app))
+ # For some reason, the response doesn't have "HTTP/1.0 " prepended; I
+ # guess that's a function the web server usually provides.
+ resp = "HTTP/1.0 %s" % resp
+ sock = FakeHttplibSocket(resp)
+ self.http_response = httplib.HTTPResponse(sock)
+ self.http_response.begin()
def getresponse(self):
- @defer.inlineCallbacks
- def _waiter():
- result = yield self.deferred
- defer.returnValue(result)
- d = _waiter()
- # NOTE(termie): defer.returnValue above should ensure that
- # this deferred has already been called by the time
- # we get here, we are going to cheat and return
- # the result of the callback
- return d.result
+ return self.http_response
def close(self):
+ """Required for compatibility with boto/tornado"""
pass
class ApiEc2TestCase(test.BaseTestCase):
- def setUp(self):
+ """Unit test for the cloud controller on an EC2 API"""
+ def setUp(self): # pylint: disable-msg=C0103,C0111
super(ApiEc2TestCase, self).setUp()
self.manager = manager.AuthManager()
- self.cloud = cloud.CloudController()
self.host = '127.0.0.1'
- self.app = api.APIServerApplication({'Cloud': self.cloud})
+ self.app = api.API()
+
+ def expect_http(self, host=None, is_secure=False):
+ """Returns a new EC2 connection"""
self.ec2 = boto.connect_ec2(
aws_access_key_id='fake',
aws_secret_access_key='fake',
is_secure=False,
region=regioninfo.RegionInfo(None, 'test', self.host),
- port=FLAGS.cc_port,
+ port=8773,
path='/services/Cloud')
self.mox.StubOutWithMock(self.ec2, 'new_http_connection')
-
- def expect_http(self, host=None, is_secure=False):
http = FakeHttplibConnection(
- self.app, '%s:%d' % (self.host, FLAGS.cc_port), False)
+ self.app, '%s:8773' % (self.host), False)
+ # pylint: disable-msg=E1103
self.ec2.new_http_connection(host, is_secure).AndReturn(http)
return http
def test_describe_instances(self):
+ """Test that, after creating a user and a project, the describe
+ instances call to the API works properly"""
self.expect_http()
self.mox.ReplayAll()
user = self.manager.create_user('fake', 'fake', 'fake')
@@ -187,14 +122,201 @@ class ApiEc2TestCase(test.BaseTestCase):
def test_get_all_key_pairs(self):
+ """Test that, after creating a user and project and generating
+ a key pair, that the API call to list key pairs works properly"""
self.expect_http()
self.mox.ReplayAll()
- keyname = "".join(random.choice("sdiuisudfsdcnpaqwertasd") for x in range(random.randint(4, 8)))
+ keyname = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
+ for x in range(random.randint(4, 8)))
user = self.manager.create_user('fake', 'fake', 'fake')
project = self.manager.create_project('fake', 'fake', 'fake')
- self.manager.generate_key_pair(user.id, keyname)
+ # NOTE(vish): create depends on pool, so call helper directly
+ cloud._gen_key(None, user.id, keyname)
rv = self.ec2.get_all_key_pairs()
- self.assertTrue(filter(lambda k: k.name == keyname, rv))
+ results = [k for k in rv if k.name == keyname]
+ self.assertEquals(len(results), 1)
+ self.manager.delete_project(project)
+ self.manager.delete_user(user)
+
+ def test_get_all_security_groups(self):
+ """Test that we can retrieve security groups"""
+ self.expect_http()
+ self.mox.ReplayAll()
+ user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
+ project = self.manager.create_project('fake', 'fake', 'fake')
+
+ rv = self.ec2.get_all_security_groups()
+
+ self.assertEquals(len(rv), 1)
+ self.assertEquals(rv[0].name, 'default')
+
+ self.manager.delete_project(project)
+ self.manager.delete_user(user)
+
+ def test_create_delete_security_group(self):
+ """Test that we can create a security group"""
+ self.expect_http()
+ self.mox.ReplayAll()
+ user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
+ project = self.manager.create_project('fake', 'fake', 'fake')
+
+ # At the moment, you need both of these to actually be netadmin
+ self.manager.add_role('fake', 'netadmin')
+ project.add_role('fake', 'netadmin')
+
+ security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
+ for x in range(random.randint(4, 8)))
+
+ self.ec2.create_security_group(security_group_name, 'test group')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ rv = self.ec2.get_all_security_groups()
+ self.assertEquals(len(rv), 2)
+ self.assertTrue(security_group_name in [group.name for group in rv])
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ self.ec2.delete_security_group(security_group_name)
+
+ self.manager.delete_project(project)
+ self.manager.delete_user(user)
+
+ def test_authorize_revoke_security_group_cidr(self):
+ """
+ Test that we can add and remove CIDR based rules
+ to a security group
+ """
+ self.expect_http()
+ self.mox.ReplayAll()
+ user = self.manager.create_user('fake', 'fake', 'fake')
+ project = self.manager.create_project('fake', 'fake', 'fake')
+
+ # At the moment, you need both of these to actually be netadmin
+ self.manager.add_role('fake', 'netadmin')
+ project.add_role('fake', 'netadmin')
+
+ security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
+ for x in range(random.randint(4, 8)))
+
+ group = self.ec2.create_security_group(security_group_name, 'test group')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+ group.connection = self.ec2
+
+ group.authorize('tcp', 80, 81, '0.0.0.0/0')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ rv = self.ec2.get_all_security_groups()
+ # I don't bother checkng that we actually find it here,
+ # because the create/delete unit test further up should
+ # be good enough for that.
+ for group in rv:
+ if group.name == security_group_name:
+ self.assertEquals(len(group.rules), 1)
+ self.assertEquals(int(group.rules[0].from_port), 80)
+ self.assertEquals(int(group.rules[0].to_port), 81)
+ self.assertEquals(len(group.rules[0].grants), 1)
+ self.assertEquals(str(group.rules[0].grants[0]), '0.0.0.0/0')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+ group.connection = self.ec2
+
+ group.revoke('tcp', 80, 81, '0.0.0.0/0')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ self.ec2.delete_security_group(security_group_name)
+
+ self.expect_http()
+ self.mox.ReplayAll()
+ group.connection = self.ec2
+
+ rv = self.ec2.get_all_security_groups()
+
+ self.assertEqual(len(rv), 1)
+ self.assertEqual(rv[0].name, 'default')
+
self.manager.delete_project(project)
self.manager.delete_user(user)
+
+ return
+
+ def test_authorize_revoke_security_group_foreign_group(self):
+ """
+ Test that we can grant and revoke another security group access
+ to a security group
+ """
+ self.expect_http()
+ self.mox.ReplayAll()
+ user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
+ project = self.manager.create_project('fake', 'fake', 'fake')
+
+ # At the moment, you need both of these to actually be netadmin
+ self.manager.add_role('fake', 'netadmin')
+ project.add_role('fake', 'netadmin')
+
+ security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
+ for x in range(random.randint(4, 8)))
+ other_security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
+ for x in range(random.randint(4, 8)))
+
+ group = self.ec2.create_security_group(security_group_name, 'test group')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ other_group = self.ec2.create_security_group(other_security_group_name,
+ 'some other group')
+
+ self.expect_http()
+ self.mox.ReplayAll()
+ group.connection = self.ec2
+
+ group.authorize(src_group=other_group)
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ rv = self.ec2.get_all_security_groups()
+
+ # I don't bother checkng that we actually find it here,
+ # because the create/delete unit test further up should
+ # be good enough for that.
+ for group in rv:
+ if group.name == security_group_name:
+ self.assertEquals(len(group.rules), 1)
+ self.assertEquals(len(group.rules[0].grants), 1)
+ self.assertEquals(str(group.rules[0].grants[0]),
+ '%s-%s' % (other_security_group_name, 'fake'))
+
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ rv = self.ec2.get_all_security_groups()
+
+ for group in rv:
+ if group.name == security_group_name:
+ self.expect_http()
+ self.mox.ReplayAll()
+ group.connection = self.ec2
+ group.revoke(src_group=other_group)
+
+ self.expect_http()
+ self.mox.ReplayAll()
+
+ self.ec2.delete_security_group(security_group_name)
+
+ self.manager.delete_project(project)
+ self.manager.delete_user(user)
+
+ return
diff --git a/nova/tests/auth_unittest.py b/nova/tests/auth_unittest.py
index 0b404bfdc..99f7ab599 100644
--- a/nova/tests/auth_unittest.py
+++ b/nova/tests/auth_unittest.py
@@ -17,8 +17,6 @@
# under the License.
import logging
-from M2Crypto import BIO
-from M2Crypto import RSA
from M2Crypto import X509
import unittest
@@ -26,31 +24,76 @@ from nova import crypto
from nova import flags
from nova import test
from nova.auth import manager
-from nova.endpoint import cloud
+from nova.api.ec2 import cloud
FLAGS = flags.FLAGS
-
-class AuthTestCase(test.BaseTestCase):
- flush_db = False
+class user_generator(object):
+ def __init__(self, manager, **user_state):
+ if 'name' not in user_state:
+ user_state['name'] = 'test1'
+ self.manager = manager
+ self.user = manager.create_user(**user_state)
+
+ def __enter__(self):
+ return self.user
+
+ def __exit__(self, value, type, trace):
+ self.manager.delete_user(self.user)
+
+class project_generator(object):
+ def __init__(self, manager, **project_state):
+ if 'name' not in project_state:
+ project_state['name'] = 'testproj'
+ if 'manager_user' not in project_state:
+ project_state['manager_user'] = 'test1'
+ self.manager = manager
+ self.project = manager.create_project(**project_state)
+
+ def __enter__(self):
+ return self.project
+
+ def __exit__(self, value, type, trace):
+ self.manager.delete_project(self.project)
+
+class user_and_project_generator(object):
+ def __init__(self, manager, user_state={}, project_state={}):
+ self.manager = manager
+ if 'name' not in user_state:
+ user_state['name'] = 'test1'
+ if 'name' not in project_state:
+ project_state['name'] = 'testproj'
+ if 'manager_user' not in project_state:
+ project_state['manager_user'] = 'test1'
+ self.user = manager.create_user(**user_state)
+ self.project = manager.create_project(**project_state)
+
+ def __enter__(self):
+ return (self.user, self.project)
+
+ def __exit__(self, value, type, trace):
+ self.manager.delete_user(self.user)
+ self.manager.delete_project(self.project)
+
+class AuthManagerTestCase(object):
def setUp(self):
- super(AuthTestCase, self).setUp()
- self.flags(connection_type='fake',
- fake_storage=True)
+ FLAGS.auth_driver = self.auth_driver
+ super(AuthManagerTestCase, self).setUp()
+ self.flags(connection_type='fake')
self.manager = manager.AuthManager()
- def test_001_can_create_users(self):
- self.manager.create_user('test1', 'access', 'secret')
- self.manager.create_user('test2')
-
- def test_002_can_get_user(self):
- user = self.manager.get_user('test1')
+ def test_create_and_find_user(self):
+ with user_generator(self.manager):
+ self.assert_(self.manager.get_user('test1'))
- def test_003_can_retreive_properties(self):
- user = self.manager.get_user('test1')
- self.assertEqual('test1', user.id)
- self.assertEqual('access', user.access)
- self.assertEqual('secret', user.secret)
+ def test_create_and_find_with_properties(self):
+ with user_generator(self.manager, name="herbert", secret="classified",
+ access="private-party"):
+ u = self.manager.get_user('herbert')
+ self.assertEqual('herbert', u.id)
+ self.assertEqual('herbert', u.name)
+ self.assertEqual('classified', u.secret)
+ self.assertEqual('private-party', u.access)
def test_004_signature_is_valid(self):
#self.assertTrue(self.manager.authenticate( **boto.generate_url ... ? ? ? ))
@@ -67,156 +110,222 @@ class AuthTestCase(test.BaseTestCase):
'export S3_URL="http://127.0.0.1:3333/"\n' +
'export EC2_USER_ID="test1"\n')
- def test_006_test_key_storage(self):
- user = self.manager.get_user('test1')
- user.create_key_pair('public', 'key', 'fingerprint')
- key = user.get_key_pair('public')
- self.assertEqual('key', key.public_key)
- self.assertEqual('fingerprint', key.fingerprint)
-
- def test_007_test_key_generation(self):
- user = self.manager.get_user('test1')
- private_key, fingerprint = user.generate_key_pair('public2')
- key = RSA.load_key_string(private_key, callback=lambda: None)
- bio = BIO.MemoryBuffer()
- public_key = user.get_key_pair('public2').public_key
- key.save_pub_key_bio(bio)
- converted = crypto.ssl_pub_to_ssh_pub(bio.read())
- # assert key fields are equal
- self.assertEqual(public_key.split(" ")[1].strip(),
- converted.split(" ")[1].strip())
-
- def test_008_can_list_key_pairs(self):
- keys = self.manager.get_user('test1').get_key_pairs()
- self.assertTrue(filter(lambda k: k.name == 'public', keys))
- self.assertTrue(filter(lambda k: k.name == 'public2', keys))
-
- def test_009_can_delete_key_pair(self):
- self.manager.get_user('test1').delete_key_pair('public')
- keys = self.manager.get_user('test1').get_key_pairs()
- self.assertFalse(filter(lambda k: k.name == 'public', keys))
-
- def test_010_can_list_users(self):
- users = self.manager.get_users()
- logging.warn(users)
- self.assertTrue(filter(lambda u: u.id == 'test1', users))
-
- def test_101_can_add_user_role(self):
- self.assertFalse(self.manager.has_role('test1', 'itsec'))
- self.manager.add_role('test1', 'itsec')
- self.assertTrue(self.manager.has_role('test1', 'itsec'))
-
- def test_199_can_remove_user_role(self):
- self.assertTrue(self.manager.has_role('test1', 'itsec'))
- self.manager.remove_role('test1', 'itsec')
- self.assertFalse(self.manager.has_role('test1', 'itsec'))
-
- def test_201_can_create_project(self):
- project = self.manager.create_project('testproj', 'test1', 'A test project', ['test1'])
- self.assertTrue(filter(lambda p: p.name == 'testproj', self.manager.get_projects()))
- self.assertEqual(project.name, 'testproj')
- self.assertEqual(project.description, 'A test project')
- self.assertEqual(project.project_manager_id, 'test1')
- self.assertTrue(project.has_member('test1'))
-
- def test_202_user1_is_project_member(self):
- self.assertTrue(self.manager.get_user('test1').is_project_member('testproj'))
-
- def test_203_user2_is_not_project_member(self):
- self.assertFalse(self.manager.get_user('test2').is_project_member('testproj'))
-
- def test_204_user1_is_project_manager(self):
- self.assertTrue(self.manager.get_user('test1').is_project_manager('testproj'))
-
- def test_205_user2_is_not_project_manager(self):
- self.assertFalse(self.manager.get_user('test2').is_project_manager('testproj'))
-
- def test_206_can_add_user_to_project(self):
- self.manager.add_to_project('test2', 'testproj')
- self.assertTrue(self.manager.get_project('testproj').has_member('test2'))
-
- def test_207_can_remove_user_from_project(self):
- self.manager.remove_from_project('test2', 'testproj')
- self.assertFalse(self.manager.get_project('testproj').has_member('test2'))
-
- def test_208_can_remove_add_user_with_role(self):
- self.manager.add_to_project('test2', 'testproj')
- self.manager.add_role('test2', 'developer', 'testproj')
- self.manager.remove_from_project('test2', 'testproj')
- self.assertFalse(self.manager.has_role('test2', 'developer', 'testproj'))
- self.manager.add_to_project('test2', 'testproj')
- self.manager.remove_from_project('test2', 'testproj')
-
- def test_209_can_generate_x509(self):
- # MUST HAVE RUN CLOUD SETUP BY NOW
- self.cloud = cloud.CloudController()
- self.cloud.setup()
- _key, cert_str = self.manager._generate_x509_cert('test1', 'testproj')
- logging.debug(cert_str)
-
- # Need to verify that it's signed by the right intermediate CA
- full_chain = crypto.fetch_ca(project_id='testproj', chain=True)
- int_cert = crypto.fetch_ca(project_id='testproj', chain=False)
- cloud_cert = crypto.fetch_ca()
- logging.debug("CA chain:\n\n =====\n%s\n\n=====" % full_chain)
- signed_cert = X509.load_cert_string(cert_str)
- chain_cert = X509.load_cert_string(full_chain)
- int_cert = X509.load_cert_string(int_cert)
- cloud_cert = X509.load_cert_string(cloud_cert)
- self.assertTrue(signed_cert.verify(chain_cert.get_pubkey()))
- self.assertTrue(signed_cert.verify(int_cert.get_pubkey()))
-
- if not FLAGS.use_intermediate_ca:
- self.assertTrue(signed_cert.verify(cloud_cert.get_pubkey()))
- else:
- self.assertFalse(signed_cert.verify(cloud_cert.get_pubkey()))
-
- def test_210_can_add_project_role(self):
- project = self.manager.get_project('testproj')
- self.assertFalse(project.has_role('test1', 'sysadmin'))
- self.manager.add_role('test1', 'sysadmin')
- self.assertFalse(project.has_role('test1', 'sysadmin'))
- project.add_role('test1', 'sysadmin')
- self.assertTrue(project.has_role('test1', 'sysadmin'))
-
- def test_211_can_list_project_roles(self):
- project = self.manager.get_project('testproj')
- user = self.manager.get_user('test1')
- self.manager.add_role(user, 'netadmin', project)
- roles = self.manager.get_user_roles(user)
- self.assertTrue('sysadmin' in roles)
- self.assertFalse('netadmin' in roles)
- project_roles = self.manager.get_user_roles(user, project)
- self.assertTrue('sysadmin' in project_roles)
- self.assertTrue('netadmin' in project_roles)
- # has role should be false because global role is missing
- self.assertFalse(self.manager.has_role(user, 'netadmin', project))
-
-
- def test_212_can_remove_project_role(self):
- project = self.manager.get_project('testproj')
- self.assertTrue(project.has_role('test1', 'sysadmin'))
- project.remove_role('test1', 'sysadmin')
- self.assertFalse(project.has_role('test1', 'sysadmin'))
- self.manager.remove_role('test1', 'sysadmin')
- self.assertFalse(project.has_role('test1', 'sysadmin'))
-
- def test_214_can_retrieve_project_by_user(self):
- project = self.manager.create_project('testproj2', 'test2', 'Another test project', ['test2'])
- self.assert_(len(self.manager.get_projects()) > 1)
- self.assertEqual(len(self.manager.get_projects('test2')), 1)
-
- def test_299_can_delete_project(self):
- self.manager.delete_project('testproj')
- self.assertFalse(filter(lambda p: p.name == 'testproj', self.manager.get_projects()))
- self.manager.delete_project('testproj2')
-
- def test_999_can_delete_users(self):
+ def test_can_list_users(self):
+ with user_generator(self.manager):
+ with user_generator(self.manager, name="test2"):
+ users = self.manager.get_users()
+ self.assert_(filter(lambda u: u.id == 'test1', users))
+ self.assert_(filter(lambda u: u.id == 'test2', users))
+ self.assert_(not filter(lambda u: u.id == 'test3', users))
+
+ def test_can_add_and_remove_user_role(self):
+ with user_generator(self.manager):
+ self.assertFalse(self.manager.has_role('test1', 'itsec'))
+ self.manager.add_role('test1', 'itsec')
+ self.assertTrue(self.manager.has_role('test1', 'itsec'))
+ self.manager.remove_role('test1', 'itsec')
+ self.assertFalse(self.manager.has_role('test1', 'itsec'))
+
+ def test_can_create_and_get_project(self):
+ with user_and_project_generator(self.manager) as (u,p):
+ self.assert_(self.manager.get_user('test1'))
+ self.assert_(self.manager.get_user('test1'))
+ self.assert_(self.manager.get_project('testproj'))
+
+ def test_can_list_projects(self):
+ with user_and_project_generator(self.manager):
+ with project_generator(self.manager, name="testproj2"):
+ projects = self.manager.get_projects()
+ self.assert_(filter(lambda p: p.name == 'testproj', projects))
+ self.assert_(filter(lambda p: p.name == 'testproj2', projects))
+ self.assert_(not filter(lambda p: p.name == 'testproj3',
+ projects))
+
+ def test_can_create_and_get_project_with_attributes(self):
+ with user_generator(self.manager):
+ with project_generator(self.manager, description='A test project'):
+ project = self.manager.get_project('testproj')
+ self.assertEqual('A test project', project.description)
+
+ def test_can_create_project_with_manager(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.assertEqual('test1', project.project_manager_id)
+ self.assertTrue(self.manager.is_project_manager(user, project))
+
+ def test_create_project_assigns_manager_to_members(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.assertTrue(self.manager.is_project_member(user, project))
+
+ def test_no_extra_project_members(self):
+ with user_generator(self.manager, name='test2') as baduser:
+ with user_and_project_generator(self.manager) as (user, project):
+ self.assertFalse(self.manager.is_project_member(baduser,
+ project))
+
+ def test_no_extra_project_managers(self):
+ with user_generator(self.manager, name='test2') as baduser:
+ with user_and_project_generator(self.manager) as (user, project):
+ self.assertFalse(self.manager.is_project_manager(baduser,
+ project))
+
+ def test_can_add_user_to_project(self):
+ with user_generator(self.manager, name='test2') as user:
+ with user_and_project_generator(self.manager) as (_user, project):
+ self.manager.add_to_project(user, project)
+ project = self.manager.get_project('testproj')
+ self.assertTrue(self.manager.is_project_member(user, project))
+
+ def test_can_remove_user_from_project(self):
+ with user_generator(self.manager, name='test2') as user:
+ with user_and_project_generator(self.manager) as (_user, project):
+ self.manager.add_to_project(user, project)
+ project = self.manager.get_project('testproj')
+ self.assertTrue(self.manager.is_project_member(user, project))
+ self.manager.remove_from_project(user, project)
+ project = self.manager.get_project('testproj')
+ self.assertFalse(self.manager.is_project_member(user, project))
+
+ def test_can_add_remove_user_with_role(self):
+ with user_generator(self.manager, name='test2') as user:
+ with user_and_project_generator(self.manager) as (_user, project):
+ # NOTE(todd): after modifying users you must reload project
+ self.manager.add_to_project(user, project)
+ project = self.manager.get_project('testproj')
+ self.manager.add_role(user, 'developer', project)
+ self.assertTrue(self.manager.is_project_member(user, project))
+ self.manager.remove_from_project(user, project)
+ project = self.manager.get_project('testproj')
+ self.assertFalse(self.manager.has_role(user, 'developer',
+ project))
+ self.assertFalse(self.manager.is_project_member(user, project))
+
+ def test_can_generate_x509(self):
+ # NOTE(todd): this doesn't assert against the auth manager
+ # so it probably belongs in crypto_unittest
+ # but I'm leaving it where I found it.
+ with user_and_project_generator(self.manager) as (user, project):
+ # NOTE(todd): Should mention why we must setup controller first
+ # (somebody please clue me in)
+ cloud_controller = cloud.CloudController()
+ cloud_controller.setup()
+ _key, cert_str = self.manager._generate_x509_cert('test1',
+ 'testproj')
+ logging.debug(cert_str)
+
+ # Need to verify that it's signed by the right intermediate CA
+ full_chain = crypto.fetch_ca(project_id='testproj', chain=True)
+ int_cert = crypto.fetch_ca(project_id='testproj', chain=False)
+ cloud_cert = crypto.fetch_ca()
+ logging.debug("CA chain:\n\n =====\n%s\n\n=====" % full_chain)
+ signed_cert = X509.load_cert_string(cert_str)
+ chain_cert = X509.load_cert_string(full_chain)
+ int_cert = X509.load_cert_string(int_cert)
+ cloud_cert = X509.load_cert_string(cloud_cert)
+ self.assertTrue(signed_cert.verify(chain_cert.get_pubkey()))
+ self.assertTrue(signed_cert.verify(int_cert.get_pubkey()))
+ if not FLAGS.use_intermediate_ca:
+ self.assertTrue(signed_cert.verify(cloud_cert.get_pubkey()))
+ else:
+ self.assertFalse(signed_cert.verify(cloud_cert.get_pubkey()))
+
+ def test_adding_role_to_project_is_ignored_unless_added_to_user(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
+ self.manager.add_role(user, 'sysadmin', project)
+ # NOTE(todd): it will still show up in get_user_roles(u, project)
+ self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
+ self.manager.add_role(user, 'sysadmin')
+ self.assertTrue(self.manager.has_role(user, 'sysadmin', project))
+
+ def test_add_user_role_doesnt_infect_project_roles(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
+ self.manager.add_role(user, 'sysadmin')
+ self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
+
+ def test_can_list_user_roles(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.manager.add_role(user, 'sysadmin')
+ roles = self.manager.get_user_roles(user)
+ self.assertTrue('sysadmin' in roles)
+ self.assertFalse('netadmin' in roles)
+
+ def test_can_list_project_roles(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.manager.add_role(user, 'sysadmin')
+ self.manager.add_role(user, 'sysadmin', project)
+ self.manager.add_role(user, 'netadmin', project)
+ project_roles = self.manager.get_user_roles(user, project)
+ self.assertTrue('sysadmin' in project_roles)
+ self.assertTrue('netadmin' in project_roles)
+ # has role should be false user-level role is missing
+ self.assertFalse(self.manager.has_role(user, 'netadmin', project))
+
+ def test_can_remove_user_roles(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.manager.add_role(user, 'sysadmin')
+ self.assertTrue(self.manager.has_role(user, 'sysadmin'))
+ self.manager.remove_role(user, 'sysadmin')
+ self.assertFalse(self.manager.has_role(user, 'sysadmin'))
+
+ def test_removing_user_role_hides_it_from_project(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.manager.add_role(user, 'sysadmin')
+ self.manager.add_role(user, 'sysadmin', project)
+ self.assertTrue(self.manager.has_role(user, 'sysadmin', project))
+ self.manager.remove_role(user, 'sysadmin')
+ self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
+
+ def test_can_remove_project_role_but_keep_user_role(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.manager.add_role(user, 'sysadmin')
+ self.manager.add_role(user, 'sysadmin', project)
+ self.assertTrue(self.manager.has_role(user, 'sysadmin'))
+ self.manager.remove_role(user, 'sysadmin', project)
+ self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
+ self.assertTrue(self.manager.has_role(user, 'sysadmin'))
+
+ def test_can_retrieve_project_by_user(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.assertEqual(1, len(self.manager.get_projects('test1')))
+
+ def test_can_modify_project(self):
+ with user_and_project_generator(self.manager):
+ with user_generator(self.manager, name='test2'):
+ self.manager.modify_project('testproj', 'test2', 'new desc')
+ project = self.manager.get_project('testproj')
+ self.assertEqual('test2', project.project_manager_id)
+ self.assertEqual('new desc', project.description)
+
+ def test_can_delete_project(self):
+ with user_generator(self.manager):
+ self.manager.create_project('testproj', 'test1')
+ self.assert_(self.manager.get_project('testproj'))
+ self.manager.delete_project('testproj')
+ projectlist = self.manager.get_projects()
+ self.assert_(not filter(lambda p: p.name == 'testproj',
+ projectlist))
+
+ def test_can_delete_user(self):
+ self.manager.create_user('test1')
+ self.assert_(self.manager.get_user('test1'))
self.manager.delete_user('test1')
- users = self.manager.get_users()
- self.assertFalse(filter(lambda u: u.id == 'test1', users))
- self.manager.delete_user('test2')
- self.assertEqual(self.manager.get_user('test2'), None)
+ userlist = self.manager.get_users()
+ self.assert_(not filter(lambda u: u.id == 'test1', userlist))
+
+ def test_can_modify_users(self):
+ with user_generator(self.manager):
+ self.manager.modify_user('test1', 'access', 'secret', True)
+ user = self.manager.get_user('test1')
+ self.assertEqual('access', user.access)
+ self.assertEqual('secret', user.secret)
+ self.assertTrue(user.is_admin())
+
+class AuthManagerLdapTestCase(AuthManagerTestCase, test.TrialTestCase):
+ auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver'
+
+class AuthManagerDbTestCase(AuthManagerTestCase, test.TrialTestCase):
+ auth_driver = 'nova.auth.dbdriver.DbDriver'
if __name__ == "__main__":
diff --git a/nova/tests/bundle/1mb.manifest.xml b/nova/tests/bundle/1mb.manifest.xml
index dc3315957..01648a544 100644
--- a/nova/tests/bundle/1mb.manifest.xml
+++ b/nova/tests/bundle/1mb.manifest.xml
@@ -1 +1 @@
-<?xml version="1.0" ?><manifest><version>2007-10-10</version><bundler><name>euca-tools</name><version>1.2</version><release>31337</release></bundler><machine_configuration><architecture>x86_64</architecture></machine_configuration><image><name>1mb</name><user>42</user><type>machine</type><digest algorithm="SHA1">da39a3ee5e6b4b0d3255bfef95601890afd80709</digest><size>1048576</size><bundled_size>1136</bundled_size><ec2_encrypted_key algorithm="AES-128-CBC">33a2ea00dc64083dd9a10eb5e233635b42a7beb1670ab75452087d9de74c60aba1cd27c136fda56f62beb581de128fb1f10d072b9e556fd25e903107a57827c21f6ee8a93a4ff55b11311fcef217e3eefb07e81f71e88216f43b4b54029c1f2549f2925a839a73947d2d5aeecec4a62ece4af9156d557ae907978298296d9915</ec2_encrypted_key><user_encrypted_key algorithm="AES-128-CBC">4c11147fd8caf92447e90ce339928933d7579244c2f8ffb07cc0ea35f8738da8b90eff6c7a49671a84500e993e9462e4c36d5c19c0b3a2b397d035b4c0cce742b58e12552175d81d129b0425e9f71ebacb9aeb539fa9dd2ac36749fb82876f6902e5fb24b6ec19f35ec4c20acd50437fd30966e99c4d9a0647577970a8fa3023</user_encrypted_key><ec2_encrypted_iv>14bd082c9715f071160c69bbfb070f51d2ba1076775f1d988ccde150e515088156b248e4b5a64e46c4fe064feeeedfe14511f7fde478a51acb89f9b2f6c84b60593e5c3f792ba6b01fed9bf2158fdac03086374883b39d13a3ca74497eeaaf579fc3f26effc73bfd9446a2a8c4061f0874bfaca058905180e22d3d8881551cb3</ec2_encrypted_iv><user_encrypted_iv>8f7606f19f00e4e19535dd234b66b31b77e9c7bad3885d9c9efa75c863631fd4f82a009e17d789066d9cc6032a436f05384832f6d9a3283d3e63eab04fa0da5c8c87db9b17e854e842c3fb416507d067a266b44538125ce732e486098e8ebd1ca91fa3079f007fce7d14957a9b7e57282407ead3c6eb68fe975df3d83190021b</user_encrypted_iv><parts count="2"><part index="0"><filename>1mb.part.0</filename><digest algorithm="SHA1">c4413423cf7a57e71187e19bfd5cd4b514a64283</digest></part><part index="1"><filename>1mb.part.1</filename><digest algorithm="SHA1">9d4262e6589393d09a11a0332af169887bc2e57d</digest></part></parts></image><signature>4e00b5ba28114dda4a9df7eeae94be847ec46117a09a1cbe41e578660642f0660dda1776b39fb3bf826b6cfec019e2a5e9c566728d186b7400ebc989a30670eb1db26ce01e68bd9d3f31290370077a85b81c66b63c1e0d5499bac115c06c17a21a81b6d3a67ebbce6c17019095af7ab07f3796c708cc843e58efc12ddc788c5e</signature></manifest> \ No newline at end of file
+<?xml version="1.0" ?><manifest><version>2007-10-10</version><bundler><name>euca-tools</name><version>1.2</version><release>31337</release></bundler><machine_configuration><architecture>x86_64</architecture><kernel_id>aki-test</kernel_id><ramdisk_id>ari-test</ramdisk_id></machine_configuration><image><name>1mb</name><user>42</user><type>machine</type><digest algorithm="SHA1">da39a3ee5e6b4b0d3255bfef95601890afd80709</digest><size>1048576</size><bundled_size>1136</bundled_size><ec2_encrypted_key algorithm="AES-128-CBC">33a2ea00dc64083dd9a10eb5e233635b42a7beb1670ab75452087d9de74c60aba1cd27c136fda56f62beb581de128fb1f10d072b9e556fd25e903107a57827c21f6ee8a93a4ff55b11311fcef217e3eefb07e81f71e88216f43b4b54029c1f2549f2925a839a73947d2d5aeecec4a62ece4af9156d557ae907978298296d9915</ec2_encrypted_key><user_encrypted_key algorithm="AES-128-CBC">4c11147fd8caf92447e90ce339928933d7579244c2f8ffb07cc0ea35f8738da8b90eff6c7a49671a84500e993e9462e4c36d5c19c0b3a2b397d035b4c0cce742b58e12552175d81d129b0425e9f71ebacb9aeb539fa9dd2ac36749fb82876f6902e5fb24b6ec19f35ec4c20acd50437fd30966e99c4d9a0647577970a8fa3023</user_encrypted_key><ec2_encrypted_iv>14bd082c9715f071160c69bbfb070f51d2ba1076775f1d988ccde150e515088156b248e4b5a64e46c4fe064feeeedfe14511f7fde478a51acb89f9b2f6c84b60593e5c3f792ba6b01fed9bf2158fdac03086374883b39d13a3ca74497eeaaf579fc3f26effc73bfd9446a2a8c4061f0874bfaca058905180e22d3d8881551cb3</ec2_encrypted_iv><user_encrypted_iv>8f7606f19f00e4e19535dd234b66b31b77e9c7bad3885d9c9efa75c863631fd4f82a009e17d789066d9cc6032a436f05384832f6d9a3283d3e63eab04fa0da5c8c87db9b17e854e842c3fb416507d067a266b44538125ce732e486098e8ebd1ca91fa3079f007fce7d14957a9b7e57282407ead3c6eb68fe975df3d83190021b</user_encrypted_iv><parts count="2"><part index="0"><filename>1mb.part.0</filename><digest algorithm="SHA1">c4413423cf7a57e71187e19bfd5cd4b514a64283</digest></part><part index="1"><filename>1mb.part.1</filename><digest algorithm="SHA1">9d4262e6589393d09a11a0332af169887bc2e57d</digest></part></parts></image><signature>4e00b5ba28114dda4a9df7eeae94be847ec46117a09a1cbe41e578660642f0660dda1776b39fb3bf826b6cfec019e2a5e9c566728d186b7400ebc989a30670eb1db26ce01e68bd9d3f31290370077a85b81c66b63c1e0d5499bac115c06c17a21a81b6d3a67ebbce6c17019095af7ab07f3796c708cc843e58efc12ddc788c5e</signature></manifest>
diff --git a/nova/tests/bundle/1mb.no_kernel_or_ramdisk.manifest.xml b/nova/tests/bundle/1mb.no_kernel_or_ramdisk.manifest.xml
new file mode 100644
index 000000000..73d7ace00
--- /dev/null
+++ b/nova/tests/bundle/1mb.no_kernel_or_ramdisk.manifest.xml
@@ -0,0 +1 @@
+<?xml version="1.0" ?><manifest><version>2007-10-10</version><bundler><name>euca-tools</name><version>1.2</version><release>31337</release></bundler><machine_configuration><architecture>x86_64</architecture></machine_configuration><image><name>1mb</name><user>42</user><type>machine</type><digest algorithm="SHA1">da39a3ee5e6b4b0d3255bfef95601890afd80709</digest><size>1048576</size><bundled_size>1136</bundled_size><ec2_encrypted_key algorithm="AES-128-CBC">33a2ea00dc64083dd9a10eb5e233635b42a7beb1670ab75452087d9de74c60aba1cd27c136fda56f62beb581de128fb1f10d072b9e556fd25e903107a57827c21f6ee8a93a4ff55b11311fcef217e3eefb07e81f71e88216f43b4b54029c1f2549f2925a839a73947d2d5aeecec4a62ece4af9156d557ae907978298296d9915</ec2_encrypted_key><user_encrypted_key algorithm="AES-128-CBC">4c11147fd8caf92447e90ce339928933d7579244c2f8ffb07cc0ea35f8738da8b90eff6c7a49671a84500e993e9462e4c36d5c19c0b3a2b397d035b4c0cce742b58e12552175d81d129b0425e9f71ebacb9aeb539fa9dd2ac36749fb82876f6902e5fb24b6ec19f35ec4c20acd50437fd30966e99c4d9a0647577970a8fa3023</user_encrypted_key><ec2_encrypted_iv>14bd082c9715f071160c69bbfb070f51d2ba1076775f1d988ccde150e515088156b248e4b5a64e46c4fe064feeeedfe14511f7fde478a51acb89f9b2f6c84b60593e5c3f792ba6b01fed9bf2158fdac03086374883b39d13a3ca74497eeaaf579fc3f26effc73bfd9446a2a8c4061f0874bfaca058905180e22d3d8881551cb3</ec2_encrypted_iv><user_encrypted_iv>8f7606f19f00e4e19535dd234b66b31b77e9c7bad3885d9c9efa75c863631fd4f82a009e17d789066d9cc6032a436f05384832f6d9a3283d3e63eab04fa0da5c8c87db9b17e854e842c3fb416507d067a266b44538125ce732e486098e8ebd1ca91fa3079f007fce7d14957a9b7e57282407ead3c6eb68fe975df3d83190021b</user_encrypted_iv><parts count="2"><part index="0"><filename>1mb.part.0</filename><digest algorithm="SHA1">c4413423cf7a57e71187e19bfd5cd4b514a64283</digest></part><part index="1"><filename>1mb.part.1</filename><digest algorithm="SHA1">9d4262e6589393d09a11a0332af169887bc2e57d</digest></part></parts></image><signature>4e00b5ba28114dda4a9df7eeae94be847ec46117a09a1cbe41e578660642f0660dda1776b39fb3bf826b6cfec019e2a5e9c566728d186b7400ebc989a30670eb1db26ce01e68bd9d3f31290370077a85b81c66b63c1e0d5499bac115c06c17a21a81b6d3a67ebbce6c17019095af7ab07f3796c708cc843e58efc12ddc788c5e</signature></manifest>
diff --git a/nova/tests/cloud_unittest.py b/nova/tests/cloud_unittest.py
index 3501771cc..ff466135d 100644
--- a/nova/tests/cloud_unittest.py
+++ b/nova/tests/cloud_unittest.py
@@ -16,70 +16,119 @@
# License for the specific language governing permissions and limitations
# under the License.
+from base64 import b64decode
+import json
import logging
+from M2Crypto import BIO
+from M2Crypto import RSA
+import os
import StringIO
+import tempfile
import time
-from tornado import ioloop
+
from twisted.internet import defer
import unittest
from xml.etree import ElementTree
+from nova import crypto
+from nova import db
from nova import flags
from nova import rpc
from nova import test
+from nova import utils
from nova.auth import manager
-from nova.compute import service
-from nova.endpoint import api
-from nova.endpoint import cloud
+from nova.compute import power_state
+from nova.api.ec2 import context
+from nova.api.ec2 import cloud
+from nova.objectstore import image
FLAGS = flags.FLAGS
-class CloudTestCase(test.BaseTestCase):
+# Temp dirs for working with image attributes through the cloud controller
+# (stole this from objectstore_unittest.py)
+OSS_TEMPDIR = tempfile.mkdtemp(prefix='test_oss-')
+IMAGES_PATH = os.path.join(OSS_TEMPDIR, 'images')
+os.makedirs(IMAGES_PATH)
+
+class CloudTestCase(test.TrialTestCase):
def setUp(self):
super(CloudTestCase, self).setUp()
- self.flags(connection_type='fake',
- fake_storage=True)
+ self.flags(connection_type='fake', images_path=IMAGES_PATH)
self.conn = rpc.Connection.instance()
logging.getLogger().setLevel(logging.DEBUG)
# set up our cloud
self.cloud = cloud.CloudController()
- self.cloud_consumer = rpc.AdapterConsumer(connection=self.conn,
- topic=FLAGS.cloud_topic,
- proxy=self.cloud)
- self.injected.append(self.cloud_consumer.attach_to_tornado(self.ioloop))
# set up a service
- self.compute = service.ComputeService()
+ self.compute = utils.import_object(FLAGS.compute_manager)
self.compute_consumer = rpc.AdapterConsumer(connection=self.conn,
- topic=FLAGS.compute_topic,
- proxy=self.compute)
- self.injected.append(self.compute_consumer.attach_to_tornado(self.ioloop))
+ topic=FLAGS.compute_topic,
+ proxy=self.compute)
+ self.compute_consumer.attach_to_eventlet()
+ self.network = utils.import_object(FLAGS.network_manager)
+ self.network_consumer = rpc.AdapterConsumer(connection=self.conn,
+ topic=FLAGS.network_topic,
+ proxy=self.network)
+ self.network_consumer.attach_to_eventlet()
- try:
- manager.AuthManager().create_user('admin', 'admin', 'admin')
- except: pass
- admin = manager.AuthManager().get_user('admin')
- project = manager.AuthManager().create_project('proj', 'admin', 'proj')
- self.context = api.APIRequestContext(handler=None,project=project,user=admin)
+ self.manager = manager.AuthManager()
+ self.user = self.manager.create_user('admin', 'admin', 'admin', True)
+ self.project = self.manager.create_project('proj', 'admin', 'proj')
+ self.context = context.APIRequestContext(user=self.user,
+ project=self.project)
def tearDown(self):
- manager.AuthManager().delete_project('proj')
- manager.AuthManager().delete_user('admin')
+ self.manager.delete_project(self.project)
+ self.manager.delete_user(self.user)
+ super(CloudTestCase, self).tearDown()
+
+ def _create_key(self, name):
+ # NOTE(vish): create depends on pool, so just call helper directly
+ return cloud._gen_key(self.context, self.context.user.id, name)
def test_console_output(self):
- if FLAGS.connection_type == 'fake':
- logging.debug("Can't test instances without a real virtual env.")
- return
- instance_id = 'foo'
- inst = yield self.compute.run_instance(instance_id)
- output = yield self.cloud.get_console_output(self.context, [instance_id])
- logging.debug(output)
- self.assert_(output)
- rv = yield self.compute.terminate_instance(instance_id)
+ image_id = FLAGS.default_image
+ instance_type = FLAGS.default_instance_type
+ max_count = 1
+ kwargs = {'image_id': image_id,
+ 'instance_type': instance_type,
+ 'max_count': max_count }
+ rv = yield self.cloud.run_instances(self.context, **kwargs)
+ instance_id = rv['instancesSet'][0]['instanceId']
+ output = yield self.cloud.get_console_output(context=self.context, instance_id=[instance_id])
+ self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE OUTPUT')
+ rv = yield self.cloud.terminate_instances(self.context, [instance_id])
+
+
+ def test_key_generation(self):
+ result = self._create_key('test')
+ private_key = result['private_key']
+ key = RSA.load_key_string(private_key, callback=lambda: None)
+ bio = BIO.MemoryBuffer()
+ public_key = db.key_pair_get(self.context,
+ self.context.user.id,
+ 'test')['public_key']
+ key.save_pub_key_bio(bio)
+ converted = crypto.ssl_pub_to_ssh_pub(bio.read())
+ # assert key fields are equal
+ self.assertEqual(public_key.split(" ")[1].strip(),
+ converted.split(" ")[1].strip())
+
+ def test_describe_key_pairs(self):
+ self._create_key('test1')
+ self._create_key('test2')
+ result = self.cloud.describe_key_pairs(self.context)
+ keys = result["keypairsSet"]
+ self.assertTrue(filter(lambda k: k['keyName'] == 'test1', keys))
+ self.assertTrue(filter(lambda k: k['keyName'] == 'test2', keys))
+
+ def test_delete_key_pair(self):
+ self._create_key('test')
+ self.cloud.delete_key_pair(self.context, 'test')
def test_run_instances(self):
if FLAGS.connection_type == 'fake':
@@ -99,7 +148,7 @@ class CloudTestCase(test.BaseTestCase):
rv = yield defer.succeed(time.sleep(1))
info = self.cloud._get_instance(instance['instance_id'])
logging.debug(info['state'])
- if info['state'] == node.Instance.RUNNING:
+ if info['state'] == power_state.RUNNING:
break
self.assert_(rv)
@@ -160,3 +209,68 @@ class CloudTestCase(test.BaseTestCase):
#for i in xrange(4):
# data = self.cloud.get_metadata(instance(i)['private_dns_name'])
# self.assert_(data['meta-data']['ami-id'] == 'ami-%s' % i)
+
+ @staticmethod
+ def _fake_set_image_description(ctxt, image_id, description):
+ from nova.objectstore import handler
+ class req:
+ pass
+ request = req()
+ request.context = ctxt
+ request.args = {'image_id': [image_id],
+ 'description': [description]}
+
+ resource = handler.ImagesResource()
+ resource.render_POST(request)
+
+ def test_user_editable_image_endpoint(self):
+ pathdir = os.path.join(FLAGS.images_path, 'ami-testing')
+ os.mkdir(pathdir)
+ info = {'isPublic': False}
+ with open(os.path.join(pathdir, 'info.json'), 'w') as f:
+ json.dump(info, f)
+ img = image.Image('ami-testing')
+ # self.cloud.set_image_description(self.context, 'ami-testing',
+ # 'Foo Img')
+ # NOTE(vish): Above won't work unless we start objectstore or create
+ # a fake version of api/ec2/images.py conn that can
+ # call methods directly instead of going through boto.
+ # for now, just cheat and call the method directly
+ self._fake_set_image_description(self.context, 'ami-testing',
+ 'Foo Img')
+ self.assertEqual('Foo Img', img.metadata['description'])
+ self._fake_set_image_description(self.context, 'ami-testing', '')
+ self.assertEqual('', img.metadata['description'])
+
+ def test_update_of_instance_display_fields(self):
+ inst = db.instance_create({}, {})
+ ec2_id = cloud.internal_id_to_ec2_id(inst['internal_id'])
+ self.cloud.update_instance(self.context, ec2_id,
+ display_name='c00l 1m4g3')
+ inst = db.instance_get({}, inst['id'])
+ self.assertEqual('c00l 1m4g3', inst['display_name'])
+ db.instance_destroy({}, inst['id'])
+
+ def test_update_of_instance_wont_update_private_fields(self):
+ inst = db.instance_create({}, {})
+ self.cloud.update_instance(self.context, inst['id'],
+ mac_address='DE:AD:BE:EF')
+ inst = db.instance_get({}, inst['id'])
+ self.assertEqual(None, inst['mac_address'])
+ db.instance_destroy({}, inst['id'])
+
+ def test_update_of_volume_display_fields(self):
+ vol = db.volume_create({}, {})
+ self.cloud.update_volume(self.context, vol['id'],
+ display_name='c00l v0lum3')
+ vol = db.volume_get({}, vol['id'])
+ self.assertEqual('c00l v0lum3', vol['display_name'])
+ db.volume_destroy({}, vol['id'])
+
+ def test_update_of_volume_wont_update_private_fields(self):
+ vol = db.volume_create({}, {})
+ self.cloud.update_volume(self.context, vol['id'],
+ mountpoint='/not/here')
+ vol = db.volume_get({}, vol['id'])
+ self.assertEqual(None, vol['mountpoint'])
+ db.volume_destroy({}, vol['id'])
diff --git a/nova/tests/compute_unittest.py b/nova/tests/compute_unittest.py
index da0f82e3a..5a7f170f3 100644
--- a/nova/tests/compute_unittest.py
+++ b/nova/tests/compute_unittest.py
@@ -15,113 +15,119 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
+"""
+Tests For Compute
+"""
+import datetime
import logging
-import time
+
from twisted.internet import defer
-from xml.etree import ElementTree
+from nova import db
from nova import exception
from nova import flags
from nova import test
from nova import utils
-from nova.compute import model
-from nova.compute import service
-
+from nova.auth import manager
+from nova.api import context
FLAGS = flags.FLAGS
-class InstanceXmlTestCase(test.TrialTestCase):
- # @defer.inlineCallbacks
- def test_serialization(self):
- # TODO: Reimplement this, it doesn't make sense in redis-land
- return
-
- # instance_id = 'foo'
- # first_node = node.Node()
- # inst = yield first_node.run_instance(instance_id)
- #
- # # force the state so that we can verify that it changes
- # inst._s['state'] = node.Instance.NOSTATE
- # xml = inst.toXml()
- # self.assert_(ElementTree.parse(StringIO.StringIO(xml)))
- #
- # second_node = node.Node()
- # new_inst = node.Instance.fromXml(second_node._conn, pool=second_node._pool, xml=xml)
- # self.assertEqual(new_inst.state, node.Instance.RUNNING)
- # rv = yield first_node.terminate_instance(instance_id)
-
-
-class ComputeConnectionTestCase(test.TrialTestCase):
- def setUp(self):
+class ComputeTestCase(test.TrialTestCase):
+ """Test case for compute"""
+ def setUp(self): # pylint: disable-msg=C0103
logging.getLogger().setLevel(logging.DEBUG)
- super(ComputeConnectionTestCase, self).setUp()
+ super(ComputeTestCase, self).setUp()
self.flags(connection_type='fake',
- fake_storage=True)
- self.compute = service.ComputeService()
-
- def create_instance(self):
- instdir = model.InstanceDirectory()
- inst = instdir.new()
- # TODO(ja): add ami, ari, aki, user_data
+ network_manager='nova.network.manager.FlatManager')
+ self.compute = utils.import_object(FLAGS.compute_manager)
+ self.manager = manager.AuthManager()
+ self.user = self.manager.create_user('fake', 'fake', 'fake')
+ self.project = self.manager.create_project('fake', 'fake', 'fake')
+ self.context = None
+
+ def tearDown(self): # pylint: disable-msg=C0103
+ self.manager.delete_user(self.user)
+ self.manager.delete_project(self.project)
+ super(ComputeTestCase, self).tearDown()
+
+ def _create_instance(self):
+ """Create a test instance"""
+ inst = {}
+ inst['image_id'] = 'ami-test'
inst['reservation_id'] = 'r-fakeres'
inst['launch_time'] = '10'
- inst['user_id'] = 'fake'
- inst['project_id'] = 'fake'
+ inst['user_id'] = self.user.id
+ inst['project_id'] = self.project.id
inst['instance_type'] = 'm1.tiny'
- inst['node_name'] = FLAGS.node_name
inst['mac_address'] = utils.generate_mac()
inst['ami_launch_index'] = 0
- inst.save()
- return inst['instance_id']
+ return db.instance_create(self.context, inst)['id']
@defer.inlineCallbacks
- def test_run_describe_terminate(self):
- instance_id = self.create_instance()
+ def test_run_terminate(self):
+ """Make sure it is possible to run and terminate instance"""
+ instance_id = self._create_instance()
- rv = yield self.compute.run_instance(instance_id)
+ yield self.compute.run_instance(self.context, instance_id)
- rv = yield self.compute.describe_instances()
- logging.info("Running instances: %s", rv)
- self.assertEqual(rv[instance_id].name, instance_id)
+ instances = db.instance_get_all(None)
+ logging.info("Running instances: %s", instances)
+ self.assertEqual(len(instances), 1)
- rv = yield self.compute.terminate_instance(instance_id)
+ yield self.compute.terminate_instance(self.context, instance_id)
- rv = yield self.compute.describe_instances()
- logging.info("After terminating instances: %s", rv)
- self.assertEqual(rv, {})
+ instances = db.instance_get_all(None)
+ logging.info("After terminating instances: %s", instances)
+ self.assertEqual(len(instances), 0)
@defer.inlineCallbacks
- def test_reboot(self):
- instance_id = self.create_instance()
- rv = yield self.compute.run_instance(instance_id)
-
- rv = yield self.compute.describe_instances()
- self.assertEqual(rv[instance_id].name, instance_id)
+ def test_run_terminate_timestamps(self):
+ """Make sure timestamps are set for launched and destroyed"""
+ instance_id = self._create_instance()
+ instance_ref = db.instance_get(self.context, instance_id)
+ self.assertEqual(instance_ref['launched_at'], None)
+ self.assertEqual(instance_ref['deleted_at'], None)
+ launch = datetime.datetime.utcnow()
+ yield self.compute.run_instance(self.context, instance_id)
+ instance_ref = db.instance_get(self.context, instance_id)
+ self.assert_(instance_ref['launched_at'] > launch)
+ self.assertEqual(instance_ref['deleted_at'], None)
+ terminate = datetime.datetime.utcnow()
+ yield self.compute.terminate_instance(self.context, instance_id)
+ self.context = context.get_admin_context(user=self.user,
+ read_deleted=True)
+ instance_ref = db.instance_get(self.context, instance_id)
+ self.assert_(instance_ref['launched_at'] < terminate)
+ self.assert_(instance_ref['deleted_at'] > terminate)
- yield self.compute.reboot_instance(instance_id)
-
- rv = yield self.compute.describe_instances()
- self.assertEqual(rv[instance_id].name, instance_id)
- rv = yield self.compute.terminate_instance(instance_id)
+ @defer.inlineCallbacks
+ def test_reboot(self):
+ """Ensure instance can be rebooted"""
+ instance_id = self._create_instance()
+ yield self.compute.run_instance(self.context, instance_id)
+ yield self.compute.reboot_instance(self.context, instance_id)
+ yield self.compute.terminate_instance(self.context, instance_id)
@defer.inlineCallbacks
def test_console_output(self):
- instance_id = self.create_instance()
- rv = yield self.compute.run_instance(instance_id)
+ """Make sure we can get console output from instance"""
+ instance_id = self._create_instance()
+ yield self.compute.run_instance(self.context, instance_id)
- console = yield self.compute.get_console_output(instance_id)
+ console = yield self.compute.get_console_output(self.context,
+ instance_id)
self.assert_(console)
- rv = yield self.compute.terminate_instance(instance_id)
+ yield self.compute.terminate_instance(self.context, instance_id)
@defer.inlineCallbacks
def test_run_instance_existing(self):
- instance_id = self.create_instance()
- rv = yield self.compute.run_instance(instance_id)
-
- rv = yield self.compute.describe_instances()
- self.assertEqual(rv[instance_id].name, instance_id)
-
- self.assertRaises(exception.Error, self.compute.run_instance, instance_id)
- rv = yield self.compute.terminate_instance(instance_id)
+ """Ensure failure when running an instance that already exists"""
+ instance_id = self._create_instance()
+ yield self.compute.run_instance(self.context, instance_id)
+ self.assertFailure(self.compute.run_instance(self.context,
+ instance_id),
+ exception.Error)
+ yield self.compute.terminate_instance(self.context, instance_id)
diff --git a/nova/tests/fake_flags.py b/nova/tests/fake_flags.py
index a7310fb26..4bbef8832 100644
--- a/nova/tests/fake_flags.py
+++ b/nova/tests/fake_flags.py
@@ -20,9 +20,20 @@ from nova import flags
FLAGS = flags.FLAGS
+flags.DECLARE('volume_driver', 'nova.volume.manager')
+FLAGS.volume_driver = 'nova.volume.driver.FakeAOEDriver'
FLAGS.connection_type = 'fake'
-FLAGS.fake_storage = True
FLAGS.fake_rabbit = True
+FLAGS.auth_driver = 'nova.auth.dbdriver.DbDriver'
+flags.DECLARE('network_size', 'nova.network.manager')
+flags.DECLARE('num_networks', 'nova.network.manager')
+flags.DECLARE('fake_network', 'nova.network.manager')
+FLAGS.network_size = 16
+FLAGS.num_networks = 5
FLAGS.fake_network = True
-FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver'
+flags.DECLARE('num_shelves', 'nova.volume.manager')
+flags.DECLARE('blades_per_shelf', 'nova.volume.manager')
+FLAGS.num_shelves = 2
+FLAGS.blades_per_shelf = 4
FLAGS.verbose = True
+FLAGS.sql_connection = 'sqlite:///nova.sqlite'
diff --git a/nova/tests/model_unittest.py b/nova/tests/model_unittest.py
deleted file mode 100644
index dc2441c24..000000000
--- a/nova/tests/model_unittest.py
+++ /dev/null
@@ -1,292 +0,0 @@
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2010 United States Government as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-from datetime import datetime, timedelta
-import logging
-import time
-
-from nova import flags
-from nova import test
-from nova import utils
-from nova.compute import model
-
-
-FLAGS = flags.FLAGS
-
-
-class ModelTestCase(test.TrialTestCase):
- def setUp(self):
- super(ModelTestCase, self).setUp()
- self.flags(connection_type='fake',
- fake_storage=True)
-
- def tearDown(self):
- model.Instance('i-test').destroy()
- model.Host('testhost').destroy()
- model.Daemon('testhost', 'nova-testdaemon').destroy()
-
- def create_instance(self):
- inst = model.Instance('i-test')
- inst['reservation_id'] = 'r-test'
- inst['launch_time'] = '10'
- inst['user_id'] = 'fake'
- inst['project_id'] = 'fake'
- inst['instance_type'] = 'm1.tiny'
- inst['mac_address'] = utils.generate_mac()
- inst['ami_launch_index'] = 0
- inst['private_dns_name'] = '10.0.0.1'
- inst.save()
- return inst
-
- def create_host(self):
- host = model.Host('testhost')
- host.save()
- return host
-
- def create_daemon(self):
- daemon = model.Daemon('testhost', 'nova-testdaemon')
- daemon.save()
- return daemon
-
- def create_session_token(self):
- session_token = model.SessionToken('tk12341234')
- session_token['user'] = 'testuser'
- session_token.save()
- return session_token
-
- def test_create_instance(self):
- """store with create_instace, then test that a load finds it"""
- instance = self.create_instance()
- old = model.Instance(instance.identifier)
- self.assertFalse(old.is_new_record())
-
- def test_delete_instance(self):
- """create, then destroy, then make sure loads a new record"""
- instance = self.create_instance()
- instance.destroy()
- newinst = model.Instance('i-test')
- self.assertTrue(newinst.is_new_record())
-
- def test_instance_added_to_set(self):
- """create, then check that it is listed in global set"""
- instance = self.create_instance()
- found = False
- for x in model.InstanceDirectory().all:
- if x.identifier == 'i-test':
- found = True
- self.assert_(found)
-
- def test_instance_associates_project(self):
- """create, then check that it is listed for the project"""
- instance = self.create_instance()
- found = False
- for x in model.InstanceDirectory().by_project(instance.project):
- if x.identifier == 'i-test':
- found = True
- self.assert_(found)
-
- def test_instance_associates_ip(self):
- """create, then check that it is listed for the ip"""
- instance = self.create_instance()
- found = False
- x = model.InstanceDirectory().by_ip(instance['private_dns_name'])
- self.assertEqual(x.identifier, 'i-test')
-
- def test_instance_associates_node(self):
- """create, then check that it is listed for the node_name"""
- instance = self.create_instance()
- found = False
- for x in model.InstanceDirectory().by_node(FLAGS.node_name):
- if x.identifier == 'i-test':
- found = True
- self.assertFalse(found)
- instance['node_name'] = 'test_node'
- instance.save()
- for x in model.InstanceDirectory().by_node('test_node'):
- if x.identifier == 'i-test':
- found = True
- self.assert_(found)
-
-
- def test_host_class_finds_hosts(self):
- host = self.create_host()
- self.assertEqual('testhost', model.Host.lookup('testhost').identifier)
-
- def test_host_class_doesnt_find_missing_hosts(self):
- rv = model.Host.lookup('woahnelly')
- self.assertEqual(None, rv)
-
- def test_create_host(self):
- """store with create_host, then test that a load finds it"""
- host = self.create_host()
- old = model.Host(host.identifier)
- self.assertFalse(old.is_new_record())
-
- def test_delete_host(self):
- """create, then destroy, then make sure loads a new record"""
- instance = self.create_host()
- instance.destroy()
- newinst = model.Host('testhost')
- self.assertTrue(newinst.is_new_record())
-
- def test_host_added_to_set(self):
- """create, then check that it is included in list"""
- instance = self.create_host()
- found = False
- for x in model.Host.all():
- if x.identifier == 'testhost':
- found = True
- self.assert_(found)
-
- def test_create_daemon_two_args(self):
- """create a daemon with two arguments"""
- d = self.create_daemon()
- d = model.Daemon('testhost', 'nova-testdaemon')
- self.assertFalse(d.is_new_record())
-
- def test_create_daemon_single_arg(self):
- """Create a daemon using the combined host:bin format"""
- d = model.Daemon("testhost:nova-testdaemon")
- d.save()
- d = model.Daemon('testhost:nova-testdaemon')
- self.assertFalse(d.is_new_record())
-
- def test_equality_of_daemon_single_and_double_args(self):
- """Create a daemon using the combined host:bin arg, find with 2"""
- d = model.Daemon("testhost:nova-testdaemon")
- d.save()
- d = model.Daemon('testhost', 'nova-testdaemon')
- self.assertFalse(d.is_new_record())
-
- def test_equality_daemon_of_double_and_single_args(self):
- """Create a daemon using the combined host:bin arg, find with 2"""
- d = self.create_daemon()
- d = model.Daemon('testhost:nova-testdaemon')
- self.assertFalse(d.is_new_record())
-
- def test_delete_daemon(self):
- """create, then destroy, then make sure loads a new record"""
- instance = self.create_daemon()
- instance.destroy()
- newinst = model.Daemon('testhost', 'nova-testdaemon')
- self.assertTrue(newinst.is_new_record())
-
- def test_daemon_heartbeat(self):
- """Create a daemon, sleep, heartbeat, check for update"""
- d = self.create_daemon()
- ts = d['updated_at']
- time.sleep(2)
- d.heartbeat()
- d2 = model.Daemon('testhost', 'nova-testdaemon')
- ts2 = d2['updated_at']
- self.assert_(ts2 > ts)
-
- def test_daemon_added_to_set(self):
- """create, then check that it is included in list"""
- instance = self.create_daemon()
- found = False
- for x in model.Daemon.all():
- if x.identifier == 'testhost:nova-testdaemon':
- found = True
- self.assert_(found)
-
- def test_daemon_associates_host(self):
- """create, then check that it is listed for the host"""
- instance = self.create_daemon()
- found = False
- for x in model.Daemon.by_host('testhost'):
- if x.identifier == 'testhost:nova-testdaemon':
- found = True
- self.assertTrue(found)
-
- def test_create_session_token(self):
- """create"""
- d = self.create_session_token()
- d = model.SessionToken(d.token)
- self.assertFalse(d.is_new_record())
-
- def test_delete_session_token(self):
- """create, then destroy, then make sure loads a new record"""
- instance = self.create_session_token()
- instance.destroy()
- newinst = model.SessionToken(instance.token)
- self.assertTrue(newinst.is_new_record())
-
- def test_session_token_added_to_set(self):
- """create, then check that it is included in list"""
- instance = self.create_session_token()
- found = False
- for x in model.SessionToken.all():
- if x.identifier == instance.token:
- found = True
- self.assert_(found)
-
- def test_session_token_associates_user(self):
- """create, then check that it is listed for the user"""
- instance = self.create_session_token()
- found = False
- for x in model.SessionToken.associated_to('user', 'testuser'):
- if x.identifier == instance.identifier:
- found = True
- self.assertTrue(found)
-
- def test_session_token_generation(self):
- instance = model.SessionToken.generate('username', 'TokenType')
- self.assertFalse(instance.is_new_record())
-
- def test_find_generated_session_token(self):
- instance = model.SessionToken.generate('username', 'TokenType')
- found = model.SessionToken.lookup(instance.identifier)
- self.assert_(found)
-
- def test_update_session_token_expiry(self):
- instance = model.SessionToken('tk12341234')
- oldtime = datetime.utcnow()
- instance['expiry'] = oldtime.strftime(utils.TIME_FORMAT)
- instance.update_expiry()
- expiry = utils.parse_isotime(instance['expiry'])
- self.assert_(expiry > datetime.utcnow())
-
- def test_session_token_lookup_when_expired(self):
- instance = model.SessionToken.generate("testuser")
- instance['expiry'] = datetime.utcnow().strftime(utils.TIME_FORMAT)
- instance.save()
- inst = model.SessionToken.lookup(instance.identifier)
- self.assertFalse(inst)
-
- def test_session_token_lookup_when_not_expired(self):
- instance = model.SessionToken.generate("testuser")
- inst = model.SessionToken.lookup(instance.identifier)
- self.assert_(inst)
-
- def test_session_token_is_expired_when_expired(self):
- instance = model.SessionToken.generate("testuser")
- instance['expiry'] = datetime.utcnow().strftime(utils.TIME_FORMAT)
- self.assert_(instance.is_expired())
-
- def test_session_token_is_expired_when_not_expired(self):
- instance = model.SessionToken.generate("testuser")
- self.assertFalse(instance.is_expired())
-
- def test_session_token_ttl(self):
- instance = model.SessionToken.generate("testuser")
- now = datetime.utcnow()
- delta = timedelta(hours=1)
- instance['expiry'] = (now + delta).strftime(utils.TIME_FORMAT)
- # give 5 seconds of fuzziness
- self.assert_(abs(instance.ttl() - FLAGS.auth_token_ttl) < 5)
diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py
index 34b68f1ed..3afb4d19e 100644
--- a/nova/tests/network_unittest.py
+++ b/nova/tests/network_unittest.py
@@ -22,14 +22,13 @@ import IPy
import os
import logging
+from nova import db
+from nova import exception
from nova import flags
from nova import test
from nova import utils
from nova.auth import manager
-from nova.network import model
-from nova.network import service
-from nova.network import vpn
-from nova.network.exception import NoMoreAddresses
+from nova.api.ec2 import context
FLAGS = flags.FLAGS
@@ -41,169 +40,200 @@ class NetworkTestCase(test.TrialTestCase):
# NOTE(vish): if you change these flags, make sure to change the
# flags in the corresponding section in nova-dhcpbridge
self.flags(connection_type='fake',
- fake_storage=True,
fake_network=True,
auth_driver='nova.auth.ldapdriver.FakeLdapDriver',
- network_size=32)
+ network_size=16,
+ num_networks=5)
logging.getLogger().setLevel(logging.DEBUG)
self.manager = manager.AuthManager()
self.user = self.manager.create_user('netuser', 'netuser', 'netuser')
self.projects = []
- self.projects.append(self.manager.create_project('netuser',
- 'netuser',
- 'netuser'))
- for i in range(0, 6):
+ self.network = utils.import_object(FLAGS.network_manager)
+ self.context = context.APIRequestContext(project=None, user=self.user)
+ for i in range(5):
name = 'project%s' % i
- self.projects.append(self.manager.create_project(name,
- 'netuser',
- name))
- vpn.NetworkData.create(self.projects[i].id)
- self.service = service.VlanNetworkService()
+ project = self.manager.create_project(name, 'netuser', name)
+ self.projects.append(project)
+ # create the necessary network data for the project
+ user_context = context.APIRequestContext(project=self.projects[i],
+ user=self.user)
+ network_ref = self.network.get_network(user_context)
+ self.network.set_network_host(context.get_admin_context(),
+ network_ref['id'])
+ instance_ref = self._create_instance(0)
+ self.instance_id = instance_ref['id']
+ instance_ref = self._create_instance(1)
+ self.instance2_id = instance_ref['id']
def tearDown(self): # pylint: disable-msg=C0103
super(NetworkTestCase, self).tearDown()
+ # TODO(termie): this should really be instantiating clean datastores
+ # in between runs, one failure kills all the tests
+ db.instance_destroy(None, self.instance_id)
+ db.instance_destroy(None, self.instance2_id)
for project in self.projects:
self.manager.delete_project(project)
self.manager.delete_user(self.user)
- def test_public_network_allocation(self):
+ def _create_instance(self, project_num, mac=None):
+ if not mac:
+ mac = utils.generate_mac()
+ project = self.projects[project_num]
+ self.context.project = project
+ return db.instance_create(self.context,
+ {'project_id': project.id,
+ 'mac_address': mac})
+
+ def _create_address(self, project_num, instance_id=None):
+ """Create an address in given project num"""
+ if instance_id is None:
+ instance_id = self.instance_id
+ self.context.project = self.projects[project_num]
+ return self.network.allocate_fixed_ip(self.context, instance_id)
+
+ def _deallocate_address(self, project_num, address):
+ self.context.project = self.projects[project_num]
+ self.network.deallocate_fixed_ip(self.context, address)
+
+
+ def test_public_network_association(self):
"""Makes sure that we can allocaate a public ip"""
- pubnet = IPy.IP(flags.FLAGS.public_range)
- address = self.service.allocate_elastic_ip(self.user.id,
- self.projects[0].id)
- self.assertTrue(IPy.IP(address) in pubnet)
+ # TODO(vish): better way of adding floating ips
+ self.context.project = self.projects[0]
+ pubnet = IPy.IP(flags.FLAGS.floating_range)
+ address = str(pubnet[0])
+ try:
+ db.floating_ip_get_by_address(None, address)
+ except exception.NotFound:
+ db.floating_ip_create(None, {'address': address,
+ 'host': FLAGS.host})
+ float_addr = self.network.allocate_floating_ip(self.context,
+ self.projects[0].id)
+ fix_addr = self._create_address(0)
+ lease_ip(fix_addr)
+ self.assertEqual(float_addr, str(pubnet[0]))
+ self.network.associate_floating_ip(self.context, float_addr, fix_addr)
+ address = db.instance_get_floating_address(None, self.instance_id)
+ self.assertEqual(address, float_addr)
+ self.network.disassociate_floating_ip(self.context, float_addr)
+ address = db.instance_get_floating_address(None, self.instance_id)
+ self.assertEqual(address, None)
+ self.network.deallocate_floating_ip(self.context, float_addr)
+ self.network.deallocate_fixed_ip(self.context, fix_addr)
+ release_ip(fix_addr)
def test_allocate_deallocate_fixed_ip(self):
"""Makes sure that we can allocate and deallocate a fixed ip"""
- result = self.service.allocate_fixed_ip(
- self.user.id, self.projects[0].id)
- address = result['private_dns_name']
- mac = result['mac_address']
- net = model.get_project_network(self.projects[0].id, "default")
- self.assertEqual(True, is_in_project(address, self.projects[0].id))
- hostname = "test-host"
- issue_ip(mac, address, hostname, net.bridge_name)
- self.service.deallocate_fixed_ip(address)
+ address = self._create_address(0)
+ self.assertTrue(is_allocated_in_project(address, self.projects[0].id))
+ lease_ip(address)
+ self._deallocate_address(0, address)
# Doesn't go away until it's dhcp released
- self.assertEqual(True, is_in_project(address, self.projects[0].id))
+ self.assertTrue(is_allocated_in_project(address, self.projects[0].id))
- release_ip(mac, address, hostname, net.bridge_name)
- self.assertEqual(False, is_in_project(address, self.projects[0].id))
+ release_ip(address)
+ self.assertFalse(is_allocated_in_project(address, self.projects[0].id))
def test_side_effects(self):
"""Ensures allocating and releasing has no side effects"""
- hostname = "side-effect-host"
- result = self.service.allocate_fixed_ip(self.user.id,
- self.projects[0].id)
- mac = result['mac_address']
- address = result['private_dns_name']
- result = self.service.allocate_fixed_ip(self.user,
- self.projects[1].id)
- secondmac = result['mac_address']
- secondaddress = result['private_dns_name']
-
- net = model.get_project_network(self.projects[0].id, "default")
- secondnet = model.get_project_network(self.projects[1].id, "default")
-
- self.assertEqual(True, is_in_project(address, self.projects[0].id))
- self.assertEqual(True, is_in_project(secondaddress,
- self.projects[1].id))
- self.assertEqual(False, is_in_project(address, self.projects[1].id))
+ address = self._create_address(0)
+ address2 = self._create_address(1, self.instance2_id)
+
+ self.assertTrue(is_allocated_in_project(address, self.projects[0].id))
+ self.assertTrue(is_allocated_in_project(address2, self.projects[1].id))
+ self.assertFalse(is_allocated_in_project(address, self.projects[1].id))
# Addresses are allocated before they're issued
- issue_ip(mac, address, hostname, net.bridge_name)
- issue_ip(secondmac, secondaddress, hostname, secondnet.bridge_name)
+ lease_ip(address)
+ lease_ip(address2)
- self.service.deallocate_fixed_ip(address)
- release_ip(mac, address, hostname, net.bridge_name)
- self.assertEqual(False, is_in_project(address, self.projects[0].id))
+ self._deallocate_address(0, address)
+ release_ip(address)
+ self.assertFalse(is_allocated_in_project(address, self.projects[0].id))
# First address release shouldn't affect the second
- self.assertEqual(True, is_in_project(secondaddress,
- self.projects[1].id))
+ self.assertTrue(is_allocated_in_project(address2, self.projects[1].id))
- self.service.deallocate_fixed_ip(secondaddress)
- release_ip(secondmac, secondaddress, hostname, secondnet.bridge_name)
- self.assertEqual(False, is_in_project(secondaddress,
- self.projects[1].id))
+ self._deallocate_address(1, address2)
+ release_ip(address2)
+ self.assertFalse(is_allocated_in_project(address2,
+ self.projects[1].id))
def test_subnet_edge(self):
"""Makes sure that private ips don't overlap"""
- result = self.service.allocate_fixed_ip(self.user.id,
- self.projects[0].id)
- firstaddress = result['private_dns_name']
- hostname = "toomany-hosts"
+ first = self._create_address(0)
+ lease_ip(first)
+ instance_ids = []
for i in range(1, 5):
- project_id = self.projects[i].id
- result = self.service.allocate_fixed_ip(
- self.user, project_id)
- mac = result['mac_address']
- address = result['private_dns_name']
- result = self.service.allocate_fixed_ip(
- self.user, project_id)
- mac2 = result['mac_address']
- address2 = result['private_dns_name']
- result = self.service.allocate_fixed_ip(
- self.user, project_id)
- mac3 = result['mac_address']
- address3 = result['private_dns_name']
- net = model.get_project_network(project_id, "default")
- issue_ip(mac, address, hostname, net.bridge_name)
- issue_ip(mac2, address2, hostname, net.bridge_name)
- issue_ip(mac3, address3, hostname, net.bridge_name)
- self.assertEqual(False, is_in_project(address,
- self.projects[0].id))
- self.assertEqual(False, is_in_project(address2,
- self.projects[0].id))
- self.assertEqual(False, is_in_project(address3,
- self.projects[0].id))
- self.service.deallocate_fixed_ip(address)
- self.service.deallocate_fixed_ip(address2)
- self.service.deallocate_fixed_ip(address3)
- release_ip(mac, address, hostname, net.bridge_name)
- release_ip(mac2, address2, hostname, net.bridge_name)
- release_ip(mac3, address3, hostname, net.bridge_name)
- net = model.get_project_network(self.projects[0].id, "default")
- self.service.deallocate_fixed_ip(firstaddress)
+ instance_ref = self._create_instance(i, mac=utils.generate_mac())
+ instance_ids.append(instance_ref['id'])
+ address = self._create_address(i, instance_ref['id'])
+ instance_ref = self._create_instance(i, mac=utils.generate_mac())
+ instance_ids.append(instance_ref['id'])
+ address2 = self._create_address(i, instance_ref['id'])
+ instance_ref = self._create_instance(i, mac=utils.generate_mac())
+ instance_ids.append(instance_ref['id'])
+ address3 = self._create_address(i, instance_ref['id'])
+ lease_ip(address)
+ lease_ip(address2)
+ lease_ip(address3)
+ self.context.project = self.projects[i]
+ self.assertFalse(is_allocated_in_project(address,
+ self.projects[0].id))
+ self.assertFalse(is_allocated_in_project(address2,
+ self.projects[0].id))
+ self.assertFalse(is_allocated_in_project(address3,
+ self.projects[0].id))
+ self.network.deallocate_fixed_ip(self.context, address)
+ self.network.deallocate_fixed_ip(self.context, address2)
+ self.network.deallocate_fixed_ip(self.context, address3)
+ release_ip(address)
+ release_ip(address2)
+ release_ip(address3)
+ for instance_id in instance_ids:
+ db.instance_destroy(None, instance_id)
+ self.context.project = self.projects[0]
+ self.network.deallocate_fixed_ip(self.context, first)
+ self._deallocate_address(0, first)
+ release_ip(first)
def test_vpn_ip_and_port_looks_valid(self):
"""Ensure the vpn ip and port are reasonable"""
self.assert_(self.projects[0].vpn_ip)
- self.assert_(self.projects[0].vpn_port >= FLAGS.vpn_start_port)
- self.assert_(self.projects[0].vpn_port <= FLAGS.vpn_end_port)
-
- def test_too_many_vpns(self):
- """Ensure error is raised if we run out of vpn ports"""
- vpns = []
- for i in xrange(vpn.NetworkData.num_ports_for_ip(FLAGS.vpn_ip)):
- vpns.append(vpn.NetworkData.create("vpnuser%s" % i))
- self.assertRaises(vpn.NoMorePorts, vpn.NetworkData.create, "boom")
- for network_datum in vpns:
- network_datum.destroy()
+ self.assert_(self.projects[0].vpn_port >= FLAGS.vpn_start)
+ self.assert_(self.projects[0].vpn_port <= FLAGS.vpn_start +
+ FLAGS.num_networks)
+
+ def test_too_many_networks(self):
+ """Ensure error is raised if we run out of networks"""
+ projects = []
+ networks_left = FLAGS.num_networks - db.network_count(None)
+ for i in range(networks_left):
+ project = self.manager.create_project('many%s' % i, self.user)
+ projects.append(project)
+ db.project_get_network(None, project.id)
+ project = self.manager.create_project('last', self.user)
+ projects.append(project)
+ self.assertRaises(db.NoMoreNetworks,
+ db.project_get_network,
+ None,
+ project.id)
+ for project in projects:
+ self.manager.delete_project(project)
def test_ips_are_reused(self):
"""Makes sure that ip addresses that are deallocated get reused"""
- result = self.service.allocate_fixed_ip(
- self.user.id, self.projects[0].id)
- mac = result['mac_address']
- address = result['private_dns_name']
-
- hostname = "reuse-host"
- net = model.get_project_network(self.projects[0].id, "default")
-
- issue_ip(mac, address, hostname, net.bridge_name)
- self.service.deallocate_fixed_ip(address)
- release_ip(mac, address, hostname, net.bridge_name)
-
- result = self.service.allocate_fixed_ip(
- self.user, self.projects[0].id)
- secondmac = result['mac_address']
- secondaddress = result['private_dns_name']
- self.assertEqual(address, secondaddress)
- issue_ip(secondmac, secondaddress, hostname, net.bridge_name)
- self.service.deallocate_fixed_ip(secondaddress)
- release_ip(secondmac, secondaddress, hostname, net.bridge_name)
+ address = self._create_address(0)
+ lease_ip(address)
+ self.network.deallocate_fixed_ip(self.context, address)
+ release_ip(address)
+
+ address2 = self._create_address(0)
+ self.assertEqual(address, address2)
+ lease_ip(address)
+ self.network.deallocate_fixed_ip(self.context, address2)
+ release_ip(address)
def test_available_ips(self):
"""Make sure the number of available ips for the network is correct
@@ -216,44 +246,51 @@ class NetworkTestCase(test.TrialTestCase):
There are ips reserved at the bottom and top of the range.
services (network, gateway, CloudPipe, broadcast)
"""
- net = model.get_project_network(self.projects[0].id, "default")
- num_preallocated_ips = len(net.assigned)
+ network = db.project_get_network(None, self.projects[0].id)
net_size = flags.FLAGS.network_size
- num_available_ips = net_size - (net.num_bottom_reserved_ips +
- num_preallocated_ips +
- net.num_top_reserved_ips)
- self.assertEqual(num_available_ips, len(list(net.available)))
+ total_ips = (db.network_count_available_ips(None, network['id']) +
+ db.network_count_reserved_ips(None, network['id']) +
+ db.network_count_allocated_ips(None, network['id']))
+ self.assertEqual(total_ips, net_size)
def test_too_many_addresses(self):
"""Test for a NoMoreAddresses exception when all fixed ips are used.
"""
- net = model.get_project_network(self.projects[0].id, "default")
-
- hostname = "toomany-hosts"
- macs = {}
- addresses = {}
- # Number of availaible ips is len of the available list
- num_available_ips = len(list(net.available))
+ network = db.project_get_network(None, self.projects[0].id)
+ num_available_ips = db.network_count_available_ips(None,
+ network['id'])
+ addresses = []
+ instance_ids = []
for i in range(num_available_ips):
- result = self.service.allocate_fixed_ip(self.user.id,
- self.projects[0].id)
- macs[i] = result['mac_address']
- addresses[i] = result['private_dns_name']
- issue_ip(macs[i], addresses[i], hostname, net.bridge_name)
-
- self.assertEqual(len(list(net.available)), 0)
- self.assertRaises(NoMoreAddresses, self.service.allocate_fixed_ip,
- self.user.id, self.projects[0].id)
+ instance_ref = self._create_instance(0)
+ instance_ids.append(instance_ref['id'])
+ address = self._create_address(0, instance_ref['id'])
+ addresses.append(address)
+ lease_ip(address)
+
+ self.assertEqual(db.network_count_available_ips(None,
+ network['id']), 0)
+ self.assertRaises(db.NoMoreAddresses,
+ self.network.allocate_fixed_ip,
+ self.context,
+ 'foo')
- for i in range(len(addresses)):
- self.service.deallocate_fixed_ip(addresses[i])
- release_ip(macs[i], addresses[i], hostname, net.bridge_name)
- self.assertEqual(len(list(net.available)), num_available_ips)
+ for i in range(num_available_ips):
+ self.network.deallocate_fixed_ip(self.context, addresses[i])
+ release_ip(addresses[i])
+ db.instance_destroy(None, instance_ids[i])
+ self.assertEqual(db.network_count_available_ips(None,
+ network['id']),
+ num_available_ips)
-def is_in_project(address, project_id):
+def is_allocated_in_project(address, project_id):
"""Returns true if address is in specified project"""
- return address in model.get_project_network(project_id).assigned
+ project_net = db.project_get_network(None, project_id)
+ network = db.fixed_ip_get_network(None, address)
+ instance = db.fixed_ip_get_instance(None, address)
+ # instance exists until release
+ return instance is not None and network['id'] == project_net['id']
def binpath(script):
@@ -261,22 +298,28 @@ def binpath(script):
return os.path.abspath(os.path.join(__file__, "../../../bin", script))
-def issue_ip(mac, private_ip, hostname, interface):
+def lease_ip(private_ip):
"""Run add command on dhcpbridge"""
- cmd = "%s add %s %s %s" % (binpath('nova-dhcpbridge'),
- mac, private_ip, hostname)
- env = {'DNSMASQ_INTERFACE': interface,
+ network_ref = db.fixed_ip_get_network(None, private_ip)
+ instance_ref = db.fixed_ip_get_instance(None, private_ip)
+ cmd = "%s add %s %s fake" % (binpath('nova-dhcpbridge'),
+ instance_ref['mac_address'],
+ private_ip)
+ env = {'DNSMASQ_INTERFACE': network_ref['bridge'],
'TESTING': '1',
'FLAGFILE': FLAGS.dhcpbridge_flagfile}
(out, err) = utils.execute(cmd, addl_env=env)
logging.debug("ISSUE_IP: %s, %s ", out, err)
-def release_ip(mac, private_ip, hostname, interface):
+def release_ip(private_ip):
"""Run del command on dhcpbridge"""
- cmd = "%s del %s %s %s" % (binpath('nova-dhcpbridge'),
- mac, private_ip, hostname)
- env = {'DNSMASQ_INTERFACE': interface,
+ network_ref = db.fixed_ip_get_network(None, private_ip)
+ instance_ref = db.fixed_ip_get_instance(None, private_ip)
+ cmd = "%s del %s %s fake" % (binpath('nova-dhcpbridge'),
+ instance_ref['mac_address'],
+ private_ip)
+ env = {'DNSMASQ_INTERFACE': network_ref['bridge'],
'TESTING': '1',
'FLAGFILE': FLAGS.dhcpbridge_flagfile}
(out, err) = utils.execute(cmd, addl_env=env)
diff --git a/nova/tests/objectstore_unittest.py b/nova/tests/objectstore_unittest.py
index dece4b5d5..872f1ab23 100644
--- a/nova/tests/objectstore_unittest.py
+++ b/nova/tests/objectstore_unittest.py
@@ -53,7 +53,7 @@ os.makedirs(os.path.join(OSS_TEMPDIR, 'images'))
os.makedirs(os.path.join(OSS_TEMPDIR, 'buckets'))
-class ObjectStoreTestCase(test.BaseTestCase):
+class ObjectStoreTestCase(test.TrialTestCase):
"""Test objectstore API directly."""
def setUp(self): # pylint: disable-msg=C0103
@@ -133,13 +133,22 @@ class ObjectStoreTestCase(test.BaseTestCase):
self.assertRaises(NotFound, objectstore.bucket.Bucket, 'new_bucket')
def test_images(self):
+ self.do_test_images('1mb.manifest.xml', True,
+ 'image_bucket1', 'i-testing1')
+
+ def test_images_no_kernel_or_ramdisk(self):
+ self.do_test_images('1mb.no_kernel_or_ramdisk.manifest.xml',
+ False, 'image_bucket2', 'i-testing2')
+
+ def do_test_images(self, manifest_file, expect_kernel_and_ramdisk,
+ image_bucket, image_name):
"Test the image API."
self.context.user = self.auth_manager.get_user('user1')
self.context.project = self.auth_manager.get_project('proj1')
# create a bucket for our bundle
- objectstore.bucket.Bucket.create('image_bucket', self.context)
- bucket = objectstore.bucket.Bucket('image_bucket')
+ objectstore.bucket.Bucket.create(image_bucket, self.context)
+ bucket = objectstore.bucket.Bucket(image_bucket)
# upload an image manifest/parts
bundle_path = os.path.join(os.path.dirname(__file__), 'bundle')
@@ -147,23 +156,39 @@ class ObjectStoreTestCase(test.BaseTestCase):
bucket[os.path.basename(path)] = open(path, 'rb').read()
# register an image
- image.Image.register_aws_image('i-testing',
- 'image_bucket/1mb.manifest.xml',
+ image.Image.register_aws_image(image_name,
+ '%s/%s' % (image_bucket, manifest_file),
self.context)
# verify image
- my_img = image.Image('i-testing')
+ my_img = image.Image(image_name)
result_image_file = os.path.join(my_img.path, 'image')
self.assertEqual(os.stat(result_image_file).st_size, 1048576)
sha = hashlib.sha1(open(result_image_file).read()).hexdigest()
self.assertEqual(sha, '3b71f43ff30f4b15b5cd85dd9e95ebc7e84eb5a3')
+ if expect_kernel_and_ramdisk:
+ # Verify the default kernel and ramdisk are set
+ self.assertEqual(my_img.metadata['kernelId'], 'aki-test')
+ self.assertEqual(my_img.metadata['ramdiskId'], 'ari-test')
+ else:
+ # Verify that the default kernel and ramdisk (the one from FLAGS)
+ # doesn't get embedded in the metadata
+ self.assertFalse('kernelId' in my_img.metadata)
+ self.assertFalse('ramdiskId' in my_img.metadata)
+
# verify image permissions
self.context.user = self.auth_manager.get_user('user2')
self.context.project = self.auth_manager.get_project('proj2')
self.assertFalse(my_img.is_authorized(self.context))
+ # change user-editable fields
+ my_img.update_user_editable_fields({'display_name': 'my cool image'})
+ self.assertEqual('my cool image', my_img.metadata['displayName'])
+ my_img.update_user_editable_fields({'display_name': ''})
+ self.assert_(not my_img.metadata['displayName'])
+
class TestHTTPChannel(http.HTTPChannel):
"""Dummy site required for twisted.web"""
@@ -185,7 +210,7 @@ class S3APITestCase(test.TrialTestCase):
"""Setup users, projects, and start a test server."""
super(S3APITestCase, self).setUp()
- FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver',
+ FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver'
FLAGS.buckets_path = os.path.join(OSS_TEMPDIR, 'buckets')
self.auth_manager = manager.AuthManager()
diff --git a/nova/tests/quota_unittest.py b/nova/tests/quota_unittest.py
new file mode 100644
index 000000000..370ccd506
--- /dev/null
+++ b/nova/tests/quota_unittest.py
@@ -0,0 +1,152 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import logging
+
+from nova import db
+from nova import exception
+from nova import flags
+from nova import quota
+from nova import test
+from nova import utils
+from nova.auth import manager
+from nova.api.ec2 import cloud
+from nova.api.ec2 import context
+
+
+FLAGS = flags.FLAGS
+
+
+class QuotaTestCase(test.TrialTestCase):
+ def setUp(self): # pylint: disable-msg=C0103
+ logging.getLogger().setLevel(logging.DEBUG)
+ super(QuotaTestCase, self).setUp()
+ self.flags(connection_type='fake',
+ quota_instances=2,
+ quota_cores=4,
+ quota_volumes=2,
+ quota_gigabytes=20,
+ quota_floating_ips=1)
+
+ self.cloud = cloud.CloudController()
+ self.manager = manager.AuthManager()
+ self.user = self.manager.create_user('admin', 'admin', 'admin', True)
+ self.project = self.manager.create_project('admin', 'admin', 'admin')
+ self.network = utils.import_object(FLAGS.network_manager)
+ self.context = context.APIRequestContext(project=self.project,
+ user=self.user)
+
+ def tearDown(self): # pylint: disable-msg=C0103
+ manager.AuthManager().delete_project(self.project)
+ manager.AuthManager().delete_user(self.user)
+ super(QuotaTestCase, self).tearDown()
+
+ def _create_instance(self, cores=2):
+ """Create a test instance"""
+ inst = {}
+ inst['image_id'] = 'ami-test'
+ inst['reservation_id'] = 'r-fakeres'
+ inst['user_id'] = self.user.id
+ inst['project_id'] = self.project.id
+ inst['instance_type'] = 'm1.large'
+ inst['vcpus'] = cores
+ inst['mac_address'] = utils.generate_mac()
+ return db.instance_create(self.context, inst)['id']
+
+ def _create_volume(self, size=10):
+ """Create a test volume"""
+ vol = {}
+ vol['user_id'] = self.user.id
+ vol['project_id'] = self.project.id
+ vol['size'] = size
+ return db.volume_create(self.context, vol)['id']
+
+ def test_quota_overrides(self):
+ """Make sure overriding a projects quotas works"""
+ num_instances = quota.allowed_instances(self.context, 100, 'm1.small')
+ self.assertEqual(num_instances, 2)
+ db.quota_create(self.context, {'project_id': self.project.id,
+ 'instances': 10})
+ num_instances = quota.allowed_instances(self.context, 100, 'm1.small')
+ self.assertEqual(num_instances, 4)
+ db.quota_update(self.context, self.project.id, {'cores': 100})
+ num_instances = quota.allowed_instances(self.context, 100, 'm1.small')
+ self.assertEqual(num_instances, 10)
+ db.quota_destroy(self.context, self.project.id)
+
+ def test_too_many_instances(self):
+ instance_ids = []
+ for i in range(FLAGS.quota_instances):
+ instance_id = self._create_instance()
+ instance_ids.append(instance_id)
+ self.assertRaises(cloud.QuotaError, self.cloud.run_instances,
+ self.context,
+ min_count=1,
+ max_count=1,
+ instance_type='m1.small')
+ for instance_id in instance_ids:
+ db.instance_destroy(self.context, instance_id)
+
+ def test_too_many_cores(self):
+ instance_ids = []
+ instance_id = self._create_instance(cores=4)
+ instance_ids.append(instance_id)
+ self.assertRaises(cloud.QuotaError, self.cloud.run_instances,
+ self.context,
+ min_count=1,
+ max_count=1,
+ instance_type='m1.small')
+ for instance_id in instance_ids:
+ db.instance_destroy(self.context, instance_id)
+
+ def test_too_many_volumes(self):
+ volume_ids = []
+ for i in range(FLAGS.quota_volumes):
+ volume_id = self._create_volume()
+ volume_ids.append(volume_id)
+ self.assertRaises(cloud.QuotaError, self.cloud.create_volume,
+ self.context,
+ size=10)
+ for volume_id in volume_ids:
+ db.volume_destroy(self.context, volume_id)
+
+ def test_too_many_gigabytes(self):
+ volume_ids = []
+ volume_id = self._create_volume(size=20)
+ volume_ids.append(volume_id)
+ self.assertRaises(cloud.QuotaError,
+ self.cloud.create_volume,
+ self.context,
+ size=10)
+ for volume_id in volume_ids:
+ db.volume_destroy(self.context, volume_id)
+
+ def test_too_many_addresses(self):
+ address = '192.168.0.100'
+ try:
+ db.floating_ip_get_by_address(None, address)
+ except exception.NotFound:
+ db.floating_ip_create(None, {'address': address,
+ 'host': FLAGS.host})
+ float_addr = self.network.allocate_floating_ip(self.context,
+ self.project.id)
+ # NOTE(vish): This assert never fails. When cloud attempts to
+ # make an rpc.call, the test just finishes with OK. It
+ # appears to be something in the magic inline callbacks
+ # that is breaking.
+ self.assertRaises(cloud.QuotaError, self.cloud.allocate_address, self.context)
diff --git a/nova/tests/real_flags.py b/nova/tests/real_flags.py
index 121f4eb41..71da04992 100644
--- a/nova/tests/real_flags.py
+++ b/nova/tests/real_flags.py
@@ -21,7 +21,6 @@ from nova import flags
FLAGS = flags.FLAGS
FLAGS.connection_type = 'libvirt'
-FLAGS.fake_storage = False
FLAGS.fake_rabbit = False
FLAGS.fake_network = False
FLAGS.verbose = False
diff --git a/nova/tests/rpc_unittest.py b/nova/tests/rpc_unittest.py
index e12a28fbc..9652841f2 100644
--- a/nova/tests/rpc_unittest.py
+++ b/nova/tests/rpc_unittest.py
@@ -30,7 +30,7 @@ from nova import test
FLAGS = flags.FLAGS
-class RpcTestCase(test.BaseTestCase):
+class RpcTestCase(test.TrialTestCase):
"""Test cases for rpc"""
def setUp(self): # pylint: disable-msg=C0103
super(RpcTestCase, self).setUp()
@@ -39,14 +39,13 @@ class RpcTestCase(test.BaseTestCase):
self.consumer = rpc.AdapterConsumer(connection=self.conn,
topic='test',
proxy=self.receiver)
-
- self.injected.append(self.consumer.attach_to_tornado(self.ioloop))
+ self.consumer.attach_to_twisted()
def test_call_succeed(self):
"""Get a value through rpc call"""
value = 42
- result = yield rpc.call('test', {"method": "echo",
- "args": {"value": value}})
+ result = yield rpc.call_twisted('test', {"method": "echo",
+ "args": {"value": value}})
self.assertEqual(value, result)
def test_call_exception(self):
@@ -57,12 +56,12 @@ class RpcTestCase(test.BaseTestCase):
to an int in the test.
"""
value = 42
- self.assertFailure(rpc.call('test', {"method": "fail",
- "args": {"value": value}}),
+ self.assertFailure(rpc.call_twisted('test', {"method": "fail",
+ "args": {"value": value}}),
rpc.RemoteError)
try:
- yield rpc.call('test', {"method": "fail",
- "args": {"value": value}})
+ yield rpc.call_twisted('test', {"method": "fail",
+ "args": {"value": value}})
self.fail("should have thrown rpc.RemoteError")
except rpc.RemoteError as exc:
self.assertEqual(int(exc.value), value)
diff --git a/nova/tests/scheduler_unittest.py b/nova/tests/scheduler_unittest.py
new file mode 100644
index 000000000..80100fc2f
--- /dev/null
+++ b/nova/tests/scheduler_unittest.py
@@ -0,0 +1,242 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+"""
+Tests For Scheduler
+"""
+
+from nova import db
+from nova import flags
+from nova import service
+from nova import test
+from nova import rpc
+from nova import utils
+from nova.auth import manager as auth_manager
+from nova.scheduler import manager
+from nova.scheduler import driver
+
+
+FLAGS = flags.FLAGS
+flags.DECLARE('max_cores', 'nova.scheduler.simple')
+
+class TestDriver(driver.Scheduler):
+ """Scheduler Driver for Tests"""
+ def schedule(context, topic, *args, **kwargs):
+ return 'fallback_host'
+
+ def schedule_named_method(context, topic, num):
+ return 'named_host'
+
+class SchedulerTestCase(test.TrialTestCase):
+ """Test case for scheduler"""
+ def setUp(self): # pylint: disable=C0103
+ super(SchedulerTestCase, self).setUp()
+ self.flags(scheduler_driver='nova.tests.scheduler_unittest.TestDriver')
+
+ def test_fallback(self):
+ scheduler = manager.SchedulerManager()
+ self.mox.StubOutWithMock(rpc, 'cast', use_mock_anything=True)
+ rpc.cast('topic.fallback_host',
+ {'method': 'noexist',
+ 'args': {'context': None,
+ 'num': 7}})
+ self.mox.ReplayAll()
+ scheduler.noexist(None, 'topic', num=7)
+
+ def test_named_method(self):
+ scheduler = manager.SchedulerManager()
+ self.mox.StubOutWithMock(rpc, 'cast', use_mock_anything=True)
+ rpc.cast('topic.named_host',
+ {'method': 'named_method',
+ 'args': {'context': None,
+ 'num': 7}})
+ self.mox.ReplayAll()
+ scheduler.named_method(None, 'topic', num=7)
+
+
+class SimpleDriverTestCase(test.TrialTestCase):
+ """Test case for simple driver"""
+ def setUp(self): # pylint: disable-msg=C0103
+ super(SimpleDriverTestCase, self).setUp()
+ self.flags(connection_type='fake',
+ max_cores=4,
+ max_gigabytes=4,
+ network_manager='nova.network.manager.FlatManager',
+ volume_driver='nova.volume.driver.FakeAOEDriver',
+ scheduler_driver='nova.scheduler.simple.SimpleScheduler')
+ self.scheduler = manager.SchedulerManager()
+ self.context = None
+ self.manager = auth_manager.AuthManager()
+ self.user = self.manager.create_user('fake', 'fake', 'fake')
+ self.project = self.manager.create_project('fake', 'fake', 'fake')
+ self.context = None
+
+ def tearDown(self): # pylint: disable-msg=C0103
+ self.manager.delete_user(self.user)
+ self.manager.delete_project(self.project)
+
+ def _create_instance(self):
+ """Create a test instance"""
+ inst = {}
+ inst['image_id'] = 'ami-test'
+ inst['reservation_id'] = 'r-fakeres'
+ inst['user_id'] = self.user.id
+ inst['project_id'] = self.project.id
+ inst['instance_type'] = 'm1.tiny'
+ inst['mac_address'] = utils.generate_mac()
+ inst['ami_launch_index'] = 0
+ inst['vcpus'] = 1
+ return db.instance_create(self.context, inst)['id']
+
+ def _create_volume(self):
+ """Create a test volume"""
+ vol = {}
+ vol['image_id'] = 'ami-test'
+ vol['reservation_id'] = 'r-fakeres'
+ vol['size'] = 1
+ return db.volume_create(self.context, vol)['id']
+
+ def test_hosts_are_up(self):
+ """Ensures driver can find the hosts that are up"""
+ # NOTE(vish): constructing service without create method
+ # because we are going to use it without queue
+ compute1 = service.Service('host1',
+ 'nova-compute',
+ 'compute',
+ FLAGS.compute_manager)
+ compute1.startService()
+ compute2 = service.Service('host2',
+ 'nova-compute',
+ 'compute',
+ FLAGS.compute_manager)
+ compute2.startService()
+ hosts = self.scheduler.driver.hosts_up(self.context, 'compute')
+ self.assertEqual(len(hosts), 2)
+ compute1.kill()
+ compute2.kill()
+
+ def test_least_busy_host_gets_instance(self):
+ """Ensures the host with less cores gets the next one"""
+ compute1 = service.Service('host1',
+ 'nova-compute',
+ 'compute',
+ FLAGS.compute_manager)
+ compute1.startService()
+ compute2 = service.Service('host2',
+ 'nova-compute',
+ 'compute',
+ FLAGS.compute_manager)
+ compute2.startService()
+ instance_id1 = self._create_instance()
+ compute1.run_instance(self.context, instance_id1)
+ instance_id2 = self._create_instance()
+ host = self.scheduler.driver.schedule_run_instance(self.context,
+ instance_id2)
+ self.assertEqual(host, 'host2')
+ compute1.terminate_instance(self.context, instance_id1)
+ db.instance_destroy(self.context, instance_id2)
+ compute1.kill()
+ compute2.kill()
+
+ def test_too_many_cores(self):
+ """Ensures we don't go over max cores"""
+ compute1 = service.Service('host1',
+ 'nova-compute',
+ 'compute',
+ FLAGS.compute_manager)
+ compute1.startService()
+ compute2 = service.Service('host2',
+ 'nova-compute',
+ 'compute',
+ FLAGS.compute_manager)
+ compute2.startService()
+ instance_ids1 = []
+ instance_ids2 = []
+ for index in xrange(FLAGS.max_cores):
+ instance_id = self._create_instance()
+ compute1.run_instance(self.context, instance_id)
+ instance_ids1.append(instance_id)
+ instance_id = self._create_instance()
+ compute2.run_instance(self.context, instance_id)
+ instance_ids2.append(instance_id)
+ instance_id = self._create_instance()
+ self.assertRaises(driver.NoValidHost,
+ self.scheduler.driver.schedule_run_instance,
+ self.context,
+ instance_id)
+ for instance_id in instance_ids1:
+ compute1.terminate_instance(self.context, instance_id)
+ for instance_id in instance_ids2:
+ compute2.terminate_instance(self.context, instance_id)
+ compute1.kill()
+ compute2.kill()
+
+ def test_least_busy_host_gets_volume(self):
+ """Ensures the host with less gigabytes gets the next one"""
+ volume1 = service.Service('host1',
+ 'nova-volume',
+ 'volume',
+ FLAGS.volume_manager)
+ volume1.startService()
+ volume2 = service.Service('host2',
+ 'nova-volume',
+ 'volume',
+ FLAGS.volume_manager)
+ volume2.startService()
+ volume_id1 = self._create_volume()
+ volume1.create_volume(self.context, volume_id1)
+ volume_id2 = self._create_volume()
+ host = self.scheduler.driver.schedule_create_volume(self.context,
+ volume_id2)
+ self.assertEqual(host, 'host2')
+ volume1.delete_volume(self.context, volume_id1)
+ db.volume_destroy(self.context, volume_id2)
+ volume1.kill()
+ volume2.kill()
+
+ def test_too_many_gigabytes(self):
+ """Ensures we don't go over max gigabytes"""
+ volume1 = service.Service('host1',
+ 'nova-volume',
+ 'volume',
+ FLAGS.volume_manager)
+ volume1.startService()
+ volume2 = service.Service('host2',
+ 'nova-volume',
+ 'volume',
+ FLAGS.volume_manager)
+ volume2.startService()
+ volume_ids1 = []
+ volume_ids2 = []
+ for index in xrange(FLAGS.max_gigabytes):
+ volume_id = self._create_volume()
+ volume1.create_volume(self.context, volume_id)
+ volume_ids1.append(volume_id)
+ volume_id = self._create_volume()
+ volume2.create_volume(self.context, volume_id)
+ volume_ids2.append(volume_id)
+ volume_id = self._create_volume()
+ self.assertRaises(driver.NoValidHost,
+ self.scheduler.driver.schedule_create_volume,
+ self.context,
+ volume_id)
+ for volume_id in volume_ids1:
+ volume1.delete_volume(self.context, volume_id)
+ for volume_id in volume_ids2:
+ volume2.delete_volume(self.context, volume_id)
+ volume1.kill()
+ volume2.kill()
diff --git a/nova/tests/service_unittest.py b/nova/tests/service_unittest.py
new file mode 100644
index 000000000..6afeec377
--- /dev/null
+++ b/nova/tests/service_unittest.py
@@ -0,0 +1,190 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Unit Tests for remote procedure calls using queue
+"""
+
+import mox
+
+from twisted.application.app import startApplication
+
+from nova import exception
+from nova import flags
+from nova import rpc
+from nova import test
+from nova import service
+from nova import manager
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string("fake_manager", "nova.tests.service_unittest.FakeManager",
+ "Manager for testing")
+
+
+class FakeManager(manager.Manager):
+ """Fake manager for tests"""
+ pass
+
+
+class ServiceTestCase(test.BaseTestCase):
+ """Test cases for rpc"""
+
+ def setUp(self): # pylint: disable=C0103
+ super(ServiceTestCase, self).setUp()
+ self.mox.StubOutWithMock(service, 'db')
+
+ def test_create(self):
+ host = 'foo'
+ binary = 'nova-fake'
+ topic = 'fake'
+ self.mox.StubOutWithMock(rpc,
+ 'AdapterConsumer',
+ use_mock_anything=True)
+ self.mox.StubOutWithMock(
+ service.task, 'LoopingCall', use_mock_anything=True)
+ rpc.AdapterConsumer(connection=mox.IgnoreArg(),
+ topic=topic,
+ proxy=mox.IsA(service.Service)).AndReturn(
+ rpc.AdapterConsumer)
+
+ rpc.AdapterConsumer(connection=mox.IgnoreArg(),
+ topic='%s.%s' % (topic, host),
+ proxy=mox.IsA(service.Service)).AndReturn(
+ rpc.AdapterConsumer)
+
+ rpc.AdapterConsumer.attach_to_twisted()
+ rpc.AdapterConsumer.attach_to_twisted()
+
+ # Stub out looping call a bit needlessly since we don't have an easy
+ # way to cancel it (yet) when the tests finishes
+ service.task.LoopingCall(mox.IgnoreArg()).AndReturn(
+ service.task.LoopingCall)
+ service.task.LoopingCall.start(interval=mox.IgnoreArg(),
+ now=mox.IgnoreArg())
+ service.task.LoopingCall(mox.IgnoreArg()).AndReturn(
+ service.task.LoopingCall)
+ service.task.LoopingCall.start(interval=mox.IgnoreArg(),
+ now=mox.IgnoreArg())
+
+ service_create = {'host': host,
+ 'binary': binary,
+ 'topic': topic,
+ 'report_count': 0}
+ service_ref = {'host': host,
+ 'binary': binary,
+ 'report_count': 0,
+ 'id': 1}
+
+ service.db.service_get_by_args(None,
+ host,
+ binary).AndRaise(exception.NotFound())
+ service.db.service_create(None,
+ service_create).AndReturn(service_ref)
+ self.mox.ReplayAll()
+
+ app = service.Service.create(host=host, binary=binary)
+ startApplication(app, False)
+ self.assert_(app)
+
+ # We're testing sort of weird behavior in how report_state decides
+ # whether it is disconnected, it looks for a variable on itself called
+ # 'model_disconnected' and report_state doesn't really do much so this
+ # these are mostly just for coverage
+ def test_report_state(self):
+ host = 'foo'
+ binary = 'bar'
+ service_ref = {'host': host,
+ 'binary': binary,
+ 'report_count': 0,
+ 'id': 1}
+ service.db.__getattr__('report_state')
+ service.db.service_get_by_args(None,
+ host,
+ binary).AndReturn(service_ref)
+ service.db.service_update(None, service_ref['id'],
+ mox.ContainsKeyValue('report_count', 1))
+
+ self.mox.ReplayAll()
+ s = service.Service()
+ rv = yield s.report_state(host, binary)
+
+ def test_report_state_no_service(self):
+ host = 'foo'
+ binary = 'bar'
+ service_create = {'host': host,
+ 'binary': binary,
+ 'report_count': 0}
+ service_ref = {'host': host,
+ 'binary': binary,
+ 'report_count': 0,
+ 'id': 1}
+
+ service.db.__getattr__('report_state')
+ service.db.service_get_by_args(None,
+ host,
+ binary).AndRaise(exception.NotFound())
+ service.db.service_create(None,
+ service_create).AndReturn(service_ref)
+ service.db.service_get(None, service_ref['id']).AndReturn(service_ref)
+ service.db.service_update(None, service_ref['id'],
+ mox.ContainsKeyValue('report_count', 1))
+
+ self.mox.ReplayAll()
+ s = service.Service()
+ rv = yield s.report_state(host, binary)
+
+ def test_report_state_newly_disconnected(self):
+ host = 'foo'
+ binary = 'bar'
+ service_ref = {'host': host,
+ 'binary': binary,
+ 'report_count': 0,
+ 'id': 1}
+
+ service.db.__getattr__('report_state')
+ service.db.service_get_by_args(None,
+ host,
+ binary).AndRaise(Exception())
+
+ self.mox.ReplayAll()
+ s = service.Service()
+ rv = yield s.report_state(host, binary)
+
+ self.assert_(s.model_disconnected)
+
+ def test_report_state_newly_connected(self):
+ host = 'foo'
+ binary = 'bar'
+ service_ref = {'host': host,
+ 'binary': binary,
+ 'report_count': 0,
+ 'id': 1}
+
+ service.db.__getattr__('report_state')
+ service.db.service_get_by_args(None,
+ host,
+ binary).AndReturn(service_ref)
+ service.db.service_update(None, service_ref['id'],
+ mox.ContainsKeyValue('report_count', 1))
+
+ self.mox.ReplayAll()
+ s = service.Service()
+ s.model_disconnected = True
+ rv = yield s.report_state(host, binary)
+
+ self.assert_(not s.model_disconnected)
diff --git a/nova/tests/storage_unittest.py b/nova/tests/storage_unittest.py
deleted file mode 100644
index f400cd2fd..000000000
--- a/nova/tests/storage_unittest.py
+++ /dev/null
@@ -1,115 +0,0 @@
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2010 United States Government as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-import logging
-
-from nova import exception
-from nova import flags
-from nova import test
-from nova.compute import node
-from nova.volume import storage
-
-
-FLAGS = flags.FLAGS
-
-
-class StorageTestCase(test.TrialTestCase):
- def setUp(self):
- logging.getLogger().setLevel(logging.DEBUG)
- super(StorageTestCase, self).setUp()
- self.mynode = node.Node()
- self.mystorage = None
- self.flags(connection_type='fake',
- fake_storage=True)
- self.mystorage = storage.BlockStore()
-
- def test_run_create_volume(self):
- vol_size = '0'
- user_id = 'fake'
- project_id = 'fake'
- volume_id = self.mystorage.create_volume(vol_size, user_id, project_id)
- # TODO(termie): get_volume returns differently than create_volume
- self.assertEqual(volume_id,
- storage.get_volume(volume_id)['volume_id'])
-
- rv = self.mystorage.delete_volume(volume_id)
- self.assertRaises(exception.Error,
- storage.get_volume,
- volume_id)
-
- def test_too_big_volume(self):
- vol_size = '1001'
- user_id = 'fake'
- project_id = 'fake'
- self.assertRaises(TypeError,
- self.mystorage.create_volume,
- vol_size, user_id, project_id)
-
- def test_too_many_volumes(self):
- vol_size = '1'
- user_id = 'fake'
- project_id = 'fake'
- num_shelves = FLAGS.last_shelf_id - FLAGS.first_shelf_id + 1
- total_slots = FLAGS.slots_per_shelf * num_shelves
- vols = []
- for i in xrange(total_slots):
- vid = self.mystorage.create_volume(vol_size, user_id, project_id)
- vols.append(vid)
- self.assertRaises(storage.NoMoreVolumes,
- self.mystorage.create_volume,
- vol_size, user_id, project_id)
- for id in vols:
- self.mystorage.delete_volume(id)
-
- def test_run_attach_detach_volume(self):
- # Create one volume and one node to test with
- instance_id = "storage-test"
- vol_size = "5"
- user_id = "fake"
- project_id = 'fake'
- mountpoint = "/dev/sdf"
- volume_id = self.mystorage.create_volume(vol_size, user_id, project_id)
-
- volume_obj = storage.get_volume(volume_id)
- volume_obj.start_attach(instance_id, mountpoint)
- rv = yield self.mynode.attach_volume(volume_id,
- instance_id,
- mountpoint)
- self.assertEqual(volume_obj['status'], "in-use")
- self.assertEqual(volume_obj['attachStatus'], "attached")
- self.assertEqual(volume_obj['instance_id'], instance_id)
- self.assertEqual(volume_obj['mountpoint'], mountpoint)
-
- self.assertRaises(exception.Error,
- self.mystorage.delete_volume,
- volume_id)
-
- rv = yield self.mystorage.detach_volume(volume_id)
- volume_obj = storage.get_volume(volume_id)
- self.assertEqual(volume_obj['status'], "available")
-
- rv = self.mystorage.delete_volume(volume_id)
- self.assertRaises(exception.Error,
- storage.get_volume,
- volume_id)
-
- def test_multi_node(self):
- # TODO(termie): Figure out how to test with two nodes,
- # each of them having a different FLAG for storage_node
- # This will allow us to test cross-node interactions
- pass
diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py
index 2aab16809..edcdba425 100644
--- a/nova/tests/virt_unittest.py
+++ b/nova/tests/virt_unittest.py
@@ -14,36 +14,77 @@
# License for the specific language governing permissions and limitations
# under the License.
+from xml.etree.ElementTree import fromstring as xml_to_tree
+from xml.dom.minidom import parseString as xml_to_dom
+
+from nova import db
from nova import flags
from nova import test
+from nova import utils
+from nova.api import context
+from nova.api.ec2 import cloud
+from nova.auth import manager
from nova.virt import libvirt_conn
FLAGS = flags.FLAGS
-
+flags.DECLARE('instances_path', 'nova.compute.manager')
class LibvirtConnTestCase(test.TrialTestCase):
+ def setUp(self):
+ super(LibvirtConnTestCase, self).setUp()
+ self.manager = manager.AuthManager()
+ self.user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
+ self.project = self.manager.create_project('fake', 'fake', 'fake')
+ self.network = utils.import_object(FLAGS.network_manager)
+ FLAGS.instances_path = ''
+
def test_get_uri_and_template(self):
- class MockDataModel(object):
- def __init__(self):
- self.datamodel = { 'name' : 'i-cafebabe',
- 'memory_kb' : '1024000',
- 'basepath' : '/some/path',
- 'bridge_name' : 'br100',
- 'mac_address' : '02:12:34:46:56:67',
- 'vcpus' : 2 }
+ ip = '10.11.12.13'
+
+ instance = { 'internal_id' : 1,
+ 'memory_kb' : '1024000',
+ 'basepath' : '/some/path',
+ 'bridge_name' : 'br100',
+ 'mac_address' : '02:12:34:46:56:67',
+ 'vcpus' : 2,
+ 'project_id' : 'fake',
+ 'bridge' : 'br101',
+ 'instance_type' : 'm1.small'}
+
+ instance_ref = db.instance_create(None, instance)
+ user_context = context.APIRequestContext(project=self.project,
+ user=self.user)
+ network_ref = self.network.get_network(user_context)
+ self.network.set_network_host(context.get_admin_context(),
+ network_ref['id'])
+
+ fixed_ip = { 'address' : ip,
+ 'network_id' : network_ref['id'] }
+
+ fixed_ip_ref = db.fixed_ip_create(None, fixed_ip)
+ db.fixed_ip_update(None, ip, { 'allocated' : True,
+ 'instance_id' : instance_ref['id'] })
type_uri_map = { 'qemu' : ('qemu:///system',
- [lambda s: '<domain type=\'qemu\'>' in s,
- lambda s: 'type>hvm</type' in s,
- lambda s: 'emulator>/usr/bin/kvm' not in s]),
+ [(lambda t: t.find('.').get('type'), 'qemu'),
+ (lambda t: t.find('./os/type').text, 'hvm'),
+ (lambda t: t.find('./devices/emulator'), None)]),
'kvm' : ('qemu:///system',
- [lambda s: '<domain type=\'kvm\'>' in s,
- lambda s: 'type>hvm</type' in s,
- lambda s: 'emulator>/usr/bin/qemu<' not in s]),
+ [(lambda t: t.find('.').get('type'), 'kvm'),
+ (lambda t: t.find('./os/type').text, 'hvm'),
+ (lambda t: t.find('./devices/emulator'), None)]),
'uml' : ('uml:///system',
- [lambda s: '<domain type=\'uml\'>' in s,
- lambda s: 'type>uml</type' in s]),
- }
+ [(lambda t: t.find('.').get('type'), 'uml'),
+ (lambda t: t.find('./os/type').text, 'uml')]),
+ }
+
+ common_checks = [(lambda t: t.find('.').tag, 'domain'),
+ (lambda t: \
+ t.find('./devices/interface/filterref/parameter') \
+ .get('name'), 'IP'),
+ (lambda t: \
+ t.find('./devices/interface/filterref/parameter') \
+ .get('value'), '10.11.12.13')]
for (libvirt_type,(expected_uri, checks)) in type_uri_map.iteritems():
FLAGS.libvirt_type = libvirt_type
@@ -52,9 +93,17 @@ class LibvirtConnTestCase(test.TrialTestCase):
uri, template = conn.get_uri_and_template()
self.assertEquals(uri, expected_uri)
- for i, check in enumerate(checks):
- xml = conn.toXml(MockDataModel())
- self.assertTrue(check(xml), '%s failed check %d' % (xml, i))
+ xml = conn.to_xml(instance_ref)
+ tree = xml_to_tree(xml)
+ for i, (check, expected_result) in enumerate(checks):
+ self.assertEqual(check(tree),
+ expected_result,
+ '%s failed check %d' % (xml, i))
+
+ for i, (check, expected_result) in enumerate(common_checks):
+ self.assertEqual(check(tree),
+ expected_result,
+ '%s failed common check %d' % (xml, i))
# Deliberately not just assigning this string to FLAGS.libvirt_uri and
# checking against that later on. This way we make sure the
@@ -67,3 +116,143 @@ class LibvirtConnTestCase(test.TrialTestCase):
uri, template = conn.get_uri_and_template()
self.assertEquals(uri, testuri)
+
+ def tearDown(self):
+ super(LibvirtConnTestCase, self).tearDown()
+ self.manager.delete_project(self.project)
+ self.manager.delete_user(self.user)
+
+class NWFilterTestCase(test.TrialTestCase):
+ def setUp(self):
+ super(NWFilterTestCase, self).setUp()
+
+ class Mock(object):
+ pass
+
+ self.manager = manager.AuthManager()
+ self.user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
+ self.project = self.manager.create_project('fake', 'fake', 'fake')
+ self.context = context.APIRequestContext(self.user, self.project)
+
+ self.fake_libvirt_connection = Mock()
+
+ self.fw = libvirt_conn.NWFilterFirewall(self.fake_libvirt_connection)
+
+ def tearDown(self):
+ self.manager.delete_project(self.project)
+ self.manager.delete_user(self.user)
+
+
+ def test_cidr_rule_nwfilter_xml(self):
+ cloud_controller = cloud.CloudController()
+ cloud_controller.create_security_group(self.context,
+ 'testgroup',
+ 'test group description')
+ cloud_controller.authorize_security_group_ingress(self.context,
+ 'testgroup',
+ from_port='80',
+ to_port='81',
+ ip_protocol='tcp',
+ cidr_ip='0.0.0.0/0')
+
+
+ security_group = db.security_group_get_by_name(self.context,
+ 'fake',
+ 'testgroup')
+
+ xml = self.fw.security_group_to_nwfilter_xml(security_group.id)
+
+ dom = xml_to_dom(xml)
+ self.assertEqual(dom.firstChild.tagName, 'filter')
+
+ rules = dom.getElementsByTagName('rule')
+ self.assertEqual(len(rules), 1)
+
+ # It's supposed to allow inbound traffic.
+ self.assertEqual(rules[0].getAttribute('action'), 'accept')
+ self.assertEqual(rules[0].getAttribute('direction'), 'in')
+
+ # Must be lower priority than the base filter (which blocks everything)
+ self.assertTrue(int(rules[0].getAttribute('priority')) < 1000)
+
+ ip_conditions = rules[0].getElementsByTagName('tcp')
+ self.assertEqual(len(ip_conditions), 1)
+ self.assertEqual(ip_conditions[0].getAttribute('srcipaddr'), '0.0.0.0')
+ self.assertEqual(ip_conditions[0].getAttribute('srcipmask'), '0.0.0.0')
+ self.assertEqual(ip_conditions[0].getAttribute('dstportstart'), '80')
+ self.assertEqual(ip_conditions[0].getAttribute('dstportend'), '81')
+
+
+ self.teardown_security_group()
+
+ def teardown_security_group(self):
+ cloud_controller = cloud.CloudController()
+ cloud_controller.delete_security_group(self.context, 'testgroup')
+
+
+ def setup_and_return_security_group(self):
+ cloud_controller = cloud.CloudController()
+ cloud_controller.create_security_group(self.context,
+ 'testgroup',
+ 'test group description')
+ cloud_controller.authorize_security_group_ingress(self.context,
+ 'testgroup',
+ from_port='80',
+ to_port='81',
+ ip_protocol='tcp',
+ cidr_ip='0.0.0.0/0')
+
+ return db.security_group_get_by_name(self.context, 'fake', 'testgroup')
+
+ def test_creates_base_rule_first(self):
+ # These come pre-defined by libvirt
+ self.defined_filters = ['no-mac-spoofing',
+ 'no-ip-spoofing',
+ 'no-arp-spoofing',
+ 'allow-dhcp-server']
+
+ self.recursive_depends = {}
+ for f in self.defined_filters:
+ self.recursive_depends[f] = []
+
+ def _filterDefineXMLMock(xml):
+ dom = xml_to_dom(xml)
+ name = dom.firstChild.getAttribute('name')
+ self.recursive_depends[name] = []
+ for f in dom.getElementsByTagName('filterref'):
+ ref = f.getAttribute('filter')
+ self.assertTrue(ref in self.defined_filters,
+ ('%s referenced filter that does ' +
+ 'not yet exist: %s') % (name, ref))
+ dependencies = [ref] + self.recursive_depends[ref]
+ self.recursive_depends[name] += dependencies
+
+ self.defined_filters.append(name)
+ return True
+
+ self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock
+
+ instance_ref = db.instance_create(self.context,
+ {'user_id': 'fake',
+ 'project_id': 'fake'})
+ inst_id = instance_ref['id']
+
+ def _ensure_all_called(_):
+ instance_filter = 'nova-instance-%s' % instance_ref['name']
+ secgroup_filter = 'nova-secgroup-%s' % self.security_group['id']
+ for required in [secgroup_filter, 'allow-dhcp-server',
+ 'no-arp-spoofing', 'no-ip-spoofing',
+ 'no-mac-spoofing']:
+ self.assertTrue(required in self.recursive_depends[instance_filter],
+ "Instance's filter does not include %s" % required)
+
+ self.security_group = self.setup_and_return_security_group()
+
+ db.instance_add_security_group(self.context, inst_id, self.security_group.id)
+ instance = db.instance_get(self.context, inst_id)
+
+ d = self.fw.setup_nwfilters_for_instance(instance)
+ d.addCallback(_ensure_all_called)
+ d.addCallback(lambda _:self.teardown_security_group())
+
+ return d
diff --git a/nova/tests/volume_unittest.py b/nova/tests/volume_unittest.py
index 2a07afe69..1d665b502 100644
--- a/nova/tests/volume_unittest.py
+++ b/nova/tests/volume_unittest.py
@@ -15,139 +15,159 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
-
+"""
+Tests for Volume Code
+"""
import logging
-import shutil
-import tempfile
from twisted.internet import defer
-from nova import compute
from nova import exception
+from nova import db
from nova import flags
from nova import test
-from nova.volume import service as volume_service
-
+from nova import utils
FLAGS = flags.FLAGS
class VolumeTestCase(test.TrialTestCase):
- def setUp(self):
+ """Test Case for volumes"""
+ def setUp(self): # pylint: disable-msg=C0103
logging.getLogger().setLevel(logging.DEBUG)
super(VolumeTestCase, self).setUp()
- self.compute = compute.service.ComputeService()
- self.volume = None
- self.tempdir = tempfile.mkdtemp()
- self.flags(connection_type='fake',
- fake_storage=True,
- aoe_export_dir=self.tempdir)
- self.volume = volume_service.VolumeService()
-
- def tearDown(self):
- shutil.rmtree(self.tempdir)
+ self.compute = utils.import_object(FLAGS.compute_manager)
+ self.flags(connection_type='fake')
+ self.volume = utils.import_object(FLAGS.volume_manager)
+ self.context = None
+
+ @staticmethod
+ def _create_volume(size='0'):
+ """Create a volume object"""
+ vol = {}
+ vol['size'] = size
+ vol['user_id'] = 'fake'
+ vol['project_id'] = 'fake'
+ vol['availability_zone'] = FLAGS.storage_availability_zone
+ vol['status'] = "creating"
+ vol['attach_status'] = "detached"
+ return db.volume_create(None, vol)['id']
@defer.inlineCallbacks
- def test_run_create_volume(self):
- vol_size = '0'
- user_id = 'fake'
- project_id = 'fake'
- volume_id = yield self.volume.create_volume(vol_size, user_id, project_id)
- # TODO(termie): get_volume returns differently than create_volume
- self.assertEqual(volume_id,
- volume_service.get_volume(volume_id)['volume_id'])
-
- rv = self.volume.delete_volume(volume_id)
- self.assertRaises(exception.Error, volume_service.get_volume, volume_id)
+ def test_create_delete_volume(self):
+ """Test volume can be created and deleted"""
+ volume_id = self._create_volume()
+ yield self.volume.create_volume(self.context, volume_id)
+ self.assertEqual(volume_id, db.volume_get(None, volume_id).id)
+
+ yield self.volume.delete_volume(self.context, volume_id)
+ self.assertRaises(exception.NotFound,
+ db.volume_get,
+ None,
+ volume_id)
@defer.inlineCallbacks
def test_too_big_volume(self):
- vol_size = '1001'
- user_id = 'fake'
- project_id = 'fake'
+ """Ensure failure if a too large of a volume is requested"""
+ # FIXME(vish): validation needs to move into the data layer in
+ # volume_create
+ defer.returnValue(True)
try:
- yield self.volume.create_volume(vol_size, user_id, project_id)
+ volume_id = self._create_volume('1001')
+ yield self.volume.create_volume(self.context, volume_id)
self.fail("Should have thrown TypeError")
except TypeError:
pass
@defer.inlineCallbacks
def test_too_many_volumes(self):
- vol_size = '1'
- user_id = 'fake'
- project_id = 'fake'
- num_shelves = FLAGS.last_shelf_id - FLAGS.first_shelf_id + 1
- total_slots = FLAGS.blades_per_shelf * num_shelves
+ """Ensure that NoMoreBlades is raised when we run out of volumes"""
vols = []
- from nova import datastore
- redis = datastore.Redis.instance()
- for i in xrange(total_slots):
- vid = yield self.volume.create_volume(vol_size, user_id, project_id)
- vols.append(vid)
- self.assertFailure(self.volume.create_volume(vol_size,
- user_id,
- project_id),
- volume_service.NoMoreBlades)
- for id in vols:
- yield self.volume.delete_volume(id)
+ total_slots = FLAGS.num_shelves * FLAGS.blades_per_shelf
+ for _index in xrange(total_slots):
+ volume_id = self._create_volume()
+ yield self.volume.create_volume(self.context, volume_id)
+ vols.append(volume_id)
+ volume_id = self._create_volume()
+ self.assertFailure(self.volume.create_volume(self.context,
+ volume_id),
+ db.NoMoreBlades)
+ db.volume_destroy(None, volume_id)
+ for volume_id in vols:
+ yield self.volume.delete_volume(self.context, volume_id)
@defer.inlineCallbacks
def test_run_attach_detach_volume(self):
- # Create one volume and one compute to test with
- instance_id = "storage-test"
- vol_size = "5"
- user_id = "fake"
- project_id = 'fake'
+ """Make sure volume can be attached and detached from instance"""
+ inst = {}
+ inst['image_id'] = 'ami-test'
+ inst['reservation_id'] = 'r-fakeres'
+ inst['launch_time'] = '10'
+ inst['user_id'] = 'fake'
+ inst['project_id'] = 'fake'
+ inst['instance_type'] = 'm1.tiny'
+ inst['mac_address'] = utils.generate_mac()
+ inst['ami_launch_index'] = 0
+ instance_id = db.instance_create(self.context, inst)['id']
mountpoint = "/dev/sdf"
- volume_id = yield self.volume.create_volume(vol_size, user_id, project_id)
- volume_obj = volume_service.get_volume(volume_id)
- volume_obj.start_attach(instance_id, mountpoint)
+ volume_id = self._create_volume()
+ yield self.volume.create_volume(self.context, volume_id)
if FLAGS.fake_tests:
- volume_obj.finish_attach()
+ db.volume_attached(None, volume_id, instance_id, mountpoint)
else:
- rv = yield self.compute.attach_volume(instance_id,
- volume_id,
- mountpoint)
- self.assertEqual(volume_obj['status'], "in-use")
- self.assertEqual(volume_obj['attach_status'], "attached")
- self.assertEqual(volume_obj['instance_id'], instance_id)
- self.assertEqual(volume_obj['mountpoint'], mountpoint)
-
- self.assertFailure(self.volume.delete_volume(volume_id), exception.Error)
- volume_obj.start_detach()
+ yield self.compute.attach_volume(instance_id,
+ volume_id,
+ mountpoint)
+ vol = db.volume_get(None, volume_id)
+ self.assertEqual(vol['status'], "in-use")
+ self.assertEqual(vol['attach_status'], "attached")
+ self.assertEqual(vol['mountpoint'], mountpoint)
+ instance_ref = db.volume_get_instance(self.context, volume_id)
+ self.assertEqual(instance_ref['id'], instance_id)
+
+ self.assertFailure(self.volume.delete_volume(self.context, volume_id),
+ exception.Error)
if FLAGS.fake_tests:
- volume_obj.finish_detach()
+ db.volume_detached(None, volume_id)
else:
- rv = yield self.volume.detach_volume(instance_id,
- volume_id)
- volume_obj = volume_service.get_volume(volume_id)
- self.assertEqual(volume_obj['status'], "available")
+ yield self.compute.detach_volume(instance_id,
+ volume_id)
+ vol = db.volume_get(None, volume_id)
+ self.assertEqual(vol['status'], "available")
- rv = self.volume.delete_volume(volume_id)
+ yield self.volume.delete_volume(self.context, volume_id)
self.assertRaises(exception.Error,
- volume_service.get_volume,
+ db.volume_get,
+ None,
volume_id)
+ db.instance_destroy(self.context, instance_id)
@defer.inlineCallbacks
- def test_multiple_volume_race_condition(self):
- vol_size = "5"
- user_id = "fake"
- project_id = 'fake'
+ def test_concurrent_volumes_get_different_blades(self):
+ """Ensure multiple concurrent volumes get different blades"""
+ volume_ids = []
shelf_blades = []
+
def _check(volume_id):
- vol = volume_service.get_volume(volume_id)
- shelf_blade = '%s.%s' % (vol['shelf_id'], vol['blade_id'])
+ """Make sure blades aren't duplicated"""
+ volume_ids.append(volume_id)
+ (shelf_id, blade_id) = db.volume_get_shelf_and_blade(None,
+ volume_id)
+ shelf_blade = '%s.%s' % (shelf_id, blade_id)
self.assert_(shelf_blade not in shelf_blades)
shelf_blades.append(shelf_blade)
- logging.debug("got %s" % shelf_blade)
- vol.destroy()
+ logging.debug("Blade %s allocated", shelf_blade)
deferreds = []
- for i in range(5):
- d = self.volume.create_volume(vol_size, user_id, project_id)
+ total_slots = FLAGS.num_shelves * FLAGS.blades_per_shelf
+ for _index in xrange(total_slots):
+ volume_id = self._create_volume()
+ d = self.volume.create_volume(self.context, volume_id)
d.addCallback(_check)
d.addErrback(self.fail)
deferreds.append(d)
yield defer.DeferredList(deferreds)
+ for volume_id in volume_ids:
+ self.volume.delete_volume(self.context, volume_id)
def test_multi_node(self):
# TODO(termie): Figure out how to test with two nodes,
diff --git a/nova/twistd.py b/nova/twistd.py
index 8de322aa5..9511c231c 100644
--- a/nova/twistd.py
+++ b/nova/twistd.py
@@ -21,6 +21,7 @@ Twisted daemon helpers, specifically to parse out gFlags from twisted flags,
manage pid files and support syslogging.
"""
+import gflags
import logging
import os
import signal
@@ -49,6 +50,14 @@ class TwistdServerOptions(ServerOptions):
return
+class FlagParser(object):
+ def __init__(self, parser):
+ self.parser = parser
+
+ def Parse(self, s):
+ return self.parser(s)
+
+
def WrapTwistedOptions(wrapped):
class TwistedOptionsToFlags(wrapped):
subCommands = None
@@ -79,7 +88,12 @@ def WrapTwistedOptions(wrapped):
reflect.accumulateClassList(self.__class__, 'optParameters', twistd_params)
for param in twistd_params:
key = param[0].replace('-', '_')
- flags.DEFINE_string(key, param[2], str(param[-1]))
+ if len(param) > 4:
+ flags.DEFINE(FlagParser(param[4]),
+ key, param[2], str(param[3]),
+ serializer=gflags.ArgumentSerializer())
+ else:
+ flags.DEFINE_string(key, param[2], str(param[3]))
def _absorbHandlers(self):
twistd_handlers = {}
diff --git a/nova/utils.py b/nova/utils.py
index dc3c626ec..10b27ffec 100644
--- a/nova/utils.py
+++ b/nova/utils.py
@@ -29,14 +29,16 @@ import subprocess
import socket
import sys
+from twisted.internet.threads import deferToThread
+
from nova import exception
from nova import flags
+from nova.exception import ProcessExecutionError
FLAGS = flags.FLAGS
TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"
-
def import_class(import_str):
"""Returns a class from a string including module and class"""
mod_str, _sep, class_str = import_str.rpartition('.')
@@ -46,6 +48,14 @@ def import_class(import_str):
except (ImportError, ValueError, AttributeError):
raise exception.NotFound('Class %s cannot be found' % class_str)
+def import_object(import_str):
+ """Returns an object including a module or module and class"""
+ try:
+ __import__(import_str)
+ return sys.modules[import_str]
+ except ImportError:
+ cls = import_class(import_str)
+ return cls()
def fetchfile(url, target):
logging.debug("Fetching %s" % url)
@@ -59,6 +69,7 @@ def fetchfile(url, target):
execute("curl --fail %s -o %s" % (url, target))
def execute(cmd, process_input=None, addl_env=None, check_exit_code=True):
+ logging.debug("Running cmd: %s", cmd)
env = os.environ.copy()
if addl_env:
env.update(addl_env)
@@ -73,8 +84,11 @@ def execute(cmd, process_input=None, addl_env=None, check_exit_code=True):
if obj.returncode:
logging.debug("Result was %s" % (obj.returncode))
if check_exit_code and obj.returncode <> 0:
- raise Exception( "Unexpected exit code: %s. result=%s"
- % (obj.returncode, result))
+ (stdout, stderr) = result
+ raise ProcessExecutionError(exit_code=obj.returncode,
+ stdout=stdout,
+ stderr=stderr,
+ cmd=cmd)
return result
@@ -105,12 +119,20 @@ def runthis(prompt, cmd, check_exit_code = True):
exit_code = subprocess.call(cmd.split(" "))
logging.debug(prompt % (exit_code))
if check_exit_code and exit_code <> 0:
- raise Exception( "Unexpected exit code: %s from cmd: %s"
- % (exit_code, cmd))
+ raise ProcessExecutionError(exit_code=exit_code,
+ stdout=None,
+ stderr=None,
+ cmd=cmd)
def generate_uid(topic, size=8):
- return '%s-%s' % (topic, ''.join([random.choice('01234567890abcdefghijklmnopqrstuvwxyz') for x in xrange(size)]))
+ if topic == "i":
+ # Instances have integer internal ids.
+ return random.randint(0, 2**32-1)
+ else:
+ characters = '01234567890abcdefghijklmnopqrstuvwxyz'
+ choices = [random.choice(characters) for x in xrange(size)]
+ return '%s-%s' % (topic, ''.join(choices))
def generate_mac():
@@ -125,8 +147,7 @@ def last_octet(address):
def get_my_ip():
- ''' returns the actual ip of the local machine.
- '''
+ """Returns the actual ip of the local machine."""
if getattr(FLAGS, 'fake_tests', None):
return '127.0.0.1'
try:
@@ -148,3 +169,39 @@ def isotime(at=None):
def parse_isotime(timestr):
return datetime.datetime.strptime(timestr, TIME_FORMAT)
+
+
+class LazyPluggable(object):
+ """A pluggable backend loaded lazily based on some value."""
+
+ def __init__(self, pivot, **backends):
+ self.__backends = backends
+ self.__pivot = pivot
+ self.__backend = None
+
+ def __get_backend(self):
+ if not self.__backend:
+ backend_name = self.__pivot.value
+ if backend_name not in self.__backends:
+ raise exception.Error('Invalid backend: %s' % backend_name)
+
+ backend = self.__backends[backend_name]
+ if type(backend) == type(tuple()):
+ name = backend[0]
+ fromlist = backend[1]
+ else:
+ name = backend
+ fromlist = backend
+
+ self.__backend = __import__(name, None, None, fromlist)
+ logging.info('backend %s', self.__backend)
+ return self.__backend
+
+ def __getattr__(self, key):
+ backend = self.__get_backend()
+ return getattr(backend, key)
+
+def deferredToThread(f):
+ def g(*args, **kwargs):
+ return deferToThread(f, *args, **kwargs)
+ return g
diff --git a/nova/virt/connection.py b/nova/virt/connection.py
index 90bc7fa0a..34e37adf7 100644
--- a/nova/virt/connection.py
+++ b/nova/virt/connection.py
@@ -17,6 +17,11 @@
# License for the specific language governing permissions and limitations
# under the License.
+"""Abstraction of the underlying virtualization API"""
+
+import logging
+import sys
+
from nova import flags
from nova.virt import fake
from nova.virt import libvirt_conn
@@ -35,7 +40,6 @@ def get_connection(read_only=False):
Any object returned here must conform to the interface documented by
FakeConnection.
"""
-
# TODO(termie): maybe lazy load after initial check for permissions
# TODO(termie): check whether we can be disconnected
t = FLAGS.connection_type
diff --git a/nova/virt/fake.py b/nova/virt/fake.py
index 155833f3f..dc6112f20 100644
--- a/nova/virt/fake.py
+++ b/nova/virt/fake.py
@@ -39,12 +39,12 @@ class FakeConnection(object):
The interface to this class talks in terms of 'instances' (Amazon EC2 and
internal Nova terminology), by which we mean 'running virtual machine'
(XenAPI terminology) or domain (Xen or libvirt terminology).
-
+
An instance has an ID, which is the identifier chosen by Nova to represent
the instance further up the stack. This is unfortunately also called a
'name' elsewhere. As far as this layer is concerned, 'instance ID' and
'instance name' are synonyms.
-
+
Note that the instance ID or name is not human-readable or
customer-controlled -- it's an internal ID chosen by Nova. At the
nova.virt layer, instances do not have human-readable names at all -- such
@@ -101,7 +101,7 @@ class FakeConnection(object):
cleaned up, and the virtualization platform should be in the state
that it was before this call began.
"""
-
+
fake_instance = FakeInstance()
self.instances[instance.name] = fake_instance
fake_instance._state = power_state.RUNNING
@@ -132,7 +132,15 @@ class FakeConnection(object):
del self.instances[instance.name]
return defer.succeed(None)
- def get_info(self, instance_id):
+ def attach_volume(self, instance_name, device_path, mountpoint):
+ """Attach the disk at device_path to the instance at mountpoint"""
+ return True
+
+ def detach_volume(self, instance_name, mountpoint):
+ """Detach the disk attached to the instance at mountpoint"""
+ return True
+
+ def get_info(self, instance_name):
"""
Get a block of information about the given instance. This is returned
as a dictionary containing 'state': The power_state of the instance,
@@ -141,42 +149,42 @@ class FakeConnection(object):
of virtual CPUs the instance has, 'cpu_time': The total CPU time used
by the instance, in nanoseconds.
"""
- i = self.instances[instance_id]
+ i = self.instances[instance_name]
return {'state': i._state,
'max_mem': 0,
'mem': 0,
'num_cpu': 2,
'cpu_time': 0}
- def list_disks(self, instance_id):
+ def list_disks(self, instance_name):
"""
Return the IDs of all the virtual disks attached to the specified
instance, as a list. These IDs are opaque to the caller (they are
only useful for giving back to this layer as a parameter to
disk_stats). These IDs only need to be unique for a given instance.
-
+
Note that this function takes an instance ID, not a
compute.service.Instance, so that it can be called by compute.monitor.
"""
return ['A_DISK']
- def list_interfaces(self, instance_id):
+ def list_interfaces(self, instance_name):
"""
Return the IDs of all the virtual network interfaces attached to the
specified instance, as a list. These IDs are opaque to the caller
(they are only useful for giving back to this layer as a parameter to
interface_stats). These IDs only need to be unique for a given
instance.
-
+
Note that this function takes an instance ID, not a
compute.service.Instance, so that it can be called by compute.monitor.
"""
return ['A_VIF']
- def block_stats(self, instance_id, disk_id):
+ def block_stats(self, instance_name, disk_id):
"""
Return performance counters associated with the given disk_id on the
- given instance_id. These are returned as [rd_req, rd_bytes, wr_req,
+ given instance_name. These are returned as [rd_req, rd_bytes, wr_req,
wr_bytes, errs], where rd indicates read, wr indicates write, req is
the total number of I/O requests made, bytes is the total number of
bytes transferred, and errs is the number of requests held up due to a
@@ -188,13 +196,13 @@ class FakeConnection(object):
statistics can be retrieved directly in aggregate form, without Nova
having to do the aggregation. On those platforms, this method is
unused.
-
+
Note that this function takes an instance ID, not a
compute.service.Instance, so that it can be called by compute.monitor.
"""
return [0L, 0L, 0L, 0L, null]
- def interface_stats(self, instance_id, iface_id):
+ def interface_stats(self, instance_name, iface_id):
"""
Return performance counters associated with the given iface_id on the
given instance_id. These are returned as [rx_bytes, rx_packets,
@@ -209,12 +217,14 @@ class FakeConnection(object):
statistics can be retrieved directly in aggregate form, without Nova
having to do the aggregation. On those platforms, this method is
unused.
-
+
Note that this function takes an instance ID, not a
compute.service.Instance, so that it can be called by compute.monitor.
"""
return [0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L]
+ def get_console_output(self, instance):
+ return 'FAKE CONSOLE OUTPUT'
class FakeInstance(object):
def __init__(self):
diff --git a/nova/virt/images.py b/nova/virt/images.py
index a60bcc4c1..dc50764d9 100644
--- a/nova/virt/images.py
+++ b/nova/virt/images.py
@@ -29,6 +29,7 @@ from nova import flags
from nova import process
from nova.auth import manager
from nova.auth import signer
+from nova.objectstore import image
FLAGS = flags.FLAGS
diff --git a/nova/virt/interfaces.template b/nova/virt/interfaces.template
index 11df301f6..87b92b84a 100644
--- a/nova/virt/interfaces.template
+++ b/nova/virt/interfaces.template
@@ -10,7 +10,6 @@ auto eth0
iface eth0 inet static
address %(address)s
netmask %(netmask)s
- network %(network)s
broadcast %(broadcast)s
gateway %(gateway)s
dns-nameservers %(dns)s
diff --git a/nova/virt/libvirt.qemu.xml.template b/nova/virt/libvirt.qemu.xml.template
index 8a7c64f88..739eceaaa 100644
--- a/nova/virt/libvirt.qemu.xml.template
+++ b/nova/virt/libvirt.qemu.xml.template
@@ -24,11 +24,14 @@
<source bridge='${bridge_name}'/>
<mac address='${mac_address}'/>
<!-- <model type='virtio'/> CANT RUN virtio network right now -->
+ <filterref filter="nova-instance-%(name)s">
+ <parameter name="IP" value="%(ip_address)s" />
+ <parameter name="DHCPSERVER" value="%(dhcp_server)s" />
+ </filterref>
</interface>
<serial type="file">
<source path='${basepath}/console.log'/>
<target port='1'/>
</serial>
</devices>
- <nova>${nova}</nova>
</domain>
diff --git a/nova/virt/libvirt.uml.xml.template b/nova/virt/libvirt.uml.xml.template
index 6f4290f98..f6e5fad69 100644
--- a/nova/virt/libvirt.uml.xml.template
+++ b/nova/virt/libvirt.uml.xml.template
@@ -1,25 +1,26 @@
-<domain type='%(type)s'>
- <name>%(name)s</name>
- <memory>%(memory_kb)s</memory>
+<domain type='${type}'>
+ <name>${name}</name>
+ <memory>${memory_kb}</memory>
<os>
- <type>%(type)s</type>
+ <type>${type}</type>
<kernel>/usr/bin/linux</kernel>
<root>/dev/ubda1</root>
</os>
<devices>
<disk type='file'>
- <source file='%(basepath)s/disk'/>
+ <source file='${disk}'/>
<target dev='ubd0' bus='uml'/>
</disk>
<interface type='bridge'>
- <source bridge='%(bridge_name)s'/>
- <mac address='%(mac_address)s'/>
+ <source bridge='${bridge_name}'/>
+ <mac address='${mac_address}'/>
+ <filterref filter="nova-instance-${name}">
+ <parameter name="IP" value="${ip_address}" />
+ <parameter name="DHCPSERVER" value="${dhcp_server}" />
+ </filterref>
</interface>
- <console type="pty" />
- <serial type="file">
- <source path='%(basepath)s/console.log'/>
- <target port='1'/>
- </serial>
+ <console type="file">
+ <source path='${basepath}/console.log'/>
+ </console>
</devices>
- <nova>%(nova)s</nova>
</domain>
diff --git a/nova/virt/libvirt.xen.xml.template b/nova/virt/libvirt.xen.xml.template
new file mode 100644
index 000000000..9508ad3b7
--- /dev/null
+++ b/nova/virt/libvirt.xen.xml.template
@@ -0,0 +1,35 @@
+<domain type='${type}'>
+ <name>${name}</name>
+ <os>
+ <type>linux</type>
+#if $getVar('kernel', None)
+ <kernel>${kernel}</kernel>
+ #if $getVar('ramdisk', None)
+ <initrd>${ramdisk}</initrd>
+ #end if
+ <cmdline>root=/dev/vda1 console=ttyS0</cmdline>
+#end if
+ <root>/dev/xvda1</root>
+ <cmdline>ro</cmdline>
+ </os>
+ <features>
+ <acpi/>
+ </features>
+ <memory>${memory_kb}</memory>
+ <vcpu>${vcpus}</vcpu>
+ <devices>
+ <disk type='file'>
+ <source file='${disk}'/>
+ <target dev='sda' />
+ </disk>
+ <interface type='bridge'>
+ <source bridge='${bridge_name}'/>
+ <mac address='${mac_address}'/>
+ </interface>
+ <console type="file">
+ <source path='${basepath}/console.log'/>
+ <target port='1'/>
+ </console>
+ </devices>
+</domain>
+
diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py
index 7f7c4a131..ba5d6dbac 100644
--- a/nova/virt/libvirt_conn.py
+++ b/nova/virt/libvirt_conn.py
@@ -21,18 +21,21 @@
A connection to a hypervisor (e.g. KVM) through libvirt.
"""
-import json
import logging
-import os.path
+import os
import shutil
+import IPy
from twisted.internet import defer
from twisted.internet import task
+from twisted.internet import threads
+from nova import db
from nova import exception
from nova import flags
from nova import process
from nova import utils
+#from nova.api import context
from nova.auth import manager
from nova.compute import disk
from nova.compute import instance_types
@@ -49,6 +52,9 @@ FLAGS = flags.FLAGS
flags.DEFINE_string('libvirt_xml_template',
utils.abspath('virt/libvirt.qemu.xml.template'),
'Libvirt XML Template for QEmu/KVM')
+flags.DEFINE_string('libvirt_xen_xml_template',
+ utils.abspath('virt/libvirt.xen.xml.template'),
+ 'Libvirt XML Template for Xen')
flags.DEFINE_string('libvirt_uml_xml_template',
utils.abspath('virt/libvirt.uml.xml.template'),
'Libvirt XML Template for user-mode-linux')
@@ -57,11 +63,14 @@ flags.DEFINE_string('injected_network_template',
'Template file for injected network')
flags.DEFINE_string('libvirt_type',
'kvm',
- 'Libvirt domain type (valid options are: kvm, qemu, uml)')
+ 'Libvirt domain type (valid options are: kvm, qemu, uml, xen)')
flags.DEFINE_string('libvirt_uri',
'',
'Override the default libvirt URI (which is dependent'
' on libvirt_type)')
+flags.DEFINE_bool('allow_project_net_traffic',
+ True,
+ 'Whether to allow in project network traffic')
def get_connection(read_only):
# These are loaded late so that there's no need to install these
@@ -85,14 +94,29 @@ class LibvirtConnection(object):
@property
def _conn(self):
- if not self._wrapped_conn:
+ if not self._wrapped_conn or not self._test_connection():
+ logging.debug('Connecting to libvirt: %s' % self.libvirt_uri)
self._wrapped_conn = self._connect(self.libvirt_uri, self.read_only)
return self._wrapped_conn
+ def _test_connection(self):
+ try:
+ self._wrapped_conn.getInfo()
+ return True
+ except libvirt.libvirtError as e:
+ if e.get_error_code() == libvirt.VIR_ERR_SYSTEM_ERROR and \
+ e.get_error_domain() == libvirt.VIR_FROM_REMOTE:
+ logging.debug('Connection to libvirt broke')
+ return False
+ raise
+
def get_uri_and_template(self):
if FLAGS.libvirt_type == 'uml':
uri = FLAGS.libvirt_uri or 'uml:///system'
template_file = FLAGS.libvirt_uml_xml_template
+ elif FLAGS.libvirt_type == 'xen':
+ uri = FLAGS.libvirt_uri or 'xen:///'
+ template_file = FLAGS.libvirt_xen_xml_template
else:
uri = FLAGS.libvirt_uri or 'qemu:///system'
template_file = FLAGS.libvirt_xml_template
@@ -113,26 +137,29 @@ class LibvirtConnection(object):
def destroy(self, instance):
try:
- virt_dom = self._conn.lookupByName(instance.name)
+ virt_dom = self._conn.lookupByName(instance['name'])
virt_dom.destroy()
- except Exception, _err:
+ except Exception as _err:
pass
# If the instance is already terminated, we're still happy
d = defer.Deferred()
d.addCallback(lambda _: self._cleanup(instance))
# FIXME: What does this comment mean?
# TODO(termie): short-circuit me for tests
- # WE'LL save this for when we do shutdown,
+ # WE'LL save this for when we do shutdown,
# instead of destroy - but destroy returns immediately
timer = task.LoopingCall(f=None)
def _wait_for_shutdown():
try:
- instance.update_state()
- if instance.state == power_state.SHUTDOWN:
+ state = self.get_info(instance['name'])['state']
+ db.instance_set_state(None, instance['id'], state)
+ if state == power_state.SHUTDOWN:
timer.stop()
d.callback(None)
except Exception:
- instance.set_state(power_state.SHUTDOWN)
+ db.instance_set_state(None,
+ instance['id'],
+ power_state.SHUTDOWN)
timer.stop()
d.callback(None)
timer.f = _wait_for_shutdown
@@ -140,30 +167,51 @@ class LibvirtConnection(object):
return d
def _cleanup(self, instance):
- target = os.path.abspath(instance.datamodel['basepath'])
- logging.info("Deleting instance files at %s", target)
+ target = os.path.join(FLAGS.instances_path, instance['name'])
+ logging.info('instance %s: deleting instance files %s',
+ instance['name'], target)
if os.path.exists(target):
shutil.rmtree(target)
@defer.inlineCallbacks
@exception.wrap_exception
+ def attach_volume(self, instance_name, device_path, mountpoint):
+ yield process.simple_execute("sudo virsh attach-disk %s %s %s" %
+ (instance_name,
+ device_path,
+ mountpoint.rpartition('/dev/')[2]))
+
+ @defer.inlineCallbacks
+ @exception.wrap_exception
+ def detach_volume(self, instance_name, mountpoint):
+ # NOTE(vish): despite the documentation, virsh detach-disk just
+ # wants the device name without the leading /dev/
+ yield process.simple_execute("sudo virsh detach-disk %s %s" %
+ (instance_name,
+ mountpoint.rpartition('/dev/')[2]))
+
+ @defer.inlineCallbacks
+ @exception.wrap_exception
def reboot(self, instance):
- xml = self.toXml(instance)
- yield self._conn.lookupByName(instance.name).destroy()
+ xml = self.to_xml(instance)
+ yield self._conn.lookupByName(instance['name']).destroy()
yield self._conn.createXML(xml, 0)
d = defer.Deferred()
timer = task.LoopingCall(f=None)
def _wait_for_reboot():
try:
- instance.update_state()
- if instance.is_running():
- logging.debug('rebooted instance %s' % instance.name)
+ state = self.get_info(instance['name'])['state']
+ db.instance_set_state(None, instance['id'], state)
+ if state == power_state.RUNNING:
+ logging.debug('instance %s: rebooted', instance['name'])
timer.stop()
d.callback(None)
except Exception, exn:
- logging.error('_wait_for_reboot failed: %s' % exn)
- instance.set_state(power_state.SHUTDOWN)
+ logging.error('_wait_for_reboot failed: %s', exn)
+ db.instance_set_state(None,
+ instance['id'],
+ power_state.SHUTDOWN)
timer.stop()
d.callback(None)
timer.f = _wait_for_reboot
@@ -173,38 +221,86 @@ class LibvirtConnection(object):
@defer.inlineCallbacks
@exception.wrap_exception
def spawn(self, instance):
- xml = self.toXml(instance)
- instance.set_state(power_state.NOSTATE, 'launching')
+ xml = self.to_xml(instance)
+ db.instance_set_state(None,
+ instance['id'],
+ power_state.NOSTATE,
+ 'launching')
+ yield NWFilterFirewall(self._conn).setup_nwfilters_for_instance(instance)
yield self._create_image(instance, xml)
yield self._conn.createXML(xml, 0)
# TODO(termie): this should actually register
# a callback to check for successful boot
- logging.debug("Instance is running")
+ logging.debug("instance %s: is running", instance['name'])
local_d = defer.Deferred()
timer = task.LoopingCall(f=None)
def _wait_for_boot():
try:
- instance.update_state()
- if instance.is_running():
- logging.debug('booted instance %s' % instance.name)
+ state = self.get_info(instance['name'])['state']
+ db.instance_set_state(None, instance['id'], state)
+ if state == power_state.RUNNING:
+ logging.debug('instance %s: booted', instance['name'])
timer.stop()
local_d.callback(None)
- except Exception, exn:
- logging.error("_wait_for_boot exception %s" % exn)
- self.set_state(power_state.SHUTDOWN)
- logging.error('Failed to boot instance %s' % instance.name)
+ except:
+ logging.exception('instance %s: failed to boot',
+ instance['name'])
+ db.instance_set_state(None,
+ instance['id'],
+ power_state.SHUTDOWN)
timer.stop()
local_d.callback(None)
timer.f = _wait_for_boot
timer.start(interval=0.5, now=True)
yield local_d
+ def _flush_xen_console(self, virsh_output):
+ logging.info('virsh said: %r' % (virsh_output,))
+ virsh_output = virsh_output[0].strip()
+
+ if virsh_output.startswith('/dev/'):
+ logging.info('cool, it\'s a device')
+ d = process.simple_execute("sudo dd if=%s iflag=nonblock" % virsh_output, check_exit_code=False)
+ d.addCallback(lambda r:r[0])
+ return d
+ else:
+ return ''
+
+ def _append_to_file(self, data, fpath):
+ logging.info('data: %r, fpath: %r' % (data, fpath))
+ fp = open(fpath, 'a+')
+ fp.write(data)
+ return fpath
+
+ def _dump_file(self, fpath):
+ fp = open(fpath, 'r+')
+ contents = fp.read()
+ logging.info('Contents: %r' % (contents,))
+ return contents
+
+ @exception.wrap_exception
+ def get_console_output(self, instance):
+ console_log = os.path.join(FLAGS.instances_path, instance['internal_id'], 'console.log')
+ logging.info('console_log: %s' % console_log)
+ logging.info('FLAGS.libvirt_type: %s' % FLAGS.libvirt_type)
+ if FLAGS.libvirt_type == 'xen':
+ # Xen is spethial
+ d = process.simple_execute("virsh ttyconsole %s" % instance['name'])
+ d.addCallback(self._flush_xen_console)
+ d.addCallback(self._append_to_file, console_log)
+ else:
+ d = defer.succeed(console_log)
+ d.addCallback(self._dump_file)
+ return d
+
+
@defer.inlineCallbacks
- def _create_image(self, instance, libvirt_xml):
+ def _create_image(self, inst, libvirt_xml):
# syntactic nicety
- data = instance.datamodel
- basepath = lambda x='': self.basepath(instance, x)
+ basepath = lambda fname='': os.path.join(FLAGS.instances_path,
+ inst['name'],
+ fname)
# ensure directories exist and are writable
yield process.simple_execute('mkdir -p %s' % basepath())
@@ -213,75 +309,100 @@ class LibvirtConnection(object):
# TODO(termie): these are blocking calls, it would be great
# if they weren't.
- logging.info('Creating image for: %s', data['instance_id'])
+ logging.info('instance %s: Creating image', inst['name'])
f = open(basepath('libvirt.xml'), 'w')
f.write(libvirt_xml)
f.close()
- user = manager.AuthManager().get_user(data['user_id'])
- project = manager.AuthManager().get_project(data['project_id'])
+ os.close(os.open(basepath('console.log'), os.O_CREAT | os.O_WRONLY, 0660))
+
+ user = manager.AuthManager().get_user(inst['user_id'])
+ project = manager.AuthManager().get_project(inst['project_id'])
+
if not os.path.exists(basepath('disk')):
- yield images.fetch(data['image_id'], basepath('disk-raw'), user, project)
-
- using_kernel = data['kernel_id'] and True
+ yield images.fetch(inst.image_id, basepath('disk-raw'), user, project)
+ using_kernel = inst.kernel_id and True
if using_kernel:
if not os.path.exists(basepath('kernel')):
- yield images.fetch(data['kernel_id'], basepath('kernel'), user, project)
+ yield images.fetch(inst.kernel_id, basepath('kernel'), user, project)
if not os.path.exists(basepath('ramdisk')):
- yield images.fetch(data['ramdisk_id'], basepath('ramdisk'), user, project)
+ yield images.fetch(inst.ramdisk_id, basepath('ramdisk'), user, project)
execute = lambda cmd, process_input=None: \
process.simple_execute(cmd=cmd,
process_input=process_input,
check_exit_code=True)
- # For now, we assume that if we're not using a kernel, we're using a partitioned disk image
- # where the target partition is the first partition
+ # For now, we assume that if we're not using a kernel, we're using a
+ # partitioned disk image where the target partition is the first
+ # partition
target_partition = None
if not using_kernel:
target_partition = "1"
- key = data['key_data']
+ key = str(inst['key_data'])
net = None
- if data.get('inject_network', False):
+ network_ref = db.network_get_by_instance(None, inst['id'])
+ if network_ref['injected']:
+ address = db.instance_get_fixed_address(None, inst['id'])
with open(FLAGS.injected_network_template) as f:
- net = f.read() % {'address': data['private_dns_name'],
- 'network': data['network_network'],
- 'netmask': data['network_netmask'],
- 'gateway': data['network_gateway'],
- 'broadcast': data['network_broadcast'],
- 'dns': data['network_dns']}
+ net = f.read() % {'address': address,
+ 'netmask': network_ref['netmask'],
+ 'gateway': network_ref['gateway'],
+ 'broadcast': network_ref['broadcast'],
+ 'dns': network_ref['dns']}
if key or net:
- logging.info('Injecting data into image %s', data['image_id'])
+ if key:
+ logging.info('instance %s: injecting key into image %s',
+ inst['name'], inst.image_id)
+ if net:
+ logging.info('instance %s: injecting net into image %s',
+ inst['name'], inst.image_id)
try:
- yield disk.inject_data(basepath('disk-raw'), key=key, net=net, dns=dns, remove_network_udev=True, partition=target_partition, execute=execute)
+ yield disk.inject_data(basepath('disk-raw'), key, net,
+ partition=target_partition,
+ execute=execute)
except Exception as e:
# This could be a windows image, or a vmdk format disk
- logging.warn('Could not inject data; ignoring. (%s)' % e)
+ logging.warn('instance %s: ignoring error injecting data'
+ ' into image %s (%s)',
+ inst['name'], inst.image_id, e)
if using_kernel:
if os.path.exists(basepath('disk')):
yield process.simple_execute('rm -f %s' % basepath('disk'))
- bytes = (instance_types.INSTANCE_TYPES[data['instance_type']]['local_gb']
+ bytes = (instance_types.INSTANCE_TYPES[inst.instance_type]['local_gb']
* 1024 * 1024 * 1024)
yield disk.partition(
basepath('disk-raw'), basepath('disk'), bytes, execute=execute)
- def basepath(self, instance, path=''):
- return os.path.abspath(os.path.join(instance.datamodel['basepath'], path))
+ if FLAGS.libvirt_type == 'uml':
+ yield process.simple_execute('sudo chown root %s' %
+ basepath('disk'))
- def toXml(self, instance):
+ def to_xml(self, instance):
# TODO(termie): cache?
- logging.debug("Starting the toXML method")
- template_contents = open(FLAGS.libvirt_xml_template).read()
- xml_info = instance.datamodel.copy()
- # TODO(joshua): Make this xml express the attached disks as well
-
- # TODO(termie): lazy lazy hack because xml is annoying
- xml_info['nova'] = json.dumps(instance.datamodel.copy())
-
+ logging.debug('instance %s: starting toXML method', instance['name'])
+ network = db.project_get_network(None,
+ instance['project_id'])
+ # FIXME(vish): stick this in db
+ instance_type = instance_types.INSTANCE_TYPES[instance['instance_type']]
+ ip_address = db.instance_get_fixed_address({}, instance['id'])
+ # Assume that the gateway also acts as the dhcp server.
+ dhcp_server = network['gateway']
+ xml_info = {'type': FLAGS.libvirt_type,
+ 'name': instance['name'],
+ 'basepath': os.path.join(FLAGS.instances_path,
+ instance['name']),
+ 'memory_kb': instance_type['memory_mb'] * 1024,
+ 'vcpus': instance_type['vcpus'],
+ 'bridge_name': network['bridge'],
+ 'mac_address': instance['mac_address'],
+ 'ip_address': ip_address,
+ 'dhcp_server': dhcp_server }
+
if xml_info['kernel_id']:
xml_info['kernel'] = xml_info['basepath'] + "/kernel"
@@ -293,16 +414,13 @@ class LibvirtConnection(object):
else:
xml_info['disk'] = xml_info['basepath'] + "/disk-raw"
- xml_info['type'] = FLAGS.libvirt_type
-
- libvirt_xml = str(Template(template_contents, searchList=[ xml_info ] ))
-
- logging.debug("Finished the toXML method")
+ xml = str(Template(self.libvirt_xml, searchList=[ xml_info ] ))
+ logging.debug('instance %s: finished toXML method', instance['name'])
- return libvirt_xml
+ return xml
- def get_info(self, instance_id):
- virt_dom = self._conn.lookupByName(instance_id)
+ def get_info(self, instance_name):
+ virt_dom = self._conn.lookupByName(instance_name)
(state, max_mem, mem, num_cpu, cpu_time) = virt_dom.info()
return {'state': state,
'max_mem': max_mem,
@@ -310,8 +428,14 @@ class LibvirtConnection(object):
'num_cpu': num_cpu,
'cpu_time': cpu_time}
- def get_disks(self, instance_id):
- domain = self._conn.lookupByName(instance_id)
+ def get_disks(self, instance_name):
+ """
+ Note that this function takes an instance name, not an Instance, so
+ that it can be called by monitor.
+
+ Returns a list of all block devices for this domain.
+ """
+ domain = self._conn.lookupByName(instance_name)
# TODO(devcamcar): Replace libxml2 with etree.
xml = domain.XMLDesc(0)
doc = None
@@ -346,8 +470,14 @@ class LibvirtConnection(object):
return disks
- def get_interfaces(self, instance_id):
- domain = self._conn.lookupByName(instance_id)
+ def get_interfaces(self, instance_name):
+ """
+ Note that this function takes an instance name, not an Instance, so
+ that it can be called by monitor.
+
+ Returns a list of all network interfaces for this instance.
+ """
+ domain = self._conn.lookupByName(instance_name)
# TODO(devcamcar): Replace libxml2 with etree.
xml = domain.XMLDesc(0)
doc = None
@@ -382,10 +512,210 @@ class LibvirtConnection(object):
return interfaces
- def block_stats(self, instance_id, disk):
- domain = self._conn.lookupByName(instance_id)
+ def block_stats(self, instance_name, disk):
+ """
+ Note that this function takes an instance name, not an Instance, so
+ that it can be called by monitor.
+ """
+ domain = self._conn.lookupByName(instance_name)
return domain.blockStats(disk)
- def interface_stats(self, instance_id, interface):
- domain = self._conn.lookupByName(instance_id)
+ def interface_stats(self, instance_name, interface):
+ """
+ Note that this function takes an instance name, not an Instance, so
+ that it can be called by monitor.
+ """
+ domain = self._conn.lookupByName(instance_name)
return domain.interfaceStats(interface)
+
+
+ def refresh_security_group(self, security_group_id):
+ fw = NWFilterFirewall(self._conn)
+ fw.ensure_security_group_filter(security_group_id)
+
+
+class NWFilterFirewall(object):
+ """
+ This class implements a network filtering mechanism versatile
+ enough for EC2 style Security Group filtering by leveraging
+ libvirt's nwfilter.
+
+ First, all instances get a filter ("nova-base-filter") applied.
+ This filter drops all incoming ipv4 and ipv6 connections.
+ Outgoing connections are never blocked.
+
+ Second, every security group maps to a nwfilter filter(*).
+ NWFilters can be updated at runtime and changes are applied
+ immediately, so changes to security groups can be applied at
+ runtime (as mandated by the spec).
+
+ Security group rules are named "nova-secgroup-<id>" where <id>
+ is the internal id of the security group. They're applied only on
+ hosts that have instances in the security group in question.
+
+ Updates to security groups are done by updating the data model
+ (in response to API calls) followed by a request sent to all
+ the nodes with instances in the security group to refresh the
+ security group.
+
+ Each instance has its own NWFilter, which references the above
+ mentioned security group NWFilters. This was done because
+ interfaces can only reference one filter while filters can
+ reference multiple other filters. This has the added benefit of
+ actually being able to add and remove security groups from an
+ instance at run time. This functionality is not exposed anywhere,
+ though.
+
+ Outstanding questions:
+
+ The name is unique, so would there be any good reason to sync
+ the uuid across the nodes (by assigning it from the datamodel)?
+
+
+ (*) This sentence brought to you by the redundancy department of
+ redundancy.
+ """
+
+ def __init__(self, get_connection):
+ self._conn = get_connection
+
+
+ nova_base_filter = '''<filter name='nova-base' chain='root'>
+ <uuid>26717364-50cf-42d1-8185-29bf893ab110</uuid>
+ <filterref filter='no-mac-spoofing'/>
+ <filterref filter='no-ip-spoofing'/>
+ <filterref filter='no-arp-spoofing'/>
+ <filterref filter='allow-dhcp-server'/>
+ <filterref filter='nova-allow-dhcp-server'/>
+ <filterref filter='nova-base-ipv4'/>
+ <filterref filter='nova-base-ipv6'/>
+ </filter>'''
+
+ nova_dhcp_filter = '''<filter name='nova-allow-dhcp-server' chain='ipv4'>
+ <uuid>891e4787-e5c0-d59b-cbd6-41bc3c6b36fc</uuid>
+ <rule action='accept' direction='out'
+ priority='100'>
+ <udp srcipaddr='0.0.0.0'
+ dstipaddr='255.255.255.255'
+ srcportstart='68'
+ dstportstart='67'/>
+ </rule>
+ <rule action='accept' direction='in' priority='100'>
+ <udp srcipaddr='$DHCPSERVER'
+ srcportstart='67'
+ dstportstart='68'/>
+ </rule>
+ </filter>'''
+
+ def nova_base_ipv4_filter(self):
+ retval = "<filter name='nova-base-ipv4' chain='ipv4'>"
+ for protocol in ['tcp', 'udp', 'icmp']:
+ for direction,action,priority in [('out','accept', 399),
+ ('inout','drop', 400)]:
+ retval += """<rule action='%s' direction='%s' priority='%d'>
+ <%s />
+ </rule>""" % (action, direction,
+ priority, protocol)
+ retval += '</filter>'
+ return retval
+
+
+ def nova_base_ipv6_filter(self):
+ retval = "<filter name='nova-base-ipv6' chain='ipv6'>"
+ for protocol in ['tcp', 'udp', 'icmp']:
+ for direction,action,priority in [('out','accept',399),
+ ('inout','drop',400)]:
+ retval += """<rule action='%s' direction='%s' priority='%d'>
+ <%s-ipv6 />
+ </rule>""" % (action, direction,
+ priority, protocol)
+ retval += '</filter>'
+ return retval
+
+
+ def nova_project_filter(self, project, net, mask):
+ retval = "<filter name='nova-project-%s' chain='ipv4'>" % project
+ for protocol in ['tcp', 'udp', 'icmp']:
+ retval += """<rule action='accept' direction='in' priority='200'>
+ <%s srcipaddr='%s' srcipmask='%s' />
+ </rule>""" % (protocol, net, mask)
+ retval += '</filter>'
+ return retval
+
+
+ def _define_filter(self, xml):
+ if callable(xml):
+ xml = xml()
+ d = threads.deferToThread(self._conn.nwfilterDefineXML, xml)
+ return d
+
+
+ @staticmethod
+ def _get_net_and_mask(cidr):
+ net = IPy.IP(cidr)
+ return str(net.net()), str(net.netmask())
+
+ @defer.inlineCallbacks
+ def setup_nwfilters_for_instance(self, instance):
+ """
+ Creates an NWFilter for the given instance. In the process,
+ it makes sure the filters for the security groups as well as
+ the base filter are all in place.
+ """
+
+ yield self._define_filter(self.nova_base_ipv4_filter)
+ yield self._define_filter(self.nova_base_ipv6_filter)
+ yield self._define_filter(self.nova_dhcp_filter)
+ yield self._define_filter(self.nova_base_filter)
+
+ nwfilter_xml = ("<filter name='nova-instance-%s' chain='root'>\n" +
+ " <filterref filter='nova-base' />\n"
+ ) % instance['name']
+
+ if FLAGS.allow_project_net_traffic:
+ network_ref = db.project_get_network({}, instance['project_id'])
+ net, mask = self._get_net_and_mask(network_ref['cidr'])
+ project_filter = self.nova_project_filter(instance['project_id'],
+ net, mask)
+ yield self._define_filter(project_filter)
+
+ nwfilter_xml += (" <filterref filter='nova-project-%s' />\n"
+ ) % instance['project_id']
+
+ for security_group in instance.security_groups:
+ yield self.ensure_security_group_filter(security_group['id'])
+
+ nwfilter_xml += (" <filterref filter='nova-secgroup-%d' />\n"
+ ) % security_group['id']
+ nwfilter_xml += "</filter>"
+
+ yield self._define_filter(nwfilter_xml)
+ return
+
+ def ensure_security_group_filter(self, security_group_id):
+ return self._define_filter(
+ self.security_group_to_nwfilter_xml(security_group_id))
+
+
+ def security_group_to_nwfilter_xml(self, security_group_id):
+ security_group = db.security_group_get({}, security_group_id)
+ rule_xml = ""
+ for rule in security_group.rules:
+ rule_xml += "<rule action='accept' direction='in' priority='300'>"
+ if rule.cidr:
+ net, mask = self._get_net_and_mask(rule.cidr)
+ rule_xml += "<%s srcipaddr='%s' srcipmask='%s' " % (rule.protocol, net, mask)
+ if rule.protocol in ['tcp', 'udp']:
+ rule_xml += "dstportstart='%s' dstportend='%s' " % \
+ (rule.from_port, rule.to_port)
+ elif rule.protocol == 'icmp':
+ logging.info('rule.protocol: %r, rule.from_port: %r, rule.to_port: %r' % (rule.protocol, rule.from_port, rule.to_port))
+ if rule.from_port != -1:
+ rule_xml += "type='%s' " % rule.from_port
+ if rule.to_port != -1:
+ rule_xml += "code='%s' " % rule.to_port
+
+ rule_xml += '/>\n'
+ rule_xml += "</rule>\n"
+ xml = '''<filter name='nova-secgroup-%s' chain='ipv4'>%s</filter>''' % (security_group_id, rule_xml,)
+ return xml
diff --git a/nova/virt/xenapi.py b/nova/virt/xenapi.py
index 2f5994983..04e830b64 100644
--- a/nova/virt/xenapi.py
+++ b/nova/virt/xenapi.py
@@ -16,18 +16,38 @@
"""
A connection to XenServer or Xen Cloud Platform.
+
+The concurrency model for this class is as follows:
+
+All XenAPI calls are on a thread (using t.i.t.deferToThread, via the decorator
+deferredToThread). They are remote calls, and so may hang for the usual
+reasons. They should not be allowed to block the reactor thread.
+
+All long-running XenAPI calls (VM.start, VM.reboot, etc) are called async
+(using XenAPI.VM.async_start etc). These return a task, which can then be
+polled for completion. Polling is handled using reactor.callLater.
+
+This combination of techniques means that we don't block the reactor thread at
+all, and at the same time we don't hold lots of threads waiting for
+long-running operations.
+
+FIXME: get_info currently doesn't conform to these rules, and will block the
+reactor thread if the VM.get_by_name_label or VM.get_record calls block.
"""
import logging
import xmlrpclib
from twisted.internet import defer
+from twisted.internet import reactor
from twisted.internet import task
-from nova import exception
+from nova import db
from nova import flags
from nova import process
+from nova import utils
from nova.auth.manager import AuthManager
+from nova.compute import instance_types
from nova.compute import power_state
from nova.virt import images
@@ -47,6 +67,11 @@ flags.DEFINE_string('xenapi_connection_password',
None,
'Password for connection to XenServer/Xen Cloud Platform.'
' Used only if connection_type=xenapi.')
+flags.DEFINE_float('xenapi_task_poll_interval',
+ 0.5,
+ 'The interval used for polling of remote tasks '
+ '(Async.VM.start, etc). Used only if '
+ 'connection_type=xenapi.')
XENAPI_POWER_STATE = {
@@ -80,48 +105,46 @@ class XenAPIConnection(object):
self._conn.login_with_password(user, pw)
def list_instances(self):
- result = [self._conn.xenapi.VM.get_name_label(vm) \
- for vm in self._conn.xenapi.VM.get_all()]
+ return [self._conn.xenapi.VM.get_name_label(vm) \
+ for vm in self._conn.xenapi.VM.get_all()]
@defer.inlineCallbacks
- @exception.wrap_exception
def spawn(self, instance):
- vm = yield self.lookup(instance.name)
+ vm = yield self._lookup(instance.name)
if vm is not None:
raise Exception('Attempted to create non-unique name %s' %
instance.name)
- if 'bridge_name' in instance.datamodel:
- network_ref = \
- yield self._find_network_with_bridge(
- instance.datamodel['bridge_name'])
- else:
- network_ref = None
-
- if 'mac_address' in instance.datamodel:
- mac_address = instance.datamodel['mac_address']
- else:
- mac_address = ''
-
- user = AuthManager().get_user(instance.datamodel['user_id'])
- project = AuthManager().get_project(instance.datamodel['project_id'])
- vdi_uuid = yield self.fetch_image(
- instance.datamodel['image_id'], user, project, True)
- kernel = yield self.fetch_image(
- instance.datamodel['kernel_id'], user, project, False)
- ramdisk = yield self.fetch_image(
- instance.datamodel['ramdisk_id'], user, project, False)
- vdi_ref = yield self._conn.xenapi.VDI.get_by_uuid(vdi_uuid)
-
- vm_ref = yield self.create_vm(instance, kernel, ramdisk)
- yield self.create_vbd(vm_ref, vdi_ref, 0, True)
+ network = db.project_get_network(None, instance.project_id)
+ network_ref = \
+ yield self._find_network_with_bridge(network.bridge)
+
+ user = AuthManager().get_user(instance.user_id)
+ project = AuthManager().get_project(instance.project_id)
+ vdi_uuid = yield self._fetch_image(
+ instance.image_id, user, project, True)
+ kernel = yield self._fetch_image(
+ instance.kernel_id, user, project, False)
+ ramdisk = yield self._fetch_image(
+ instance.ramdisk_id, user, project, False)
+ vdi_ref = yield self._call_xenapi('VDI.get_by_uuid', vdi_uuid)
+
+ vm_ref = yield self._create_vm(instance, kernel, ramdisk)
+ yield self._create_vbd(vm_ref, vdi_ref, 0, True)
if network_ref:
- yield self._create_vif(vm_ref, network_ref, mac_address)
- yield self._conn.xenapi.VM.start(vm_ref, False, False)
+ yield self._create_vif(vm_ref, network_ref, instance.mac_address)
+ logging.debug('Starting VM %s...', vm_ref)
+ yield self._call_xenapi('VM.start', vm_ref, False, False)
+ logging.info('Spawning VM %s created %s.', instance.name, vm_ref)
- def create_vm(self, instance, kernel, ramdisk):
- mem = str(long(instance.datamodel['memory_kb']) * 1024)
- vcpus = str(instance.datamodel['vcpus'])
+ @defer.inlineCallbacks
+ def _create_vm(self, instance, kernel, ramdisk):
+ """Create a VM record. Returns a Deferred that gives the new
+ VM reference."""
+
+ instance_type = instance_types.INSTANCE_TYPES[instance.instance_type]
+ mem = str(long(instance_type['memory_mb']) * 1024 * 1024)
+ vcpus = str(instance_type['vcpus'])
rec = {
'name_label': instance.name,
'name_description': '',
@@ -152,11 +175,15 @@ class XenAPIConnection(object):
'other_config': {},
}
logging.debug('Created VM %s...', instance.name)
- vm_ref = self._conn.xenapi.VM.create(rec)
+ vm_ref = yield self._call_xenapi('VM.create', rec)
logging.debug('Created VM %s as %s.', instance.name, vm_ref)
- return vm_ref
+ defer.returnValue(vm_ref)
- def create_vbd(self, vm_ref, vdi_ref, userdevice, bootable):
+ @defer.inlineCallbacks
+ def _create_vbd(self, vm_ref, vdi_ref, userdevice, bootable):
+ """Create a VBD record. Returns a Deferred that gives the new
+ VBD reference."""
+
vbd_rec = {}
vbd_rec['VM'] = vm_ref
vbd_rec['VDI'] = vdi_ref
@@ -171,12 +198,16 @@ class XenAPIConnection(object):
vbd_rec['qos_algorithm_params'] = {}
vbd_rec['qos_supported_algorithms'] = []
logging.debug('Creating VBD for VM %s, VDI %s ... ', vm_ref, vdi_ref)
- vbd_ref = self._conn.xenapi.VBD.create(vbd_rec)
+ vbd_ref = yield self._call_xenapi('VBD.create', vbd_rec)
logging.debug('Created VBD %s for VM %s, VDI %s.', vbd_ref, vm_ref,
vdi_ref)
- return vbd_ref
+ defer.returnValue(vbd_ref)
+ @defer.inlineCallbacks
def _create_vif(self, vm_ref, network_ref, mac_address):
+ """Create a VIF record. Returns a Deferred that gives the new
+ VIF reference."""
+
vif_rec = {}
vif_rec['device'] = '0'
vif_rec['network']= network_ref
@@ -188,25 +219,29 @@ class XenAPIConnection(object):
vif_rec['qos_algorithm_params'] = {}
logging.debug('Creating VIF for VM %s, network %s ... ', vm_ref,
network_ref)
- vif_ref = self._conn.xenapi.VIF.create(vif_rec)
+ vif_ref = yield self._call_xenapi('VIF.create', vif_rec)
logging.debug('Created VIF %s for VM %s, network %s.', vif_ref,
vm_ref, network_ref)
- return vif_ref
+ defer.returnValue(vif_ref)
+ @defer.inlineCallbacks
def _find_network_with_bridge(self, bridge):
expr = 'field "bridge" = "%s"' % bridge
- networks = self._conn.xenapi.network.get_all_records_where(expr)
+ networks = yield self._call_xenapi('network.get_all_records_where',
+ expr)
if len(networks) == 1:
- return networks.keys()[0]
+ defer.returnValue(networks.keys()[0])
elif len(networks) > 1:
raise Exception('Found non-unique network for bridge %s' % bridge)
else:
raise Exception('Found no network for bridge %s' % bridge)
- def fetch_image(self, image, user, project, use_sr):
+ @defer.inlineCallbacks
+ def _fetch_image(self, image, user, project, use_sr):
"""use_sr: True to put the image as a VDI in an SR, False to place
it on dom0's filesystem. The former is for VM disks, the latter for
- its kernel and ramdisk (if external kernels are being used)."""
+ its kernel and ramdisk (if external kernels are being used).
+ Returns a Deferred that gives the new VDI UUID."""
url = images.image_url(image)
access = AuthManager().get_access_key(user, project)
@@ -218,22 +253,38 @@ class XenAPIConnection(object):
args['password'] = user.secret
if use_sr:
args['add_partition'] = 'true'
- return self._call_plugin('objectstore', fn, args)
+ task = yield self._async_call_plugin('objectstore', fn, args)
+ uuid = yield self._wait_for_task(task)
+ defer.returnValue(uuid)
+ @defer.inlineCallbacks
def reboot(self, instance):
- vm = self.lookup(instance.name)
+ vm = yield self._lookup(instance.name)
if vm is None:
raise Exception('instance not present %s' % instance.name)
- yield self._conn.xenapi.VM.clean_reboot(vm)
+ task = yield self._call_xenapi('Async.VM.clean_reboot', vm)
+ yield self._wait_for_task(task)
+ @defer.inlineCallbacks
def destroy(self, instance):
- vm = self.lookup(instance.name)
+ vm = yield self._lookup(instance.name)
if vm is None:
- raise Exception('instance not present %s' % instance.name)
- yield self._conn.xenapi.VM.destroy(vm)
+ # Don't complain, just return. This lets us clean up instances
+ # that have already disappeared from the underlying platform.
+ defer.returnValue(None)
+ try:
+ task = yield self._call_xenapi('Async.VM.hard_shutdown', vm)
+ yield self._wait_for_task(task)
+ except Exception, exc:
+ logging.warn(exc)
+ try:
+ task = yield self._call_xenapi('Async.VM.destroy', vm)
+ yield self._wait_for_task(task)
+ except Exception, exc:
+ logging.warn(exc)
def get_info(self, instance_id):
- vm = self.lookup(instance_id)
+ vm = self._lookup_blocking(instance_id)
if vm is None:
raise Exception('instance not present %s' % instance_id)
rec = self._conn.xenapi.VM.get_record(vm)
@@ -243,7 +294,14 @@ class XenAPIConnection(object):
'num_cpu': rec['VCPUs_max'],
'cpu_time': 0}
- def lookup(self, i):
+ def get_console_output(self, instance):
+ return 'FAKE CONSOLE OUTPUT'
+
+ @utils.deferredToThread
+ def _lookup(self, i):
+ return self._lookup_blocking(i)
+
+ def _lookup_blocking(self, i):
vms = self._conn.xenapi.VM.get_by_name_label(i)
n = len(vms)
if n == 0:
@@ -253,9 +311,52 @@ class XenAPIConnection(object):
else:
return vms[0]
- def _call_plugin(self, plugin, fn, args):
+ def _wait_for_task(self, task):
+ """Return a Deferred that will give the result of the given task.
+ The task is polled until it completes."""
+ d = defer.Deferred()
+ reactor.callLater(0, self._poll_task, task, d)
+ return d
+
+ @utils.deferredToThread
+ def _poll_task(self, task, deferred):
+ """Poll the given XenAPI task, and fire the given Deferred if we
+ get a result."""
+ try:
+ #logging.debug('Polling task %s...', task)
+ status = self._conn.xenapi.task.get_status(task)
+ if status == 'pending':
+ reactor.callLater(FLAGS.xenapi_task_poll_interval,
+ self._poll_task, task, deferred)
+ elif status == 'success':
+ result = self._conn.xenapi.task.get_result(task)
+ logging.info('Task %s status: success. %s', task, result)
+ deferred.callback(_parse_xmlrpc_value(result))
+ else:
+ error_info = self._conn.xenapi.task.get_error_info(task)
+ logging.warn('Task %s status: %s. %s', task, status,
+ error_info)
+ deferred.errback(XenAPI.Failure(error_info))
+ #logging.debug('Polling task %s done.', task)
+ except Exception, exc:
+ logging.warn(exc)
+ deferred.errback(exc)
+
+ @utils.deferredToThread
+ def _call_xenapi(self, method, *args):
+ """Call the specified XenAPI method on a background thread. Returns
+ a Deferred for the result."""
+ f = self._conn.xenapi
+ for m in method.split('.'):
+ f = f.__getattr__(m)
+ return f(*args)
+
+ @utils.deferredToThread
+ def _async_call_plugin(self, plugin, fn, args):
+ """Call Async.host.call_plugin on a background thread. Returns a
+ Deferred with the task reference."""
return _unwrap_plugin_exceptions(
- self._conn.xenapi.host.call_plugin,
+ self._conn.xenapi.Async.host.call_plugin,
self._get_xenapi_host(), plugin, fn, args)
def _get_xenapi_host(self):
@@ -265,19 +366,31 @@ class XenAPIConnection(object):
def _unwrap_plugin_exceptions(func, *args, **kwargs):
try:
return func(*args, **kwargs)
- except XenAPI.Failure, exn:
- logging.debug("Got exception: %s", exn)
- if (len(exn.details) == 4 and
- exn.details[0] == 'XENAPI_PLUGIN_EXCEPTION' and
- exn.details[2] == 'Failure'):
+ except XenAPI.Failure, exc:
+ logging.debug("Got exception: %s", exc)
+ if (len(exc.details) == 4 and
+ exc.details[0] == 'XENAPI_PLUGIN_EXCEPTION' and
+ exc.details[2] == 'Failure'):
params = None
try:
- params = eval(exn.details[3])
+ params = eval(exc.details[3])
except:
- raise exn
+ raise exc
raise XenAPI.Failure(params)
else:
raise
- except xmlrpclib.ProtocolError, exn:
- logging.debug("Got exception: %s", exn)
+ except xmlrpclib.ProtocolError, exc:
+ logging.debug("Got exception: %s", exc)
raise
+
+
+def _parse_xmlrpc_value(val):
+ """Parse the given value as if it were an XML-RPC value. This is
+ sometimes used as the format for the task.result field."""
+ if not val:
+ return val
+ x = xmlrpclib.loads(
+ '<?xml version="1.0"?><methodResponse><params><param>' +
+ val +
+ '</param></params></methodResponse>')
+ return x[0][0]
diff --git a/nova/volume/driver.py b/nova/volume/driver.py
new file mode 100644
index 000000000..cca619550
--- /dev/null
+++ b/nova/volume/driver.py
@@ -0,0 +1,136 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Drivers for volumes
+"""
+
+import logging
+
+from twisted.internet import defer
+
+from nova import exception
+from nova import flags
+from nova import process
+
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string('volume_group', 'nova-volumes',
+ 'Name for the VG that will contain exported volumes')
+flags.DEFINE_string('aoe_eth_dev', 'eth0',
+ 'Which device to export the volumes on')
+flags.DEFINE_string('num_shell_tries', 3,
+ 'number of times to attempt to run flakey shell commands')
+
+
+class AOEDriver(object):
+ """Executes commands relating to AOE volumes"""
+ def __init__(self, execute=process.simple_execute, *args, **kwargs):
+ self._execute = execute
+
+ @defer.inlineCallbacks
+ def _try_execute(self, command):
+ # NOTE(vish): Volume commands can partially fail due to timing, but
+ # running them a second time on failure will usually
+ # recover nicely.
+ tries = 0
+ while True:
+ try:
+ yield self._execute(command)
+ defer.returnValue(True)
+ except exception.ProcessExecutionError:
+ tries = tries + 1
+ if tries >= FLAGS.num_shell_tries:
+ raise
+ logging.exception("Recovering from a failed execute."
+ "Try number %s", tries)
+ yield self._execute("sleep %s" % tries ** 2)
+
+
+ @defer.inlineCallbacks
+ def create_volume(self, volume_name, size):
+ """Creates a logical volume"""
+ # NOTE(vish): makes sure that the volume group exists
+ yield self._execute("vgs %s" % FLAGS.volume_group)
+ if int(size) == 0:
+ sizestr = '100M'
+ else:
+ sizestr = '%sG' % size
+ yield self._try_execute("sudo lvcreate -L %s -n %s %s" %
+ (sizestr,
+ volume_name,
+ FLAGS.volume_group))
+
+ @defer.inlineCallbacks
+ def delete_volume(self, volume_name):
+ """Deletes a logical volume"""
+ yield self._try_execute("sudo lvremove -f %s/%s" %
+ (FLAGS.volume_group,
+ volume_name))
+
+ @defer.inlineCallbacks
+ def create_export(self, volume_name, shelf_id, blade_id):
+ """Creates an export for a logical volume"""
+ yield self._try_execute(
+ "sudo vblade-persist setup %s %s %s /dev/%s/%s" %
+ (shelf_id,
+ blade_id,
+ FLAGS.aoe_eth_dev,
+ FLAGS.volume_group,
+ volume_name))
+
+ @defer.inlineCallbacks
+ def discover_volume(self, _volume_name):
+ """Discover volume on a remote host"""
+ yield self._execute("sudo aoe-discover")
+ yield self._execute("sudo aoe-stat")
+
+ @defer.inlineCallbacks
+ def remove_export(self, _volume_name, shelf_id, blade_id):
+ """Removes an export for a logical volume"""
+ yield self._try_execute("sudo vblade-persist stop %s %s" %
+ (shelf_id, blade_id))
+ yield self._try_execute("sudo vblade-persist destroy %s %s" %
+ (shelf_id, blade_id))
+
+ @defer.inlineCallbacks
+ def ensure_exports(self):
+ """Runs all existing exports"""
+ # NOTE(vish): The standard _try_execute does not work here
+ # because these methods throw errors if other
+ # volumes on this host are in the process of
+ # being created. The good news is the command
+ # still works for the other volumes, so we
+ # just wait a bit for the current volume to
+ # be ready and ignore any errors.
+ yield self._execute("sleep 2")
+ yield self._execute("sudo vblade-persist auto all",
+ check_exit_code=False)
+ yield self._execute("sudo vblade-persist start all",
+ check_exit_code=False)
+
+
+class FakeAOEDriver(AOEDriver):
+ """Logs calls instead of executing"""
+ def __init__(self, *args, **kwargs):
+ super(FakeAOEDriver, self).__init__(self.fake_execute)
+
+ @staticmethod
+ def fake_execute(cmd, *_args, **_kwargs):
+ """Execute that simply logs the command"""
+ logging.debug("FAKE AOE: %s", cmd)
diff --git a/nova/volume/manager.py b/nova/volume/manager.py
new file mode 100644
index 000000000..081a2d695
--- /dev/null
+++ b/nova/volume/manager.py
@@ -0,0 +1,132 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Volume manager manages creating, attaching, detaching, and
+destroying persistent storage volumes, ala EBS.
+"""
+
+import logging
+import datetime
+
+from twisted.internet import defer
+
+from nova import exception
+from nova import flags
+from nova import manager
+from nova import utils
+
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string('storage_availability_zone',
+ 'nova',
+ 'availability zone of this service')
+flags.DEFINE_string('volume_driver', 'nova.volume.driver.AOEDriver',
+ 'Driver to use for volume creation')
+flags.DEFINE_integer('num_shelves',
+ 100,
+ 'Number of vblade shelves')
+flags.DEFINE_integer('blades_per_shelf',
+ 16,
+ 'Number of vblade blades per shelf')
+
+
+class AOEManager(manager.Manager):
+ """Manages Ata-Over_Ethernet volumes"""
+ def __init__(self, volume_driver=None, *args, **kwargs):
+ if not volume_driver:
+ volume_driver = FLAGS.volume_driver
+ self.driver = utils.import_object(volume_driver)
+ super(AOEManager, self).__init__(*args, **kwargs)
+
+ def _ensure_blades(self, context):
+ """Ensure that blades have been created in datastore"""
+ total_blades = FLAGS.num_shelves * FLAGS.blades_per_shelf
+ if self.db.export_device_count(context) >= total_blades:
+ return
+ for shelf_id in xrange(FLAGS.num_shelves):
+ for blade_id in xrange(FLAGS.blades_per_shelf):
+ dev = {'shelf_id': shelf_id, 'blade_id': blade_id}
+ self.db.export_device_create_safe(context, dev)
+
+ @defer.inlineCallbacks
+ def create_volume(self, context, volume_id):
+ """Creates and exports the volume"""
+ logging.info("volume %s: creating", volume_id)
+
+ volume_ref = self.db.volume_get(context, volume_id)
+
+ self.db.volume_update(context,
+ volume_id,
+ {'host': self.host})
+
+ size = volume_ref['size']
+ logging.debug("volume %s: creating lv of size %sG", volume_id, size)
+ yield self.driver.create_volume(volume_ref['ec2_id'], size)
+
+ logging.debug("volume %s: allocating shelf & blade", volume_id)
+ self._ensure_blades(context)
+ rval = self.db.volume_allocate_shelf_and_blade(context, volume_id)
+ (shelf_id, blade_id) = rval
+
+ logging.debug("volume %s: exporting shelf %s & blade %s", volume_id,
+ shelf_id, blade_id)
+
+ yield self.driver.create_export(volume_ref['ec2_id'],
+ shelf_id,
+ blade_id)
+
+ logging.debug("volume %s: re-exporting all values", volume_id)
+ yield self.driver.ensure_exports()
+
+ now = datetime.datetime.utcnow()
+ self.db.volume_update(context,
+ volume_ref['id'], {'status': 'available',
+ 'launched_at': now})
+ logging.debug("volume %s: created successfully", volume_id)
+ defer.returnValue(volume_id)
+
+ @defer.inlineCallbacks
+ def delete_volume(self, context, volume_id):
+ """Deletes and unexports volume"""
+ volume_ref = self.db.volume_get(context, volume_id)
+ if volume_ref['attach_status'] == "attached":
+ raise exception.Error("Volume is still attached")
+ if volume_ref['host'] != self.host:
+ raise exception.Error("Volume is not local to this node")
+ logging.debug("Deleting volume with id of: %s", volume_id)
+ shelf_id, blade_id = self.db.volume_get_shelf_and_blade(context,
+ volume_id)
+ yield self.driver.remove_export(volume_ref['ec2_id'],
+ shelf_id,
+ blade_id)
+ yield self.driver.delete_volume(volume_ref['ec2_id'])
+ self.db.volume_destroy(context, volume_id)
+ defer.returnValue(True)
+
+ @defer.inlineCallbacks
+ def setup_compute_volume(self, context, volume_id):
+ """Setup remote volume on compute host
+
+ Returns path to device.
+ """
+ volume_ref = self.db.volume_get(context, volume_id)
+ yield self.driver.discover_volume(volume_ref['ec2_id'])
+ shelf_id, blade_id = self.db.volume_get_shelf_and_blade(context,
+ volume_id)
+ defer.returnValue("/dev/etherd/e%s.%s" % (shelf_id, blade_id))
diff --git a/nova/volume/service.py b/nova/volume/service.py
deleted file mode 100644
index be62f621d..000000000
--- a/nova/volume/service.py
+++ /dev/null
@@ -1,322 +0,0 @@
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2010 United States Government as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""
-Nova Storage manages creating, attaching, detaching, and
-destroying persistent storage volumes, ala EBS.
-Currently uses Ata-over-Ethernet.
-"""
-
-import logging
-import os
-
-from twisted.internet import defer
-
-from nova import datastore
-from nova import exception
-from nova import flags
-from nova import process
-from nova import service
-from nova import utils
-from nova import validate
-
-
-FLAGS = flags.FLAGS
-flags.DEFINE_string('storage_dev', '/dev/sdb',
- 'Physical device to use for volumes')
-flags.DEFINE_string('volume_group', 'nova-volumes',
- 'Name for the VG that will contain exported volumes')
-flags.DEFINE_string('aoe_eth_dev', 'eth0',
- 'Which device to export the volumes on')
-flags.DEFINE_integer('first_shelf_id',
- utils.last_octet(utils.get_my_ip()) * 10,
- 'AoE starting shelf_id for this service')
-flags.DEFINE_integer('last_shelf_id',
- utils.last_octet(utils.get_my_ip()) * 10 + 9,
- 'AoE starting shelf_id for this service')
-flags.DEFINE_string('aoe_export_dir',
- '/var/lib/vblade-persist/vblades',
- 'AoE directory where exports are created')
-flags.DEFINE_integer('blades_per_shelf',
- 16,
- 'Number of AoE blades per shelf')
-flags.DEFINE_string('storage_availability_zone',
- 'nova',
- 'availability zone of this service')
-flags.DEFINE_boolean('fake_storage', False,
- 'Should we make real storage volumes to attach?')
-
-
-class NoMoreBlades(exception.Error):
- pass
-
-
-def get_volume(volume_id):
- """ Returns a redis-backed volume object """
- volume_class = Volume
- if FLAGS.fake_storage:
- volume_class = FakeVolume
- vol = volume_class.lookup(volume_id)
- if vol:
- return vol
- raise exception.Error("Volume does not exist")
-
-
-class VolumeService(service.Service):
- """
- There is one VolumeNode running on each host.
- However, each VolumeNode can report on the state of
- *all* volumes in the cluster.
- """
- def __init__(self):
- super(VolumeService, self).__init__()
- self.volume_class = Volume
- if FLAGS.fake_storage:
- self.volume_class = FakeVolume
- self._init_volume_group()
-
- @defer.inlineCallbacks
- @validate.rangetest(size=(0, 1000))
- def create_volume(self, size, user_id, project_id):
- """
- Creates an exported volume (fake or real),
- restarts exports to make it available.
- Volume at this point has size, owner, and zone.
- """
- logging.debug("Creating volume of size: %s" % (size))
- vol = yield self.volume_class.create(size, user_id, project_id)
- logging.debug("restarting exports")
- yield self._restart_exports()
- defer.returnValue(vol['volume_id'])
-
- def by_node(self, node_id):
- """ returns a list of volumes for a node """
- for volume_id in datastore.Redis.instance().smembers('volumes:%s' % (node_id)):
- yield self.volume_class(volume_id=volume_id)
-
- @property
- def all(self):
- """ returns a list of all volumes """
- for volume_id in datastore.Redis.instance().smembers('volumes'):
- yield self.volume_class(volume_id=volume_id)
-
- @defer.inlineCallbacks
- def delete_volume(self, volume_id):
- logging.debug("Deleting volume with id of: %s" % (volume_id))
- vol = get_volume(volume_id)
- if vol['attach_status'] == "attached":
- raise exception.Error("Volume is still attached")
- if vol['node_name'] != FLAGS.node_name:
- raise exception.Error("Volume is not local to this node")
- yield vol.destroy()
- defer.returnValue(True)
-
- @defer.inlineCallbacks
- def _restart_exports(self):
- if FLAGS.fake_storage:
- return
- # NOTE(vish): these commands sometimes sends output to stderr for warnings
- yield process.simple_execute( "sudo vblade-persist auto all",
- terminate_on_stderr=False)
- yield process.simple_execute( "sudo vblade-persist start all",
- terminate_on_stderr=False)
-
- @defer.inlineCallbacks
- def _init_volume_group(self):
- if FLAGS.fake_storage:
- return
- yield process.simple_execute(
- "sudo pvcreate %s" % (FLAGS.storage_dev))
- yield process.simple_execute(
- "sudo vgcreate %s %s" % (FLAGS.volume_group,
- FLAGS.storage_dev))
-
-
-class Volume(datastore.BasicModel):
-
- def __init__(self, volume_id=None):
- self.volume_id = volume_id
- super(Volume, self).__init__()
-
- @property
- def identifier(self):
- return self.volume_id
-
- def default_state(self):
- return {"volume_id": self.volume_id,
- "node_name": "unassigned"}
-
- @classmethod
- @defer.inlineCallbacks
- def create(cls, size, user_id, project_id):
- volume_id = utils.generate_uid('vol')
- vol = cls(volume_id)
- vol['node_name'] = FLAGS.node_name
- vol['size'] = size
- vol['user_id'] = user_id
- vol['project_id'] = project_id
- vol['availability_zone'] = FLAGS.storage_availability_zone
- vol["instance_id"] = 'none'
- vol["mountpoint"] = 'none'
- vol['attach_time'] = 'none'
- vol['status'] = "creating" # creating | available | in-use
- vol['attach_status'] = "detached" # attaching | attached | detaching | detached
- vol['delete_on_termination'] = 'False'
- vol.save()
- yield vol._create_lv()
- yield vol._setup_export()
- # TODO(joshua) - We need to trigger a fanout message for aoe-discover on all the nodes
- vol['status'] = "available"
- vol.save()
- defer.returnValue(vol)
-
- def start_attach(self, instance_id, mountpoint):
- """ """
- self['instance_id'] = instance_id
- self['mountpoint'] = mountpoint
- self['status'] = "in-use"
- self['attach_status'] = "attaching"
- self['attach_time'] = utils.isotime()
- self['delete_on_termination'] = 'False'
- self.save()
-
- def finish_attach(self):
- """ """
- self['attach_status'] = "attached"
- self.save()
-
- def start_detach(self):
- """ """
- self['attach_status'] = "detaching"
- self.save()
-
- def finish_detach(self):
- self['instance_id'] = None
- self['mountpoint'] = None
- self['status'] = "available"
- self['attach_status'] = "detached"
- self.save()
-
- def save(self):
- is_new = self.is_new_record()
- super(Volume, self).save()
- if is_new:
- redis = datastore.Redis.instance()
- key = self.__devices_key
- # TODO(vish): these should be added by admin commands
- more = redis.scard(self._redis_association_name("node",
- self['node_name']))
- if (not redis.exists(key) and not more):
- for shelf_id in range(FLAGS.first_shelf_id,
- FLAGS.last_shelf_id + 1):
- for blade_id in range(FLAGS.blades_per_shelf):
- redis.sadd(key, "%s.%s" % (shelf_id, blade_id))
- self.associate_with("node", self['node_name'])
-
- @defer.inlineCallbacks
- def destroy(self):
- yield self._remove_export()
- yield self._delete_lv()
- self.unassociate_with("node", self['node_name'])
- if self.get('shelf_id', None) and self.get('blade_id', None):
- redis = datastore.Redis.instance()
- key = self.__devices_key
- redis.sadd(key, "%s.%s" % (self['shelf_id'], self['blade_id']))
- super(Volume, self).destroy()
-
- @defer.inlineCallbacks
- def _create_lv(self):
- if str(self['size']) == '0':
- sizestr = '100M'
- else:
- sizestr = '%sG' % self['size']
- yield process.simple_execute(
- "sudo lvcreate -L %s -n %s %s" % (sizestr,
- self['volume_id'],
- FLAGS.volume_group),
- terminate_on_stderr=False)
-
- @defer.inlineCallbacks
- def _delete_lv(self):
- yield process.simple_execute(
- "sudo lvremove -f %s/%s" % (FLAGS.volume_group,
- self['volume_id']),
- terminate_on_stderr=False)
-
- @property
- def __devices_key(self):
- return 'volume_devices:%s' % FLAGS.node_name
-
- @defer.inlineCallbacks
- def _setup_export(self):
- redis = datastore.Redis.instance()
- key = self.__devices_key
- device = redis.spop(key)
- if not device:
- raise NoMoreBlades()
- (shelf_id, blade_id) = device.split('.')
- self['aoe_device'] = "e%s.%s" % (shelf_id, blade_id)
- self['shelf_id'] = shelf_id
- self['blade_id'] = blade_id
- self.save()
- yield self._exec_setup_export()
-
- @defer.inlineCallbacks
- def _exec_setup_export(self):
- yield process.simple_execute(
- "sudo vblade-persist setup %s %s %s /dev/%s/%s" %
- (self['shelf_id'],
- self['blade_id'],
- FLAGS.aoe_eth_dev,
- FLAGS.volume_group,
- self['volume_id']),
- terminate_on_stderr=False)
-
- @defer.inlineCallbacks
- def _remove_export(self):
- if not self.get('shelf_id', None) or not self.get('blade_id', None):
- defer.returnValue(False)
- yield self._exec_remove_export()
- defer.returnValue(True)
-
- @defer.inlineCallbacks
- def _exec_remove_export(self):
- yield process.simple_execute(
- "sudo vblade-persist stop %s %s" % (self['shelf_id'],
- self['blade_id']),
- terminate_on_stderr=False)
- yield process.simple_execute(
- "sudo vblade-persist destroy %s %s" % (self['shelf_id'],
- self['blade_id']),
- terminate_on_stderr=False)
-
-
-class FakeVolume(Volume):
- def _create_lv(self):
- pass
-
- def _exec_setup_export(self):
- fname = os.path.join(FLAGS.aoe_export_dir, self['aoe_device'])
- f = file(fname, "w")
- f.close()
-
- def _exec_remove_export(self):
- os.unlink(os.path.join(FLAGS.aoe_export_dir, self['aoe_device']))
-
- def _delete_lv(self):
- pass
diff --git a/nova/wsgi.py b/nova/wsgi.py
index fd87afe6e..b91d91121 100644
--- a/nova/wsgi.py
+++ b/nova/wsgi.py
@@ -21,14 +21,17 @@
Utility methods for working with WSGI servers
"""
+import json
import logging
import sys
+from xml.dom import minidom
import eventlet
import eventlet.wsgi
eventlet.patcher.monkey_patch(all=False, socket=True)
import routes
import routes.middleware
+import webob
import webob.dec
import webob.exc
@@ -83,7 +86,7 @@ class Application(object):
raise NotImplementedError("You must implement __call__")
-class Middleware(Application): # pylint: disable-msg=W0223
+class Middleware(Application):
"""
Base WSGI middleware wrapper. These classes require an application to be
initialized that will be called next. By default the middleware will
@@ -95,7 +98,7 @@ class Middleware(Application): # pylint: disable-msg=W0223
self.application = application
@webob.dec.wsgify
- def __call__(self, req):
+ def __call__(self, req): # pylint: disable-msg=W0221
"""Override to implement middleware behavior."""
return self.application
@@ -113,7 +116,7 @@ class Debug(Middleware):
resp = req.get_response(self.application)
print ("*" * 40) + " RESPONSE HEADERS"
- for (key, value) in resp.headers:
+ for (key, value) in resp.headers.iteritems():
print key, "=", value
print
@@ -127,7 +130,7 @@ class Debug(Middleware):
Iterator that prints the contents of a wrapper string iterator
when iterated.
"""
- print ("*" * 40) + "BODY"
+ print ("*" * 40) + " BODY"
for part in app_iter:
sys.stdout.write(part)
sys.stdout.flush()
@@ -176,8 +179,9 @@ class Router(object):
"""
return self._router
+ @staticmethod
@webob.dec.wsgify
- def _dispatch(self, req):
+ def _dispatch(req):
"""
Called by self._router after matching the incoming request to a route
and putting the information into req.environ. Either returns 404
@@ -195,8 +199,10 @@ class Controller(object):
WSGI app that reads routing information supplied by RoutesMiddleware
and calls the requested action method upon itself. All action methods
must, in addition to their normal parameters, accept a 'req' argument
- which is the incoming webob.Request.
+ which is the incoming webob.Request. They raise a webob.exc exception,
+ or return a dict which will be serialized by requested content type.
"""
+
@webob.dec.wsgify
def __call__(self, req):
"""
@@ -208,12 +214,35 @@ class Controller(object):
del arg_dict['controller']
del arg_dict['action']
arg_dict['req'] = req
- return method(**arg_dict)
+ result = method(**arg_dict)
+ if type(result) is dict:
+ return self._serialize(result, req)
+ else:
+ return result
+ def _serialize(self, data, request):
+ """
+ Serialize the given dict to the response type requested in request.
+ Uses self._serialization_metadata if it exists, which is a dict mapping
+ MIME types to information needed to serialize to that type.
+ """
+ _metadata = getattr(type(self), "_serialization_metadata", {})
+ serializer = Serializer(request.environ, _metadata)
+ return serializer.to_content_type(data)
+
+ def _deserialize(self, data, request):
+ """
+ Deserialize the request body to the response type requested in request.
+ Uses self._serialization_metadata if it exists, which is a dict mapping
+ MIME types to information needed to serialize to that type.
+ """
+ _metadata = getattr(type(self), "_serialization_metadata", {})
+ serializer = Serializer(request.environ, _metadata)
+ return serializer.deserialize(data)
class Serializer(object):
"""
- Serializes a dictionary to a Content Type specified by a WSGI environment.
+ Serializes and deserializes dictionaries to certain MIME types.
"""
def __init__(self, environ, metadata=None):
@@ -222,33 +251,83 @@ class Serializer(object):
'metadata' is an optional dict mapping MIME types to information
needed to serialize a dictionary to that type.
"""
- self.environ = environ
self.metadata = metadata or {}
+ req = webob.Request(environ)
+ suffix = req.path_info.split('.')[-1].lower()
+ if suffix == 'json':
+ self.handler = self._to_json
+ elif suffix == 'xml':
+ self.handler = self._to_xml
+ elif 'application/json' in req.accept:
+ self.handler = self._to_json
+ elif 'application/xml' in req.accept:
+ self.handler = self._to_xml
+ else:
+ self.handler = self._to_json # default
def to_content_type(self, data):
"""
- Serialize a dictionary into a string. The format of the string
- will be decided based on the Content Type requested in self.environ:
- by Accept: header, or by URL suffix.
+ Serialize a dictionary into a string.
+
+ The format of the string will be decided based on the Content Type
+ requested in self.environ: by Accept: header, or by URL suffix.
+ """
+ return self.handler(data)
+
+ def deserialize(self, datastring):
+ """
+ Deserialize a string to a dictionary.
+
+ The string must be in the format of a supported MIME type.
+ """
+ datastring = datastring.strip()
+ try:
+ is_xml = (datastring[0] == '<')
+ if not is_xml:
+ return json.loads(datastring)
+ return self._from_xml(datastring)
+ except:
+ return None
+
+ def _from_xml(self, datastring):
+ xmldata = self.metadata.get('application/xml', {})
+ plurals = set(xmldata.get('plurals', {}))
+ node = minidom.parseString(datastring).childNodes[0]
+ return {node.nodeName: self._from_xml_node(node, plurals)}
+
+ def _from_xml_node(self, node, listnames):
+ """
+ Convert a minidom node to a simple Python type.
+
+ listnames is a collection of names of XML nodes whose subnodes should
+ be considered list items.
"""
- mimetype = 'application/xml'
- # TODO(gundlach): determine mimetype from request
-
- if mimetype == 'application/json':
- import json
- return json.dumps(data)
- elif mimetype == 'application/xml':
- metadata = self.metadata.get('application/xml', {})
- # We expect data to contain a single key which is the XML root.
- root_key = data.keys()[0]
- from xml.dom import minidom
- doc = minidom.Document()
- node = self._to_xml_node(doc, metadata, root_key, data[root_key])
- return node.toprettyxml(indent=' ')
+ if len(node.childNodes) == 1 and node.childNodes[0].nodeType == 3:
+ return node.childNodes[0].nodeValue
+ elif node.nodeName in listnames:
+ return [self._from_xml_node(n, listnames) for n in node.childNodes]
else:
- return repr(data)
+ result = dict()
+ for attr in node.attributes.keys():
+ result[attr] = node.attributes[attr].nodeValue
+ for child in node.childNodes:
+ if child.nodeType != node.TEXT_NODE:
+ result[child.nodeName] = self._from_xml_node(child, listnames)
+ return result
+
+ def _to_json(self, data):
+ return json.dumps(data)
+
+ def _to_xml(self, data):
+ metadata = self.metadata.get('application/xml', {})
+ # We expect data to contain a single key which is the XML root.
+ root_key = data.keys()[0]
+ doc = minidom.Document()
+ node = self._to_xml_node(doc, metadata, root_key, data[root_key])
+ return node.toprettyxml(indent=' ')
def _to_xml_node(self, doc, metadata, nodename, data):
+ """Recursive method to convert data members to XML nodes."""
result = doc.createElement(nodename)
if type(data) is list:
singular = metadata.get('plurals', {}).get(nodename, None)
@@ -262,7 +341,7 @@ class Serializer(object):
result.appendChild(node)
elif type(data) is dict:
attrs = metadata.get('attributes', {}).get(nodename, {})
- for k,v in data.items():
+ for k, v in data.items():
if k in attrs:
result.setAttribute(k, str(v))
else: