summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rwxr-xr-xbin/nova-api47
-rwxr-xr-xbin/nova-api-new45
-rwxr-xr-xbin/nova-manage119
-rw-r--r--doc/source/auth.rst8
-rw-r--r--nova/adminclient.py73
-rw-r--r--nova/api/__init__.py78
-rw-r--r--nova/api/cloud.py42
-rw-r--r--nova/api/cloudpipe/__init__.py69
-rw-r--r--nova/api/context.py46
-rw-r--r--nova/api/ec2/__init__.py228
-rw-r--r--nova/api/ec2/admin.py (renamed from nova/endpoint/admin.py)31
-rw-r--r--nova/api/ec2/apirequest.py131
-rw-r--r--nova/api/ec2/cloud.py (renamed from nova/endpoint/cloud.py)317
-rw-r--r--nova/api/ec2/images.py (renamed from nova/endpoint/images.py)8
-rw-r--r--nova/api/ec2/metadatarequesthandler.py73
-rw-r--r--nova/api/rackspace/__init__.py137
-rw-r--r--nova/api/rackspace/_id_translator.py2
-rw-r--r--nova/api/rackspace/auth.py101
-rw-r--r--nova/api/rackspace/backup_schedules.py39
-rw-r--r--nova/api/rackspace/context.py (renamed from nova/api/rackspace/base.py)25
-rw-r--r--nova/api/rackspace/faults.py62
-rw-r--r--nova/api/rackspace/flavors.py12
-rw-r--r--nova/api/rackspace/images.py18
-rw-r--r--nova/api/rackspace/ratelimiting/__init__.py122
-rw-r--r--nova/api/rackspace/ratelimiting/tests.py237
-rw-r--r--nova/api/rackspace/servers.py299
-rw-r--r--nova/api/rackspace/sharedipgroups.py4
-rw-r--r--nova/auth/dbdriver.py236
-rw-r--r--nova/auth/fakeldap.py5
-rw-r--r--nova/auth/ldapdriver.py94
-rw-r--r--nova/auth/manager.py133
-rw-r--r--nova/auth/rbac.py69
-rw-r--r--nova/cloudpipe/api.py59
-rw-r--r--nova/cloudpipe/pipelib.py14
-rw-r--r--nova/compute/manager.py10
-rw-r--r--nova/crypto.py2
-rw-r--r--nova/db/api.py215
-rw-r--r--nova/db/sqlalchemy/api.py1055
-rw-r--r--nova/db/sqlalchemy/models.py252
-rw-r--r--nova/endpoint/__init__.py0
-rwxr-xr-xnova/endpoint/api.py347
-rw-r--r--nova/flags.py5
-rw-r--r--nova/manager.py13
-rw-r--r--nova/network/linux_net.py121
-rw-r--r--nova/network/manager.py68
-rw-r--r--nova/objectstore/__init__.py2
-rw-r--r--nova/objectstore/handler.py25
-rw-r--r--nova/objectstore/image.py10
-rw-r--r--nova/quota.py9
-rw-r--r--nova/rpc.py62
-rw-r--r--nova/scheduler/driver.py3
-rw-r--r--nova/service.py37
-rw-r--r--nova/test.py54
-rw-r--r--nova/tests/access_unittest.py113
-rw-r--r--nova/tests/api/__init__.py33
-rw-r--r--nova/tests/api/rackspace/__init__.py108
-rw-r--r--nova/tests/api/rackspace/auth.py108
-rw-r--r--nova/tests/api/rackspace/flavors.py14
-rw-r--r--nova/tests/api/rackspace/images.py1
-rw-r--r--nova/tests/api/rackspace/servers.py215
-rw-r--r--nova/tests/api/rackspace/sharedipgroups.py1
-rw-r--r--nova/tests/api/rackspace/test_helper.py134
-rw-r--r--nova/tests/api/rackspace/testfaults.py40
-rw-r--r--nova/tests/api/test_helper.py1
-rw-r--r--nova/tests/api/wsgi_test.py57
-rw-r--r--nova/tests/api_unittest.py143
-rw-r--r--nova/tests/auth_unittest.py443
-rw-r--r--nova/tests/cloud_unittest.py141
-rw-r--r--nova/tests/compute_unittest.py6
-rw-r--r--nova/tests/fake_flags.py2
-rw-r--r--nova/tests/network_unittest.py54
-rw-r--r--nova/tests/objectstore_unittest.py8
-rw-r--r--nova/tests/quota_unittest.py39
-rw-r--r--nova/tests/rpc_unittest.py17
-rw-r--r--nova/tests/scheduler_unittest.py10
-rw-r--r--nova/tests/service_unittest.py12
-rw-r--r--nova/utils.py11
-rw-r--r--nova/virt/xenapi.py37
-rw-r--r--nova/volume/manager.py10
-rw-r--r--nova/wsgi.py84
-rw-r--r--pylintrc3
-rw-r--r--setup.py2
-rw-r--r--tools/install_venv.py4
-rw-r--r--tools/pip-requires5
-rwxr-xr-xtools/setup_iptables.sh158
85 files changed, 5310 insertions, 1977 deletions
diff --git a/bin/nova-api b/bin/nova-api
index ede09d38c..a5027700b 100755
--- a/bin/nova-api
+++ b/bin/nova-api
@@ -1,31 +1,28 @@
#!/usr/bin/env python
+# pylint: disable-msg=C0103
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
#
-# http://www.apache.org/licenses/LICENSE-2.0
+# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""
-Tornado daemon for the main API endpoint.
+Nova API daemon.
"""
-import logging
import os
import sys
-from tornado import httpserver
-from tornado import ioloop
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
@@ -36,28 +33,16 @@ if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import flags
-from nova import server
from nova import utils
-from nova.endpoint import admin
-from nova.endpoint import api
-from nova.endpoint import cloud
+from nova import server
FLAGS = flags.FLAGS
+flags.DEFINE_integer('api_port', 8773, 'API port')
-
-def main(_argv):
- """Load the controllers and start the tornado I/O loop."""
- controllers = {
- 'Cloud': cloud.CloudController(),
- 'Admin': admin.AdminController()}
- _app = api.APIServerApplication(controllers)
-
- io_inst = ioloop.IOLoop.instance()
- http_server = httpserver.HTTPServer(_app)
- http_server.listen(FLAGS.cc_port)
- logging.debug('Started HTTP server on %s', FLAGS.cc_port)
- io_inst.start()
-
+def main(_args):
+ from nova import api
+ from nova import wsgi
+ wsgi.run_server(api.API(), FLAGS.api_port)
if __name__ == '__main__':
utils.default_flagfile()
diff --git a/bin/nova-api-new b/bin/nova-api-new
deleted file mode 100755
index 8625c487f..000000000
--- a/bin/nova-api-new
+++ /dev/null
@@ -1,45 +0,0 @@
-#!/usr/bin/env python
-# pylint: disable-msg=C0103
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2010 United States Government as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Nova API daemon.
-"""
-
-import os
-import sys
-
-# If ../nova/__init__.py exists, add ../ to Python search path, so that
-# it will override what happens to be installed in /usr/(local/)lib/python...
-possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
- os.pardir,
- os.pardir))
-if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
- sys.path.insert(0, possible_topdir)
-
-from nova import api
-from nova import flags
-from nova import utils
-from nova import wsgi
-
-FLAGS = flags.FLAGS
-flags.DEFINE_integer('api_port', 8773, 'API port')
-
-if __name__ == '__main__':
- utils.default_flagfile()
- wsgi.run_server(api.API(), FLAGS.api_port)
diff --git a/bin/nova-manage b/bin/nova-manage
index 325245ac4..5b72c170f 100755
--- a/bin/nova-manage
+++ b/bin/nova-manage
@@ -50,9 +50,9 @@
"""
CLI interface for nova management.
- Connects to the running ADMIN api in the api daemon.
"""
+import logging
import os
import sys
import time
@@ -68,11 +68,12 @@ if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import db
+from nova import exception
from nova import flags
+from nova import quota
from nova import utils
from nova.auth import manager
from nova.cloudpipe import pipelib
-from nova.endpoint import cloud
FLAGS = flags.FLAGS
@@ -83,7 +84,7 @@ class VpnCommands(object):
def __init__(self):
self.manager = manager.AuthManager()
- self.pipe = pipelib.CloudPipe(cloud.CloudController())
+ self.pipe = pipelib.CloudPipe()
def list(self):
"""Print a listing of the VPNs for all projects."""
@@ -114,7 +115,7 @@ class VpnCommands(object):
def _vpn_for(self, project_id):
"""Get the VPN instance for a project ID."""
- for instance in db.instance_get_all():
+ for instance in db.instance_get_all(None):
if (instance['image_id'] == FLAGS.vpn_image_id
and not instance['state_description'] in
['shutting_down', 'shutdown']
@@ -135,15 +136,48 @@ class VpnCommands(object):
class ShellCommands(object):
- def run(self):
- "Runs a Python interactive interpreter. Tries to use IPython, if it's available."
- try:
- import IPython
- # Explicitly pass an empty list as arguments, because otherwise IPython
- # would use sys.argv from this script.
- shell = IPython.Shell.IPShell(argv=[])
- shell.mainloop()
- except ImportError:
+ def bpython(self):
+ """Runs a bpython shell.
+
+ Falls back to Ipython/python shell if unavailable"""
+ self.run('bpython')
+
+ def ipython(self):
+ """Runs an Ipython shell.
+
+ Falls back to Python shell if unavailable"""
+ self.run('ipython')
+
+ def python(self):
+ """Runs a python shell.
+
+ Falls back to Python shell if unavailable"""
+ self.run('python')
+
+ def run(self, shell=None):
+ """Runs a Python interactive interpreter.
+
+ args: [shell=bpython]"""
+ if not shell:
+ shell = 'bpython'
+
+ if shell == 'bpython':
+ try:
+ import bpython
+ bpython.embed()
+ except ImportError:
+ shell = 'ipython'
+ if shell == 'ipython':
+ try:
+ import IPython
+ # Explicitly pass an empty list as arguments, because otherwise IPython
+ # would use sys.argv from this script.
+ shell = IPython.Shell.IPShell(argv=[])
+ shell.mainloop()
+ except ImportError:
+ shell = 'python'
+
+ if shell == 'python':
import code
try: # Try activating rlcompleter, because it's handy.
import readline
@@ -156,6 +190,11 @@ class ShellCommands(object):
readline.parse_and_bind("tab:complete")
code.interact()
+ def script(self, path):
+ """Runs the script from the specifed path with flags set properly.
+ arguments: path"""
+ exec(compile(open(path).read(), path, 'exec'), locals(), globals())
+
class RoleCommands(object):
"""Class for managing roles."""
@@ -186,6 +225,13 @@ class RoleCommands(object):
class UserCommands(object):
"""Class for managing users."""
+ @staticmethod
+ def _print_export(user):
+ """Print export variables to use with API."""
+ print 'export EC2_ACCESS_KEY=%s' % user.access
+ print 'export EC2_SECRET_KEY=%s' % user.secret
+
+
def __init__(self):
self.manager = manager.AuthManager()
@@ -193,13 +239,13 @@ class UserCommands(object):
"""creates a new admin and prints exports
arguments: name [access] [secret]"""
user = self.manager.create_user(name, access, secret, True)
- print_export(user)
+ self._print_export(user)
def create(self, name, access=None, secret=None):
"""creates a new user and prints exports
arguments: name [access] [secret]"""
user = self.manager.create_user(name, access, secret, False)
- print_export(user)
+ self._print_export(user)
def delete(self, name):
"""deletes an existing user
@@ -211,7 +257,7 @@ class UserCommands(object):
arguments: name"""
user = self.manager.get_user(name)
if user:
- print_export(user)
+ self._print_export(user)
else:
print "User %s doesn't exist" % name
@@ -221,12 +267,18 @@ class UserCommands(object):
for user in self.manager.get_users():
print user.name
-
-def print_export(user):
- """Print export variables to use with API."""
- print 'export EC2_ACCESS_KEY=%s' % user.access
- print 'export EC2_SECRET_KEY=%s' % user.secret
-
+ def modify(self, name, access_key, secret_key, is_admin):
+ """update a users keys & admin flag
+ arguments: accesskey secretkey admin
+ leave any field blank to ignore it, admin should be 'T', 'F', or blank
+ """
+ if not is_admin:
+ is_admin = None
+ elif is_admin.upper()[0] == 'T':
+ is_admin = True
+ else:
+ is_admin = False
+ self.manager.modify_user(name, access_key, secret_key, is_admin)
class ProjectCommands(object):
"""Class for managing projects."""
@@ -252,7 +304,7 @@ class ProjectCommands(object):
def environment(self, project_id, user_id, filename='novarc'):
"""Exports environment variables to an sourcable file
arguments: project_id user_id [filename='novarc]"""
- rc = self.manager.get_environment_rc(project_id, user_id)
+ rc = self.manager.get_environment_rc(user_id, project_id)
with open(filename, 'w') as f:
f.write(rc)
@@ -262,6 +314,19 @@ class ProjectCommands(object):
for project in self.manager.get_projects():
print project.name
+ def quota(self, project_id, key=None, value=None):
+ """Set or display quotas for project
+ arguments: project_id [key] [value]"""
+ if key:
+ quo = {'project_id': project_id, key: value}
+ try:
+ db.quota_update(None, project_id, quo)
+ except exception.NotFound:
+ db.quota_create(None, quo)
+ project_quota = quota.get_quota(None, project_id)
+ for key, value in project_quota.iteritems():
+ print '%s: %s' % (key, value)
+
def remove(self, project, user):
"""Removes user from project
arguments: project user"""
@@ -274,6 +339,7 @@ class ProjectCommands(object):
with open(filename, 'w') as f:
f.write(zip_file)
+
class FloatingIpCommands(object):
"""Class for managing floating ip."""
@@ -301,11 +367,12 @@ class FloatingIpCommands(object):
for floating_ip in floating_ips:
instance = None
if floating_ip['fixed_ip']:
- instance = floating_ip['fixed_ip']['instance']['str_id']
+ instance = floating_ip['fixed_ip']['instance']['ec2_id']
print "%s\t%s\t%s" % (floating_ip['host'],
floating_ip['address'],
instance)
+
CATEGORIES = [
('user', UserCommands),
('project', ProjectCommands),
@@ -351,6 +418,10 @@ def main():
"""Parse options and call the appropriate class/method."""
utils.default_flagfile('/etc/nova/nova-manage.conf')
argv = FLAGS(sys.argv)
+
+ if FLAGS.verbose:
+ logging.getLogger().setLevel(logging.DEBUG)
+
script_name = argv.pop(0)
if len(argv) < 1:
print script_name + " category action [<args>]"
diff --git a/doc/source/auth.rst b/doc/source/auth.rst
index 70aca704a..3fcb309cd 100644
--- a/doc/source/auth.rst
+++ b/doc/source/auth.rst
@@ -172,14 +172,6 @@ Further Challenges
-The :mod:`rbac` Module
---------------------------
-
-.. automodule:: nova.auth.rbac
- :members:
- :undoc-members:
- :show-inheritance:
-
The :mod:`signer` Module
------------------------
diff --git a/nova/adminclient.py b/nova/adminclient.py
index 0ca32b1e5..fc9fcfde0 100644
--- a/nova/adminclient.py
+++ b/nova/adminclient.py
@@ -20,11 +20,17 @@ Nova User API client library.
"""
import base64
-
import boto
+import httplib
from boto.ec2.regioninfo import RegionInfo
+DEFAULT_CLC_URL='http://127.0.0.1:8773'
+DEFAULT_REGION='nova'
+DEFAULT_ACCESS_KEY='admin'
+DEFAULT_SECRET_KEY='admin'
+
+
class UserInfo(object):
"""
Information about a Nova user, as parsed through SAX
@@ -68,13 +74,13 @@ class UserRole(object):
def __init__(self, connection=None):
self.connection = connection
self.role = None
-
+
def __repr__(self):
return 'UserRole:%s' % self.role
def startElement(self, name, attrs, connection):
return None
-
+
def endElement(self, name, value, connection):
if name == 'role':
self.role = value
@@ -128,20 +134,20 @@ class ProjectMember(object):
def __init__(self, connection=None):
self.connection = connection
self.memberId = None
-
+
def __repr__(self):
return 'ProjectMember:%s' % self.memberId
def startElement(self, name, attrs, connection):
return None
-
+
def endElement(self, name, value, connection):
if name == 'member':
self.memberId = value
else:
setattr(self, name, str(value))
-
+
class HostInfo(object):
"""
Information about a Nova Host, as parsed through SAX:
@@ -171,35 +177,56 @@ class HostInfo(object):
class NovaAdminClient(object):
- def __init__(self, clc_ip='127.0.0.1', region='nova', access_key='admin',
- secret_key='admin', **kwargs):
- self.clc_ip = clc_ip
+ def __init__(self, clc_url=DEFAULT_CLC_URL, region=DEFAULT_REGION,
+ access_key=DEFAULT_ACCESS_KEY, secret_key=DEFAULT_SECRET_KEY,
+ **kwargs):
+ parts = self.split_clc_url(clc_url)
+
+ self.clc_url = clc_url
self.region = region
self.access = access_key
self.secret = secret_key
self.apiconn = boto.connect_ec2(aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
- is_secure=False,
- region=RegionInfo(None, region, clc_ip),
- port=8773,
+ is_secure=parts['is_secure'],
+ region=RegionInfo(None,
+ region,
+ parts['ip']),
+ port=parts['port'],
path='/services/Admin',
**kwargs)
self.apiconn.APIVersion = 'nova'
- def connection_for(self, username, project, **kwargs):
+ def connection_for(self, username, project, clc_url=None, region=None,
+ **kwargs):
"""
Returns a boto ec2 connection for the given username.
"""
+ if not clc_url:
+ clc_url = self.clc_url
+ if not region:
+ region = self.region
+ parts = self.split_clc_url(clc_url)
user = self.get_user(username)
access_key = '%s:%s' % (user.accesskey, project)
- return boto.connect_ec2(
- aws_access_key_id=access_key,
- aws_secret_access_key=user.secretkey,
- is_secure=False,
- region=RegionInfo(None, self.region, self.clc_ip),
- port=8773,
- path='/services/Cloud',
- **kwargs)
+ return boto.connect_ec2(aws_access_key_id=access_key,
+ aws_secret_access_key=user.secretkey,
+ is_secure=parts['is_secure'],
+ region=RegionInfo(None,
+ self.region,
+ parts['ip']),
+ port=parts['port'],
+ path='/services/Cloud',
+ **kwargs)
+
+ def split_clc_url(self, clc_url):
+ """
+ Splits a cloud controller endpoint url.
+ """
+ parts = httplib.urlsplit(clc_url)
+ is_secure = parts.scheme == 'https'
+ ip, port = parts.netloc.split(':')
+ return {'ip': ip, 'port': int(port), 'is_secure': is_secure}
def get_users(self):
""" grabs the list of all users """
@@ -289,7 +316,7 @@ class NovaAdminClient(object):
if project.projectname != None:
return project
-
+
def create_project(self, projectname, manager_user, description=None,
member_users=None):
"""
@@ -322,7 +349,7 @@ class NovaAdminClient(object):
Adds a user to a project.
"""
return self.modify_project_member(user, project, operation='add')
-
+
def remove_project_member(self, user, project):
"""
Removes a user from a project.
diff --git a/nova/api/__init__.py b/nova/api/__init__.py
index b9b9e3988..744abd621 100644
--- a/nova/api/__init__.py
+++ b/nova/api/__init__.py
@@ -21,17 +21,91 @@ Root WSGI middleware for all API controllers.
"""
import routes
+import webob.dec
+from nova import flags
from nova import wsgi
+from nova.api import cloudpipe
from nova.api import ec2
from nova.api import rackspace
+from nova.api.ec2 import metadatarequesthandler
+
+
+flags.DEFINE_string('rsapi_subdomain', 'rs',
+ 'subdomain running the RS API')
+flags.DEFINE_string('ec2api_subdomain', 'ec2',
+ 'subdomain running the EC2 API')
+flags.DEFINE_string('FAKE_subdomain', None,
+ 'set to rs or ec2 to fake the subdomain of the host for testing')
+FLAGS = flags.FLAGS
class API(wsgi.Router):
"""Routes top-level requests to the appropriate controller."""
def __init__(self):
+ rsdomain = {'sub_domain': [FLAGS.rsapi_subdomain]}
+ ec2domain = {'sub_domain': [FLAGS.ec2api_subdomain]}
+ # If someone wants to pretend they're hitting the RS subdomain
+ # on their local box, they can set FAKE_subdomain to 'rs', which
+ # removes subdomain restrictions from the RS routes below.
+ if FLAGS.FAKE_subdomain == 'rs':
+ rsdomain = {}
+ elif FLAGS.FAKE_subdomain == 'ec2':
+ ec2domain = {}
mapper = routes.Mapper()
- mapper.connect("/v1.0/{path_info:.*}", controller=rackspace.API())
- mapper.connect("/ec2/{path_info:.*}", controller=ec2.API())
+ mapper.sub_domains = True
+ mapper.connect("/", controller=self.rsapi_versions,
+ conditions=rsdomain)
+ mapper.connect("/v1.0/{path_info:.*}", controller=rackspace.API(),
+ conditions=rsdomain)
+
+ mapper.connect("/", controller=self.ec2api_versions,
+ conditions=ec2domain)
+ mapper.connect("/services/{path_info:.*}", controller=ec2.API(),
+ conditions=ec2domain)
+ mapper.connect("/cloudpipe/{path_info:.*}", controller=cloudpipe.API())
+ mrh = metadatarequesthandler.MetadataRequestHandler()
+ for s in ['/latest',
+ '/2009-04-04',
+ '/2008-09-01',
+ '/2008-02-01',
+ '/2007-12-15',
+ '/2007-10-10',
+ '/2007-08-29',
+ '/2007-03-01',
+ '/2007-01-19',
+ '/1.0']:
+ mapper.connect('%s/{path_info:.*}' % s, controller=mrh,
+ conditions=ec2domain)
super(API, self).__init__(mapper)
+
+ @webob.dec.wsgify
+ def rsapi_versions(self, req):
+ """Respond to a request for all OpenStack API versions."""
+ response = {
+ "versions": [
+ dict(status="CURRENT", id="v1.0")]}
+ metadata = {
+ "application/xml": {
+ "attributes": dict(version=["status", "id"])}}
+ return wsgi.Serializer(req.environ, metadata).to_content_type(response)
+
+ @webob.dec.wsgify
+ def ec2api_versions(self, req):
+ """Respond to a request for all EC2 versions."""
+ # available api versions
+ versions = [
+ '1.0',
+ '2007-01-19',
+ '2007-03-01',
+ '2007-08-29',
+ '2007-10-10',
+ '2007-12-15',
+ '2008-02-01',
+ '2008-09-01',
+ '2009-04-04',
+ ]
+ return ''.join('%s\n' % v for v in versions)
+
+
diff --git a/nova/api/cloud.py b/nova/api/cloud.py
new file mode 100644
index 000000000..345677d4f
--- /dev/null
+++ b/nova/api/cloud.py
@@ -0,0 +1,42 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Methods for API calls to control instances via AMQP.
+"""
+
+
+from nova import db
+from nova import flags
+from nova import rpc
+
+FLAGS = flags.FLAGS
+
+
+def reboot(instance_id, context=None):
+ """Reboot the given instance.
+
+ #TODO(gundlach) not actually sure what context is used for by ec2 here
+ -- I think we can just remove it and use None all the time.
+ """
+ instance_ref = db.instance_get_by_ec2_id(None, instance_id)
+ host = instance_ref['host']
+ rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host),
+ {"method": "reboot_instance",
+ "args": {"context": None,
+ "instance_id": instance_ref['id']}})
diff --git a/nova/api/cloudpipe/__init__.py b/nova/api/cloudpipe/__init__.py
new file mode 100644
index 000000000..6d40990a8
--- /dev/null
+++ b/nova/api/cloudpipe/__init__.py
@@ -0,0 +1,69 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+REST API Request Handlers for CloudPipe
+"""
+
+import logging
+import urllib
+import webob
+import webob.dec
+import webob.exc
+
+from nova import crypto
+from nova import wsgi
+from nova.auth import manager
+from nova.api.ec2 import cloud
+
+
+_log = logging.getLogger("api")
+_log.setLevel(logging.DEBUG)
+
+
+class API(wsgi.Application):
+
+ def __init__(self):
+ self.controller = cloud.CloudController()
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ if req.method == 'POST':
+ return self.sign_csr(req)
+ _log.debug("Cloudpipe path is %s" % req.path_info)
+ if req.path_info.endswith("/getca/"):
+ return self.send_root_ca(req)
+ return webob.exc.HTTPNotFound()
+
+ def get_project_id_from_ip(self, ip):
+ # TODO(eday): This was removed with the ORM branch, fix!
+ instance = self.controller.get_instance_by_ip(ip)
+ return instance['project_id']
+
+ def send_root_ca(self, req):
+ _log.debug("Getting root ca")
+ project_id = self.get_project_id_from_ip(req.remote_addr)
+ res = webob.Response()
+ res.headers["Content-Type"] = "text/plain"
+ res.body = crypto.fetch_ca(project_id)
+ return res
+
+ def sign_csr(self, req):
+ project_id = self.get_project_id_from_ip(req.remote_addr)
+ cert = self.str_params['cert']
+ return crypto.sign_csr(urllib.unquote(cert), project_id)
diff --git a/nova/api/context.py b/nova/api/context.py
new file mode 100644
index 000000000..b66cfe468
--- /dev/null
+++ b/nova/api/context.py
@@ -0,0 +1,46 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+APIRequestContext
+"""
+
+import random
+
+
+class APIRequestContext(object):
+ def __init__(self, user, project):
+ self.user = user
+ self.project = project
+ self.request_id = ''.join(
+ [random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-')
+ for x in xrange(20)]
+ )
+ if user:
+ self.is_admin = user.is_admin()
+ else:
+ self.is_admin = False
+ self.read_deleted = False
+
+
+def get_admin_context(user=None, read_deleted=False):
+ context_ref = APIRequestContext(user=user, project=None)
+ context_ref.is_admin = True
+ context_ref.read_deleted = read_deleted
+ return context_ref
+
diff --git a/nova/api/ec2/__init__.py b/nova/api/ec2/__init__.py
index 6eec0abf7..6b538a7f1 100644
--- a/nova/api/ec2/__init__.py
+++ b/nova/api/ec2/__init__.py
@@ -1,6 +1,7 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
-# Copyright 2010 OpenStack LLC.
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
@@ -15,28 +16,225 @@
# License for the specific language governing permissions and limitations
# under the License.
-"""
-WSGI middleware for EC2 API controllers.
-"""
+"""Starting point for routing EC2 requests"""
+import logging
import routes
+import webob
import webob.dec
+import webob.exc
+from nova import exception
+from nova import flags
from nova import wsgi
+from nova.api import context
+from nova.api.ec2 import apirequest
+from nova.api.ec2 import admin
+from nova.api.ec2 import cloud
+from nova.auth import manager
-class API(wsgi.Router):
- """Routes EC2 requests to the appropriate controller."""
+FLAGS = flags.FLAGS
+_log = logging.getLogger("api")
+_log.setLevel(logging.DEBUG)
+
+
+class API(wsgi.Middleware):
+
+ """Routing for all EC2 API requests."""
def __init__(self):
- mapper = routes.Mapper()
- mapper.connect(None, "{all:.*}", controller=self.dummy)
- super(API, self).__init__(mapper)
+ self.application = Authenticate(Router(Authorizer(Executor())))
+
+
+class Authenticate(wsgi.Middleware):
+
+ """Authenticate an EC2 request and add 'ec2.context' to WSGI environ."""
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ # Read request signature and access id.
+ try:
+ signature = req.params['Signature']
+ access = req.params['AWSAccessKeyId']
+ except:
+ raise webob.exc.HTTPBadRequest()
+
+ # Make a copy of args for authentication and signature verification.
+ auth_params = dict(req.params)
+ auth_params.pop('Signature') # not part of authentication args
+
+ # Authenticate the request.
+ try:
+ (user, project) = manager.AuthManager().authenticate(
+ access,
+ signature,
+ auth_params,
+ req.method,
+ req.host,
+ req.path)
+ except exception.Error, ex:
+ logging.debug("Authentication Failure: %s" % ex)
+ raise webob.exc.HTTPForbidden()
+
+ # Authenticated!
+ req.environ['ec2.context'] = context.APIRequestContext(user, project)
+ return self.application
+
+
+class Router(wsgi.Middleware):
+
+ """Add ec2.'controller', .'action', and .'action_args' to WSGI environ."""
+
+ def __init__(self, application):
+ super(Router, self).__init__(application)
+ self.map = routes.Mapper()
+ self.map.connect("/{controller_name}/")
+ self.controllers = dict(Cloud=cloud.CloudController(),
+ Admin=admin.AdminController())
- @staticmethod
@webob.dec.wsgify
- def dummy(req):
- """Temporary dummy controller."""
- msg = "dummy response -- please hook up __init__() to cloud.py instead"
- return repr({'dummy': msg,
- 'kwargs': repr(req.environ['wsgiorg.routing_args'][1])})
+ def __call__(self, req):
+ # Obtain the appropriate controller and action for this request.
+ try:
+ match = self.map.match(req.path_info)
+ controller_name = match['controller_name']
+ controller = self.controllers[controller_name]
+ except:
+ raise webob.exc.HTTPNotFound()
+ non_args = ['Action', 'Signature', 'AWSAccessKeyId', 'SignatureMethod',
+ 'SignatureVersion', 'Version', 'Timestamp']
+ args = dict(req.params)
+ try:
+ action = req.params['Action'] # raise KeyError if omitted
+ for non_arg in non_args:
+ args.pop(non_arg) # remove, but raise KeyError if omitted
+ except:
+ raise webob.exc.HTTPBadRequest()
+
+ _log.debug('action: %s' % action)
+ for key, value in args.items():
+ _log.debug('arg: %s\t\tval: %s' % (key, value))
+
+ # Success!
+ req.environ['ec2.controller'] = controller
+ req.environ['ec2.action'] = action
+ req.environ['ec2.action_args'] = args
+ return self.application
+
+
+class Authorizer(wsgi.Middleware):
+
+ """Authorize an EC2 API request.
+
+ Return a 401 if ec2.controller and ec2.action in WSGI environ may not be
+ executed in ec2.context.
+ """
+
+ def __init__(self, application):
+ super(Authorizer, self).__init__(application)
+ self.action_roles = {
+ 'CloudController': {
+ 'DescribeAvailabilityzones': ['all'],
+ 'DescribeRegions': ['all'],
+ 'DescribeSnapshots': ['all'],
+ 'DescribeKeyPairs': ['all'],
+ 'CreateKeyPair': ['all'],
+ 'DeleteKeyPair': ['all'],
+ 'DescribeSecurityGroups': ['all'],
+ 'CreateSecurityGroup': ['netadmin'],
+ 'DeleteSecurityGroup': ['netadmin'],
+ 'GetConsoleOutput': ['projectmanager', 'sysadmin'],
+ 'DescribeVolumes': ['projectmanager', 'sysadmin'],
+ 'CreateVolume': ['projectmanager', 'sysadmin'],
+ 'AttachVolume': ['projectmanager', 'sysadmin'],
+ 'DetachVolume': ['projectmanager', 'sysadmin'],
+ 'DescribeInstances': ['all'],
+ 'DescribeAddresses': ['all'],
+ 'AllocateAddress': ['netadmin'],
+ 'ReleaseAddress': ['netadmin'],
+ 'AssociateAddress': ['netadmin'],
+ 'DisassociateAddress': ['netadmin'],
+ 'RunInstances': ['projectmanager', 'sysadmin'],
+ 'TerminateInstances': ['projectmanager', 'sysadmin'],
+ 'RebootInstances': ['projectmanager', 'sysadmin'],
+ 'UpdateInstance': ['projectmanager', 'sysadmin'],
+ 'DeleteVolume': ['projectmanager', 'sysadmin'],
+ 'DescribeImages': ['all'],
+ 'DeregisterImage': ['projectmanager', 'sysadmin'],
+ 'RegisterImage': ['projectmanager', 'sysadmin'],
+ 'DescribeImageAttribute': ['all'],
+ 'ModifyImageAttribute': ['projectmanager', 'sysadmin'],
+ 'UpdateImage': ['projectmanager', 'sysadmin'],
+ },
+ 'AdminController': {
+ # All actions have the same permission: ['none'] (the default)
+ # superusers will be allowed to run them
+ # all others will get HTTPUnauthorized.
+ },
+ }
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ context = req.environ['ec2.context']
+ controller_name = req.environ['ec2.controller'].__class__.__name__
+ action = req.environ['ec2.action']
+ allowed_roles = self.action_roles[controller_name].get(action, ['none'])
+ if self._matches_any_role(context, allowed_roles):
+ return self.application
+ else:
+ raise webob.exc.HTTPUnauthorized()
+
+ def _matches_any_role(self, context, roles):
+ """Return True if any role in roles is allowed in context."""
+ if context.user.is_superuser():
+ return True
+ if 'all' in roles:
+ return True
+ if 'none' in roles:
+ return False
+ return any(context.project.has_role(context.user.id, role)
+ for role in roles)
+
+
+class Executor(wsgi.Application):
+
+ """Execute an EC2 API request.
+
+ Executes 'ec2.action' upon 'ec2.controller', passing 'ec2.context' and
+ 'ec2.action_args' (all variables in WSGI environ.) Returns an XML
+ response, or a 400 upon failure.
+ """
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ context = req.environ['ec2.context']
+ controller = req.environ['ec2.controller']
+ action = req.environ['ec2.action']
+ args = req.environ['ec2.action_args']
+
+ api_request = apirequest.APIRequest(controller, action)
+ try:
+ result = api_request.send(context, **args)
+ req.headers['Content-Type'] = 'text/xml'
+ return result
+ except exception.ApiError as ex:
+
+ if ex.code:
+ return self._error(req, ex.code, ex.message)
+ else:
+ return self._error(req, type(ex).__name__, ex.message)
+ # TODO(vish): do something more useful with unknown exceptions
+ except Exception as ex:
+ return self._error(req, type(ex).__name__, str(ex))
+
+ def _error(self, req, code, message):
+ resp = webob.Response()
+ resp.status = 400
+ resp.headers['Content-Type'] = 'text/xml'
+ resp.body = ('<?xml version="1.0"?>\n'
+ '<Response><Errors><Error><Code>%s</Code>'
+ '<Message>%s</Message></Error></Errors>'
+ '<RequestID>?</RequestID></Response>') % (code, message)
+ return resp
+
diff --git a/nova/endpoint/admin.py b/nova/api/ec2/admin.py
index c6dcb5320..36feae451 100644
--- a/nova/endpoint/admin.py
+++ b/nova/api/ec2/admin.py
@@ -58,46 +58,27 @@ def host_dict(host):
return {}
-def admin_only(target):
- """Decorator for admin-only API calls"""
- def wrapper(*args, **kwargs):
- """Internal wrapper method for admin-only API calls"""
- context = args[1]
- if context.user.is_admin():
- return target(*args, **kwargs)
- else:
- return {}
-
- return wrapper
-
-
class AdminController(object):
"""
API Controller for users, hosts, nodes, and workers.
- Trivial admin_only wrapper will be replaced with RBAC,
- allowing project managers to administer project users.
"""
def __str__(self):
return 'AdminController'
- @admin_only
def describe_user(self, _context, name, **_kwargs):
"""Returns user data, including access and secret keys."""
return user_dict(manager.AuthManager().get_user(name))
- @admin_only
def describe_users(self, _context, **_kwargs):
"""Returns all users - should be changed to deal with a list."""
return {'userSet':
[user_dict(u) for u in manager.AuthManager().get_users()] }
- @admin_only
def register_user(self, _context, name, **_kwargs):
"""Creates a new user, and returns generated credentials."""
return user_dict(manager.AuthManager().create_user(name))
- @admin_only
def deregister_user(self, _context, name, **_kwargs):
"""Deletes a single user (NOT undoable.)
Should throw an exception if the user has instances,
@@ -107,13 +88,11 @@ class AdminController(object):
return True
- @admin_only
def describe_roles(self, context, project_roles=True, **kwargs):
"""Returns a list of allowed roles."""
roles = manager.AuthManager().get_roles(project_roles)
return { 'roles': [{'role': r} for r in roles]}
- @admin_only
def describe_user_roles(self, context, user, project=None, **kwargs):
"""Returns a list of roles for the given user.
Omitting project will return any global roles that the user has.
@@ -122,7 +101,6 @@ class AdminController(object):
roles = manager.AuthManager().get_user_roles(user, project=project)
return { 'roles': [{'role': r} for r in roles]}
- @admin_only
def modify_user_role(self, context, user, role, project=None,
operation='add', **kwargs):
"""Add or remove a role for a user and project."""
@@ -135,7 +113,6 @@ class AdminController(object):
return True
- @admin_only
def generate_x509_for_user(self, _context, name, project=None, **kwargs):
"""Generates and returns an x509 certificate for a single user.
Is usually called from a client that will wrap this with
@@ -147,19 +124,16 @@ class AdminController(object):
user = manager.AuthManager().get_user(name)
return user_dict(user, base64.b64encode(project.get_credentials(user)))
- @admin_only
def describe_project(self, context, name, **kwargs):
"""Returns project data, including member ids."""
return project_dict(manager.AuthManager().get_project(name))
- @admin_only
def describe_projects(self, context, user=None, **kwargs):
"""Returns all projects - should be changed to deal with a list."""
return {'projectSet':
[project_dict(u) for u in
manager.AuthManager().get_projects(user=user)]}
- @admin_only
def register_project(self, context, name, manager_user, description=None,
member_users=None, **kwargs):
"""Creates a new project"""
@@ -170,20 +144,17 @@ class AdminController(object):
description=None,
member_users=None))
- @admin_only
def deregister_project(self, context, name):
"""Permanently deletes a project."""
manager.AuthManager().delete_project(name)
return True
- @admin_only
def describe_project_members(self, context, name, **kwargs):
project = manager.AuthManager().get_project(name)
result = {
'members': [{'member': m} for m in project.member_ids]}
return result
- @admin_only
def modify_project_member(self, context, user, project, operation, **kwargs):
"""Add or remove a user from a project."""
if operation =='add':
@@ -196,7 +167,6 @@ class AdminController(object):
# FIXME(vish): these host commands don't work yet, perhaps some of the
# required data can be retrieved from service objects?
- @admin_only
def describe_hosts(self, _context, **_kwargs):
"""Returns status info for all nodes. Includes:
* Disk Space
@@ -208,7 +178,6 @@ class AdminController(object):
"""
return {'hostSet': [host_dict(h) for h in db.host_get_all()]}
- @admin_only
def describe_host(self, _context, name, **_kwargs):
"""Returns status info for single node."""
return host_dict(db.host_get(name))
diff --git a/nova/api/ec2/apirequest.py b/nova/api/ec2/apirequest.py
new file mode 100644
index 000000000..a87c21fb3
--- /dev/null
+++ b/nova/api/ec2/apirequest.py
@@ -0,0 +1,131 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+APIRequest class
+"""
+
+import logging
+import re
+# TODO(termie): replace minidom with etree
+from xml.dom import minidom
+
+_log = logging.getLogger("api")
+_log.setLevel(logging.DEBUG)
+
+
+_c2u = re.compile('(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))')
+
+
+def _camelcase_to_underscore(str):
+ return _c2u.sub(r'_\1', str).lower().strip('_')
+
+
+def _underscore_to_camelcase(str):
+ return ''.join([x[:1].upper() + x[1:] for x in str.split('_')])
+
+
+def _underscore_to_xmlcase(str):
+ res = _underscore_to_camelcase(str)
+ return res[:1].lower() + res[1:]
+
+
+class APIRequest(object):
+ def __init__(self, controller, action):
+ self.controller = controller
+ self.action = action
+
+ def send(self, context, **kwargs):
+ try:
+ method = getattr(self.controller,
+ _camelcase_to_underscore(self.action))
+ except AttributeError:
+ _error = ('Unsupported API request: controller = %s,'
+ 'action = %s') % (self.controller, self.action)
+ _log.warning(_error)
+ # TODO: Raise custom exception, trap in apiserver,
+ # and reraise as 400 error.
+ raise Exception(_error)
+
+ args = {}
+ for key, value in kwargs.items():
+ parts = key.split(".")
+ key = _camelcase_to_underscore(parts[0])
+ if len(parts) > 1:
+ d = args.get(key, {})
+ d[parts[1]] = value
+ value = d
+ args[key] = value
+
+ for key in args.keys():
+ if isinstance(args[key], dict):
+ if args[key] != {} and args[key].keys()[0].isdigit():
+ s = args[key].items()
+ s.sort()
+ args[key] = [v for k, v in s]
+
+ result = method(context, **args)
+ return self._render_response(result, context.request_id)
+
+ def _render_response(self, response_data, request_id):
+ xml = minidom.Document()
+
+ response_el = xml.createElement(self.action + 'Response')
+ response_el.setAttribute('xmlns',
+ 'http://ec2.amazonaws.com/doc/2009-11-30/')
+ request_id_el = xml.createElement('requestId')
+ request_id_el.appendChild(xml.createTextNode(request_id))
+ response_el.appendChild(request_id_el)
+ if(response_data == True):
+ self._render_dict(xml, response_el, {'return': 'true'})
+ else:
+ self._render_dict(xml, response_el, response_data)
+
+ xml.appendChild(response_el)
+
+ response = xml.toxml()
+ xml.unlink()
+ _log.debug(response)
+ return response
+
+ def _render_dict(self, xml, el, data):
+ try:
+ for key in data.keys():
+ val = data[key]
+ el.appendChild(self._render_data(xml, key, val))
+ except:
+ _log.debug(data)
+ raise
+
+ def _render_data(self, xml, el_name, data):
+ el_name = _underscore_to_xmlcase(el_name)
+ data_el = xml.createElement(el_name)
+
+ if isinstance(data, list):
+ for item in data:
+ data_el.appendChild(self._render_data(xml, 'item', item))
+ elif isinstance(data, dict):
+ self._render_dict(xml, data_el, data)
+ elif hasattr(data, '__dict__'):
+ self._render_dict(xml, data_el, data.__dict__)
+ elif isinstance(data, bool):
+ data_el.appendChild(xml.createTextNode(str(data).lower()))
+ elif data != None:
+ data_el.appendChild(xml.createTextNode(str(data)))
+
+ return data_el
diff --git a/nova/endpoint/cloud.py b/nova/api/ec2/cloud.py
index 339a67b0e..1f01731ae 100644
--- a/nova/endpoint/cloud.py
+++ b/nova/api/ec2/cloud.py
@@ -28,18 +28,16 @@ import logging
import os
import time
-from twisted.internet import defer
-
+from nova import crypto
from nova import db
from nova import exception
from nova import flags
from nova import quota
from nova import rpc
from nova import utils
-from nova.auth import rbac
-from nova.auth import manager
from nova.compute.instance_types import INSTANCE_TYPES
-from nova.endpoint import images
+from nova.api import cloud
+from nova.api.ec2 import images
FLAGS = flags.FLAGS
@@ -51,13 +49,26 @@ class QuotaError(exception.ApiError):
pass
-def _gen_key(user_id, key_name):
- """ Tuck this into AuthManager """
+def _gen_key(context, user_id, key_name):
+ """Generate a key
+
+ This is a module level method because it is slow and we need to defer
+ it into a process pool."""
+ # NOTE(vish): generating key pair is slow so check for legal
+ # creation before creating key_pair
try:
- mgr = manager.AuthManager()
- private_key, fingerprint = mgr.generate_key_pair(user_id, key_name)
- except Exception as ex:
- return {'exception': ex}
+ db.key_pair_get(context, user_id, key_name)
+ raise exception.Duplicate("The key_pair %s already exists"
+ % key_name)
+ except exception.NotFound:
+ pass
+ private_key, public_key, fingerprint = crypto.generate_key_pair()
+ key = {}
+ key['user_id'] = user_id
+ key['name'] = key_name
+ key['public_key'] = public_key
+ key['fingerprint'] = fingerprint
+ db.key_pair_create(context, key)
return {'private_key': private_key, 'fingerprint': fingerprint}
@@ -91,14 +102,15 @@ class CloudController(object):
def _get_mpi_data(self, project_id):
result = {}
- for instance in db.instance_get_by_project(None, project_id):
+ for instance in db.instance_get_all_by_project(None, project_id):
if instance['fixed_ip']:
- line = '%s slots=%d' % (instance['fixed_ip']['str_id'],
+ line = '%s slots=%d' % (instance['fixed_ip']['address'],
INSTANCE_TYPES[instance['instance_type']]['vcpus'])
- if instance['key_name'] in result:
- result[instance['key_name']].append(line)
+ key = str(instance['key_name'])
+ if key in result:
+ result[key].append(line)
else:
- result[instance['key_name']] = [line]
+ result[key] = [line]
return result
def get_metadata(self, address):
@@ -132,7 +144,7 @@ class CloudController(object):
},
'hostname': hostname,
'instance-action': 'none',
- 'instance-id': instance_ref['str_id'],
+ 'instance-id': instance_ref['ec2_id'],
'instance-type': instance_ref['instance_type'],
'local-hostname': hostname,
'local-ipv4': address,
@@ -155,18 +167,24 @@ class CloudController(object):
data['product-codes'] = []
return data
- @rbac.allow('all')
def describe_availability_zones(self, context, **kwargs):
return {'availabilityZoneInfo': [{'zoneName': 'nova',
'zoneState': 'available'}]}
- @rbac.allow('all')
def describe_regions(self, context, region_name=None, **kwargs):
- # TODO(vish): region_name is an array. Support filtering
- return {'regionInfo': [{'regionName': 'nova',
- 'regionUrl': FLAGS.ec2_url}]}
+ if FLAGS.region_list:
+ regions = []
+ for region in FLAGS.region_list:
+ name, _sep, url = region.partition('=')
+ regions.append({'regionName': name,
+ 'regionEndpoint': url})
+ else:
+ regions = [{'regionName': 'nova',
+ 'regionEndpoint': FLAGS.ec2_url}]
+ if region_name:
+ regions = [r for r in regions if r['regionName'] in region_name]
+ return {'regionInfo': regions }
- @rbac.allow('all')
def describe_snapshots(self,
context,
snapshot_id=None,
@@ -182,64 +200,53 @@ class CloudController(object):
'volumeSize': 0,
'description': 'fixme'}]}
- @rbac.allow('all')
def describe_key_pairs(self, context, key_name=None, **kwargs):
- key_pairs = context.user.get_key_pairs()
+ key_pairs = db.key_pair_get_all_by_user(context, context.user.id)
if not key_name is None:
- key_pairs = [x for x in key_pairs if x.name in key_name]
+ key_pairs = [x for x in key_pairs if x['name'] in key_name]
result = []
for key_pair in key_pairs:
# filter out the vpn keys
suffix = FLAGS.vpn_key_suffix
- if context.user.is_admin() or not key_pair.name.endswith(suffix):
+ if context.user.is_admin() or not key_pair['name'].endswith(suffix):
result.append({
- 'keyName': key_pair.name,
- 'keyFingerprint': key_pair.fingerprint,
+ 'keyName': key_pair['name'],
+ 'keyFingerprint': key_pair['fingerprint'],
})
return {'keypairsSet': result}
- @rbac.allow('all')
def create_key_pair(self, context, key_name, **kwargs):
- dcall = defer.Deferred()
- pool = context.handler.application.settings.get('pool')
- def _complete(kwargs):
- if 'exception' in kwargs:
- dcall.errback(kwargs['exception'])
- return
- dcall.callback({'keyName': key_name,
- 'keyFingerprint': kwargs['fingerprint'],
- 'keyMaterial': kwargs['private_key']})
- pool.apply_async(_gen_key, [context.user.id, key_name],
- callback=_complete)
- return dcall
-
- @rbac.allow('all')
+ data = _gen_key(None, context.user.id, key_name)
+ return {'keyName': key_name,
+ 'keyFingerprint': data['fingerprint'],
+ 'keyMaterial': data['private_key']}
+ # TODO(vish): when context is no longer an object, pass it here
+
def delete_key_pair(self, context, key_name, **kwargs):
- context.user.delete_key_pair(key_name)
- # aws returns true even if the key doens't exist
+ try:
+ db.key_pair_destroy(context, context.user.id, key_name)
+ except exception.NotFound:
+ # aws returns true even if the key doesn't exist
+ pass
return True
- @rbac.allow('all')
def describe_security_groups(self, context, group_names, **kwargs):
groups = {'securityGroupSet': []}
# Stubbed for now to unblock other things.
return groups
- @rbac.allow('netadmin')
def create_security_group(self, context, group_name, **kwargs):
return True
- @rbac.allow('netadmin')
def delete_security_group(self, context, group_name, **kwargs):
return True
- @rbac.allow('projectmanager', 'sysadmin')
def get_console_output(self, context, instance_id, **kwargs):
# instance_id is passed in as a list of instances
- instance_ref = db.instance_get_by_str(context, instance_id[0])
+ instance_ref = db.instance_get_by_ec2_id(context, instance_id[0])
d = rpc.call('%s.%s' % (FLAGS.compute_topic,
instance_ref['host']),
{ "method" : "get_console_output",
@@ -251,12 +258,11 @@ class CloudController(object):
"output": base64.b64encode(output)})
return d
- @rbac.allow('projectmanager', 'sysadmin')
def describe_volumes(self, context, **kwargs):
if context.user.is_admin():
volumes = db.volume_get_all(context)
else:
- volumes = db.volume_get_by_project(context, context.project.id)
+ volumes = db.volume_get_all_by_project(context, context.project.id)
volumes = [self._format_volume(context, v) for v in volumes]
@@ -264,7 +270,7 @@ class CloudController(object):
def _format_volume(self, context, volume):
v = {}
- v['volumeId'] = volume['str_id']
+ v['volumeId'] = volume['ec2_id']
v['status'] = volume['status']
v['size'] = volume['size']
v['availabilityZone'] = volume['availability_zone']
@@ -282,15 +288,16 @@ class CloudController(object):
'device': volume['mountpoint'],
'instanceId': volume['instance_id'],
'status': 'attached',
- 'volume_id': volume['str_id']}]
+ 'volume_id': volume['ec2_id']}]
else:
v['attachmentSet'] = [{}]
+
+ v['display_name'] = volume['display_name']
+ v['display_description'] = volume['display_description']
return v
- @rbac.allow('projectmanager', 'sysadmin')
def create_volume(self, context, size, **kwargs):
# check quota
- size = int(size)
if quota.allowed_volumes(context, 1, size) < 1:
logging.warn("Quota exceeeded for %s, tried to create %sG volume",
context.project.id, size)
@@ -304,6 +311,8 @@ class CloudController(object):
vol['availability_zone'] = FLAGS.storage_availability_zone
vol['status'] = "creating"
vol['attach_status'] = "detached"
+ vol['display_name'] = kwargs.get('display_name')
+ vol['display_description'] = kwargs.get('display_description')
volume_ref = db.volume_create(context, vol)
rpc.cast(FLAGS.scheduler_topic,
@@ -315,15 +324,14 @@ class CloudController(object):
return {'volumeSet': [self._format_volume(context, volume_ref)]}
- @rbac.allow('projectmanager', 'sysadmin')
def attach_volume(self, context, volume_id, instance_id, device, **kwargs):
- volume_ref = db.volume_get_by_str(context, volume_id)
+ volume_ref = db.volume_get_by_ec2_id(context, volume_id)
# TODO(vish): abstract status checking?
if volume_ref['status'] != "available":
raise exception.ApiError("Volume status must be available")
if volume_ref['attach_status'] == "attached":
raise exception.ApiError("Volume is already attached")
- instance_ref = db.instance_get_by_str(context, instance_id)
+ instance_ref = db.instance_get_by_ec2_id(context, instance_id)
host = instance_ref['host']
rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host),
{"method": "attach_volume",
@@ -331,16 +339,15 @@ class CloudController(object):
"volume_id": volume_ref['id'],
"instance_id": instance_ref['id'],
"mountpoint": device}})
- return defer.succeed({'attachTime': volume_ref['attach_time'],
- 'device': volume_ref['mountpoint'],
- 'instanceId': instance_ref['id'],
- 'requestId': context.request_id,
- 'status': volume_ref['attach_status'],
- 'volumeId': volume_ref['id']})
-
- @rbac.allow('projectmanager', 'sysadmin')
+ return {'attachTime': volume_ref['attach_time'],
+ 'device': volume_ref['mountpoint'],
+ 'instanceId': instance_ref['id'],
+ 'requestId': context.request_id,
+ 'status': volume_ref['attach_status'],
+ 'volumeId': volume_ref['id']}
+
def detach_volume(self, context, volume_id, **kwargs):
- volume_ref = db.volume_get_by_str(context, volume_id)
+ volume_ref = db.volume_get_by_ec2_id(context, volume_id)
instance_ref = db.volume_get_instance(context, volume_ref['id'])
if not instance_ref:
raise exception.ApiError("Volume isn't attached to anything!")
@@ -358,12 +365,12 @@ class CloudController(object):
# If the instance doesn't exist anymore,
# then we need to call detach blind
db.volume_detached(context)
- return defer.succeed({'attachTime': volume_ref['attach_time'],
- 'device': volume_ref['mountpoint'],
- 'instanceId': instance_ref['str_id'],
- 'requestId': context.request_id,
- 'status': volume_ref['attach_status'],
- 'volumeId': volume_ref['id']})
+ return {'attachTime': volume_ref['attach_time'],
+ 'device': volume_ref['mountpoint'],
+ 'instanceId': instance_ref['ec2_id'],
+ 'requestId': context.request_id,
+ 'status': volume_ref['attach_status'],
+ 'volumeId': volume_ref['id']}
def _convert_to_set(self, lst, label):
if lst == None or lst == []:
@@ -372,9 +379,18 @@ class CloudController(object):
lst = [lst]
return [{label: x} for x in lst]
- @rbac.allow('all')
+ def update_volume(self, context, volume_id, **kwargs):
+ updatable_fields = ['display_name', 'display_description']
+ changes = {}
+ for field in updatable_fields:
+ if field in kwargs:
+ changes[field] = kwargs[field]
+ if changes:
+ db.volume_update(context, volume_id, kwargs)
+ return True
+
def describe_instances(self, context, **kwargs):
- return defer.succeed(self._format_describe_instances(context))
+ return self._format_describe_instances(context)
def _format_describe_instances(self, context):
return { 'reservationSet': self._format_instances(context) }
@@ -387,20 +403,20 @@ class CloudController(object):
def _format_instances(self, context, reservation_id=None):
reservations = {}
if reservation_id:
- instances = db.instance_get_by_reservation(context,
- reservation_id)
+ instances = db.instance_get_all_by_reservation(context,
+ reservation_id)
else:
if context.user.is_admin():
instances = db.instance_get_all(context)
else:
- instances = db.instance_get_by_project(context,
- context.project.id)
+ instances = db.instance_get_all_by_project(context,
+ context.project.id)
for instance in instances:
if not context.user.is_admin():
if instance['image_id'] == FLAGS.vpn_image_id:
continue
i = {}
- i['instanceId'] = instance['str_id']
+ i['instanceId'] = instance['ec2_id']
i['imageId'] = instance['image_id']
i['instanceState'] = {
'code': instance['state'],
@@ -409,10 +425,10 @@ class CloudController(object):
fixed_addr = None
floating_addr = None
if instance['fixed_ip']:
- fixed_addr = instance['fixed_ip']['str_id']
+ fixed_addr = instance['fixed_ip']['address']
if instance['fixed_ip']['floating_ips']:
fixed = instance['fixed_ip']
- floating_addr = fixed['floating_ips'][0]['str_id']
+ floating_addr = fixed['floating_ips'][0]['address']
i['privateDnsName'] = fixed_addr
i['publicDnsName'] = floating_addr
i['dnsName'] = i['publicDnsName'] or i['privateDnsName']
@@ -425,6 +441,8 @@ class CloudController(object):
i['instanceType'] = instance['instance_type']
i['launchTime'] = instance['created_at']
i['amiLaunchIndex'] = instance['launch_index']
+ i['displayName'] = instance['display_name']
+ i['displayDescription'] = instance['display_description']
if not reservations.has_key(instance['reservation_id']):
r = {}
r['reservationId'] = instance['reservation_id']
@@ -436,7 +454,6 @@ class CloudController(object):
return list(reservations.values())
- @rbac.allow('all')
def describe_addresses(self, context, **kwargs):
return self.format_addresses(context)
@@ -445,14 +462,14 @@ class CloudController(object):
if context.user.is_admin():
iterator = db.floating_ip_get_all(context)
else:
- iterator = db.floating_ip_get_by_project(context,
- context.project.id)
+ iterator = db.floating_ip_get_all_by_project(context,
+ context.project.id)
for floating_ip_ref in iterator:
- address = floating_ip_ref['str_id']
+ address = floating_ip_ref['address']
instance_id = None
if (floating_ip_ref['fixed_ip']
and floating_ip_ref['fixed_ip']['instance']):
- instance_id = floating_ip_ref['fixed_ip']['instance']['str_id']
+ instance_id = floating_ip_ref['fixed_ip']['instance']['ec2_id']
address_rv = {'public_ip': address,
'instance_id': instance_id}
if context.user.is_admin():
@@ -462,8 +479,6 @@ class CloudController(object):
addresses.append(address_rv)
return {'addressesSet': addresses}
- @rbac.allow('netadmin')
- @defer.inlineCallbacks
def allocate_address(self, context, **kwargs):
# check quota
if quota.allowed_floating_ips(context, 1) < 1:
@@ -471,64 +486,56 @@ class CloudController(object):
context.project.id)
raise QuotaError("Address quota exceeded. You cannot "
"allocate any more addresses")
- network_topic = yield self._get_network_topic(context)
- public_ip = yield rpc.call(network_topic,
+ network_topic = self._get_network_topic(context)
+ public_ip = rpc.call(network_topic,
{"method": "allocate_floating_ip",
"args": {"context": None,
"project_id": context.project.id}})
- defer.returnValue({'addressSet': [{'publicIp': public_ip}]})
+ return {'addressSet': [{'publicIp': public_ip}]}
- @rbac.allow('netadmin')
- @defer.inlineCallbacks
def release_address(self, context, public_ip, **kwargs):
# NOTE(vish): Should we make sure this works?
floating_ip_ref = db.floating_ip_get_by_address(context, public_ip)
- network_topic = yield self._get_network_topic(context)
+ network_topic = self._get_network_topic(context)
rpc.cast(network_topic,
{"method": "deallocate_floating_ip",
"args": {"context": None,
- "floating_address": floating_ip_ref['str_id']}})
- defer.returnValue({'releaseResponse': ["Address released."]})
+ "floating_address": floating_ip_ref['address']}})
+ return {'releaseResponse': ["Address released."]}
- @rbac.allow('netadmin')
- @defer.inlineCallbacks
def associate_address(self, context, instance_id, public_ip, **kwargs):
- instance_ref = db.instance_get_by_str(context, instance_id)
- fixed_ip_ref = db.fixed_ip_get_by_instance(context, instance_ref['id'])
+ instance_ref = db.instance_get_by_ec2_id(context, instance_id)
+ fixed_address = db.instance_get_fixed_address(context,
+ instance_ref['id'])
floating_ip_ref = db.floating_ip_get_by_address(context, public_ip)
- network_topic = yield self._get_network_topic(context)
+ network_topic = self._get_network_topic(context)
rpc.cast(network_topic,
{"method": "associate_floating_ip",
"args": {"context": None,
- "floating_address": floating_ip_ref['str_id'],
- "fixed_address": fixed_ip_ref['str_id']}})
- defer.returnValue({'associateResponse': ["Address associated."]})
+ "floating_address": floating_ip_ref['address'],
+ "fixed_address": fixed_address}})
+ return {'associateResponse': ["Address associated."]}
- @rbac.allow('netadmin')
- @defer.inlineCallbacks
def disassociate_address(self, context, public_ip, **kwargs):
floating_ip_ref = db.floating_ip_get_by_address(context, public_ip)
- network_topic = yield self._get_network_topic(context)
+ network_topic = self._get_network_topic(context)
rpc.cast(network_topic,
{"method": "disassociate_floating_ip",
"args": {"context": None,
- "floating_address": floating_ip_ref['str_id']}})
- defer.returnValue({'disassociateResponse': ["Address disassociated."]})
+ "floating_address": floating_ip_ref['address']}})
+ return {'disassociateResponse': ["Address disassociated."]}
- @defer.inlineCallbacks
def _get_network_topic(self, context):
"""Retrieves the network host for a project"""
network_ref = db.project_get_network(context, context.project.id)
host = network_ref['host']
if not host:
- host = yield rpc.call(FLAGS.network_topic,
+ host = rpc.call(FLAGS.network_topic,
{"method": "set_network_host",
"args": {"context": None,
"project_id": context.project.id}})
- defer.returnValue(db.queue_get_for(context, FLAGS.network_topic, host))
+ return db.queue_get_for(context, FLAGS.network_topic, host)
- @rbac.allow('projectmanager', 'sysadmin')
- @defer.inlineCallbacks
def run_instances(self, context, **kwargs):
instance_type = kwargs.get('instance_type', 'm1.small')
if instance_type not in INSTANCE_TYPES:
@@ -571,11 +578,10 @@ class CloudController(object):
launch_time = time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime())
key_data = None
if kwargs.has_key('key_name'):
- key_pair = context.user.get_key_pair(kwargs['key_name'])
- if not key_pair:
- raise exception.ApiError('Key Pair %s not found' %
- kwargs['key_name'])
- key_data = key_pair.public_key
+ key_pair_ref = db.key_pair_get(context,
+ context.user.id,
+ kwargs['key_name'])
+ key_data = key_pair_ref['public_key']
# TODO: Get the real security group of launch in here
security_group = "default"
@@ -594,6 +600,8 @@ class CloudController(object):
base_options['user_data'] = kwargs.get('user_data', '')
base_options['security_group'] = security_group
base_options['instance_type'] = instance_type
+ base_options['display_name'] = kwargs.get('display_name')
+ base_options['display_description'] = kwargs.get('display_description')
type_data = INSTANCE_TYPES[instance_type]
base_options['memory_mb'] = type_data['memory_mb']
@@ -607,7 +615,7 @@ class CloudController(object):
inst = {}
inst['mac_address'] = utils.generate_mac()
inst['launch_index'] = num
- inst['hostname'] = instance_ref['str_id']
+ inst['hostname'] = instance_ref['ec2_id']
db.instance_update(context, inst_id, inst)
address = self.network_manager.allocate_fixed_ip(context,
inst_id,
@@ -615,7 +623,7 @@ class CloudController(object):
# TODO(vish): This probably should be done in the scheduler
# network is setup when host is assigned
- network_topic = yield self._get_network_topic(context)
+ network_topic = self._get_network_topic(context)
rpc.call(network_topic,
{"method": "setup_fixed_ip",
"args": {"context": None,
@@ -628,18 +636,15 @@ class CloudController(object):
"instance_id": inst_id}})
logging.debug("Casting to scheduler for %s/%s's instance %s" %
(context.project.name, context.user.name, inst_id))
- defer.returnValue(self._format_run_instances(context,
- reservation_id))
+ return self._format_run_instances(context, reservation_id)
- @rbac.allow('projectmanager', 'sysadmin')
- @defer.inlineCallbacks
def terminate_instances(self, context, instance_id, **kwargs):
logging.debug("Going to start terminating instances")
for id_str in instance_id:
logging.debug("Going to try and terminate %s" % id_str)
try:
- instance_ref = db.instance_get_by_str(context, id_str)
+ instance_ref = db.instance_get_by_ec2_id(context, id_str)
except exception.NotFound:
logging.warning("Instance %s was not found during terminate"
% id_str)
@@ -657,11 +662,11 @@ class CloudController(object):
# NOTE(vish): Right now we don't really care if the ip is
# disassociated. We may need to worry about
# checking this later. Perhaps in the scheduler?
- network_topic = yield self._get_network_topic(context)
+ network_topic = self._get_network_topic(context)
rpc.cast(network_topic,
{"method": "disassociate_floating_ip",
"args": {"context": None,
- "address": address}})
+ "floating_address": address}})
address = db.instance_get_fixed_address(context,
instance_ref['id'])
@@ -680,24 +685,29 @@ class CloudController(object):
"instance_id": instance_ref['id']}})
else:
db.instance_destroy(context, instance_ref['id'])
- defer.returnValue(True)
+ return True
- @rbac.allow('projectmanager', 'sysadmin')
def reboot_instances(self, context, instance_id, **kwargs):
"""instance_id is a list of instance ids"""
for id_str in instance_id:
- instance_ref = db.instance_get_by_str(context, id_str)
- host = instance_ref['host']
- rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host),
- {"method": "reboot_instance",
- "args": {"context": None,
- "instance_id": instance_ref['id']}})
- return defer.succeed(True)
+ cloud.reboot(id_str, context=context)
+ return True
+
+ def update_instance(self, context, instance_id, **kwargs):
+ updatable_fields = ['display_name', 'display_description']
+ changes = {}
+ for field in updatable_fields:
+ if field in kwargs:
+ changes[field] = kwargs[field]
+ if changes:
+ db_context = {}
+ inst = db.instance_get_by_ec2_id(db_context, instance_id)
+ db.instance_update(db_context, inst['id'], kwargs)
+ return True
- @rbac.allow('projectmanager', 'sysadmin')
def delete_volume(self, context, volume_id, **kwargs):
# TODO: return error if not authorized
- volume_ref = db.volume_get_by_str(context, volume_id)
+ volume_ref = db.volume_get_by_ec2_id(context, volume_id)
if volume_ref['status'] != "available":
raise exception.ApiError("Volume status must be available")
now = datetime.datetime.utcnow()
@@ -707,31 +717,26 @@ class CloudController(object):
{"method": "delete_volume",
"args": {"context": None,
"volume_id": volume_ref['id']}})
- return defer.succeed(True)
+ return True
- @rbac.allow('all')
def describe_images(self, context, image_id=None, **kwargs):
# The objectstore does its own authorization for describe
imageSet = images.list(context, image_id)
- return defer.succeed({'imagesSet': imageSet})
+ return {'imagesSet': imageSet}
- @rbac.allow('projectmanager', 'sysadmin')
def deregister_image(self, context, image_id, **kwargs):
# FIXME: should the objectstore be doing these authorization checks?
images.deregister(context, image_id)
- return defer.succeed({'imageId': image_id})
+ return {'imageId': image_id}
- @rbac.allow('projectmanager', 'sysadmin')
def register_image(self, context, image_location=None, **kwargs):
# FIXME: should the objectstore be doing these authorization checks?
if image_location is None and kwargs.has_key('name'):
image_location = kwargs['name']
image_id = images.register(context, image_location)
logging.debug("Registered %s as %s" % (image_location, image_id))
+ return {'imageId': image_id}
- return defer.succeed({'imageId': image_id})
-
- @rbac.allow('all')
def describe_image_attribute(self, context, image_id, attribute, **kwargs):
if attribute != 'launchPermission':
raise exception.ApiError('attribute not supported: %s' % attribute)
@@ -742,9 +747,8 @@ class CloudController(object):
result = {'image_id': image_id, 'launchPermission': []}
if image['isPublic']:
result['launchPermission'].append({'group': 'all'})
- return defer.succeed(result)
+ return result
- @rbac.allow('projectmanager', 'sysadmin')
def modify_image_attribute(self, context, image_id, attribute, operation_type, **kwargs):
# TODO(devcamcar): Support users and groups other than 'all'.
if attribute != 'launchPermission':
@@ -755,5 +759,8 @@ class CloudController(object):
raise exception.ApiError('only group "all" is supported')
if not operation_type in ['add', 'remove']:
raise exception.ApiError('operation_type must be add or remove')
- result = images.modify(context, image_id, operation_type)
- return defer.succeed(result)
+ return images.modify(context, image_id, operation_type)
+
+ def update_image(self, context, image_id, **kwargs):
+ result = images.update(context, image_id, dict(kwargs))
+ return result
diff --git a/nova/endpoint/images.py b/nova/api/ec2/images.py
index 4579cd81a..cb54cdda2 100644
--- a/nova/endpoint/images.py
+++ b/nova/api/ec2/images.py
@@ -43,6 +43,14 @@ def modify(context, image_id, operation):
return True
+def update(context, image_id, attributes):
+ """update an image's attributes / info.json"""
+ attributes.update({"image_id": image_id})
+ conn(context).make_request(
+ method='POST',
+ bucket='_images',
+ query_args=qs(attributes))
+ return True
def register(context, image_location):
""" rpc call to register a new image based from a manifest """
diff --git a/nova/api/ec2/metadatarequesthandler.py b/nova/api/ec2/metadatarequesthandler.py
new file mode 100644
index 000000000..08a8040ca
--- /dev/null
+++ b/nova/api/ec2/metadatarequesthandler.py
@@ -0,0 +1,73 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Metadata request handler."""
+
+import logging
+
+import webob.dec
+import webob.exc
+
+from nova.api.ec2 import cloud
+
+
+class MetadataRequestHandler(object):
+
+ """Serve metadata from the EC2 API."""
+
+ def print_data(self, data):
+ if isinstance(data, dict):
+ output = ''
+ for key in data:
+ if key == '_name':
+ continue
+ output += key
+ if isinstance(data[key], dict):
+ if '_name' in data[key]:
+ output += '=' + str(data[key]['_name'])
+ else:
+ output += '/'
+ output += '\n'
+ return output[:-1] # cut off last \n
+ elif isinstance(data, list):
+ return '\n'.join(data)
+ else:
+ return str(data)
+
+ def lookup(self, path, data):
+ items = path.split('/')
+ for item in items:
+ if item:
+ if not isinstance(data, dict):
+ return data
+ if not item in data:
+ return None
+ data = data[item]
+ return data
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ cc = cloud.CloudController()
+ meta_data = cc.get_metadata(req.remote_addr)
+ if meta_data is None:
+ logging.error('Failed to get metadata for ip: %s' % req.remote_addr)
+ raise webob.exc.HTTPNotFound()
+ data = self.lookup(req.path_info, meta_data)
+ if data is None:
+ raise webob.exc.HTTPNotFound()
+ return self.print_data(data)
diff --git a/nova/api/rackspace/__init__.py b/nova/api/rackspace/__init__.py
index b4d666d63..89a4693ad 100644
--- a/nova/api/rackspace/__init__.py
+++ b/nova/api/rackspace/__init__.py
@@ -26,44 +26,122 @@ import time
import routes
import webob.dec
import webob.exc
+import webob
from nova import flags
+from nova import utils
from nova import wsgi
+from nova.api.rackspace import faults
+from nova.api.rackspace import backup_schedules
from nova.api.rackspace import flavors
from nova.api.rackspace import images
+from nova.api.rackspace import ratelimiting
from nova.api.rackspace import servers
from nova.api.rackspace import sharedipgroups
from nova.auth import manager
+FLAGS = flags.FLAGS
+flags.DEFINE_string('nova_api_auth',
+ 'nova.api.rackspace.auth.BasicApiAuthManager',
+ 'The auth mechanism to use for the Rackspace API implemenation')
+
class API(wsgi.Middleware):
"""WSGI entry point for all Rackspace API requests."""
def __init__(self):
- app = AuthMiddleware(APIRouter())
+ app = AuthMiddleware(RateLimitingMiddleware(APIRouter()))
super(API, self).__init__(app)
-
class AuthMiddleware(wsgi.Middleware):
"""Authorize the rackspace API request or return an HTTP Forbidden."""
- #TODO(gundlach): isn't this the old Nova API's auth? Should it be replaced
- #with correct RS API auth?
+ def __init__(self, application):
+ self.auth_driver = utils.import_class(FLAGS.nova_api_auth)()
+ super(AuthMiddleware, self).__init__(application)
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ if not req.headers.has_key("X-Auth-Token"):
+ return self.auth_driver.authenticate(req)
+
+ user = self.auth_driver.authorize_token(req.headers["X-Auth-Token"])
+
+ if not user:
+ return faults.Fault(webob.exc.HTTPUnauthorized())
+
+ if not req.environ.has_key('nova.context'):
+ req.environ['nova.context'] = {}
+ req.environ['nova.context']['user'] = user
+ return self.application
+
+class RateLimitingMiddleware(wsgi.Middleware):
+ """Rate limit incoming requests according to the OpenStack rate limits."""
+
+ def __init__(self, application, service_host=None):
+ """Create a rate limiting middleware that wraps the given application.
+
+ By default, rate counters are stored in memory. If service_host is
+ specified, the middleware instead relies on the ratelimiting.WSGIApp
+ at the given host+port to keep rate counters.
+ """
+ super(RateLimitingMiddleware, self).__init__(application)
+ if not service_host:
+ #TODO(gundlach): These limits were based on limitations of Cloud
+ #Servers. We should revisit them in Nova.
+ self.limiter = ratelimiting.Limiter(limits={
+ 'DELETE': (100, ratelimiting.PER_MINUTE),
+ 'PUT': (10, ratelimiting.PER_MINUTE),
+ 'POST': (10, ratelimiting.PER_MINUTE),
+ 'POST servers': (50, ratelimiting.PER_DAY),
+ 'GET changes-since': (3, ratelimiting.PER_MINUTE),
+ })
+ else:
+ self.limiter = ratelimiting.WSGIAppProxy(service_host)
@webob.dec.wsgify
def __call__(self, req):
- context = {}
- if "HTTP_X_AUTH_TOKEN" in req.environ:
- context['user'] = manager.AuthManager().get_user_from_access_key(
- req.environ['HTTP_X_AUTH_TOKEN'])
- if context['user']:
- context['project'] = manager.AuthManager().get_project(
- context['user'].name)
- if "user" not in context:
- return webob.exc.HTTPForbidden()
- req.environ['nova.context'] = context
+ """Rate limit the request.
+
+ If the request should be rate limited, return a 413 status with a
+ Retry-After header giving the time when the request would succeed.
+ """
+ username = req.headers['X-Auth-User']
+ action_name = self.get_action_name(req)
+ if not action_name: # not rate limited
+ return self.application
+ delay = self.get_delay(action_name, username)
+ if delay:
+ # TODO(gundlach): Get the retry-after format correct.
+ exc = webob.exc.HTTPRequestEntityTooLarge(
+ explanation='Too many requests.',
+ headers={'Retry-After': time.time() + delay})
+ raise faults.Fault(exc)
return self.application
+ def get_delay(self, action_name, username):
+ """Return the delay for the given action and username, or None if
+ the action would not be rate limited.
+ """
+ if action_name == 'POST servers':
+ # "POST servers" is a POST, so it counts against "POST" too.
+ # Attempt the "POST" first, lest we are rate limited by "POST" but
+ # use up a precious "POST servers" call.
+ delay = self.limiter.perform("POST", username=username)
+ if delay:
+ return delay
+ return self.limiter.perform(action_name, username=username)
+
+ def get_action_name(self, req):
+ """Return the action name for this request."""
+ if req.method == 'GET' and 'changes-since' in req.GET:
+ return 'GET changes-since'
+ if req.method == 'POST' and req.path_info.startswith('/servers'):
+ return 'POST servers'
+ if req.method in ['PUT', 'POST', 'DELETE']:
+ return req.method
+ return None
+
class APIRouter(wsgi.Router):
"""
@@ -73,11 +151,40 @@ class APIRouter(wsgi.Router):
def __init__(self):
mapper = routes.Mapper()
- mapper.resource("server", "servers", controller=servers.Controller())
+ mapper.resource("server", "servers", controller=servers.Controller(),
+ collection={ 'detail': 'GET'},
+ member={'action':'POST'})
+
+ mapper.resource("backup_schedule", "backup_schedules",
+ controller=backup_schedules.Controller(),
+ parent_resource=dict(member_name='server',
+ collection_name = 'servers'))
+
mapper.resource("image", "images", controller=images.Controller(),
collection={'detail': 'GET'})
mapper.resource("flavor", "flavors", controller=flavors.Controller(),
collection={'detail': 'GET'})
mapper.resource("sharedipgroup", "sharedipgroups",
controller=sharedipgroups.Controller())
+
super(APIRouter, self).__init__(mapper)
+
+
+def limited(items, req):
+ """Return a slice of items according to requested offset and limit.
+
+ items - a sliceable
+ req - wobob.Request possibly containing offset and limit GET variables.
+ offset is where to start in the list, and limit is the maximum number
+ of items to return.
+
+ If limit is not specified, 0, or > 1000, defaults to 1000.
+ """
+ offset = int(req.GET.get('offset', 0))
+ limit = int(req.GET.get('limit', 0))
+ if not limit:
+ limit = 1000
+ limit = min(1000, limit)
+ range_end = offset + limit
+ return items[offset:range_end]
+
diff --git a/nova/api/rackspace/_id_translator.py b/nova/api/rackspace/_id_translator.py
index aec5fb6a5..333aa8434 100644
--- a/nova/api/rackspace/_id_translator.py
+++ b/nova/api/rackspace/_id_translator.py
@@ -37,6 +37,6 @@ class RackspaceAPIIdTranslator(object):
# every int id be used.)
return int(self._store.hget(self._fwd_key, str(opaque_id)))
- def from_rs_id(self, strategy_name, rs_id):
+ def from_rs_id(self, rs_id):
"""Convert a Rackspace id to a strategy-specific one."""
return self._store.hget(self._rev_key, rs_id)
diff --git a/nova/api/rackspace/auth.py b/nova/api/rackspace/auth.py
new file mode 100644
index 000000000..c45156ebd
--- /dev/null
+++ b/nova/api/rackspace/auth.py
@@ -0,0 +1,101 @@
+import datetime
+import hashlib
+import json
+import time
+
+import webob.exc
+import webob.dec
+
+from nova import auth
+from nova import db
+from nova import flags
+from nova import manager
+from nova import utils
+from nova.api.rackspace import faults
+
+FLAGS = flags.FLAGS
+
+class Context(object):
+ pass
+
+class BasicApiAuthManager(object):
+ """ Implements a somewhat rudimentary version of Rackspace Auth"""
+
+ def __init__(self, host=None, db_driver=None):
+ if not host:
+ host = FLAGS.host
+ self.host = host
+ if not db_driver:
+ db_driver = FLAGS.db_driver
+ self.db = utils.import_object(db_driver)
+ self.auth = auth.manager.AuthManager()
+ self.context = Context()
+ super(BasicApiAuthManager, self).__init__()
+
+ def authenticate(self, req):
+ # Unless the request is explicitly made against /<version>/ don't
+ # honor it
+ path_info = req.path_info
+ if len(path_info) > 1:
+ return faults.Fault(webob.exc.HTTPUnauthorized())
+
+ try:
+ username, key = req.headers['X-Auth-User'], \
+ req.headers['X-Auth-Key']
+ except KeyError:
+ return faults.Fault(webob.exc.HTTPUnauthorized())
+
+ username, key = req.headers['X-Auth-User'], req.headers['X-Auth-Key']
+ token, user = self._authorize_user(username, key)
+ if user and token:
+ res = webob.Response()
+ res.headers['X-Auth-Token'] = token['token_hash']
+ res.headers['X-Server-Management-Url'] = \
+ token['server_management_url']
+ res.headers['X-Storage-Url'] = token['storage_url']
+ res.headers['X-CDN-Management-Url'] = token['cdn_management_url']
+ res.content_type = 'text/plain'
+ res.status = '204'
+ return res
+ else:
+ return faults.Fault(webob.exc.HTTPUnauthorized())
+
+ def authorize_token(self, token_hash):
+ """ retrieves user information from the datastore given a token
+
+ If the token has expired, returns None
+ If the token is not found, returns None
+ Otherwise returns the token
+
+ This method will also remove the token if the timestamp is older than
+ 2 days ago.
+ """
+ token = self.db.auth_get_token(self.context, token_hash)
+ if token:
+ delta = datetime.datetime.now() - token['created_at']
+ if delta.days >= 2:
+ self.db.auth_destroy_token(self.context, token)
+ else:
+ user = self.auth.get_user(token['user_id'])
+ return { 'id':user['uid'] }
+ return None
+
+ def _authorize_user(self, username, key):
+ """ Generates a new token and assigns it to a user """
+ user = self.auth.get_user_from_access_key(key)
+ if user and user['name'] == username:
+ token_hash = hashlib.sha1('%s%s%f' % (username, key,
+ time.time())).hexdigest()
+ token = {}
+ token['token_hash'] = token_hash
+ token['cdn_management_url'] = ''
+ token['server_management_url'] = self._get_server_mgmt_url()
+ token['storage_url'] = ''
+ token['user_id'] = user['uid']
+ self.db.auth_create_token(self.context, token)
+ return token, user
+ return None, None
+
+ def _get_server_mgmt_url(self):
+ return 'https://%s/v1.0/' % self.host
+
diff --git a/nova/api/rackspace/backup_schedules.py b/nova/api/rackspace/backup_schedules.py
new file mode 100644
index 000000000..cb83023bc
--- /dev/null
+++ b/nova/api/rackspace/backup_schedules.py
@@ -0,0 +1,39 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import time
+from webob import exc
+
+from nova import wsgi
+from nova.api.rackspace import _id_translator
+from nova.api.rackspace import faults
+import nova.image.service
+
+class Controller(wsgi.Controller):
+ def __init__(self):
+ pass
+
+ def index(self, req, server_id):
+ return faults.Fault(exc.HTTPNotFound())
+
+ def create(self, req, server_id):
+ """ No actual update method required, since the existing API allows
+ both create and update through a POST """
+ return faults.Fault(exc.HTTPNotFound())
+
+ def delete(self, req, server_id):
+ return faults.Fault(exc.HTTPNotFound())
diff --git a/nova/api/rackspace/base.py b/nova/api/rackspace/context.py
index dd2c6543c..77394615b 100644
--- a/nova/api/rackspace/base.py
+++ b/nova/api/rackspace/context.py
@@ -15,16 +15,19 @@
# License for the specific language governing permissions and limitations
# under the License.
-from nova import wsgi
+"""
+APIRequestContext
+"""
+import random
-class Controller(wsgi.Controller):
- """TODO(eday): Base controller for all rackspace controllers. What is this
- for? Is this just Rackspace specific? """
-
- @classmethod
- def render(cls, instance):
- if isinstance(instance, list):
- return {cls.entity_name: cls.render(instance)}
- else:
- return {"TODO": "TODO"}
+class Project(object):
+ def __init__(self, user_id):
+ self.id = user_id
+
+class APIRequestContext(object):
+ """ This is an adapter class to get around all of the assumptions made in
+ the FlatNetworking """
+ def __init__(self, user_id):
+ self.user_id = user_id
+ self.project = Project(user_id)
diff --git a/nova/api/rackspace/faults.py b/nova/api/rackspace/faults.py
new file mode 100644
index 000000000..32e5c866f
--- /dev/null
+++ b/nova/api/rackspace/faults.py
@@ -0,0 +1,62 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+
+import webob.dec
+import webob.exc
+
+from nova import wsgi
+
+
+class Fault(webob.exc.HTTPException):
+
+ """An RS API fault response."""
+
+ _fault_names = {
+ 400: "badRequest",
+ 401: "unauthorized",
+ 403: "resizeNotAllowed",
+ 404: "itemNotFound",
+ 405: "badMethod",
+ 409: "inProgress",
+ 413: "overLimit",
+ 415: "badMediaType",
+ 501: "notImplemented",
+ 503: "serviceUnavailable"}
+
+ def __init__(self, exception):
+ """Create a Fault for the given webob.exc.exception."""
+ self.wrapped_exc = exception
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ """Generate a WSGI response based on the exception passed to ctor."""
+ # Replace the body with fault details.
+ code = self.wrapped_exc.status_int
+ fault_name = self._fault_names.get(code, "cloudServersFault")
+ fault_data = {
+ fault_name: {
+ 'code': code,
+ 'message': self.wrapped_exc.explanation}}
+ if code == 413:
+ retry = self.wrapped_exc.headers['Retry-After']
+ fault_data[fault_name]['retryAfter'] = retry
+ # 'code' is an attribute on the fault tag itself
+ metadata = {'application/xml': {'attributes': {fault_name: 'code'}}}
+ serializer = wsgi.Serializer(req.environ, metadata)
+ self.wrapped_exc.body = serializer.to_content_type(fault_data)
+ return self.wrapped_exc
diff --git a/nova/api/rackspace/flavors.py b/nova/api/rackspace/flavors.py
index 60b35c939..916449854 100644
--- a/nova/api/rackspace/flavors.py
+++ b/nova/api/rackspace/flavors.py
@@ -15,11 +15,14 @@
# License for the specific language governing permissions and limitations
# under the License.
-from nova.api.rackspace import base
-from nova.compute import instance_types
from webob import exc
-class Controller(base.Controller):
+from nova.api.rackspace import faults
+from nova.compute import instance_types
+from nova import wsgi
+import nova.api.rackspace
+
+class Controller(wsgi.Controller):
"""Flavor controller for the Rackspace API."""
_serialization_metadata = {
@@ -38,6 +41,7 @@ class Controller(base.Controller):
def detail(self, req):
"""Return all flavors in detail."""
items = [self.show(req, id)['flavor'] for id in self._all_ids()]
+ items = nova.api.rackspace.limited(items, req)
return dict(flavors=items)
def show(self, req, id):
@@ -47,7 +51,7 @@ class Controller(base.Controller):
item = dict(ram=val['memory_mb'], disk=val['local_gb'],
id=val['flavorid'], name=name)
return dict(flavor=item)
- raise exc.HTTPNotFound()
+ raise faults.Fault(exc.HTTPNotFound())
def _all_ids(self):
"""Return the list of all flavorids."""
diff --git a/nova/api/rackspace/images.py b/nova/api/rackspace/images.py
index 2f3e928b9..4a7dd489c 100644
--- a/nova/api/rackspace/images.py
+++ b/nova/api/rackspace/images.py
@@ -15,12 +15,15 @@
# License for the specific language governing permissions and limitations
# under the License.
-import nova.image.service
-from nova.api.rackspace import base
-from nova.api.rackspace import _id_translator
from webob import exc
-class Controller(base.Controller):
+from nova import wsgi
+from nova.api.rackspace import _id_translator
+import nova.api.rackspace
+import nova.image.service
+from nova.api.rackspace import faults
+
+class Controller(wsgi.Controller):
_serialization_metadata = {
'application/xml': {
@@ -44,6 +47,7 @@ class Controller(base.Controller):
def detail(self, req):
"""Return all public images in detail."""
data = self._service.index()
+ data = nova.api.rackspace.limited(data, req)
for img in data:
img['id'] = self._id_translator.to_rs_id(img['id'])
return dict(images=data)
@@ -57,14 +61,14 @@ class Controller(base.Controller):
def delete(self, req, id):
# Only public images are supported for now.
- raise exc.HTTPNotFound()
+ raise faults.Fault(exc.HTTPNotFound())
def create(self, req):
# Only public images are supported for now, so a request to
# make a backup of a server cannot be supproted.
- raise exc.HTTPNotFound()
+ raise faults.Fault(exc.HTTPNotFound())
def update(self, req, id):
# Users may not modify public images, and that's all that
# we support for now.
- raise exc.HTTPNotFound()
+ raise faults.Fault(exc.HTTPNotFound())
diff --git a/nova/api/rackspace/ratelimiting/__init__.py b/nova/api/rackspace/ratelimiting/__init__.py
new file mode 100644
index 000000000..f843bac0f
--- /dev/null
+++ b/nova/api/rackspace/ratelimiting/__init__.py
@@ -0,0 +1,122 @@
+"""Rate limiting of arbitrary actions."""
+
+import httplib
+import time
+import urllib
+import webob.dec
+import webob.exc
+
+
+# Convenience constants for the limits dictionary passed to Limiter().
+PER_SECOND = 1
+PER_MINUTE = 60
+PER_HOUR = 60 * 60
+PER_DAY = 60 * 60 * 24
+
+class Limiter(object):
+
+ """Class providing rate limiting of arbitrary actions."""
+
+ def __init__(self, limits):
+ """Create a rate limiter.
+
+ limits: a dict mapping from action name to a tuple. The tuple contains
+ the number of times the action may be performed, and the time period
+ (in seconds) during which the number must not be exceeded for this
+ action. Example: dict(reboot=(10, ratelimiting.PER_MINUTE)) would
+ allow 10 'reboot' actions per minute.
+ """
+ self.limits = limits
+ self._levels = {}
+
+ def perform(self, action_name, username='nobody'):
+ """Attempt to perform an action by the given username.
+
+ action_name: the string name of the action to perform. This must
+ be a key in the limits dict passed to the ctor.
+
+ username: an optional string name of the user performing the action.
+ Each user has her own set of rate limiting counters. Defaults to
+ 'nobody' (so that if you never specify a username when calling
+ perform(), a single set of counters will be used.)
+
+ Return None if the action may proceed. If the action may not proceed
+ because it has been rate limited, return the float number of seconds
+ until the action would succeed.
+ """
+ # Think of rate limiting as a bucket leaking water at 1cc/second. The
+ # bucket can hold as many ccs as there are seconds in the rate
+ # limiting period (e.g. 3600 for per-hour ratelimits), and if you can
+ # perform N actions in that time, each action fills the bucket by
+ # 1/Nth of its volume. You may only perform an action if the bucket
+ # would not overflow.
+ now = time.time()
+ key = '%s:%s' % (username, action_name)
+ last_time_performed, water_level = self._levels.get(key, (now, 0))
+ # The bucket leaks 1cc/second.
+ water_level -= (now - last_time_performed)
+ if water_level < 0:
+ water_level = 0
+ num_allowed_per_period, period_in_secs = self.limits[action_name]
+ # Fill the bucket by 1/Nth its capacity, and hope it doesn't overflow.
+ capacity = period_in_secs
+ new_level = water_level + (capacity * 1.0 / num_allowed_per_period)
+ if new_level > capacity:
+ # Delay this many seconds.
+ return new_level - capacity
+ self._levels[key] = (now, new_level)
+ return None
+
+
+# If one instance of this WSGIApps is unable to handle your load, put a
+# sharding app in front that shards by username to one of many backends.
+
+class WSGIApp(object):
+
+ """Application that tracks rate limits in memory. Send requests to it of
+ this form:
+
+ POST /limiter/<username>/<urlencoded action>
+
+ and receive a 200 OK, or a 403 Forbidden with an X-Wait-Seconds header
+ containing the number of seconds to wait before the action would succeed.
+ """
+
+ def __init__(self, limiter):
+ """Create the WSGI application using the given Limiter instance."""
+ self.limiter = limiter
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ parts = req.path_info.split('/')
+ # format: /limiter/<username>/<urlencoded action>
+ if req.method != 'POST':
+ raise webob.exc.HTTPMethodNotAllowed()
+ if len(parts) != 4 or parts[1] != 'limiter':
+ raise webob.exc.HTTPNotFound()
+ username = parts[2]
+ action_name = urllib.unquote(parts[3])
+ delay = self.limiter.perform(action_name, username)
+ if delay:
+ return webob.exc.HTTPForbidden(
+ headers={'X-Wait-Seconds': "%.2f" % delay})
+ else:
+ return '' # 200 OK
+
+
+class WSGIAppProxy(object):
+
+ """Limiter lookalike that proxies to a ratelimiting.WSGIApp."""
+
+ def __init__(self, service_host):
+ """Creates a proxy pointing to a ratelimiting.WSGIApp at the given
+ host."""
+ self.service_host = service_host
+
+ def perform(self, action, username='nobody'):
+ conn = httplib.HTTPConnection(self.service_host)
+ conn.request('POST', '/limiter/%s/%s' % (username, action))
+ resp = conn.getresponse()
+ if resp.status == 200:
+ return None # no delay
+ return float(resp.getheader('X-Wait-Seconds'))
diff --git a/nova/api/rackspace/ratelimiting/tests.py b/nova/api/rackspace/ratelimiting/tests.py
new file mode 100644
index 000000000..4c9510917
--- /dev/null
+++ b/nova/api/rackspace/ratelimiting/tests.py
@@ -0,0 +1,237 @@
+import httplib
+import StringIO
+import time
+import unittest
+import webob
+
+import nova.api.rackspace.ratelimiting as ratelimiting
+
+class LimiterTest(unittest.TestCase):
+
+ def setUp(self):
+ self.limits = {
+ 'a': (5, ratelimiting.PER_SECOND),
+ 'b': (5, ratelimiting.PER_MINUTE),
+ 'c': (5, ratelimiting.PER_HOUR),
+ 'd': (1, ratelimiting.PER_SECOND),
+ 'e': (100, ratelimiting.PER_SECOND)}
+ self.rl = ratelimiting.Limiter(self.limits)
+
+ def exhaust(self, action, times_until_exhausted, **kwargs):
+ for i in range(times_until_exhausted):
+ when = self.rl.perform(action, **kwargs)
+ self.assertEqual(when, None)
+ num, period = self.limits[action]
+ delay = period * 1.0 / num
+ # Verify that we are now thoroughly delayed
+ for i in range(10):
+ when = self.rl.perform(action, **kwargs)
+ self.assertAlmostEqual(when, delay, 2)
+
+ def test_second(self):
+ self.exhaust('a', 5)
+ time.sleep(0.2)
+ self.exhaust('a', 1)
+ time.sleep(1)
+ self.exhaust('a', 5)
+
+ def test_minute(self):
+ self.exhaust('b', 5)
+
+ def test_one_per_period(self):
+ def allow_once_and_deny_once():
+ when = self.rl.perform('d')
+ self.assertEqual(when, None)
+ when = self.rl.perform('d')
+ self.assertAlmostEqual(when, 1, 2)
+ return when
+ time.sleep(allow_once_and_deny_once())
+ time.sleep(allow_once_and_deny_once())
+ allow_once_and_deny_once()
+
+ def test_we_can_go_indefinitely_if_we_spread_out_requests(self):
+ for i in range(200):
+ when = self.rl.perform('e')
+ self.assertEqual(when, None)
+ time.sleep(0.01)
+
+ def test_users_get_separate_buckets(self):
+ self.exhaust('c', 5, username='alice')
+ self.exhaust('c', 5, username='bob')
+ self.exhaust('c', 5, username='chuck')
+ self.exhaust('c', 0, username='chuck')
+ self.exhaust('c', 0, username='bob')
+ self.exhaust('c', 0, username='alice')
+
+
+class FakeLimiter(object):
+ """Fake Limiter class that you can tell how to behave."""
+ def __init__(self, test):
+ self._action = self._username = self._delay = None
+ self.test = test
+ def mock(self, action, username, delay):
+ self._action = action
+ self._username = username
+ self._delay = delay
+ def perform(self, action, username):
+ self.test.assertEqual(action, self._action)
+ self.test.assertEqual(username, self._username)
+ return self._delay
+
+
+class WSGIAppTest(unittest.TestCase):
+
+ def setUp(self):
+ self.limiter = FakeLimiter(self)
+ self.app = ratelimiting.WSGIApp(self.limiter)
+
+ def test_invalid_methods(self):
+ requests = []
+ for method in ['GET', 'PUT', 'DELETE']:
+ req = webob.Request.blank('/limits/michael/breakdance',
+ dict(REQUEST_METHOD=method))
+ requests.append(req)
+ for req in requests:
+ self.assertEqual(req.get_response(self.app).status_int, 405)
+
+ def test_invalid_urls(self):
+ requests = []
+ for prefix in ['limit', '', 'limiter2', 'limiter/limits', 'limiter/1']:
+ req = webob.Request.blank('/%s/michael/breakdance' % prefix,
+ dict(REQUEST_METHOD='POST'))
+ requests.append(req)
+ for req in requests:
+ self.assertEqual(req.get_response(self.app).status_int, 404)
+
+ def verify(self, url, username, action, delay=None):
+ """Make sure that POSTing to the given url causes the given username
+ to perform the given action. Make the internal rate limiter return
+ delay and make sure that the WSGI app returns the correct response.
+ """
+ req = webob.Request.blank(url, dict(REQUEST_METHOD='POST'))
+ self.limiter.mock(action, username, delay)
+ resp = req.get_response(self.app)
+ if not delay:
+ self.assertEqual(resp.status_int, 200)
+ else:
+ self.assertEqual(resp.status_int, 403)
+ self.assertEqual(resp.headers['X-Wait-Seconds'], "%.2f" % delay)
+
+ def test_good_urls(self):
+ self.verify('/limiter/michael/hoot', 'michael', 'hoot')
+
+ def test_escaping(self):
+ self.verify('/limiter/michael/jump%20up', 'michael', 'jump up')
+
+ def test_response_to_delays(self):
+ self.verify('/limiter/michael/hoot', 'michael', 'hoot', 1)
+ self.verify('/limiter/michael/hoot', 'michael', 'hoot', 1.56)
+ self.verify('/limiter/michael/hoot', 'michael', 'hoot', 1000)
+
+
+class FakeHttplibSocket(object):
+ """a fake socket implementation for httplib.HTTPResponse, trivial"""
+
+ def __init__(self, response_string):
+ self._buffer = StringIO.StringIO(response_string)
+
+ def makefile(self, _mode, _other):
+ """Returns the socket's internal buffer"""
+ return self._buffer
+
+
+class FakeHttplibConnection(object):
+ """A fake httplib.HTTPConnection
+
+ Requests made via this connection actually get translated and routed into
+ our WSGI app, we then wait for the response and turn it back into
+ an httplib.HTTPResponse.
+ """
+ def __init__(self, app, host, is_secure=False):
+ self.app = app
+ self.host = host
+
+ def request(self, method, path, data='', headers={}):
+ req = webob.Request.blank(path)
+ req.method = method
+ req.body = data
+ req.headers = headers
+ req.host = self.host
+ # Call the WSGI app, get the HTTP response
+ resp = str(req.get_response(self.app))
+ # For some reason, the response doesn't have "HTTP/1.0 " prepended; I
+ # guess that's a function the web server usually provides.
+ resp = "HTTP/1.0 %s" % resp
+ sock = FakeHttplibSocket(resp)
+ self.http_response = httplib.HTTPResponse(sock)
+ self.http_response.begin()
+
+ def getresponse(self):
+ return self.http_response
+
+
+def wire_HTTPConnection_to_WSGI(host, app):
+ """Monkeypatches HTTPConnection so that if you try to connect to host, you
+ are instead routed straight to the given WSGI app.
+
+ After calling this method, when any code calls
+
+ httplib.HTTPConnection(host)
+
+ the connection object will be a fake. Its requests will be sent directly
+ to the given WSGI app rather than through a socket.
+
+ Code connecting to hosts other than host will not be affected.
+
+ This method may be called multiple times to map different hosts to
+ different apps.
+ """
+ class HTTPConnectionDecorator(object):
+ """Wraps the real HTTPConnection class so that when you instantiate
+ the class you might instead get a fake instance."""
+ def __init__(self, wrapped):
+ self.wrapped = wrapped
+ def __call__(self, connection_host, *args, **kwargs):
+ if connection_host == host:
+ return FakeHttplibConnection(app, host)
+ else:
+ return self.wrapped(connection_host, *args, **kwargs)
+ httplib.HTTPConnection = HTTPConnectionDecorator(httplib.HTTPConnection)
+
+
+class WSGIAppProxyTest(unittest.TestCase):
+
+ def setUp(self):
+ """Our WSGIAppProxy is going to call across an HTTPConnection to a
+ WSGIApp running a limiter. The proxy will send input, and the proxy
+ should receive that same input, pass it to the limiter who gives a
+ result, and send the expected result back.
+
+ The HTTPConnection isn't real -- it's monkeypatched to point straight
+ at the WSGIApp. And the limiter isn't real -- it's a fake that
+ behaves the way we tell it to.
+ """
+ self.limiter = FakeLimiter(self)
+ app = ratelimiting.WSGIApp(self.limiter)
+ wire_HTTPConnection_to_WSGI('100.100.100.100:80', app)
+ self.proxy = ratelimiting.WSGIAppProxy('100.100.100.100:80')
+
+ def test_200(self):
+ self.limiter.mock('conquer', 'caesar', None)
+ when = self.proxy.perform('conquer', 'caesar')
+ self.assertEqual(when, None)
+
+ def test_403(self):
+ self.limiter.mock('grumble', 'proletariat', 1.5)
+ when = self.proxy.perform('grumble', 'proletariat')
+ self.assertEqual(when, 1.5)
+
+ def test_failure(self):
+ def shouldRaise():
+ self.limiter.mock('murder', 'brutus', None)
+ self.proxy.perform('stab', 'brutus')
+ self.assertRaises(AssertionError, shouldRaise)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/nova/api/rackspace/servers.py b/nova/api/rackspace/servers.py
index 1815f7523..11efd8aef 100644
--- a/nova/api/rackspace/servers.py
+++ b/nova/api/rackspace/servers.py
@@ -14,67 +14,286 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
+
import time
-from nova import db
+import webob
+from webob import exc
+
from nova import flags
from nova import rpc
from nova import utils
-from nova.api.rackspace import base
+from nova import wsgi
+from nova.api import cloud
+from nova.api.rackspace import _id_translator
+from nova.api.rackspace import context
+from nova.api.rackspace import faults
+from nova.compute import instance_types
+from nova.compute import power_state
+import nova.api.rackspace
+import nova.image.service
FLAGS = flags.FLAGS
-class Controller(base.Controller):
- entity_name = 'servers'
+flags.DEFINE_string('rs_network_manager', 'nova.network.manager.FlatManager',
+ 'Networking for rackspace')
+
+def _instance_id_translator():
+ """ Helper method for initializing an id translator for Rackspace instance
+ ids """
+ return _id_translator.RackspaceAPIIdTranslator( "instance", 'nova')
+
+def _image_service():
+ """ Helper method for initializing the image id translator """
+ service = nova.image.service.ImageService.load()
+ return (service, _id_translator.RackspaceAPIIdTranslator(
+ "image", service.__class__.__name__))
+
+def _filter_params(inst_dict):
+ """ Extracts all updatable parameters for a server update request """
+ keys = dict(name='name', admin_pass='adminPass')
+ new_attrs = {}
+ for k, v in keys.items():
+ if inst_dict.has_key(v):
+ new_attrs[k] = inst_dict[v]
+ return new_attrs
+
+def _entity_list(entities):
+ """ Coerces a list of servers into proper dictionary format """
+ return dict(servers=entities)
+
+def _entity_detail(inst):
+ """ Maps everything to Rackspace-like attributes for return"""
+ power_mapping = {
+ power_state.NOSTATE: 'build',
+ power_state.RUNNING: 'active',
+ power_state.BLOCKED: 'active',
+ power_state.PAUSED: 'suspended',
+ power_state.SHUTDOWN: 'active',
+ power_state.SHUTOFF: 'active',
+ power_state.CRASHED: 'error'
+ }
+ inst_dict = {}
+
+ mapped_keys = dict(status='state', imageId='image_id',
+ flavorId='instance_type', name='server_name', id='id')
+
+ for k, v in mapped_keys.iteritems():
+ inst_dict[k] = inst[v]
+
+ inst_dict['status'] = power_mapping[inst_dict['status']]
+ inst_dict['addresses'] = dict(public=[], private=[])
+ inst_dict['metadata'] = {}
+ inst_dict['hostId'] = ''
+
+ return dict(server=inst_dict)
+
+def _entity_inst(inst):
+ """ Filters all model attributes save for id and name """
+ return dict(server=dict(id=inst['id'], name=inst['server_name']))
- def index(self, **kwargs):
- instances = []
- for inst in db.instance_get_all(None):
- instances.append(instance_details(inst))
+class Controller(wsgi.Controller):
+ """ The Server API controller for the Openstack API """
- def show(self, **kwargs):
- instance_id = kwargs['id']
- return db.instance_get(None, instance_id)
+ _serialization_metadata = {
+ 'application/xml': {
+ "attributes": {
+ "server": [ "id", "imageId", "name", "flavorId", "hostId",
+ "status", "progress", "progress" ]
+ }
+ }
+ }
- def delete(self, **kwargs):
- instance_id = kwargs['id']
- instance = db.instance_get(None, instance_id)
- if not instance:
- raise ServerNotFound("The requested server was not found")
- instance.destroy()
- return True
+ def __init__(self, db_driver=None):
+ if not db_driver:
+ db_driver = FLAGS.db_driver
+ self.db_driver = utils.import_object(db_driver)
+ super(Controller, self).__init__()
+
+ def index(self, req):
+ """ Returns a list of server names and ids for a given user """
+ return self._items(req, entity_maker=_entity_inst)
+
+ def detail(self, req):
+ """ Returns a list of server details for a given user """
+ return self._items(req, entity_maker=_entity_detail)
+
+ def _items(self, req, entity_maker):
+ """Returns a list of servers for a given user.
+
+ entity_maker - either _entity_detail or _entity_inst
+ """
+ user_id = req.environ['nova.context']['user']['id']
+ instance_list = self.db_driver.instance_get_all_by_user(None, user_id)
+ limited_list = nova.api.rackspace.limited(instance_list, req)
+ res = [entity_maker(inst)['server'] for inst in limited_list]
+ return _entity_list(res)
+
+ def show(self, req, id):
+ """ Returns server details by server id """
+ inst_id_trans = _instance_id_translator()
+ inst_id = inst_id_trans.from_rs_id(id)
+
+ user_id = req.environ['nova.context']['user']['id']
+ inst = self.db_driver.instance_get_by_ec2_id(None, inst_id)
+ if inst:
+ if inst.user_id == user_id:
+ return _entity_detail(inst)
+ raise faults.Fault(exc.HTTPNotFound())
+
+ def delete(self, req, id):
+ """ Destroys a server """
+ inst_id_trans = _instance_id_translator()
+ inst_id = inst_id_trans.from_rs_id(id)
+
+ user_id = req.environ['nova.context']['user']['id']
+ instance = self.db_driver.instance_get_by_ec2_id(None, inst_id)
+ if instance and instance['user_id'] == user_id:
+ self.db_driver.instance_destroy(None, id)
+ return faults.Fault(exc.HTTPAccepted())
+ return faults.Fault(exc.HTTPNotFound())
+
+ def create(self, req):
+ """ Creates a new server for a given user """
+
+ env = self._deserialize(req.body, req)
+ if not env:
+ return faults.Fault(exc.HTTPUnprocessableEntity())
+
+ try:
+ inst = self._build_server_instance(req, env)
+ except Exception, e:
+ return faults.Fault(exc.HTTPUnprocessableEntity())
- def create(self, **kwargs):
- inst = self.build_server_instance(kwargs['server'])
rpc.cast(
FLAGS.compute_topic, {
"method": "run_instance",
"args": {"instance_id": inst['id']}})
+ return _entity_inst(inst)
+
+ def update(self, req, id):
+ """ Updates the server name or password """
+ inst_id_trans = _instance_id_translator()
+ inst_id = inst_id_trans.from_rs_id(id)
+ user_id = req.environ['nova.context']['user']['id']
- def update(self, **kwargs):
- instance_id = kwargs['id']
- instance = db.instance_get(None, instance_id)
- if not instance:
- raise ServerNotFound("The requested server was not found")
- instance.update(kwargs['server'])
- instance.save()
+ inst_dict = self._deserialize(req.body, req)
+
+ if not inst_dict:
+ return faults.Fault(exc.HTTPUnprocessableEntity())
- def build_server_instance(self, env):
+ instance = self.db_driver.instance_get_by_ec2_id(None, inst_id)
+ if not instance or instance.user_id != user_id:
+ return faults.Fault(exc.HTTPNotFound())
+
+ self.db_driver.instance_update(None, id,
+ _filter_params(inst_dict['server']))
+ return faults.Fault(exc.HTTPNoContent())
+
+ def action(self, req, id):
+ """ multi-purpose method used to reboot, rebuild, and
+ resize a server """
+ input_dict = self._deserialize(req.body, req)
+ try:
+ reboot_type = input_dict['reboot']['type']
+ except Exception:
+ raise faults.Fault(webob.exc.HTTPNotImplemented())
+ opaque_id = _instance_id_translator().from_rs_id(id)
+ cloud.reboot(opaque_id)
+
+ def _build_server_instance(self, req, env):
"""Build instance data structure and save it to the data store."""
- reservation = utils.generate_uid('r')
ltime = time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime())
inst = {}
- inst['name'] = env['server']['name']
- inst['image_id'] = env['server']['imageId']
- inst['instance_type'] = env['server']['flavorId']
- inst['user_id'] = env['user']['id']
- inst['project_id'] = env['project']['id']
- inst['reservation_id'] = reservation
+
+ inst_id_trans = _instance_id_translator()
+
+ user_id = req.environ['nova.context']['user']['id']
+
+ flavor_id = env['server']['flavorId']
+
+ instance_type, flavor = [(k, v) for k, v in
+ instance_types.INSTANCE_TYPES.iteritems()
+ if v['flavorid'] == flavor_id][0]
+
+ image_id = env['server']['imageId']
+
+ img_service, image_id_trans = _image_service()
+
+ opaque_image_id = image_id_trans.to_rs_id(image_id)
+ image = img_service.show(opaque_image_id)
+
+ if not image:
+ raise Exception, "Image not found"
+
+ inst['server_name'] = env['server']['name']
+ inst['image_id'] = opaque_image_id
+ inst['user_id'] = user_id
inst['launch_time'] = ltime
inst['mac_address'] = utils.generate_mac()
- inst_id = db.instance_create(None, inst)['id']
- address = self.network_manager.allocate_fixed_ip(None, inst_id)
- # key_data, key_name, ami_launch_index
- # TODO(todd): key data or root password
- inst.save()
+ inst['project_id'] = user_id
+
+ inst['state_description'] = 'scheduling'
+ inst['kernel_id'] = image.get('kernelId', FLAGS.default_kernel)
+ inst['ramdisk_id'] = image.get('ramdiskId', FLAGS.default_ramdisk)
+ inst['reservation_id'] = utils.generate_uid('r')
+
+ inst['display_name'] = env['server']['name']
+ inst['display_description'] = env['server']['name']
+
+ #TODO(dietz) this may be ill advised
+ key_pair_ref = self.db_driver.key_pair_get_all_by_user(
+ None, user_id)[0]
+
+ inst['key_data'] = key_pair_ref['public_key']
+ inst['key_name'] = key_pair_ref['name']
+
+ #TODO(dietz) stolen from ec2 api, see TODO there
+ inst['security_group'] = 'default'
+
+ # Flavor related attributes
+ inst['instance_type'] = instance_type
+ inst['memory_mb'] = flavor['memory_mb']
+ inst['vcpus'] = flavor['vcpus']
+ inst['local_gb'] = flavor['local_gb']
+
+ ref = self.db_driver.instance_create(None, inst)
+ inst['id'] = inst_id_trans.to_rs_id(ref.ec2_id)
+
+ # TODO(dietz): this isn't explicitly necessary, but the networking
+ # calls depend on an object with a project_id property, and therefore
+ # should be cleaned up later
+ api_context = context.APIRequestContext(user_id)
+
+ inst['mac_address'] = utils.generate_mac()
+
+ #TODO(dietz) is this necessary?
+ inst['launch_index'] = 0
+
+ inst['hostname'] = ref.ec2_id
+ self.db_driver.instance_update(None, inst['id'], inst)
+
+ network_manager = utils.import_object(FLAGS.rs_network_manager)
+ address = network_manager.allocate_fixed_ip(api_context,
+ inst['id'])
+
+ # TODO(vish): This probably should be done in the scheduler
+ # network is setup when host is assigned
+ network_topic = self._get_network_topic(user_id)
+ rpc.call(network_topic,
+ {"method": "setup_fixed_ip",
+ "args": {"context": None,
+ "address": address}})
return inst
+
+ def _get_network_topic(self, user_id):
+ """Retrieves the network host for a project"""
+ network_ref = self.db_driver.project_get_network(None,
+ user_id)
+ host = network_ref['host']
+ if not host:
+ host = rpc.call(FLAGS.network_topic,
+ {"method": "set_network_host",
+ "args": {"context": None,
+ "project_id": user_id}})
+ return self.db_driver.queue_get_for(None, FLAGS.network_topic, host)
diff --git a/nova/api/rackspace/sharedipgroups.py b/nova/api/rackspace/sharedipgroups.py
index 986f11434..4d2d0ede1 100644
--- a/nova/api/rackspace/sharedipgroups.py
+++ b/nova/api/rackspace/sharedipgroups.py
@@ -15,4 +15,6 @@
# License for the specific language governing permissions and limitations
# under the License.
-class Controller(object): pass
+from nova import wsgi
+
+class Controller(wsgi.Controller): pass
diff --git a/nova/auth/dbdriver.py b/nova/auth/dbdriver.py
new file mode 100644
index 000000000..09d15018b
--- /dev/null
+++ b/nova/auth/dbdriver.py
@@ -0,0 +1,236 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Auth driver using the DB as its backend.
+"""
+
+import logging
+import sys
+
+from nova import exception
+from nova import db
+
+
+class DbDriver(object):
+ """DB Auth driver
+
+ Defines enter and exit and therefore supports the with/as syntax.
+ """
+
+ def __init__(self):
+ """Imports the LDAP module"""
+ pass
+ db
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ pass
+
+ def get_user(self, uid):
+ """Retrieve user by id"""
+ return self._db_user_to_auth_user(db.user_get({}, uid))
+
+ def get_user_from_access_key(self, access):
+ """Retrieve user by access key"""
+ return self._db_user_to_auth_user(db.user_get_by_access_key({}, access))
+
+ def get_project(self, pid):
+ """Retrieve project by id"""
+ return self._db_project_to_auth_projectuser(db.project_get({}, pid))
+
+ def get_users(self):
+ """Retrieve list of users"""
+ return [self._db_user_to_auth_user(user) for user in db.user_get_all({})]
+
+ def get_projects(self, uid=None):
+ """Retrieve list of projects"""
+ if uid:
+ result = db.project_get_by_user({}, uid)
+ else:
+ result = db.project_get_all({})
+ return [self._db_project_to_auth_projectuser(proj) for proj in result]
+
+ def create_user(self, name, access_key, secret_key, is_admin):
+ """Create a user"""
+ values = { 'id' : name,
+ 'access_key' : access_key,
+ 'secret_key' : secret_key,
+ 'is_admin' : is_admin
+ }
+ try:
+ user_ref = db.user_create({}, values)
+ return self._db_user_to_auth_user(user_ref)
+ except exception.Duplicate, e:
+ raise exception.Duplicate('User %s already exists' % name)
+
+ def _db_user_to_auth_user(self, user_ref):
+ return { 'id' : user_ref['id'],
+ 'name' : user_ref['id'],
+ 'access' : user_ref['access_key'],
+ 'secret' : user_ref['secret_key'],
+ 'admin' : user_ref['is_admin'] }
+
+ def _db_project_to_auth_projectuser(self, project_ref):
+ return { 'id' : project_ref['id'],
+ 'name' : project_ref['name'],
+ 'project_manager_id' : project_ref['project_manager'],
+ 'description' : project_ref['description'],
+ 'member_ids' : [member['id'] for member in project_ref['members']] }
+
+ def create_project(self, name, manager_uid,
+ description=None, member_uids=None):
+ """Create a project"""
+ manager = db.user_get({}, manager_uid)
+ if not manager:
+ raise exception.NotFound("Project can't be created because "
+ "manager %s doesn't exist" % manager_uid)
+
+ # description is a required attribute
+ if description is None:
+ description = name
+
+ # First, we ensure that all the given users exist before we go
+ # on to create the project. This way we won't have to destroy
+ # the project again because a user turns out to be invalid.
+ members = set([manager])
+ if member_uids != None:
+ for member_uid in member_uids:
+ member = db.user_get({}, member_uid)
+ if not member:
+ raise exception.NotFound("Project can't be created "
+ "because user %s doesn't exist"
+ % member_uid)
+ members.add(member)
+
+ values = { 'id' : name,
+ 'name' : name,
+ 'project_manager' : manager['id'],
+ 'description': description }
+
+ try:
+ project = db.project_create({}, values)
+ except exception.Duplicate:
+ raise exception.Duplicate("Project can't be created because "
+ "project %s already exists" % name)
+
+ for member in members:
+ db.project_add_member({}, project['id'], member['id'])
+
+ # This looks silly, but ensures that the members element has been
+ # correctly populated
+ project_ref = db.project_get({}, project['id'])
+ return self._db_project_to_auth_projectuser(project_ref)
+
+ def modify_project(self, project_id, manager_uid=None, description=None):
+ """Modify an existing project"""
+ if not manager_uid and not description:
+ return
+ values = {}
+ if manager_uid:
+ manager = db.user_get({}, manager_uid)
+ if not manager:
+ raise exception.NotFound("Project can't be modified because "
+ "manager %s doesn't exist" %
+ manager_uid)
+ values['project_manager'] = manager['id']
+ if description:
+ values['description'] = description
+
+ db.project_update({}, project_id, values)
+
+ def add_to_project(self, uid, project_id):
+ """Add user to project"""
+ user, project = self._validate_user_and_project(uid, project_id)
+ db.project_add_member({}, project['id'], user['id'])
+
+ def remove_from_project(self, uid, project_id):
+ """Remove user from project"""
+ user, project = self._validate_user_and_project(uid, project_id)
+ db.project_remove_member({}, project['id'], user['id'])
+
+ def is_in_project(self, uid, project_id):
+ """Check if user is in project"""
+ user, project = self._validate_user_and_project(uid, project_id)
+ return user in project.members
+
+ def has_role(self, uid, role, project_id=None):
+ """Check if user has role
+
+ If project is specified, it checks for local role, otherwise it
+ checks for global role
+ """
+
+ return role in self.get_user_roles(uid, project_id)
+
+ def add_role(self, uid, role, project_id=None):
+ """Add role for user (or user and project)"""
+ if not project_id:
+ db.user_add_role({}, uid, role)
+ return
+ db.user_add_project_role({}, uid, project_id, role)
+
+ def remove_role(self, uid, role, project_id=None):
+ """Remove role for user (or user and project)"""
+ if not project_id:
+ db.user_remove_role({}, uid, role)
+ return
+ db.user_remove_project_role({}, uid, project_id, role)
+
+ def get_user_roles(self, uid, project_id=None):
+ """Retrieve list of roles for user (or user and project)"""
+ if project_id is None:
+ roles = db.user_get_roles({}, uid)
+ return roles
+ else:
+ roles = db.user_get_roles_for_project({}, uid, project_id)
+ return roles
+
+ def delete_user(self, id):
+ """Delete a user"""
+ user = db.user_get({}, id)
+ db.user_delete({}, user['id'])
+
+ def delete_project(self, project_id):
+ """Delete a project"""
+ db.project_delete({}, project_id)
+
+ def modify_user(self, uid, access_key=None, secret_key=None, admin=None):
+ """Modify an existing user"""
+ if not access_key and not secret_key and admin is None:
+ return
+ values = {}
+ if access_key:
+ values['access_key'] = access_key
+ if secret_key:
+ values['secret_key'] = secret_key
+ if admin is not None:
+ values['is_admin'] = admin
+ db.user_update({}, uid, values)
+
+ def _validate_user_and_project(self, user_id, project_id):
+ user = db.user_get({}, user_id)
+ if not user:
+ raise exception.NotFound('User "%s" not found' % user_id)
+ project = db.project_get({}, project_id)
+ if not project:
+ raise exception.NotFound('Project "%s" not found' % project_id)
+ return user, project
+
diff --git a/nova/auth/fakeldap.py b/nova/auth/fakeldap.py
index bfc3433c5..2791dfde6 100644
--- a/nova/auth/fakeldap.py
+++ b/nova/auth/fakeldap.py
@@ -33,6 +33,7 @@ SCOPE_ONELEVEL = 1 # not implemented
SCOPE_SUBTREE = 2
MOD_ADD = 0
MOD_DELETE = 1
+MOD_REPLACE = 2
class NO_SUCH_OBJECT(Exception): # pylint: disable-msg=C0103
@@ -175,7 +176,7 @@ class FakeLDAP(object):
Args:
dn -- a dn
attrs -- a list of tuples in the following form:
- ([MOD_ADD | MOD_DELETE], attribute, value)
+ ([MOD_ADD | MOD_DELETE | MOD_REPACE], attribute, value)
"""
redis = datastore.Redis.instance()
@@ -185,6 +186,8 @@ class FakeLDAP(object):
values = _from_json(redis.hget(key, k))
if cmd == MOD_ADD:
values.append(v)
+ elif cmd == MOD_REPLACE:
+ values = [v]
else:
values.remove(v)
values = redis.hset(key, k, _to_json(values))
diff --git a/nova/auth/ldapdriver.py b/nova/auth/ldapdriver.py
index 74ba011b5..640ea169e 100644
--- a/nova/auth/ldapdriver.py
+++ b/nova/auth/ldapdriver.py
@@ -99,13 +99,6 @@ class LdapDriver(object):
dn = FLAGS.ldap_user_subtree
return self.__to_user(self.__find_object(dn, query))
- def get_key_pair(self, uid, key_name):
- """Retrieve key pair by uid and key name"""
- dn = 'cn=%s,%s' % (key_name,
- self.__uid_to_dn(uid))
- attr = self.__find_object(dn, '(objectclass=novaKeyPair)')
- return self.__to_key_pair(uid, attr)
-
def get_project(self, pid):
"""Retrieve project by id"""
dn = 'cn=%s,%s' % (pid,
@@ -119,12 +112,6 @@ class LdapDriver(object):
'(objectclass=novaUser)')
return [self.__to_user(attr) for attr in attrs]
- def get_key_pairs(self, uid):
- """Retrieve list of key pairs"""
- attrs = self.__find_objects(self.__uid_to_dn(uid),
- '(objectclass=novaKeyPair)')
- return [self.__to_key_pair(uid, attr) for attr in attrs]
-
def get_projects(self, uid=None):
"""Retrieve list of projects"""
pattern = '(objectclass=novaProject)'
@@ -154,21 +141,6 @@ class LdapDriver(object):
self.conn.add_s(self.__uid_to_dn(name), attr)
return self.__to_user(dict(attr))
- def create_key_pair(self, uid, key_name, public_key, fingerprint):
- """Create a key pair"""
- # TODO(vish): possibly refactor this to store keys in their own ou
- # and put dn reference in the user object
- attr = [
- ('objectclass', ['novaKeyPair']),
- ('cn', [key_name]),
- ('sshPublicKey', [public_key]),
- ('keyFingerprint', [fingerprint]),
- ]
- self.conn.add_s('cn=%s,%s' % (key_name,
- self.__uid_to_dn(uid)),
- attr)
- return self.__to_key_pair(uid, dict(attr))
-
def create_project(self, name, manager_uid,
description=None, member_uids=None):
"""Create a project"""
@@ -202,6 +174,24 @@ class LdapDriver(object):
self.conn.add_s('cn=%s,%s' % (name, FLAGS.ldap_project_subtree), attr)
return self.__to_project(dict(attr))
+ def modify_project(self, project_id, manager_uid=None, description=None):
+ """Modify an existing project"""
+ if not manager_uid and not description:
+ return
+ attr = []
+ if manager_uid:
+ if not self.__user_exists(manager_uid):
+ raise exception.NotFound("Project can't be modified because "
+ "manager %s doesn't exist" %
+ manager_uid)
+ manager_dn = self.__uid_to_dn(manager_uid)
+ attr.append((self.ldap.MOD_REPLACE, 'projectManager', manager_dn))
+ if description:
+ attr.append((self.ldap.MOD_REPLACE, 'description', description))
+ self.conn.modify_s('cn=%s,%s' % (project_id,
+ FLAGS.ldap_project_subtree),
+ attr)
+
def add_to_project(self, uid, project_id):
"""Add user to project"""
dn = 'cn=%s,%s' % (project_id, FLAGS.ldap_project_subtree)
@@ -265,18 +255,8 @@ class LdapDriver(object):
"""Delete a user"""
if not self.__user_exists(uid):
raise exception.NotFound("User %s doesn't exist" % uid)
- self.__delete_key_pairs(uid)
self.__remove_from_all(uid)
- self.conn.delete_s('uid=%s,%s' % (uid,
- FLAGS.ldap_user_subtree))
-
- def delete_key_pair(self, uid, key_name):
- """Delete a key pair"""
- if not self.__key_pair_exists(uid, key_name):
- raise exception.NotFound("Key Pair %s doesn't exist for user %s" %
- (key_name, uid))
- self.conn.delete_s('cn=%s,uid=%s,%s' % (key_name, uid,
- FLAGS.ldap_user_subtree))
+ self.conn.delete_s(self.__uid_to_dn(uid))
def delete_project(self, project_id):
"""Delete a project"""
@@ -284,14 +264,23 @@ class LdapDriver(object):
self.__delete_roles(project_dn)
self.__delete_group(project_dn)
+ def modify_user(self, uid, access_key=None, secret_key=None, admin=None):
+ """Modify an existing project"""
+ if not access_key and not secret_key and admin is None:
+ return
+ attr = []
+ if access_key:
+ attr.append((self.ldap.MOD_REPLACE, 'accessKey', access_key))
+ if secret_key:
+ attr.append((self.ldap.MOD_REPLACE, 'secretKey', secret_key))
+ if admin is not None:
+ attr.append((self.ldap.MOD_REPLACE, 'isAdmin', str(admin).upper()))
+ self.conn.modify_s(self.__uid_to_dn(uid), attr)
+
def __user_exists(self, uid):
"""Check if user exists"""
return self.get_user(uid) != None
- def __key_pair_exists(self, uid, key_name):
- """Check if key pair exists"""
- return self.get_key_pair(uid, key_name) != None
-
def __project_exists(self, project_id):
"""Check if project exists"""
return self.get_project(project_id) != None
@@ -341,13 +330,6 @@ class LdapDriver(object):
"""Check if group exists"""
return self.__find_object(dn, '(objectclass=groupOfNames)') != None
- def __delete_key_pairs(self, uid):
- """Delete all key pairs for user"""
- keys = self.get_key_pairs(uid)
- if keys != None:
- for key in keys:
- self.delete_key_pair(uid, key['name'])
-
@staticmethod
def __role_to_dn(role, project_id=None):
"""Convert role to corresponding dn"""
@@ -472,18 +454,6 @@ class LdapDriver(object):
'secret': attr['secretKey'][0],
'admin': (attr['isAdmin'][0] == 'TRUE')}
- @staticmethod
- def __to_key_pair(owner, attr):
- """Convert ldap attributes to KeyPair object"""
- if attr == None:
- return None
- return {
- 'id': attr['cn'][0],
- 'name': attr['cn'][0],
- 'owner_id': owner,
- 'public_key': attr['sshPublicKey'][0],
- 'fingerprint': attr['keyFingerprint'][0]}
-
def __to_project(self, attr):
"""Convert ldap attributes to Project object"""
if attr == None:
diff --git a/nova/auth/manager.py b/nova/auth/manager.py
index d5fbec7c5..ce8a294df 100644
--- a/nova/auth/manager.py
+++ b/nova/auth/manager.py
@@ -44,7 +44,7 @@ flags.DEFINE_list('allowed_roles',
# NOTE(vish): a user with one of these roles will be a superuser and
# have access to all api commands
flags.DEFINE_list('superuser_roles', ['cloudadmin'],
- 'Roles that ignore rbac checking completely')
+ 'Roles that ignore authorization checking completely')
# NOTE(vish): a user with one of these roles will have it for every
# project, even if he or she is not a member of the project
@@ -69,7 +69,7 @@ flags.DEFINE_string('credential_cert_subject',
'/C=US/ST=California/L=MountainView/O=AnsoLabs/'
'OU=NovaDev/CN=%s-%s',
'Subject for certificate for users')
-flags.DEFINE_string('auth_driver', 'nova.auth.ldapdriver.FakeLdapDriver',
+flags.DEFINE_string('auth_driver', 'nova.auth.dbdriver.DbDriver',
'Driver that auth manager uses')
@@ -128,24 +128,6 @@ class User(AuthBase):
def is_project_manager(self, project):
return AuthManager().is_project_manager(self, project)
- def generate_key_pair(self, name):
- return AuthManager().generate_key_pair(self.id, name)
-
- def create_key_pair(self, name, public_key, fingerprint):
- return AuthManager().create_key_pair(self.id,
- name,
- public_key,
- fingerprint)
-
- def get_key_pair(self, name):
- return AuthManager().get_key_pair(self.id, name)
-
- def delete_key_pair(self, name):
- return AuthManager().delete_key_pair(self.id, name)
-
- def get_key_pairs(self):
- return AuthManager().get_key_pairs(self.id)
-
def __repr__(self):
return "User('%s', '%s', '%s', '%s', %s)" % (self.id,
self.name,
@@ -154,29 +136,6 @@ class User(AuthBase):
self.admin)
-class KeyPair(AuthBase):
- """Represents an ssh key returned from the datastore
-
- Even though this object is named KeyPair, only the public key and
- fingerprint is stored. The user's private key is not saved.
- """
-
- def __init__(self, id, name, owner_id, public_key, fingerprint):
- AuthBase.__init__(self)
- self.id = id
- self.name = name
- self.owner_id = owner_id
- self.public_key = public_key
- self.fingerprint = fingerprint
-
- def __repr__(self):
- return "KeyPair('%s', '%s', '%s', '%s', '%s')" % (self.id,
- self.name,
- self.owner_id,
- self.public_key,
- self.fingerprint)
-
-
class Project(AuthBase):
"""Represents a Project returned from the datastore"""
@@ -307,7 +266,7 @@ class AuthManager(object):
# NOTE(vish): if we stop using project name as id we need better
# logic to find a default project for user
- if project_id is '':
+ if project_id == '':
project_id = user.name
project = self.get_project(project_id)
@@ -345,7 +304,7 @@ class AuthManager(object):
return "%s:%s" % (user.access, Project.safe_id(project))
def is_superuser(self, user):
- """Checks for superuser status, allowing user to bypass rbac
+ """Checks for superuser status, allowing user to bypass authorization
@type user: User or uid
@param user: User to check.
@@ -533,6 +492,26 @@ class AuthManager(object):
raise
return project
+ def modify_project(self, project, manager_user=None, description=None):
+ """Modify a project
+
+ @type name: Project or project_id
+ @param project: The project to modify.
+
+ @type manager_user: User or uid
+ @param manager_user: This user will be the new project manager.
+
+ @type description: str
+ @param project: This will be the new description of the project.
+
+ """
+ if manager_user:
+ manager_user = User.safe_id(manager_user)
+ with self.driver() as drv:
+ drv.modify_project(Project.safe_id(project),
+ manager_user,
+ description)
+
def add_to_project(self, user, project):
"""Add user to project"""
with self.driver() as drv:
@@ -643,67 +622,19 @@ class AuthManager(object):
return User(**user_dict)
def delete_user(self, user):
- """Deletes a user"""
- with self.driver() as drv:
- drv.delete_user(User.safe_id(user))
-
- def generate_key_pair(self, user, key_name):
- """Generates a key pair for a user
-
- Generates a public and private key, stores the public key using the
- key_name, and returns the private key and fingerprint.
-
- @type user: User or uid
- @param user: User for which to create key pair.
-
- @type key_name: str
- @param key_name: Name to use for the generated KeyPair.
+ """Deletes a user
- @rtype: tuple (private_key, fingerprint)
- @return: A tuple containing the private_key and fingerprint.
- """
- # NOTE(vish): generating key pair is slow so check for legal
- # creation before creating keypair
+ Additionally deletes all users key_pairs"""
uid = User.safe_id(user)
+ db.key_pair_destroy_all_by_user(None, uid)
with self.driver() as drv:
- if not drv.get_user(uid):
- raise exception.NotFound("User %s doesn't exist" % user)
- if drv.get_key_pair(uid, key_name):
- raise exception.Duplicate("The keypair %s already exists"
- % key_name)
- private_key, public_key, fingerprint = crypto.generate_key_pair()
- self.create_key_pair(uid, key_name, public_key, fingerprint)
- return private_key, fingerprint
-
- def create_key_pair(self, user, key_name, public_key, fingerprint):
- """Creates a key pair for user"""
- with self.driver() as drv:
- kp_dict = drv.create_key_pair(User.safe_id(user),
- key_name,
- public_key,
- fingerprint)
- if kp_dict:
- return KeyPair(**kp_dict)
-
- def get_key_pair(self, user, key_name):
- """Retrieves a key pair for user"""
- with self.driver() as drv:
- kp_dict = drv.get_key_pair(User.safe_id(user), key_name)
- if kp_dict:
- return KeyPair(**kp_dict)
+ drv.delete_user(uid)
- def get_key_pairs(self, user):
- """Retrieves all key pairs for user"""
- with self.driver() as drv:
- kp_list = drv.get_key_pairs(User.safe_id(user))
- if not kp_list:
- return []
- return [KeyPair(**kp_dict) for kp_dict in kp_list]
-
- def delete_key_pair(self, user, key_name):
- """Deletes a key pair for user"""
+ def modify_user(self, user, access_key=None, secret_key=None, admin=None):
+ """Modify credentials for a user"""
+ uid = User.safe_id(user)
with self.driver() as drv:
- drv.delete_key_pair(User.safe_id(user), key_name)
+ drv.modify_user(uid, access_key, secret_key, admin)
def get_credentials(self, user, project=None):
"""Get credential zip for user in project"""
diff --git a/nova/auth/rbac.py b/nova/auth/rbac.py
deleted file mode 100644
index d157f44b3..000000000
--- a/nova/auth/rbac.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2010 United States Government as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""Role-based access control decorators to use fpr wrapping other
-methods with."""
-
-from nova import exception
-
-
-def allow(*roles):
- """Allow the given roles access the wrapped function."""
-
- def wrap(func): # pylint: disable-msg=C0111
-
- def wrapped_func(self, context, *args,
- **kwargs): # pylint: disable-msg=C0111
- if context.user.is_superuser():
- return func(self, context, *args, **kwargs)
- for role in roles:
- if __matches_role(context, role):
- return func(self, context, *args, **kwargs)
- raise exception.NotAuthorized()
-
- return wrapped_func
-
- return wrap
-
-
-def deny(*roles):
- """Deny the given roles access the wrapped function."""
-
- def wrap(func): # pylint: disable-msg=C0111
-
- def wrapped_func(self, context, *args,
- **kwargs): # pylint: disable-msg=C0111
- if context.user.is_superuser():
- return func(self, context, *args, **kwargs)
- for role in roles:
- if __matches_role(context, role):
- raise exception.NotAuthorized()
- return func(self, context, *args, **kwargs)
-
- return wrapped_func
-
- return wrap
-
-
-def __matches_role(context, role):
- """Check if a role is allowed."""
- if role == 'all':
- return True
- if role == 'none':
- return False
- return context.project.has_role(context.user.id, role)
diff --git a/nova/cloudpipe/api.py b/nova/cloudpipe/api.py
deleted file mode 100644
index 56aa89834..000000000
--- a/nova/cloudpipe/api.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2010 United States Government as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""
-Tornado REST API Request Handlers for CloudPipe
-"""
-
-import logging
-import urllib
-
-import tornado.web
-
-from nova import crypto
-from nova.auth import manager
-
-
-_log = logging.getLogger("api")
-_log.setLevel(logging.DEBUG)
-
-
-class CloudPipeRequestHandler(tornado.web.RequestHandler):
- def get(self, path):
- path = self.request.path
- _log.debug( "Cloudpipe path is %s" % path)
- if path.endswith("/getca/"):
- self.send_root_ca()
- self.finish()
-
- def get_project_id_from_ip(self, ip):
- cc = self.application.controllers['Cloud']
- instance = cc.get_instance_by_ip(ip)
- instance['project_id']
-
- def send_root_ca(self):
- _log.debug( "Getting root ca")
- project_id = self.get_project_id_from_ip(self.request.remote_ip)
- self.set_header("Content-Type", "text/plain")
- self.write(crypto.fetch_ca(project_id))
-
- def post(self, *args, **kwargs):
- project_id = self.get_project_id_from_ip(self.request.remote_ip)
- cert = self.get_argument('cert', '')
- self.write(crypto.sign_csr(urllib.unquote(cert), project_id))
- self.finish()
diff --git a/nova/cloudpipe/pipelib.py b/nova/cloudpipe/pipelib.py
index 2867bcb21..706a175d9 100644
--- a/nova/cloudpipe/pipelib.py
+++ b/nova/cloudpipe/pipelib.py
@@ -32,7 +32,9 @@ from nova import exception
from nova import flags
from nova import utils
from nova.auth import manager
-from nova.endpoint import api
+# TODO(eday): Eventually changes these to something not ec2-specific
+from nova.api.ec2 import cloud
+from nova.api.ec2 import context
FLAGS = flags.FLAGS
@@ -42,8 +44,8 @@ flags.DEFINE_string('boot_script_template',
class CloudPipe(object):
- def __init__(self, cloud_controller):
- self.controller = cloud_controller
+ def __init__(self):
+ self.controller = cloud.CloudController()
self.manager = manager.AuthManager()
def launch_vpn_instance(self, project_id):
@@ -58,9 +60,9 @@ class CloudPipe(object):
z.write(FLAGS.boot_script_template,'autorun.sh')
z.close()
- key_name = self.setup_keypair(project.project_manager_id, project_id)
+ key_name = self.setup_key_pair(project.project_manager_id, project_id)
zippy = open(zippath, "r")
- context = api.APIRequestContext(handler=None, user=project.project_manager, project=project)
+ context = context.APIRequestContext(user=project.project_manager, project=project)
reservation = self.controller.run_instances(context,
# run instances expects encoded userdata, it is decoded in the get_metadata_call
@@ -74,7 +76,7 @@ class CloudPipe(object):
security_groups=["vpn-secgroup"])
zippy.close()
- def setup_keypair(self, user_id, project_id):
+ def setup_key_pair(self, user_id, project_id):
key_name = '%s%s' % (project_id, FLAGS.vpn_key_suffix)
try:
private_key, fingerprint = self.manager.generate_key_pair(user_id, key_name)
diff --git a/nova/compute/manager.py b/nova/compute/manager.py
index cb6434694..8b98a2b0e 100644
--- a/nova/compute/manager.py
+++ b/nova/compute/manager.py
@@ -67,7 +67,7 @@ class ComputeManager(manager.Manager):
def run_instance(self, context, instance_id, **_kwargs):
"""Launch a new instance with specified options."""
instance_ref = self.db.instance_get(context, instance_id)
- if instance_ref['str_id'] in self.driver.list_instances():
+ if instance_ref['ec2_id'] in self.driver.list_instances():
raise exception.Error("Instance has already been created")
logging.debug("instance %s: starting...", instance_id)
project_id = instance_ref['project_id']
@@ -129,7 +129,7 @@ class ComputeManager(manager.Manager):
raise exception.Error(
'trying to reboot a non-running'
'instance: %s (state: %s excepted: %s)' %
- (instance_ref['str_id'],
+ (instance_ref['ec2_id'],
instance_ref['state'],
power_state.RUNNING))
@@ -158,7 +158,7 @@ class ComputeManager(manager.Manager):
instance_ref = self.db.instance_get(context, instance_id)
dev_path = yield self.volume_manager.setup_compute_volume(context,
volume_id)
- yield self.driver.attach_volume(instance_ref['str_id'],
+ yield self.driver.attach_volume(instance_ref['ec2_id'],
dev_path,
mountpoint)
self.db.volume_attached(context, volume_id, instance_id, mountpoint)
@@ -173,7 +173,7 @@ class ComputeManager(manager.Manager):
volume_id)
instance_ref = self.db.instance_get(context, instance_id)
volume_ref = self.db.volume_get(context, volume_id)
- self.driver.detach_volume(instance_ref['str_id'],
- volume_ref['mountpoint'])
+ yield self.driver.detach_volume(instance_ref['ec2_id'],
+ volume_ref['mountpoint'])
self.db.volume_detached(context, volume_id)
defer.returnValue(True)
diff --git a/nova/crypto.py b/nova/crypto.py
index b05548ea1..1c6fe57ad 100644
--- a/nova/crypto.py
+++ b/nova/crypto.py
@@ -18,7 +18,7 @@
"""
Wrappers around standard crypto, including root and intermediate CAs,
-SSH keypairs and x509 certificates.
+SSH key_pairs and x509 certificates.
"""
import base64
diff --git a/nova/db/api.py b/nova/db/api.py
index 9f6ff99c3..a6d1f405a 100644
--- a/nova/db/api.py
+++ b/nova/db/api.py
@@ -161,20 +161,20 @@ def floating_ip_get_all(context):
def floating_ip_get_all_by_host(context, host):
- """Get all floating ips."""
+ """Get all floating ips by host."""
return IMPL.floating_ip_get_all_by_host(context, host)
+def floating_ip_get_all_by_project(context, project_id):
+ """Get all floating ips by project."""
+ return IMPL.floating_ip_get_all_by_project(context, project_id)
+
+
def floating_ip_get_by_address(context, address):
"""Get a floating ip by address or raise if it doesn't exist."""
return IMPL.floating_ip_get_by_address(context, address)
-def floating_ip_get_instance(context, address):
- """Get an instance for a floating ip by address."""
- return IMPL.floating_ip_get_instance(context, address)
-
-
####################
@@ -204,6 +204,11 @@ def fixed_ip_disassociate(context, address):
return IMPL.fixed_ip_disassociate(context, address)
+def fixed_ip_disassociate_all_by_timeout(context, host, time):
+ """Disassociate old fixed ips from host"""
+ return IMPL.fixed_ip_disassociate_all_by_timeout(context, host, time)
+
+
def fixed_ip_get_by_address(context, address):
"""Get a fixed ip by address or raise if it does not exist."""
return IMPL.fixed_ip_get_by_address(context, address)
@@ -251,15 +256,18 @@ def instance_get_all(context):
"""Get all instances."""
return IMPL.instance_get_all(context)
+def instance_get_all_by_user(context, user_id):
+ """Get all instances."""
+ return IMPL.instance_get_all(context, user_id)
-def instance_get_by_project(context, project_id):
+def instance_get_all_by_project(context, project_id):
"""Get all instance belonging to a project."""
- return IMPL.instance_get_by_project(context, project_id)
+ return IMPL.instance_get_all_by_project(context, project_id)
-def instance_get_by_reservation(context, reservation_id):
+def instance_get_all_by_reservation(context, reservation_id):
"""Get all instance belonging to a reservation."""
- return IMPL.instance_get_by_reservation(context, reservation_id)
+ return IMPL.instance_get_all_by_reservation(context, reservation_id)
def instance_get_fixed_address(context, instance_id):
@@ -272,9 +280,9 @@ def instance_get_floating_address(context, instance_id):
return IMPL.instance_get_floating_address(context, instance_id)
-def instance_get_by_str(context, str_id):
- """Get an instance by string id."""
- return IMPL.instance_get_by_str(context, str_id)
+def instance_get_by_ec2_id(context, ec2_id):
+ """Get an instance by ec2 id."""
+ return IMPL.instance_get_by_ec2_id(context, ec2_id)
def instance_is_vpn(context, instance_id):
@@ -296,6 +304,34 @@ def instance_update(context, instance_id, values):
return IMPL.instance_update(context, instance_id, values)
+###################
+
+
+def key_pair_create(context, values):
+ """Create a key_pair from the values dictionary."""
+ return IMPL.key_pair_create(context, values)
+
+
+def key_pair_destroy(context, user_id, name):
+ """Destroy the key_pair or raise if it does not exist."""
+ return IMPL.key_pair_destroy(context, user_id, name)
+
+
+def key_pair_destroy_all_by_user(context, user_id):
+ """Destroy all key_pairs by user."""
+ return IMPL.key_pair_destroy_all_by_user(context, user_id)
+
+
+def key_pair_get(context, user_id, name):
+ """Get a key_pair or raise if it does not exist."""
+ return IMPL.key_pair_get(context, user_id, name)
+
+
+def key_pair_get_all_by_user(context, user_id):
+ """Get all key_pairs by user."""
+ return IMPL.key_pair_get_all_by_user(context, user_id)
+
+
####################
@@ -365,9 +401,12 @@ def network_index_count(context):
return IMPL.network_index_count(context)
-def network_index_create(context, values):
- """Create a network index from the values dict"""
- return IMPL.network_index_create(context, values)
+def network_index_create_safe(context, values):
+ """Create a network index from the values dict
+
+ The index is not returned. If the create violates the unique
+ constraints because the index already exists, no exception is raised."""
+ return IMPL.network_index_create_safe(context, values)
def network_set_cidr(context, network_id, cidr):
@@ -420,6 +459,21 @@ def export_device_create(context, values):
###################
+def auth_destroy_token(context, token):
+ """Destroy an auth token"""
+ return IMPL.auth_destroy_token(context, token)
+
+def auth_get_token(context, token_hash):
+ """Retrieves a token given the hash representing it"""
+ return IMPL.auth_get_token(context, token_hash)
+
+def auth_create_token(context, token):
+ """Creates a new token"""
+ return IMPL.auth_create_token(context, token_hash, token)
+
+
+###################
+
def quota_create(context, values):
"""Create a quota from the values dictionary."""
@@ -489,14 +543,14 @@ def volume_get_instance(context, volume_id):
return IMPL.volume_get_instance(context, volume_id)
-def volume_get_by_project(context, project_id):
+def volume_get_all_by_project(context, project_id):
"""Get all volumes belonging to a project."""
- return IMPL.volume_get_by_project(context, project_id)
+ return IMPL.volume_get_all_by_project(context, project_id)
-def volume_get_by_str(context, str_id):
- """Get a volume by string id."""
- return IMPL.volume_get_by_str(context, str_id)
+def volume_get_by_ec2_id(context, ec2_id):
+ """Get a volume by ec2 id."""
+ return IMPL.volume_get_by_ec2_id(context, ec2_id)
def volume_get_shelf_and_blade(context, volume_id):
@@ -511,3 +565,122 @@ def volume_update(context, volume_id, values):
"""
return IMPL.volume_update(context, volume_id, values)
+
+
+###################
+
+
+def user_get(context, id):
+ """Get user by id"""
+ return IMPL.user_get(context, id)
+
+
+def user_get_by_uid(context, uid):
+ """Get user by uid"""
+ return IMPL.user_get_by_uid(context, uid)
+
+
+def user_get_by_access_key(context, access_key):
+ """Get user by access key"""
+ return IMPL.user_get_by_access_key(context, access_key)
+
+
+def user_create(context, values):
+ """Create a new user"""
+ return IMPL.user_create(context, values)
+
+
+def user_delete(context, id):
+ """Delete a user"""
+ return IMPL.user_delete(context, id)
+
+
+def user_get_all(context):
+ """Create a new user"""
+ return IMPL.user_get_all(context)
+
+
+def user_add_role(context, user_id, role):
+ """Add another global role for user"""
+ return IMPL.user_add_role(context, user_id, role)
+
+
+def user_remove_role(context, user_id, role):
+ """Remove global role from user"""
+ return IMPL.user_remove_role(context, user_id, role)
+
+
+def user_get_roles(context, user_id):
+ """Get global roles for user"""
+ return IMPL.user_get_roles(context, user_id)
+
+
+def user_add_project_role(context, user_id, project_id, role):
+ """Add project role for user"""
+ return IMPL.user_add_project_role(context, user_id, project_id, role)
+
+
+def user_remove_project_role(context, user_id, project_id, role):
+ """Remove project role from user"""
+ return IMPL.user_remove_project_role(context, user_id, project_id, role)
+
+
+def user_get_roles_for_project(context, user_id, project_id):
+ """Return list of roles a user holds on project"""
+ return IMPL.user_get_roles_for_project(context, user_id, project_id)
+
+
+def user_update(context, user_id, values):
+ """Update user"""
+ return IMPL.user_update(context, user_id, values)
+
+
+def project_get(context, id):
+ """Get project by id"""
+ return IMPL.project_get(context, id)
+
+
+def project_create(context, values):
+ """Create a new project"""
+ return IMPL.project_create(context, values)
+
+
+def project_add_member(context, project_id, user_id):
+ """Add user to project"""
+ return IMPL.project_add_member(context, project_id, user_id)
+
+
+def project_get_all(context):
+ """Get all projects"""
+ return IMPL.project_get_all(context)
+
+
+def project_get_by_user(context, user_id):
+ """Get all projects of which the given user is a member"""
+ return IMPL.project_get_by_user(context, user_id)
+
+
+def project_remove_member(context, project_id, user_id):
+ """Remove the given user from the given project"""
+ return IMPL.project_remove_member(context, project_id, user_id)
+
+
+def project_update(context, project_id, values):
+ """Update Remove the given user from the given project"""
+ return IMPL.project_update(context, project_id, values)
+
+
+def project_delete(context, project_id):
+ """Delete project"""
+ return IMPL.project_delete(context, project_id)
+
+
+###################
+
+
+def host_get_networks(context, host):
+ """Return all networks for which the given host is the designated
+ network host
+ """
+ return IMPL.host_get_networks(context, host)
+
diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py
index d612fe669..e0c6a34b8 100644
--- a/nova/db/sqlalchemy/api.py
+++ b/nova/db/sqlalchemy/api.py
@@ -19,62 +19,140 @@
Implementation of SQLAlchemy backend
"""
+import warnings
+
from nova import db
from nova import exception
from nova import flags
+from nova import utils
from nova.db.sqlalchemy import models
from nova.db.sqlalchemy.session import get_session
from sqlalchemy import or_
-from sqlalchemy.orm import joinedload_all
-from sqlalchemy.sql import func
+from sqlalchemy.exc import IntegrityError
+from sqlalchemy.orm import joinedload, joinedload_all
+from sqlalchemy.sql import exists, func
FLAGS = flags.FLAGS
-# NOTE(vish): disabling docstring pylint because the docstrings are
-# in the interface definition
-# pylint: disable-msg=C0111
-def _deleted(context):
- """Calculates whether to include deleted objects based on context.
+def is_admin_context(context):
+ """Indicates if the request context is an administrator."""
+ if not context:
+ warnings.warn('Use of empty request context is deprecated',
+ DeprecationWarning)
+ return True
+ return context.is_admin
+
+
+def is_user_context(context):
+ """Indicates if the request context is a normal user."""
+ if not context:
+ return False
+ if not context.user or not context.project:
+ return False
+ return True
+
+
+def authorize_project_context(context, project_id):
+ """Ensures that the request context has permission to access the
+ given project.
+ """
+ if is_user_context(context):
+ if not context.project:
+ raise exception.NotAuthorized()
+ elif context.project.id != project_id:
+ raise exception.NotAuthorized()
+
- Currently just looks for a flag called deleted in the context dict.
+def authorize_user_context(context, user_id):
+ """Ensures that the request context has permission to access the
+ given user.
"""
- if not hasattr(context, 'get'):
+ if is_user_context(context):
+ if not context.user:
+ raise exception.NotAuthorized()
+ elif context.user.id != user_id:
+ raise exception.NotAuthorized()
+
+
+def can_read_deleted(context):
+ """Indicates if the context has access to deleted objects."""
+ if not context:
return False
- return context.get('deleted', False)
+ return context.read_deleted
-###################
+def require_admin_context(f):
+ """Decorator used to indicate that the method requires an
+ administrator context.
+ """
+ def wrapper(*args, **kwargs):
+ if not is_admin_context(args[0]):
+ raise exception.NotAuthorized()
+ return f(*args, **kwargs)
+ return wrapper
+
+
+def require_context(f):
+ """Decorator used to indicate that the method requires either
+ an administrator or normal user context.
+ """
+ def wrapper(*args, **kwargs):
+ if not is_admin_context(args[0]) and not is_user_context(args[0]):
+ raise exception.NotAuthorized()
+ return f(*args, **kwargs)
+ return wrapper
+###################
+
+@require_admin_context
def service_destroy(context, service_id):
session = get_session()
with session.begin():
- service_ref = models.Service.find(service_id, session=session)
+ service_ref = service_get(context, service_id, session=session)
service_ref.delete(session=session)
-def service_get(_context, service_id):
- return models.Service.find(service_id)
+@require_admin_context
+def service_get(context, service_id, session=None):
+ if not session:
+ session = get_session()
+
+ result = session.query(models.Service
+ ).filter_by(id=service_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+
+ if not result:
+ raise exception.NotFound('No service for id %s' % service_id)
+
+ return result
+
+@require_admin_context
def service_get_all_by_topic(context, topic):
session = get_session()
return session.query(models.Service
).filter_by(deleted=False
+ ).filter_by(disabled=False
).filter_by(topic=topic
).all()
-def _service_get_all_topic_subquery(_context, session, topic, subq, label):
+@require_admin_context
+def _service_get_all_topic_subquery(context, session, topic, subq, label):
sort_value = getattr(subq.c, label)
return session.query(models.Service, func.coalesce(sort_value, 0)
).filter_by(topic=topic
).filter_by(deleted=False
+ ).filter_by(disabled=False
).outerjoin((subq, models.Service.host == subq.c.host)
).order_by(sort_value
).all()
+@require_admin_context
def service_get_all_compute_sorted(context):
session = get_session()
with session.begin():
@@ -99,6 +177,7 @@ def service_get_all_compute_sorted(context):
label)
+@require_admin_context
def service_get_all_network_sorted(context):
session = get_session()
with session.begin():
@@ -116,6 +195,7 @@ def service_get_all_network_sorted(context):
label)
+@require_admin_context
def service_get_all_volume_sorted(context):
session = get_session()
with session.begin():
@@ -133,11 +213,22 @@ def service_get_all_volume_sorted(context):
label)
-def service_get_by_args(_context, host, binary):
- return models.Service.find_by_args(host, binary)
+@require_admin_context
+def service_get_by_args(context, host, binary):
+ session = get_session()
+ result = session.query(models.Service
+ ).filter_by(host=host
+ ).filter_by(binary=binary
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ if not result:
+ raise exception.NotFound('No service for %s, %s' % (host, binary))
+
+ return result
-def service_create(_context, values):
+@require_admin_context
+def service_create(context, values):
service_ref = models.Service()
for (key, value) in values.iteritems():
service_ref[key] = value
@@ -145,10 +236,11 @@ def service_create(_context, values):
return service_ref
-def service_update(_context, service_id, values):
+@require_admin_context
+def service_update(context, service_id, values):
session = get_session()
with session.begin():
- service_ref = models.Service.find(service_id, session=session)
+ service_ref = service_get(context, service_id, session=session)
for (key, value) in values.iteritems():
service_ref[key] = value
service_ref.save(session=session)
@@ -157,12 +249,15 @@ def service_update(_context, service_id, values):
###################
-def floating_ip_allocate_address(_context, host, project_id):
+@require_context
+def floating_ip_allocate_address(context, host, project_id):
+ authorize_project_context(context, project_id)
session = get_session()
with session.begin():
floating_ip_ref = session.query(models.FloatingIp
).filter_by(host=host
).filter_by(fixed_ip_id=None
+ ).filter_by(project_id=None
).filter_by(deleted=False
).with_lockmode('update'
).first()
@@ -175,7 +270,8 @@ def floating_ip_allocate_address(_context, host, project_id):
return floating_ip_ref['address']
-def floating_ip_create(_context, values):
+@require_context
+def floating_ip_create(context, values):
floating_ip_ref = models.FloatingIp()
for (key, value) in values.iteritems():
floating_ip_ref[key] = value
@@ -183,7 +279,9 @@ def floating_ip_create(_context, values):
return floating_ip_ref['address']
-def floating_ip_count_by_project(_context, project_id):
+@require_context
+def floating_ip_count_by_project(context, project_id):
+ authorize_project_context(context, project_id)
session = get_session()
return session.query(models.FloatingIp
).filter_by(project_id=project_id
@@ -191,39 +289,53 @@ def floating_ip_count_by_project(_context, project_id):
).count()
-def floating_ip_fixed_ip_associate(_context, floating_address, fixed_address):
+@require_context
+def floating_ip_fixed_ip_associate(context, floating_address, fixed_address):
session = get_session()
with session.begin():
- floating_ip_ref = models.FloatingIp.find_by_str(floating_address,
- session=session)
- fixed_ip_ref = models.FixedIp.find_by_str(fixed_address,
- session=session)
+ # TODO(devcamcar): How to ensure floating_id belongs to user?
+ floating_ip_ref = floating_ip_get_by_address(context,
+ floating_address,
+ session=session)
+ fixed_ip_ref = fixed_ip_get_by_address(context,
+ fixed_address,
+ session=session)
floating_ip_ref.fixed_ip = fixed_ip_ref
floating_ip_ref.save(session=session)
-def floating_ip_deallocate(_context, address):
+@require_context
+def floating_ip_deallocate(context, address):
session = get_session()
with session.begin():
- floating_ip_ref = models.FloatingIp.find_by_str(address,
- session=session)
+ # TODO(devcamcar): How to ensure floating id belongs to user?
+ floating_ip_ref = floating_ip_get_by_address(context,
+ address,
+ session=session)
floating_ip_ref['project_id'] = None
floating_ip_ref.save(session=session)
-def floating_ip_destroy(_context, address):
+@require_context
+def floating_ip_destroy(context, address):
session = get_session()
with session.begin():
- floating_ip_ref = models.FloatingIp.find_by_str(address,
- session=session)
+ # TODO(devcamcar): Ensure address belongs to user.
+ floating_ip_ref = get_floating_ip_by_address(context,
+ address,
+ session=session)
floating_ip_ref.delete(session=session)
-def floating_ip_disassociate(_context, address):
+@require_context
+def floating_ip_disassociate(context, address):
session = get_session()
with session.begin():
- floating_ip_ref = models.FloatingIp.find_by_str(address,
- session=session)
+ # TODO(devcamcar): Ensure address belongs to user.
+ # Does get_floating_ip_by_address handle this?
+ floating_ip_ref = floating_ip_get_by_address(context,
+ address,
+ session=session)
fixed_ip_ref = floating_ip_ref.fixed_ip
if fixed_ip_ref:
fixed_ip_address = fixed_ip_ref['address']
@@ -234,7 +346,8 @@ def floating_ip_disassociate(_context, address):
return fixed_ip_address
-def floating_ip_get_all(_context):
+@require_admin_context
+def floating_ip_get_all(context):
session = get_session()
return session.query(models.FloatingIp
).options(joinedload_all('fixed_ip.instance')
@@ -242,7 +355,8 @@ def floating_ip_get_all(_context):
).all()
-def floating_ip_get_all_by_host(_context, host):
+@require_admin_context
+def floating_ip_get_all_by_host(context, host):
session = get_session()
return session.query(models.FloatingIp
).options(joinedload_all('fixed_ip.instance')
@@ -250,24 +364,42 @@ def floating_ip_get_all_by_host(_context, host):
).filter_by(deleted=False
).all()
-def floating_ip_get_by_address(_context, address):
- return models.FloatingIp.find_by_str(address)
-
-def floating_ip_get_instance(_context, address):
+@require_context
+def floating_ip_get_all_by_project(context, project_id):
+ authorize_project_context(context, project_id)
session = get_session()
- with session.begin():
- floating_ip_ref = models.FloatingIp.find_by_str(address,
- session=session)
- return floating_ip_ref.fixed_ip.instance
+ return session.query(models.FloatingIp
+ ).options(joinedload_all('fixed_ip.instance')
+ ).filter_by(project_id=project_id
+ ).filter_by(deleted=False
+ ).all()
+
+
+@require_context
+def floating_ip_get_by_address(context, address, session=None):
+ # TODO(devcamcar): Ensure the address belongs to user.
+ if not session:
+ session = get_session()
+
+ result = session.query(models.FloatingIp
+ ).filter_by(address=address
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ if not result:
+ raise exception.NotFound('No fixed ip for address %s' % address)
+
+ return result
###################
-def fixed_ip_associate(_context, address, instance_id):
+@require_context
+def fixed_ip_associate(context, address, instance_id):
session = get_session()
with session.begin():
+ instance = instance_get(context, instance_id, session=session)
fixed_ip_ref = session.query(models.FixedIp
).filter_by(address=address
).filter_by(deleted=False
@@ -278,12 +410,12 @@ def fixed_ip_associate(_context, address, instance_id):
# then this has concurrency issues
if not fixed_ip_ref:
raise db.NoMoreAddresses()
- fixed_ip_ref.instance = models.Instance.find(instance_id,
- session=session)
+ fixed_ip_ref.instance = instance
session.add(fixed_ip_ref)
-def fixed_ip_associate_pool(_context, network_id, instance_id):
+@require_admin_context
+def fixed_ip_associate_pool(context, network_id, instance_id):
session = get_session()
with session.begin():
network_or_none = or_(models.FixedIp.network_id == network_id,
@@ -300,14 +432,17 @@ def fixed_ip_associate_pool(_context, network_id, instance_id):
if not fixed_ip_ref:
raise db.NoMoreAddresses()
if not fixed_ip_ref.network:
- fixed_ip_ref.network = models.Network.find(network_id,
- session=session)
- fixed_ip_ref.instance = models.Instance.find(instance_id,
- session=session)
+ fixed_ip_ref.network = network_get(context,
+ network_id,
+ session=session)
+ fixed_ip_ref.instance = instance_get(context,
+ instance_id,
+ session=session)
session.add(fixed_ip_ref)
return fixed_ip_ref['address']
+@require_context
def fixed_ip_create(_context, values):
fixed_ip_ref = models.FixedIp()
for (key, value) in values.iteritems():
@@ -316,34 +451,72 @@ def fixed_ip_create(_context, values):
return fixed_ip_ref['address']
-def fixed_ip_disassociate(_context, address):
+@require_context
+def fixed_ip_disassociate(context, address):
session = get_session()
with session.begin():
- fixed_ip_ref = models.FixedIp.find_by_str(address, session=session)
+ fixed_ip_ref = fixed_ip_get_by_address(context,
+ address,
+ session=session)
fixed_ip_ref.instance = None
fixed_ip_ref.save(session=session)
-def fixed_ip_get_by_address(_context, address):
- return models.FixedIp.find_by_str(address)
+@require_admin_context
+def fixed_ip_disassociate_all_by_timeout(_context, host, time):
+ session = get_session()
+ # NOTE(vish): The nested select is because sqlite doesn't support
+ # JOINs in UPDATEs.
+ result = session.execute('UPDATE fixed_ips SET instance_id = NULL, '
+ 'leased = 0 '
+ 'WHERE network_id IN (SELECT id FROM networks '
+ 'WHERE host = :host) '
+ 'AND updated_at < :time '
+ 'AND instance_id IS NOT NULL '
+ 'AND allocated = 0',
+ {'host': host,
+ 'time': time.isoformat()})
+ return result.rowcount
+
+
+@require_context
+def fixed_ip_get_by_address(context, address, session=None):
+ if not session:
+ session = get_session()
+ result = session.query(models.FixedIp
+ ).filter_by(address=address
+ ).filter_by(deleted=can_read_deleted(context)
+ ).options(joinedload('network')
+ ).options(joinedload('instance')
+ ).first()
+ if not result:
+ raise exception.NotFound('No floating ip for address %s' % address)
+ if is_user_context(context):
+ authorize_project_context(context, result.instance.project_id)
-def fixed_ip_get_instance(_context, address):
- session = get_session()
- with session.begin():
- return models.FixedIp.find_by_str(address, session=session).instance
+ return result
-def fixed_ip_get_network(_context, address):
- session = get_session()
- with session.begin():
- return models.FixedIp.find_by_str(address, session=session).network
+@require_context
+def fixed_ip_get_instance(context, address):
+ fixed_ip_ref = fixed_ip_get_by_address(context, address)
+ return fixed_ip_ref.instance
+
+
+@require_admin_context
+def fixed_ip_get_network(context, address):
+ fixed_ip_ref = fixed_ip_get_by_address(context, address)
+ return fixed_ip_ref.network
-def fixed_ip_update(_context, address, values):
+@require_context
+def fixed_ip_update(context, address, values):
session = get_session()
with session.begin():
- fixed_ip_ref = models.FixedIp.find_by_str(address, session=session)
+ fixed_ip_ref = fixed_ip_get_by_address(context,
+ address,
+ session=session)
for (key, value) in values.iteritems():
fixed_ip_ref[key] = value
fixed_ip_ref.save(session=session)
@@ -352,15 +525,24 @@ def fixed_ip_update(_context, address, values):
###################
-def instance_create(_context, values):
+@require_context
+def instance_create(context, values):
instance_ref = models.Instance()
for (key, value) in values.iteritems():
instance_ref[key] = value
- instance_ref.save()
+
+ session = get_session()
+ with session.begin():
+ while instance_ref.ec2_id == None:
+ ec2_id = utils.generate_uid(instance_ref.__prefix__)
+ if not instance_ec2_id_exists(context, ec2_id, session=session):
+ instance_ref.ec2_id = ec2_id
+ instance_ref.save(session=session)
return instance_ref
-def instance_data_get_for_project(_context, project_id):
+@require_admin_context
+def instance_data_get_for_project(context, project_id):
session = get_session()
result = session.query(func.count(models.Instance.id),
func.sum(models.Instance.vcpus)
@@ -371,60 +553,130 @@ def instance_data_get_for_project(_context, project_id):
return (result[0] or 0, result[1] or 0)
-def instance_destroy(_context, instance_id):
+@require_context
+def instance_destroy(context, instance_id):
session = get_session()
with session.begin():
- instance_ref = models.Instance.find(instance_id, session=session)
+ instance_ref = instance_get(context, instance_id, session=session)
instance_ref.delete(session=session)
-def instance_get(context, instance_id):
- return models.Instance.find(instance_id, deleted=_deleted(context))
+@require_context
+def instance_get(context, instance_id, session=None):
+ if not session:
+ session = get_session()
+ result = None
+ if is_admin_context(context):
+ result = session.query(models.Instance
+ ).filter_by(id=instance_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Instance
+ ).filter_by(project_id=context.project.id
+ ).filter_by(id=instance_id
+ ).filter_by(deleted=False
+ ).first()
+ if not result:
+ raise exception.NotFound('No instance for id %s' % instance_id)
+ return result
+
+
+@require_admin_context
def instance_get_all(context):
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
- ).filter_by(deleted=_deleted(context)
+ ).filter_by(deleted=can_read_deleted(context)
).all()
-def instance_get_by_project(context, project_id):
+@require_admin_context
+def instance_get_all_by_user(context, user_id):
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
- ).filter_by(project_id=project_id
- ).filter_by(deleted=_deleted(context)
+ ).filter_by(deleted=can_read_deleted(context)
+ ).filter_by(user_id=user_id
).all()
-def instance_get_by_reservation(_context, reservation_id):
+@require_context
+def instance_get_all_by_project(context, project_id):
+ authorize_project_context(context, project_id)
+
session = get_session()
return session.query(models.Instance
).options(joinedload_all('fixed_ip.floating_ips')
- ).filter_by(reservation_id=reservation_id
- ).filter_by(deleted=False
+ ).filter_by(project_id=project_id
+ ).filter_by(deleted=can_read_deleted(context)
).all()
-def instance_get_by_str(context, str_id):
- return models.Instance.find_by_str(str_id, deleted=_deleted(context))
+@require_context
+def instance_get_all_by_reservation(context, reservation_id):
+ session = get_session()
+
+ if is_admin_context(context):
+ return session.query(models.Instance
+ ).options(joinedload_all('fixed_ip.floating_ips')
+ ).filter_by(reservation_id=reservation_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).all()
+ elif is_user_context(context):
+ return session.query(models.Instance
+ ).options(joinedload_all('fixed_ip.floating_ips')
+ ).filter_by(project_id=context.project.id
+ ).filter_by(reservation_id=reservation_id
+ ).filter_by(deleted=False
+ ).all()
+
+
+@require_context
+def instance_get_by_ec2_id(context, ec2_id):
+ session = get_session()
+
+ if is_admin_context(context):
+ result = session.query(models.Instance
+ ).filter_by(ec2_id=ec2_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Instance
+ ).filter_by(project_id=context.project.id
+ ).filter_by(ec2_id=ec2_id
+ ).filter_by(deleted=False
+ ).first()
+ if not result:
+ raise exception.NotFound('Instance %s not found' % (ec2_id))
+
+ return result
+
+@require_context
+def instance_ec2_id_exists(context, ec2_id, session=None):
+ if not session:
+ session = get_session()
+ return session.query(exists().where(models.Instance.id==ec2_id)).one()[0]
-def instance_get_fixed_address(_context, instance_id):
+
+@require_context
+def instance_get_fixed_address(context, instance_id):
session = get_session()
with session.begin():
- instance_ref = models.Instance.find(instance_id, session=session)
+ instance_ref = instance_get(context, instance_id, session=session)
if not instance_ref.fixed_ip:
return None
return instance_ref.fixed_ip['address']
-def instance_get_floating_address(_context, instance_id):
+@require_context
+def instance_get_floating_address(context, instance_id):
session = get_session()
with session.begin():
- instance_ref = models.Instance.find(instance_id, session=session)
+ instance_ref = instance_get(context, instance_id, session=session)
if not instance_ref.fixed_ip:
return None
if not instance_ref.fixed_ip.floating_ips:
@@ -433,12 +685,14 @@ def instance_get_floating_address(_context, instance_id):
return instance_ref.fixed_ip.floating_ips[0]['address']
+@require_admin_context
def instance_is_vpn(context, instance_id):
# TODO(vish): Move this into image code somewhere
instance_ref = instance_get(context, instance_id)
return instance_ref['image_id'] == FLAGS.vpn_image_id
+@require_admin_context
def instance_set_state(context, instance_id, state, description=None):
# TODO(devcamcar): Move this out of models and into driver
from nova.compute import power_state
@@ -450,10 +704,11 @@ def instance_set_state(context, instance_id, state, description=None):
'state_description': description})
-def instance_update(_context, instance_id, values):
+@require_context
+def instance_update(context, instance_id, values):
session = get_session()
with session.begin():
- instance_ref = models.Instance.find(instance_id, session=session)
+ instance_ref = instance_get(context, instance_id, session=session)
for (key, value) in values.iteritems():
instance_ref[key] = value
instance_ref.save(session=session)
@@ -462,11 +717,75 @@ def instance_update(_context, instance_id, values):
###################
-def network_count(_context):
- return models.Network.count()
+@require_context
+def key_pair_create(context, values):
+ key_pair_ref = models.KeyPair()
+ for (key, value) in values.iteritems():
+ key_pair_ref[key] = value
+ key_pair_ref.save()
+ return key_pair_ref
+
+
+@require_context
+def key_pair_destroy(context, user_id, name):
+ authorize_user_context(context, user_id)
+ session = get_session()
+ with session.begin():
+ key_pair_ref = key_pair_get(context, user_id, name, session=session)
+ key_pair_ref.delete(session=session)
+
+
+@require_context
+def key_pair_destroy_all_by_user(context, user_id):
+ authorize_user_context(context, user_id)
+ session = get_session()
+ with session.begin():
+ # TODO(vish): do we have to use sql here?
+ session.execute('update key_pairs set deleted=1 where user_id=:id',
+ {'id': user_id})
+
+
+@require_context
+def key_pair_get(context, user_id, name, session=None):
+ authorize_user_context(context, user_id)
+
+ if not session:
+ session = get_session()
+
+ result = session.query(models.KeyPair
+ ).filter_by(user_id=user_id
+ ).filter_by(name=name
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ if not result:
+ raise exception.NotFound('no keypair for user %s, name %s' %
+ (user_id, name))
+ return result
+
+
+@require_context
+def key_pair_get_all_by_user(context, user_id):
+ authorize_user_context(context, user_id)
+ session = get_session()
+ return session.query(models.KeyPair
+ ).filter_by(user_id=user_id
+ ).filter_by(deleted=False
+ ).all()
+
+
+###################
-def network_count_allocated_ips(_context, network_id):
+@require_admin_context
+def network_count(context):
+ session = get_session()
+ return session.query(models.Network
+ ).filter_by(deleted=can_read_deleted(context)
+ ).count()
+
+
+@require_admin_context
+def network_count_allocated_ips(context, network_id):
session = get_session()
return session.query(models.FixedIp
).filter_by(network_id=network_id
@@ -475,7 +794,8 @@ def network_count_allocated_ips(_context, network_id):
).count()
-def network_count_available_ips(_context, network_id):
+@require_admin_context
+def network_count_available_ips(context, network_id):
session = get_session()
return session.query(models.FixedIp
).filter_by(network_id=network_id
@@ -485,7 +805,8 @@ def network_count_available_ips(_context, network_id):
).count()
-def network_count_reserved_ips(_context, network_id):
+@require_admin_context
+def network_count_reserved_ips(context, network_id):
session = get_session()
return session.query(models.FixedIp
).filter_by(network_id=network_id
@@ -494,7 +815,8 @@ def network_count_reserved_ips(_context, network_id):
).count()
-def network_create(_context, values):
+@require_admin_context
+def network_create(context, values):
network_ref = models.Network()
for (key, value) in values.iteritems():
network_ref[key] = value
@@ -502,7 +824,8 @@ def network_create(_context, values):
return network_ref
-def network_destroy(_context, network_id):
+@require_admin_context
+def network_destroy(context, network_id):
session = get_session()
with session.begin():
# TODO(vish): do we have to use sql here?
@@ -520,34 +843,59 @@ def network_destroy(_context, network_id):
{'id': network_id})
-def network_get(_context, network_id):
- return models.Network.find(network_id)
+@require_context
+def network_get(context, network_id, session=None):
+ if not session:
+ session = get_session()
+ result = None
+
+ if is_admin_context(context):
+ result = session.query(models.Network
+ ).filter_by(id=network_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Network
+ ).filter_by(project_id=context.project.id
+ ).filter_by(id=network_id
+ ).filter_by(deleted=False
+ ).first()
+ if not result:
+ raise exception.NotFound('No network for id %s' % network_id)
+
+ return result
# NOTE(vish): pylint complains because of the long method name, but
# it fits with the names of the rest of the methods
# pylint: disable-msg=C0103
-def network_get_associated_fixed_ips(_context, network_id):
+@require_admin_context
+def network_get_associated_fixed_ips(context, network_id):
session = get_session()
return session.query(models.FixedIp
+ ).options(joinedload_all('instance')
).filter_by(network_id=network_id
).filter(models.FixedIp.instance_id != None
).filter_by(deleted=False
).all()
-def network_get_by_bridge(_context, bridge):
+@require_admin_context
+def network_get_by_bridge(context, bridge):
session = get_session()
- rv = session.query(models.Network
+ result = session.query(models.Network
).filter_by(bridge=bridge
).filter_by(deleted=False
).first()
- if not rv:
+
+ if not result:
raise exception.NotFound('No network for bridge %s' % bridge)
- return rv
+
+ return result
-def network_get_index(_context, network_id):
+@require_admin_context
+def network_get_index(context, network_id):
session = get_session()
with session.begin():
network_index = session.query(models.NetworkIndex
@@ -555,48 +903,63 @@ def network_get_index(_context, network_id):
).filter_by(deleted=False
).with_lockmode('update'
).first()
+
if not network_index:
raise db.NoMoreNetworks()
- network_index['network'] = models.Network.find(network_id,
- session=session)
+
+ network_index['network'] = network_get(context,
+ network_id,
+ session=session)
session.add(network_index)
+
return network_index['index']
-def network_index_count(_context):
- return models.NetworkIndex.count()
+@require_admin_context
+def network_index_count(context):
+ session = get_session()
+ return session.query(models.NetworkIndex
+ ).filter_by(deleted=can_read_deleted(context)
+ ).count()
-def network_index_create(_context, values):
+@require_admin_context
+def network_index_create_safe(context, values):
network_index_ref = models.NetworkIndex()
for (key, value) in values.iteritems():
network_index_ref[key] = value
- network_index_ref.save()
+ try:
+ network_index_ref.save()
+ except IntegrityError:
+ pass
-def network_set_host(_context, network_id, host_id):
+@require_admin_context
+def network_set_host(context, network_id, host_id):
session = get_session()
with session.begin():
- network = session.query(models.Network
- ).filter_by(id=network_id
- ).filter_by(deleted=False
- ).with_lockmode('update'
- ).first()
- if not network:
- raise exception.NotFound("Couldn't find network with %s" %
- network_id)
+ network_ref = session.query(models.Network
+ ).filter_by(id=network_id
+ ).filter_by(deleted=False
+ ).with_lockmode('update'
+ ).first()
+ if not network_ref:
+ raise exception.NotFound('No network for id %s' % network_id)
+
# NOTE(vish): if with_lockmode isn't supported, as in sqlite,
# then this has concurrency issues
- if not network['host']:
- network['host'] = host_id
- session.add(network)
- return network['host']
+ if not network_ref['host']:
+ network_ref['host'] = host_id
+ session.add(network_ref)
+ return network_ref['host']
-def network_update(_context, network_id, values):
+
+@require_context
+def network_update(context, network_id, values):
session = get_session()
with session.begin():
- network_ref = models.Network.find(network_id, session=session)
+ network_ref = network_get(context, network_id, session=session)
for (key, value) in values.iteritems():
network_ref[key] = value
network_ref.save(session=session)
@@ -605,15 +968,18 @@ def network_update(_context, network_id, values):
###################
-def project_get_network(_context, project_id):
+@require_context
+def project_get_network(context, project_id):
session = get_session()
- rv = session.query(models.Network
+ result= session.query(models.Network
).filter_by(project_id=project_id
).filter_by(deleted=False
).first()
- if not rv:
+
+ if not result:
raise exception.NotFound('No network for project: %s' % project_id)
- return rv
+
+ return result
###################
@@ -623,14 +989,20 @@ def queue_get_for(_context, topic, physical_node_id):
# FIXME(ja): this should be servername?
return "%s.%s" % (topic, physical_node_id)
+
###################
-def export_device_count(_context):
- return models.ExportDevice.count()
+@require_admin_context
+def export_device_count(context):
+ session = get_session()
+ return session.query(models.ExportDevice
+ ).filter_by(deleted=can_read_deleted(context)
+ ).count()
-def export_device_create(_context, values):
+@require_admin_context
+def export_device_create(context, values):
export_device_ref = models.ExportDevice()
for (key, value) in values.iteritems():
export_device_ref[key] = value
@@ -641,7 +1013,46 @@ def export_device_create(_context, values):
###################
-def quota_create(_context, values):
+def auth_destroy_token(_context, token):
+ session = get_session()
+ session.delete(token)
+
+def auth_get_token(_context, token_hash):
+ session = get_session()
+ tk = session.query(models.AuthToken
+ ).filter_by(token_hash=token_hash)
+ if not tk:
+ raise exception.NotFound('Token %s does not exist' % token_hash)
+ return tk
+
+def auth_create_token(_context, token):
+ tk = models.AuthToken()
+ for k,v in token.iteritems():
+ tk[k] = v
+ tk.save()
+ return tk
+
+
+###################
+
+
+@require_admin_context
+def quota_get(context, project_id, session=None):
+ if not session:
+ session = get_session()
+
+ result = session.query(models.Quota
+ ).filter_by(project_id=project_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ if not result:
+ raise exception.NotFound('No quota for project_id %s' % project_id)
+
+ return result
+
+
+@require_admin_context
+def quota_create(context, values):
quota_ref = models.Quota()
for (key, value) in values.iteritems():
quota_ref[key] = value
@@ -649,30 +1060,29 @@ def quota_create(_context, values):
return quota_ref
-def quota_get(_context, project_id):
- return models.Quota.find_by_str(project_id)
-
-
-def quota_update(_context, project_id, values):
+@require_admin_context
+def quota_update(context, project_id, values):
session = get_session()
with session.begin():
- quota_ref = models.Quota.find_by_str(project_id, session=session)
+ quota_ref = quota_get(context, project_id, session=session)
for (key, value) in values.iteritems():
quota_ref[key] = value
quota_ref.save(session=session)
-def quota_destroy(_context, project_id):
+@require_admin_context
+def quota_destroy(context, project_id):
session = get_session()
with session.begin():
- quota_ref = models.Quota.find_by_str(project_id, session=session)
+ quota_ref = quota_get(context, project_id, session=session)
quota_ref.delete(session=session)
###################
-def volume_allocate_shelf_and_blade(_context, volume_id):
+@require_admin_context
+def volume_allocate_shelf_and_blade(context, volume_id):
session = get_session()
with session.begin():
export_device = session.query(models.ExportDevice
@@ -689,27 +1099,36 @@ def volume_allocate_shelf_and_blade(_context, volume_id):
return (export_device.shelf_id, export_device.blade_id)
-def volume_attached(_context, volume_id, instance_id, mountpoint):
+@require_admin_context
+def volume_attached(context, volume_id, instance_id, mountpoint):
session = get_session()
with session.begin():
- volume_ref = models.Volume.find(volume_id, session=session)
+ volume_ref = volume_get(context, volume_id, session=session)
volume_ref['status'] = 'in-use'
volume_ref['mountpoint'] = mountpoint
volume_ref['attach_status'] = 'attached'
- volume_ref.instance = models.Instance.find(instance_id,
- session=session)
+ volume_ref.instance = instance_get(context, instance_id, session=session)
volume_ref.save(session=session)
-def volume_create(_context, values):
+@require_context
+def volume_create(context, values):
volume_ref = models.Volume()
for (key, value) in values.iteritems():
volume_ref[key] = value
- volume_ref.save()
+
+ session = get_session()
+ with session.begin():
+ while volume_ref.ec2_id == None:
+ ec2_id = utils.generate_uid(volume_ref.__prefix__)
+ if not volume_ec2_id_exists(context, ec2_id, session=session):
+ volume_ref.ec2_id = ec2_id
+ volume_ref.save(session=session)
return volume_ref
-def volume_data_get_for_project(_context, project_id):
+@require_admin_context
+def volume_data_get_for_project(context, project_id):
session = get_session()
result = session.query(func.count(models.Volume.id),
func.sum(models.Volume.size)
@@ -720,7 +1139,8 @@ def volume_data_get_for_project(_context, project_id):
return (result[0] or 0, result[1] or 0)
-def volume_destroy(_context, volume_id):
+@require_admin_context
+def volume_destroy(context, volume_id):
session = get_session()
with session.begin():
# TODO(vish): do we have to use sql here?
@@ -731,10 +1151,11 @@ def volume_destroy(_context, volume_id):
{'id': volume_id})
-def volume_detached(_context, volume_id):
+@require_admin_context
+def volume_detached(context, volume_id):
session = get_session()
with session.begin():
- volume_ref = models.Volume.find(volume_id, session=session)
+ volume_ref = volume_get(context, volume_id, session=session)
volume_ref['status'] = 'available'
volume_ref['mountpoint'] = None
volume_ref['attach_status'] = 'detached'
@@ -742,46 +1163,334 @@ def volume_detached(_context, volume_id):
volume_ref.save(session=session)
-def volume_get(context, volume_id):
- return models.Volume.find(volume_id, deleted=_deleted(context))
+@require_context
+def volume_get(context, volume_id, session=None):
+ if not session:
+ session = get_session()
+ result = None
+
+ if is_admin_context(context):
+ result = session.query(models.Volume
+ ).filter_by(id=volume_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Volume
+ ).filter_by(project_id=context.project.id
+ ).filter_by(id=volume_id
+ ).filter_by(deleted=False
+ ).first()
+ if not result:
+ raise exception.NotFound('No volume for id %s' % volume_id)
+
+ return result
+@require_admin_context
def volume_get_all(context):
- return models.Volume.all(deleted=_deleted(context))
+ return session.query(models.Volume
+ ).filter_by(deleted=can_read_deleted(context)
+ ).all()
+@require_context
+def volume_get_all_by_project(context, project_id):
+ authorize_project_context(context, project_id)
-def volume_get_by_project(context, project_id):
session = get_session()
return session.query(models.Volume
).filter_by(project_id=project_id
- ).filter_by(deleted=_deleted(context)
+ ).filter_by(deleted=can_read_deleted(context)
).all()
-def volume_get_by_str(context, str_id):
- return models.Volume.find_by_str(str_id, deleted=_deleted(context))
+@require_context
+def volume_get_by_ec2_id(context, ec2_id):
+ session = get_session()
+ result = None
+
+ if is_admin_context(context):
+ result = session.query(models.Volume
+ ).filter_by(ec2_id=ec2_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+ elif is_user_context(context):
+ result = session.query(models.Volume
+ ).filter_by(project_id=context.project.id
+ ).filter_by(ec2_id=ec2_id
+ ).filter_by(deleted=False
+ ).first()
+ else:
+ raise exception.NotAuthorized()
+
+ if not result:
+ raise exception.NotFound('Volume %s not found' % ec2_id)
+
+ return result
+
+
+@require_context
+def volume_ec2_id_exists(context, ec2_id, session=None):
+ if not session:
+ session = get_session()
+ return session.query(exists(
+ ).where(models.Volume.id==ec2_id)
+ ).one()[0]
-def volume_get_instance(_context, volume_id):
+
+@require_admin_context
+def volume_get_instance(context, volume_id):
session = get_session()
- with session.begin():
- return models.Volume.find(volume_id, session=session).instance
+ result = session.query(models.Volume
+ ).filter_by(id=volume_id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).options(joinedload('instance')
+ ).first()
+ if not result:
+ raise exception.NotFound('Volume %s not found' % ec2_id)
+
+ return result.instance
-def volume_get_shelf_and_blade(_context, volume_id):
+@require_admin_context
+def volume_get_shelf_and_blade(context, volume_id):
session = get_session()
- export_device = session.query(models.ExportDevice
- ).filter_by(volume_id=volume_id
- ).first()
- if not export_device:
- raise exception.NotFound()
- return (export_device.shelf_id, export_device.blade_id)
+ result = session.query(models.ExportDevice
+ ).filter_by(volume_id=volume_id
+ ).first()
+ if not result:
+ raise exception.NotFound('No export device found for volume %s' %
+ volume_id)
+ return (result.shelf_id, result.blade_id)
-def volume_update(_context, volume_id, values):
+
+@require_context
+def volume_update(context, volume_id, values):
session = get_session()
with session.begin():
- volume_ref = models.Volume.find(volume_id, session=session)
+ volume_ref = volume_get(context, volume_id, session=session)
for (key, value) in values.iteritems():
volume_ref[key] = value
volume_ref.save(session=session)
+
+
+###################
+
+
+@require_admin_context
+def user_get(context, id, session=None):
+ if not session:
+ session = get_session()
+
+ result = session.query(models.User
+ ).filter_by(id=id
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+
+ if not result:
+ raise exception.NotFound('No user for id %s' % id)
+
+ return result
+
+
+@require_admin_context
+def user_get_by_access_key(context, access_key, session=None):
+ if not session:
+ session = get_session()
+
+ result = session.query(models.User
+ ).filter_by(access_key=access_key
+ ).filter_by(deleted=can_read_deleted(context)
+ ).first()
+
+ if not result:
+ raise exception.NotFound('No user for id %s' % id)
+
+ return result
+
+
+@require_admin_context
+def user_create(_context, values):
+ user_ref = models.User()
+ for (key, value) in values.iteritems():
+ user_ref[key] = value
+ user_ref.save()
+ return user_ref
+
+
+@require_admin_context
+def user_delete(context, id):
+ session = get_session()
+ with session.begin():
+ session.execute('delete from user_project_association where user_id=:id',
+ {'id': id})
+ session.execute('delete from user_role_association where user_id=:id',
+ {'id': id})
+ session.execute('delete from user_project_role_association where user_id=:id',
+ {'id': id})
+ user_ref = user_get(context, id, session=session)
+ session.delete(user_ref)
+
+
+def user_get_all(context):
+ session = get_session()
+ return session.query(models.User
+ ).filter_by(deleted=can_read_deleted(context)
+ ).all()
+
+
+def project_create(_context, values):
+ project_ref = models.Project()
+ for (key, value) in values.iteritems():
+ project_ref[key] = value
+ project_ref.save()
+ return project_ref
+
+
+def project_add_member(context, project_id, user_id):
+ session = get_session()
+ with session.begin():
+ project_ref = project_get(context, project_id, session=session)
+ user_ref = user_get(context, user_id, session=session)
+
+ project_ref.members += [user_ref]
+ project_ref.save(session=session)
+
+
+def project_get(context, id, session=None):
+ if not session:
+ session = get_session()
+
+ result = session.query(models.Project
+ ).filter_by(deleted=False
+ ).filter_by(id=id
+ ).options(joinedload_all('members')
+ ).first()
+
+ if not result:
+ raise exception.NotFound("No project with id %s" % id)
+
+ return result
+
+
+def project_get_all(context):
+ session = get_session()
+ return session.query(models.Project
+ ).filter_by(deleted=can_read_deleted(context)
+ ).options(joinedload_all('members')
+ ).all()
+
+
+def project_get_by_user(context, user_id):
+ session = get_session()
+ user = session.query(models.User
+ ).filter_by(deleted=can_read_deleted(context)
+ ).options(joinedload_all('projects')
+ ).first()
+ return user.projects
+
+
+def project_remove_member(context, project_id, user_id):
+ session = get_session()
+ project = project_get(context, project_id, session=session)
+ user = user_get(context, user_id, session=session)
+
+ if user in project.members:
+ project.members.remove(user)
+ project.save(session=session)
+
+
+def user_update(context, user_id, values):
+ session = get_session()
+ with session.begin():
+ user_ref = user_get(context, user_id, session=session)
+ for (key, value) in values.iteritems():
+ user_ref[key] = value
+ user_ref.save(session=session)
+
+
+def project_update(context, project_id, values):
+ session = get_session()
+ with session.begin():
+ project_ref = project_get(context, project_id, session=session)
+ for (key, value) in values.iteritems():
+ project_ref[key] = value
+ project_ref.save(session=session)
+
+
+def project_delete(context, id):
+ session = get_session()
+ with session.begin():
+ session.execute('delete from user_project_association where project_id=:id',
+ {'id': id})
+ session.execute('delete from user_project_role_association where project_id=:id',
+ {'id': id})
+ project_ref = project_get(context, id, session=session)
+ session.delete(project_ref)
+
+
+def user_get_roles(context, user_id):
+ session = get_session()
+ with session.begin():
+ user_ref = user_get(context, user_id, session=session)
+ return [role.role for role in user_ref['roles']]
+
+
+def user_get_roles_for_project(context, user_id, project_id):
+ session = get_session()
+ with session.begin():
+ res = session.query(models.UserProjectRoleAssociation
+ ).filter_by(user_id=user_id
+ ).filter_by(project_id=project_id
+ ).all()
+ return [association.role for association in res]
+
+def user_remove_project_role(context, user_id, project_id, role):
+ session = get_session()
+ with session.begin():
+ session.execute('delete from user_project_role_association where ' + \
+ 'user_id=:user_id and project_id=:project_id and ' + \
+ 'role=:role', { 'user_id' : user_id,
+ 'project_id' : project_id,
+ 'role' : role })
+
+
+def user_remove_role(context, user_id, role):
+ session = get_session()
+ with session.begin():
+ res = session.query(models.UserRoleAssociation
+ ).filter_by(user_id=user_id
+ ).filter_by(role=role
+ ).all()
+ for role in res:
+ session.delete(role)
+
+
+def user_add_role(context, user_id, role):
+ session = get_session()
+ with session.begin():
+ user_ref = user_get(context, user_id, session=session)
+ models.UserRoleAssociation(user=user_ref, role=role).save(session=session)
+
+
+def user_add_project_role(context, user_id, project_id, role):
+ session = get_session()
+ with session.begin():
+ user_ref = user_get(context, user_id, session=session)
+ project_ref = project_get(context, project_id, session=session)
+ models.UserProjectRoleAssociation(user_id=user_ref['id'],
+ project_id=project_ref['id'],
+ role=role).save(session=session)
+
+
+###################
+
+
+def host_get_networks(context, host):
+ session = get_session()
+ with session.begin():
+ return session.query(models.Network
+ ).filter_by(deleted=False
+ ).filter_by(host=host
+ ).all()
diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py
index 41013f41b..673c8e94f 100644
--- a/nova/db/sqlalchemy/models.py
+++ b/nova/db/sqlalchemy/models.py
@@ -27,7 +27,9 @@ import datetime
from sqlalchemy.orm import relationship, backref, exc, object_mapper
from sqlalchemy import Column, Integer, String
from sqlalchemy import ForeignKey, DateTime, Boolean, Text
+from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.schema import ForeignKeyConstraint
from nova.db.sqlalchemy.session import get_session
@@ -50,44 +52,6 @@ class NovaBase(object):
deleted_at = Column(DateTime)
deleted = Column(Boolean, default=False)
- @classmethod
- def all(cls, session=None, deleted=False):
- """Get all objects of this type"""
- if not session:
- session = get_session()
- return session.query(cls
- ).filter_by(deleted=deleted
- ).all()
-
- @classmethod
- def count(cls, session=None, deleted=False):
- """Count objects of this type"""
- if not session:
- session = get_session()
- return session.query(cls
- ).filter_by(deleted=deleted
- ).count()
-
- @classmethod
- def find(cls, obj_id, session=None, deleted=False):
- """Find object by id"""
- if not session:
- session = get_session()
- try:
- return session.query(cls
- ).filter_by(id=obj_id
- ).filter_by(deleted=deleted
- ).one()
- except exc.NoResultFound:
- new_exc = exception.NotFound("No model for id %s" % obj_id)
- raise new_exc.__class__, new_exc, sys.exc_info()[2]
-
- @classmethod
- def find_by_str(cls, str_id, session=None, deleted=False):
- """Find object by str_id"""
- int_id = int(str_id.rpartition('-')[2])
- return cls.find(int_id, session=session, deleted=deleted)
-
@property
def str_id(self):
"""Get string id of object (generally prefix + '-' + id)"""
@@ -98,7 +62,13 @@ class NovaBase(object):
if not session:
session = get_session()
session.add(self)
- session.flush()
+ try:
+ session.flush()
+ except IntegrityError, e:
+ if str(e).endswith('is not unique'):
+ raise exception.Duplicate(str(e))
+ else:
+ raise
def delete(self, session=None):
"""Delete this object"""
@@ -126,6 +96,7 @@ class NovaBase(object):
# __tablename__ = 'images'
# __prefix__ = 'ami'
# id = Column(Integer, primary_key=True)
+# ec2_id = Column(String(12), unique=True)
# user_id = Column(String(255))
# project_id = Column(String(255))
# image_type = Column(String(255))
@@ -173,21 +144,7 @@ class Service(BASE, NovaBase):
binary = Column(String(255))
topic = Column(String(255))
report_count = Column(Integer, nullable=False, default=0)
-
- @classmethod
- def find_by_args(cls, host, binary, session=None, deleted=False):
- if not session:
- session = get_session()
- try:
- return session.query(cls
- ).filter_by(host=host
- ).filter_by(binary=binary
- ).filter_by(deleted=deleted
- ).one()
- except exc.NoResultFound:
- new_exc = exception.NotFound("No model for %s, %s" % (host,
- binary))
- raise new_exc.__class__, new_exc, sys.exc_info()[2]
+ disabled = Column(Boolean, default=False)
class Instance(BASE, NovaBase):
@@ -195,6 +152,9 @@ class Instance(BASE, NovaBase):
__tablename__ = 'instances'
__prefix__ = 'i'
id = Column(Integer, primary_key=True)
+ ec2_id = Column(String(10), unique=True)
+
+ admin_pass = Column(String(255))
user_id = Column(String(255))
project_id = Column(String(255))
@@ -209,11 +169,14 @@ class Instance(BASE, NovaBase):
@property
def name(self):
- return self.str_id
+ return self.ec2_id
image_id = Column(String(255))
kernel_id = Column(String(255))
ramdisk_id = Column(String(255))
+
+ server_name = Column(String(255))
+
# image_id = Column(Integer, ForeignKey('images.id'), nullable=True)
# kernel_id = Column(Integer, ForeignKey('images.id'), nullable=True)
# ramdisk_id = Column(Integer, ForeignKey('images.id'), nullable=True)
@@ -233,7 +196,6 @@ class Instance(BASE, NovaBase):
vcpus = Column(Integer)
local_gb = Column(Integer)
-
hostname = Column(String(255))
host = Column(String(255)) # , ForeignKey('hosts.id'))
@@ -247,6 +209,10 @@ class Instance(BASE, NovaBase):
scheduled_at = Column(DateTime)
launched_at = Column(DateTime)
terminated_at = Column(DateTime)
+
+ display_name = Column(String(255))
+ display_description = Column(String(255))
+
# TODO(vish): see Ewan's email about state improvements, probably
# should be in a driver base class or some such
# vmstate_state = running, halted, suspended, paused
@@ -264,6 +230,7 @@ class Volume(BASE, NovaBase):
__tablename__ = 'volumes'
__prefix__ = 'vol'
id = Column(Integer, primary_key=True)
+ ec2_id = Column(String(12), unique=True)
user_id = Column(String(255))
project_id = Column(String(255))
@@ -272,7 +239,11 @@ class Volume(BASE, NovaBase):
size = Column(Integer)
availability_zone = Column(String(255)) # TODO(vish): foreign key?
instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True)
- instance = relationship(Instance, backref=backref('volumes'))
+ instance = relationship(Instance,
+ backref=backref('volumes'),
+ foreign_keys=instance_id,
+ primaryjoin='and_(Volume.instance_id==Instance.id,'
+ 'Volume.deleted==False)')
mountpoint = Column(String(255))
attach_time = Column(String(255)) # TODO(vish): datetime
status = Column(String(255)) # TODO(vish): enum?
@@ -282,6 +253,10 @@ class Volume(BASE, NovaBase):
launched_at = Column(DateTime)
terminated_at = Column(DateTime)
+ display_name = Column(String(255))
+ display_description = Column(String(255))
+
+
class Quota(BASE, NovaBase):
"""Represents quota overrides for a project"""
__tablename__ = 'quotas'
@@ -299,18 +274,6 @@ class Quota(BASE, NovaBase):
def str_id(self):
return self.project_id
- @classmethod
- def find_by_str(cls, str_id, session=None, deleted=False):
- if not session:
- session = get_session()
- try:
- return session.query(cls
- ).filter_by(project_id=str_id
- ).filter_by(deleted=deleted
- ).one()
- except exc.NoResultFound:
- new_exc = exception.NotFound("No model for project_id %s" % str_id)
- raise new_exc.__class__, new_exc, sys.exc_info()[2]
class ExportDevice(BASE, NovaBase):
"""Represates a shelf and blade that a volume can be exported on"""
@@ -319,8 +282,27 @@ class ExportDevice(BASE, NovaBase):
shelf_id = Column(Integer)
blade_id = Column(Integer)
volume_id = Column(Integer, ForeignKey('volumes.id'), nullable=True)
- volume = relationship(Volume, backref=backref('export_device',
- uselist=False))
+ volume = relationship(Volume,
+ backref=backref('export_device', uselist=False),
+ foreign_keys=volume_id,
+ primaryjoin='and_(ExportDevice.volume_id==Volume.id,'
+ 'ExportDevice.deleted==False)')
+
+
+class KeyPair(BASE, NovaBase):
+ """Represents a public key pair for ssh"""
+ __tablename__ = 'key_pairs'
+ id = Column(Integer, primary_key=True)
+ name = Column(String(255))
+
+ user_id = Column(String(255))
+
+ fingerprint = Column(String(255))
+ public_key = Column(Text)
+
+ @property
+ def str_id(self):
+ return '%s.%s' % (self.user_id, self.name)
class Network(BASE, NovaBase):
@@ -355,10 +337,26 @@ class NetworkIndex(BASE, NovaBase):
"""
__tablename__ = 'network_indexes'
id = Column(Integer, primary_key=True)
- index = Column(Integer)
+ index = Column(Integer, unique=True)
network_id = Column(Integer, ForeignKey('networks.id'), nullable=True)
- network = relationship(Network, backref=backref('network_index',
- uselist=False))
+ network = relationship(Network,
+ backref=backref('network_index', uselist=False),
+ foreign_keys=network_id,
+ primaryjoin='and_(NetworkIndex.network_id==Network.id,'
+ 'NetworkIndex.deleted==False)')
+
+
+class AuthToken(BASE, NovaBase):
+ """Represents an authorization token for all API transactions. Fields
+ are a string representing the actual token and a user id for mapping
+ to the actual user"""
+ __tablename__ = 'auth_tokens'
+ token_hash = Column(String(255), primary_key=True)
+ user_id = Column(Integer)
+ server_manageent_url = Column(String(255))
+ storage_url = Column(String(255))
+ cdn_management_url = Column(String(255))
+
# TODO(vish): can these both come from the same baseclass?
@@ -370,8 +368,11 @@ class FixedIp(BASE, NovaBase):
network_id = Column(Integer, ForeignKey('networks.id'), nullable=True)
network = relationship(Network, backref=backref('fixed_ips'))
instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True)
- instance = relationship(Instance, backref=backref('fixed_ip',
- uselist=False))
+ instance = relationship(Instance,
+ backref=backref('fixed_ip', uselist=False),
+ foreign_keys=instance_id,
+ primaryjoin='and_(FixedIp.instance_id==Instance.id,'
+ 'FixedIp.deleted==False)')
allocated = Column(Boolean, default=False)
leased = Column(Boolean, default=False)
reserved = Column(Boolean, default=False)
@@ -380,18 +381,66 @@ class FixedIp(BASE, NovaBase):
def str_id(self):
return self.address
- @classmethod
- def find_by_str(cls, str_id, session=None, deleted=False):
- if not session:
- session = get_session()
- try:
- return session.query(cls
- ).filter_by(address=str_id
- ).filter_by(deleted=deleted
- ).one()
- except exc.NoResultFound:
- new_exc = exception.NotFound("No model for address %s" % str_id)
- raise new_exc.__class__, new_exc, sys.exc_info()[2]
+
+class User(BASE, NovaBase):
+ """Represents a user"""
+ __tablename__ = 'users'
+ id = Column(String(255), primary_key=True)
+
+ name = Column(String(255))
+ access_key = Column(String(255))
+ secret_key = Column(String(255))
+
+ is_admin = Column(Boolean)
+
+
+class Project(BASE, NovaBase):
+ """Represents a project"""
+ __tablename__ = 'projects'
+ id = Column(String(255), primary_key=True)
+ name = Column(String(255))
+ description = Column(String(255))
+
+ project_manager = Column(String(255), ForeignKey(User.id))
+
+ members = relationship(User,
+ secondary='user_project_association',
+ backref='projects')
+
+
+class UserProjectRoleAssociation(BASE, NovaBase):
+ __tablename__ = 'user_project_role_association'
+ user_id = Column(String(255), primary_key=True)
+ user = relationship(User,
+ primaryjoin=user_id==User.id,
+ foreign_keys=[User.id],
+ uselist=False)
+
+ project_id = Column(String(255), primary_key=True)
+ project = relationship(Project,
+ primaryjoin=project_id==Project.id,
+ foreign_keys=[Project.id],
+ uselist=False)
+
+ role = Column(String(255), primary_key=True)
+ ForeignKeyConstraint(['user_id',
+ 'project_id'],
+ ['user_project_association.user_id',
+ 'user_project_association.project_id'])
+
+
+class UserRoleAssociation(BASE, NovaBase):
+ __tablename__ = 'user_role_association'
+ user_id = Column(String(255), ForeignKey('users.id'), primary_key=True)
+ user = relationship(User, backref='roles')
+ role = Column(String(255), primary_key=True)
+
+
+class UserProjectAssociation(BASE, NovaBase):
+ __tablename__ = 'user_project_association'
+ user_id = Column(String(255), ForeignKey(User.id), primary_key=True)
+ project_id = Column(String(255), ForeignKey(Project.id), primary_key=True)
+
class FloatingIp(BASE, NovaBase):
@@ -400,34 +449,21 @@ class FloatingIp(BASE, NovaBase):
id = Column(Integer, primary_key=True)
address = Column(String(255))
fixed_ip_id = Column(Integer, ForeignKey('fixed_ips.id'), nullable=True)
- fixed_ip = relationship(FixedIp, backref=backref('floating_ips'))
-
+ fixed_ip = relationship(FixedIp,
+ backref=backref('floating_ips'),
+ foreign_keys=fixed_ip_id,
+ primaryjoin='and_(FloatingIp.fixed_ip_id==FixedIp.id,'
+ 'FloatingIp.deleted==False)')
project_id = Column(String(255))
host = Column(String(255)) # , ForeignKey('hosts.id'))
- @property
- def str_id(self):
- return self.address
-
- @classmethod
- def find_by_str(cls, str_id, session=None, deleted=False):
- if not session:
- session = get_session()
- try:
- return session.query(cls
- ).filter_by(address=str_id
- ).filter_by(deleted=deleted
- ).one()
- except exc.NoResultFound:
- new_exc = exception.NotFound("No model for address %s" % str_id)
- raise new_exc.__class__, new_exc, sys.exc_info()[2]
-
def register_models():
"""Register Models and create metadata"""
from sqlalchemy import create_engine
models = (Service, Instance, Volume, ExportDevice,
- FixedIp, FloatingIp, Network, NetworkIndex) # , Image, Host)
+ FixedIp, FloatingIp, Network, NetworkIndex,
+ AuthToken, UserProjectAssociation, User, Project) # , Image, Host)
engine = create_engine(FLAGS.sql_connection, echo=False)
for model in models:
model.metadata.create_all(engine)
diff --git a/nova/endpoint/__init__.py b/nova/endpoint/__init__.py
deleted file mode 100644
index e69de29bb..000000000
--- a/nova/endpoint/__init__.py
+++ /dev/null
diff --git a/nova/endpoint/api.py b/nova/endpoint/api.py
deleted file mode 100755
index 12eedfe67..000000000
--- a/nova/endpoint/api.py
+++ /dev/null
@@ -1,347 +0,0 @@
-# vim: tabstop=4 shiftwidth=4 softtabstop=4
-
-# Copyright 2010 United States Government as represented by the
-# Administrator of the National Aeronautics and Space Administration.
-# All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may
-# not use this file except in compliance with the License. You may obtain
-# a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations
-# under the License.
-
-"""
-Tornado REST API Request Handlers for Nova functions
-Most calls are proxied into the responsible controller.
-"""
-
-import logging
-import multiprocessing
-import random
-import re
-import urllib
-# TODO(termie): replace minidom with etree
-from xml.dom import minidom
-
-import tornado.web
-from twisted.internet import defer
-
-from nova import crypto
-from nova import exception
-from nova import flags
-from nova import utils
-from nova.auth import manager
-import nova.cloudpipe.api
-from nova.endpoint import cloud
-
-
-FLAGS = flags.FLAGS
-flags.DEFINE_integer('cc_port', 8773, 'cloud controller port')
-
-
-_log = logging.getLogger("api")
-_log.setLevel(logging.DEBUG)
-
-
-_c2u = re.compile('(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))')
-
-
-def _camelcase_to_underscore(str):
- return _c2u.sub(r'_\1', str).lower().strip('_')
-
-
-def _underscore_to_camelcase(str):
- return ''.join([x[:1].upper() + x[1:] for x in str.split('_')])
-
-
-def _underscore_to_xmlcase(str):
- res = _underscore_to_camelcase(str)
- return res[:1].lower() + res[1:]
-
-
-class APIRequestContext(object):
- def __init__(self, handler, user, project):
- self.handler = handler
- self.user = user
- self.project = project
- self.request_id = ''.join(
- [random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-')
- for x in xrange(20)]
- )
-
-
-class APIRequest(object):
- def __init__(self, controller, action):
- self.controller = controller
- self.action = action
-
- def send(self, context, **kwargs):
-
- try:
- method = getattr(self.controller,
- _camelcase_to_underscore(self.action))
- except AttributeError:
- _error = ('Unsupported API request: controller = %s,'
- 'action = %s') % (self.controller, self.action)
- _log.warning(_error)
- # TODO: Raise custom exception, trap in apiserver,
- # and reraise as 400 error.
- raise Exception(_error)
-
- args = {}
- for key, value in kwargs.items():
- parts = key.split(".")
- key = _camelcase_to_underscore(parts[0])
- if len(parts) > 1:
- d = args.get(key, {})
- d[parts[1]] = value[0]
- value = d
- else:
- value = value[0]
- args[key] = value
-
- for key in args.keys():
- if isinstance(args[key], dict):
- if args[key] != {} and args[key].keys()[0].isdigit():
- s = args[key].items()
- s.sort()
- args[key] = [v for k, v in s]
-
- d = defer.maybeDeferred(method, context, **args)
- d.addCallback(self._render_response, context.request_id)
- return d
-
- def _render_response(self, response_data, request_id):
- xml = minidom.Document()
-
- response_el = xml.createElement(self.action + 'Response')
- response_el.setAttribute('xmlns',
- 'http://ec2.amazonaws.com/doc/2009-11-30/')
- request_id_el = xml.createElement('requestId')
- request_id_el.appendChild(xml.createTextNode(request_id))
- response_el.appendChild(request_id_el)
- if(response_data == True):
- self._render_dict(xml, response_el, {'return': 'true'})
- else:
- self._render_dict(xml, response_el, response_data)
-
- xml.appendChild(response_el)
-
- response = xml.toxml()
- xml.unlink()
- _log.debug(response)
- return response
-
- def _render_dict(self, xml, el, data):
- try:
- for key in data.keys():
- val = data[key]
- el.appendChild(self._render_data(xml, key, val))
- except:
- _log.debug(data)
- raise
-
- def _render_data(self, xml, el_name, data):
- el_name = _underscore_to_xmlcase(el_name)
- data_el = xml.createElement(el_name)
-
- if isinstance(data, list):
- for item in data:
- data_el.appendChild(self._render_data(xml, 'item', item))
- elif isinstance(data, dict):
- self._render_dict(xml, data_el, data)
- elif hasattr(data, '__dict__'):
- self._render_dict(xml, data_el, data.__dict__)
- elif isinstance(data, bool):
- data_el.appendChild(xml.createTextNode(str(data).lower()))
- elif data != None:
- data_el.appendChild(xml.createTextNode(str(data)))
-
- return data_el
-
-
-class RootRequestHandler(tornado.web.RequestHandler):
- def get(self):
- # available api versions
- versions = [
- '1.0',
- '2007-01-19',
- '2007-03-01',
- '2007-08-29',
- '2007-10-10',
- '2007-12-15',
- '2008-02-01',
- '2008-09-01',
- '2009-04-04',
- ]
- for version in versions:
- self.write('%s\n' % version)
- self.finish()
-
-
-class MetadataRequestHandler(tornado.web.RequestHandler):
- def print_data(self, data):
- if isinstance(data, dict):
- output = ''
- for key in data:
- if key == '_name':
- continue
- output += key
- if isinstance(data[key], dict):
- if '_name' in data[key]:
- output += '=' + str(data[key]['_name'])
- else:
- output += '/'
- output += '\n'
- self.write(output[:-1]) # cut off last \n
- elif isinstance(data, list):
- self.write('\n'.join(data))
- else:
- self.write(str(data))
-
- def lookup(self, path, data):
- items = path.split('/')
- for item in items:
- if item:
- if not isinstance(data, dict):
- return data
- if not item in data:
- return None
- data = data[item]
- return data
-
- def get(self, path):
- cc = self.application.controllers['Cloud']
- meta_data = cc.get_metadata(self.request.remote_ip)
- if meta_data is None:
- _log.error('Failed to get metadata for ip: %s' %
- self.request.remote_ip)
- raise tornado.web.HTTPError(404)
- data = self.lookup(path, meta_data)
- if data is None:
- raise tornado.web.HTTPError(404)
- self.print_data(data)
- self.finish()
-
-
-class APIRequestHandler(tornado.web.RequestHandler):
- def get(self, controller_name):
- self.execute(controller_name)
-
- @tornado.web.asynchronous
- def execute(self, controller_name):
- # Obtain the appropriate controller for this request.
- try:
- controller = self.application.controllers[controller_name]
- except KeyError:
- self._error('unhandled', 'no controller named %s' % controller_name)
- return
-
- args = self.request.arguments
-
- # Read request signature.
- try:
- signature = args.pop('Signature')[0]
- except:
- raise tornado.web.HTTPError(400)
-
- # Make a copy of args for authentication and signature verification.
- auth_params = {}
- for key, value in args.items():
- auth_params[key] = value[0]
-
- # Get requested action and remove authentication args for final request.
- try:
- action = args.pop('Action')[0]
- access = args.pop('AWSAccessKeyId')[0]
- args.pop('SignatureMethod')
- args.pop('SignatureVersion')
- args.pop('Version')
- args.pop('Timestamp')
- except:
- raise tornado.web.HTTPError(400)
-
- # Authenticate the request.
- try:
- (user, project) = manager.AuthManager().authenticate(
- access,
- signature,
- auth_params,
- self.request.method,
- self.request.host,
- self.request.path
- )
-
- except exception.Error, ex:
- logging.debug("Authentication Failure: %s" % ex)
- raise tornado.web.HTTPError(403)
-
- _log.debug('action: %s' % action)
-
- for key, value in args.items():
- _log.debug('arg: %s\t\tval: %s' % (key, value))
-
- request = APIRequest(controller, action)
- context = APIRequestContext(self, user, project)
- d = request.send(context, **args)
- # d.addCallback(utils.debug)
-
- # TODO: Wrap response in AWS XML format
- d.addCallbacks(self._write_callback, self._error_callback)
-
- def _write_callback(self, data):
- self.set_header('Content-Type', 'text/xml')
- self.write(data)
- self.finish()
-
- def _error_callback(self, failure):
- try:
- failure.raiseException()
- except exception.ApiError as ex:
- if ex.code:
- self._error(ex.code, ex.message)
- else:
- self._error(type(ex).__name__, ex.message)
- # TODO(vish): do something more useful with unknown exceptions
- except Exception as ex:
- self._error(type(ex).__name__, str(ex))
- raise
-
- def post(self, controller_name):
- self.execute(controller_name)
-
- def _error(self, code, message):
- self._status_code = 400
- self.set_header('Content-Type', 'text/xml')
- self.write('<?xml version="1.0"?>\n')
- self.write('<Response><Errors><Error><Code>%s</Code>'
- '<Message>%s</Message></Error></Errors>'
- '<RequestID>?</RequestID></Response>' % (code, message))
- self.finish()
-
-
-class APIServerApplication(tornado.web.Application):
- def __init__(self, controllers):
- tornado.web.Application.__init__(self, [
- (r'/', RootRequestHandler),
- (r'/cloudpipe/(.*)', nova.cloudpipe.api.CloudPipeRequestHandler),
- (r'/cloudpipe', nova.cloudpipe.api.CloudPipeRequestHandler),
- (r'/services/([A-Za-z0-9]+)/', APIRequestHandler),
- (r'/latest/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- (r'/2009-04-04/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- (r'/2008-09-01/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- (r'/2008-02-01/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- (r'/2007-12-15/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- (r'/2007-10-10/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- (r'/2007-08-29/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- (r'/2007-03-01/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- (r'/2007-01-19/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- (r'/1.0/([-A-Za-z0-9/]*)', MetadataRequestHandler),
- ], pool=multiprocessing.Pool(4))
- self.controllers = controllers
diff --git a/nova/flags.py b/nova/flags.py
index ed0baee65..c32cdd7a4 100644
--- a/nova/flags.py
+++ b/nova/flags.py
@@ -167,6 +167,9 @@ def DECLARE(name, module_string, flag_values=FLAGS):
# Define any app-specific flags in their own files, docs at:
# http://code.google.com/p/python-gflags/source/browse/trunk/gflags.py#39
+DEFINE_list('region_list',
+ [],
+ 'list of region=url pairs separated by commas')
DEFINE_string('connection_type', 'libvirt', 'libvirt, xenapi or fake')
DEFINE_integer('s3_port', 3333, 's3 port')
DEFINE_string('s3_host', '127.0.0.1', 's3 host')
@@ -185,6 +188,8 @@ DEFINE_string('rabbit_userid', 'guest', 'rabbit userid')
DEFINE_string('rabbit_password', 'guest', 'rabbit password')
DEFINE_string('rabbit_virtual_host', '/', 'rabbit virtual host')
DEFINE_string('control_exchange', 'nova', 'the main exchange to connect to')
+DEFINE_string('cc_host', '127.0.0.1', 'ip of api server')
+DEFINE_integer('cc_port', 8773, 'cloud controller port')
DEFINE_string('ec2_url', 'http://127.0.0.1:8773/services/Cloud',
'Url to ec2 api server')
diff --git a/nova/manager.py b/nova/manager.py
index e9aa50c56..56ba7d3f6 100644
--- a/nova/manager.py
+++ b/nova/manager.py
@@ -22,6 +22,7 @@ Base class for managers of different parts of the system
from nova import utils
from nova import flags
+from twisted.internet import defer
FLAGS = flags.FLAGS
flags.DEFINE_string('db_driver', 'nova.db.api',
@@ -37,3 +38,15 @@ class Manager(object):
if not db_driver:
db_driver = FLAGS.db_driver
self.db = utils.import_object(db_driver) # pylint: disable-msg=C0103
+
+ @defer.inlineCallbacks
+ def periodic_tasks(self, context=None):
+ """Tasks to be run at a periodic interval"""
+ yield
+
+ def init_host(self):
+ """Do any initialization that needs to be run if this is a standalone service.
+
+ Child classes should override this method.
+ """
+ pass
diff --git a/nova/network/linux_net.py b/nova/network/linux_net.py
index 41aeb5da7..37f9c8253 100644
--- a/nova/network/linux_net.py
+++ b/nova/network/linux_net.py
@@ -28,6 +28,11 @@ from nova import flags
from nova import utils
+def _bin_file(script):
+ """Return the absolute path to scipt in the bin directory"""
+ return os.path.abspath(os.path.join(__file__, "../../../bin", script))
+
+
FLAGS = flags.FLAGS
flags.DEFINE_string('dhcpbridge_flagfile',
'/etc/nova/nova-dhcpbridge.conf',
@@ -36,13 +41,36 @@ flags.DEFINE_string('dhcpbridge_flagfile',
flags.DEFINE_string('networks_path', utils.abspath('../networks'),
'Location to keep network config files')
flags.DEFINE_string('public_interface', 'vlan1',
- 'Interface for public IP addresses')
+ 'Interface for public IP addresses')
flags.DEFINE_string('bridge_dev', 'eth0',
'network device for bridges')
-
+flags.DEFINE_string('dhcpbridge', _bin_file('nova-dhcpbridge'),
+ 'location of nova-dhcpbridge')
+flags.DEFINE_string('routing_source_ip', '127.0.0.1',
+ 'Public IP of network host')
+flags.DEFINE_bool('use_nova_chains', False,
+ 'use the nova_ routing chains instead of default')
DEFAULT_PORTS = [("tcp", 80), ("tcp", 22), ("udp", 1194), ("tcp", 443)]
+def init_host():
+ """Basic networking setup goes here"""
+ # NOTE(devcamcar): Cloud public DNAT entries, CloudPipe port
+ # forwarding entries and a default DNAT entry.
+ _confirm_rule("PREROUTING", "-t nat -s 0.0.0.0/0 "
+ "-d 169.254.169.254/32 -p tcp -m tcp --dport 80 -j DNAT "
+ "--to-destination %s:%s" % (FLAGS.cc_host, FLAGS.cc_port))
+
+ # NOTE(devcamcar): Cloud public SNAT entries and the default
+ # SNAT rule for outbound traffic.
+ _confirm_rule("POSTROUTING", "-t nat -s %s "
+ "-j SNAT --to-source %s"
+ % (FLAGS.private_range, FLAGS.routing_source_ip))
+
+ _confirm_rule("POSTROUTING", "-t nat -s %s -j MASQUERADE" %
+ FLAGS.private_range)
+ _confirm_rule("POSTROUTING", "-t nat -s %(range)s -d %(range)s -j ACCEPT" %
+ {'range': FLAGS.private_range})
def bind_floating_ip(floating_ip):
"""Bind ip to public interface"""
@@ -58,37 +86,37 @@ def unbind_floating_ip(floating_ip):
def ensure_vlan_forward(public_ip, port, private_ip):
"""Sets up forwarding rules for vlan"""
- _confirm_rule("FORWARD -d %s -p udp --dport 1194 -j ACCEPT" % private_ip)
- _confirm_rule(
- "PREROUTING -t nat -d %s -p udp --dport %s -j DNAT --to %s:1194"
+ _confirm_rule("FORWARD", "-d %s -p udp --dport 1194 -j ACCEPT" %
+ private_ip)
+ _confirm_rule("PREROUTING",
+ "-t nat -d %s -p udp --dport %s -j DNAT --to %s:1194"
% (public_ip, port, private_ip))
def ensure_floating_forward(floating_ip, fixed_ip):
"""Ensure floating ip forwarding rule"""
- _confirm_rule("PREROUTING -t nat -d %s -j DNAT --to %s"
+ _confirm_rule("PREROUTING", "-t nat -d %s -j DNAT --to %s"
% (floating_ip, fixed_ip))
- _confirm_rule("POSTROUTING -t nat -s %s -j SNAT --to %s"
+ _confirm_rule("POSTROUTING", "-t nat -s %s -j SNAT --to %s"
% (fixed_ip, floating_ip))
# TODO(joshua): Get these from the secgroup datastore entries
- _confirm_rule("FORWARD -d %s -p icmp -j ACCEPT"
+ _confirm_rule("FORWARD", "-d %s -p icmp -j ACCEPT"
% (fixed_ip))
for (protocol, port) in DEFAULT_PORTS:
- _confirm_rule(
- "FORWARD -d %s -p %s --dport %s -j ACCEPT"
+ _confirm_rule("FORWARD","-d %s -p %s --dport %s -j ACCEPT"
% (fixed_ip, protocol, port))
def remove_floating_forward(floating_ip, fixed_ip):
"""Remove forwarding for floating ip"""
- _remove_rule("PREROUTING -t nat -d %s -j DNAT --to %s"
+ _remove_rule("PREROUTING", "-t nat -d %s -j DNAT --to %s"
% (floating_ip, fixed_ip))
- _remove_rule("POSTROUTING -t nat -s %s -j SNAT --to %s"
+ _remove_rule("POSTROUTING", "-t nat -s %s -j SNAT --to %s"
% (fixed_ip, floating_ip))
- _remove_rule("FORWARD -d %s -p icmp -j ACCEPT"
+ _remove_rule("FORWARD", "-d %s -p icmp -j ACCEPT"
% (fixed_ip))
for (protocol, port) in DEFAULT_PORTS:
- _remove_rule("FORWARD -d %s -p %s --dport %s -j ACCEPT"
+ _remove_rule("FORWARD", "-d %s -p %s --dport %s -j ACCEPT"
% (fixed_ip, protocol, port))
@@ -118,22 +146,24 @@ def ensure_bridge(bridge, interface, net_attrs=None):
# _execute("sudo brctl setageing %s 10" % bridge)
_execute("sudo brctl stp %s off" % bridge)
_execute("sudo brctl addif %s %s" % (bridge, interface))
- if net_attrs:
- _execute("sudo ifconfig %s %s broadcast %s netmask %s up" % \
- (bridge,
- net_attrs['gateway'],
- net_attrs['broadcast'],
- net_attrs['netmask']))
- _confirm_rule("FORWARD --in-interface %s -j ACCEPT" % bridge)
- else:
- _execute("sudo ifconfig %s up" % bridge)
+ if net_attrs:
+ _execute("sudo ifconfig %s %s broadcast %s netmask %s up" % \
+ (bridge,
+ net_attrs['gateway'],
+ net_attrs['broadcast'],
+ net_attrs['netmask']))
+ else:
+ _execute("sudo ifconfig %s up" % bridge)
+ _confirm_rule("FORWARD", "--in-interface %s -j ACCEPT" % bridge)
+ _confirm_rule("FORWARD", "--out-interface %s -j ACCEPT" % bridge)
def get_dhcp_hosts(context, network_id):
"""Get a string containing a network's hosts config in dnsmasq format"""
hosts = []
- for fixed_ip in db.network_get_associated_fixed_ips(context, network_id):
- hosts.append(_host_dhcp(fixed_ip['str_id']))
+ for fixed_ip_ref in db.network_get_associated_fixed_ips(context,
+ network_id):
+ hosts.append(_host_dhcp(fixed_ip_ref))
return '\n'.join(hosts)
@@ -149,9 +179,14 @@ def update_dhcp(context, network_id):
signal causing it to reload, otherwise spawn a new instance
"""
network_ref = db.network_get(context, network_id)
- with open(_dhcp_file(network_ref['vlan'], 'conf'), 'w') as f:
+
+ conffile = _dhcp_file(network_ref['vlan'], 'conf')
+ with open(conffile, 'w') as f:
f.write(get_dhcp_hosts(context, network_id))
+ # Make sure dnsmasq can actually read it (it setuid()s to "nobody")
+ os.chmod(conffile, 0644)
+
pid = _dnsmasq_pid_for(network_ref['vlan'])
# if dnsmasq is already running, then tell it to reload
@@ -159,7 +194,7 @@ def update_dhcp(context, network_id):
# TODO(ja): use "/proc/%d/cmdline" % (pid) to determine if pid refers
# correct dnsmasq process
try:
- os.kill(pid, signal.SIGHUP)
+ _execute('sudo kill -HUP %d' % pid)
return
except Exception as exc: # pylint: disable-msg=W0703
logging.debug("Hupping dnsmasq threw %s", exc)
@@ -171,12 +206,12 @@ def update_dhcp(context, network_id):
_execute(command, addl_env=env)
-def _host_dhcp(address):
+def _host_dhcp(fixed_ip_ref):
"""Return a host string for an address"""
- instance_ref = db.fixed_ip_get_instance(None, address)
+ instance_ref = fixed_ip_ref['instance']
return "%s,%s.novalocal,%s" % (instance_ref['mac_address'],
instance_ref['hostname'],
- address)
+ fixed_ip_ref['address'])
def _execute(cmd, *args, **kwargs):
@@ -194,15 +229,19 @@ def _device_exists(device):
return not err
-def _confirm_rule(cmd):
+def _confirm_rule(chain, cmd):
"""Delete and re-add iptables rule"""
- _execute("sudo iptables --delete %s" % (cmd), check_exit_code=False)
- _execute("sudo iptables -I %s" % (cmd))
+ if FLAGS.use_nova_chains:
+ chain = "nova_%s" % chain.lower()
+ _execute("sudo iptables --delete %s %s" % (chain, cmd), check_exit_code=False)
+ _execute("sudo iptables -I %s %s" % (chain, cmd))
-def _remove_rule(cmd):
+def _remove_rule(chain, cmd):
"""Remove iptables rule"""
- _execute("sudo iptables --delete %s" % (cmd))
+ if FLAGS.use_nova_chains:
+ chain = "%S" % chain.lower()
+ _execute("sudo iptables --delete %s %s" % (chain, cmd))
def _dnsmasq_cmd(net):
@@ -216,7 +255,7 @@ def _dnsmasq_cmd(net):
' --except-interface=lo',
' --dhcp-range=%s,static,120s' % net['dhcp_start'],
' --dhcp-hostsfile=%s' % _dhcp_file(net['vlan'], 'conf'),
- ' --dhcp-script=%s' % _bin_file('nova-dhcpbridge'),
+ ' --dhcp-script=%s' % FLAGS.dhcpbridge,
' --leasefile-ro']
return ''.join(cmd)
@@ -227,7 +266,7 @@ def _stop_dnsmasq(network):
if pid:
try:
- os.kill(pid, signal.SIGTERM)
+ _execute('sudo kill -TERM %d' % pid)
except Exception as exc: # pylint: disable-msg=W0703
logging.debug("Killing dnsmasq threw %s", exc)
@@ -235,12 +274,10 @@ def _stop_dnsmasq(network):
def _dhcp_file(vlan, kind):
"""Return path to a pid, leases or conf file for a vlan"""
- return os.path.abspath("%s/nova-%s.%s" % (FLAGS.networks_path, vlan, kind))
-
+ if not os.path.exists(FLAGS.networks_path):
+ os.makedirs(FLAGS.networks_path)
-def _bin_file(script):
- """Return the absolute path to scipt in the bin directory"""
- return os.path.abspath(os.path.join(__file__, "../../../bin", script))
+ return os.path.abspath("%s/nova-%s.%s" % (FLAGS.networks_path, vlan, kind))
def _dnsmasq_pid_for(vlan):
diff --git a/nova/network/manager.py b/nova/network/manager.py
index 191c1d364..9c1846dd9 100644
--- a/nova/network/manager.py
+++ b/nova/network/manager.py
@@ -20,10 +20,12 @@
Network Hosts are responsible for allocating ips and setting up network
"""
+import datetime
import logging
import math
import IPy
+from twisted.internet import defer
from nova import db
from nova import exception
@@ -62,7 +64,9 @@ flags.DEFINE_integer('cnt_vpn_clients', 5,
flags.DEFINE_string('network_driver', 'nova.network.linux_net',
'Driver to use for network creation')
flags.DEFINE_bool('update_dhcp_on_disassociate', False,
- 'Whether to update dhcp when fixed_ip is disassocated')
+ 'Whether to update dhcp when fixed_ip is disassociated')
+flags.DEFINE_integer('fixed_ip_disassociate_timeout', 600,
+ 'Seconds after which a deallocated ip is disassociated')
class AddressAlreadyAllocated(exception.Error):
@@ -81,6 +85,12 @@ class NetworkManager(manager.Manager):
self.driver = utils.import_object(network_driver)
super(NetworkManager, self).__init__(*args, **kwargs)
+ def init_host(self):
+ # Set up networking for the projects for which we're already
+ # the designated network host.
+ for network in self.db.host_get_networks(None, self.host):
+ self._on_set_network_host(None, network['id'])
+
def set_network_host(self, context, project_id):
"""Safely sets the host of the projects network"""
logging.debug("setting network host")
@@ -88,7 +98,7 @@ class NetworkManager(manager.Manager):
# TODO(vish): can we minimize db access by just getting the
# id here instead of the ref?
network_id = network_ref['id']
- host = self.db.network_set_host(context,
+ host = self.db.network_set_host(None,
network_id,
self.host)
self._on_set_network_host(context, network_id)
@@ -218,6 +228,27 @@ class FlatManager(NetworkManager):
class VlanManager(NetworkManager):
"""Vlan network with dhcp"""
+
+ @defer.inlineCallbacks
+ def periodic_tasks(self, context=None):
+ """Tasks to be run at a periodic interval"""
+ yield super(VlanManager, self).periodic_tasks(context)
+ now = datetime.datetime.utcnow()
+ timeout = FLAGS.fixed_ip_disassociate_timeout
+ time = now - datetime.timedelta(seconds=timeout)
+ num = self.db.fixed_ip_disassociate_all_by_timeout(context,
+ self.host,
+ time)
+ if num:
+ logging.debug("Dissassociated %s stale fixed ip(s)", num)
+
+ def init_host(self):
+ """Do any initialization that needs to be run if this is a
+ standalone service.
+ """
+ super(VlanManager, self).init_host()
+ self.driver.init_host()
+
def allocate_fixed_ip(self, context, instance_id, *args, **kwargs):
"""Gets a fixed ip from the pool"""
network_ref = self.db.project_get_network(context, context.project.id)
@@ -225,7 +256,7 @@ class VlanManager(NetworkManager):
address = network_ref['vpn_private_address']
self.db.fixed_ip_associate(context, address, instance_id)
else:
- address = self.db.fixed_ip_associate_pool(context,
+ address = self.db.fixed_ip_associate_pool(None,
network_ref['id'],
instance_id)
self.db.fixed_ip_update(context, address, {'allocated': True})
@@ -235,14 +266,6 @@ class VlanManager(NetworkManager):
"""Returns a fixed ip to the pool"""
self.db.fixed_ip_update(context, address, {'allocated': False})
fixed_ip_ref = self.db.fixed_ip_get_by_address(context, address)
- if not fixed_ip_ref['leased']:
- self.db.fixed_ip_disassociate(context, address)
- # NOTE(vish): dhcp server isn't updated until next setup, this
- # means there will stale entries in the conf file
- # the code below will update the file if necessary
- if FLAGS.update_dhcp_on_disassociate:
- network_ref = self.db.fixed_ip_get_network(context, address)
- self.driver.update_dhcp(context, network_ref['id'])
def setup_fixed_ip(self, context, address):
@@ -259,10 +282,7 @@ class VlanManager(NetworkManager):
"""Called by dhcp-bridge when ip is leased"""
logging.debug("Leasing IP %s", address)
fixed_ip_ref = self.db.fixed_ip_get_by_address(context, address)
- if not fixed_ip_ref['allocated']:
- logging.warn("IP %s leased that was already deallocated", address)
- return
- instance_ref = self.db.fixed_ip_get_instance(context, address)
+ instance_ref = fixed_ip_ref['instance']
if not instance_ref:
raise exception.Error("IP %s leased that isn't associated" %
address)
@@ -270,24 +290,27 @@ class VlanManager(NetworkManager):
raise exception.Error("IP %s leased to bad mac %s vs %s" %
(address, instance_ref['mac_address'], mac))
self.db.fixed_ip_update(context,
- fixed_ip_ref['str_id'],
+ fixed_ip_ref['address'],
{'leased': True})
+ if not fixed_ip_ref['allocated']:
+ logging.warn("IP %s leased that was already deallocated", address)
def release_fixed_ip(self, context, mac, address):
"""Called by dhcp-bridge when ip is released"""
logging.debug("Releasing IP %s", address)
fixed_ip_ref = self.db.fixed_ip_get_by_address(context, address)
- if not fixed_ip_ref['leased']:
- logging.warn("IP %s released that was not leased", address)
- return
- instance_ref = self.db.fixed_ip_get_instance(context, address)
+ instance_ref = fixed_ip_ref['instance']
if not instance_ref:
raise exception.Error("IP %s released that isn't associated" %
address)
if instance_ref['mac_address'] != mac:
raise exception.Error("IP %s released from bad mac %s vs %s" %
(address, instance_ref['mac_address'], mac))
- self.db.fixed_ip_update(context, address, {'leased': False})
+ if not fixed_ip_ref['leased']:
+ logging.warn("IP %s released that was not leased", address)
+ self.db.fixed_ip_update(context,
+ fixed_ip_ref['str_id'],
+ {'leased': False})
if not fixed_ip_ref['allocated']:
self.db.fixed_ip_disassociate(context, address)
# NOTE(vish): dhcp server isn't updated until next setup, this
@@ -343,7 +366,7 @@ class VlanManager(NetworkManager):
This could use a manage command instead of keying off of a flag"""
if not self.db.network_index_count(context):
for index in range(FLAGS.num_networks):
- self.db.network_index_create(context, {'index': index})
+ self.db.network_index_create_safe(context, {'index': index})
def _on_set_network_host(self, context, network_id):
"""Called when this host becomes the host for a project"""
@@ -351,6 +374,7 @@ class VlanManager(NetworkManager):
self.driver.ensure_vlan_bridge(network_ref['vlan'],
network_ref['bridge'],
network_ref)
+ self.driver.update_dhcp(context, network_id)
@property
def _bottom_reserved_ips(self):
diff --git a/nova/objectstore/__init__.py b/nova/objectstore/__init__.py
index b8890ac03..ecad9be7c 100644
--- a/nova/objectstore/__init__.py
+++ b/nova/objectstore/__init__.py
@@ -22,7 +22,7 @@
.. automodule:: nova.objectstore
:platform: Unix
- :synopsis: Currently a trivial file-based system, getting extended w/ mongo.
+ :synopsis: Currently a trivial file-based system, getting extended w/ swift.
.. moduleauthor:: Jesse Andrews <jesse@ansolabs.com>
.. moduleauthor:: Devin Carlen <devin.carlen@gmail.com>
.. moduleauthor:: Vishvananda Ishaya <vishvananda@yahoo.com>
diff --git a/nova/objectstore/handler.py b/nova/objectstore/handler.py
index 5c3dc286b..dfee64aca 100644
--- a/nova/objectstore/handler.py
+++ b/nova/objectstore/handler.py
@@ -55,7 +55,7 @@ from twisted.web import static
from nova import exception
from nova import flags
from nova.auth import manager
-from nova.endpoint import api
+from nova.api.ec2 import context
from nova.objectstore import bucket
from nova.objectstore import image
@@ -131,7 +131,7 @@ def get_context(request):
request.uri,
headers=request.getAllHeaders(),
check_type='s3')
- return api.APIRequestContext(None, user, project)
+ return context.APIRequestContext(user, project)
except exception.Error as ex:
logging.debug("Authentication Failure: %s", ex)
raise exception.NotAuthorized
@@ -352,6 +352,8 @@ class ImagesResource(resource.Resource):
m[u'imageType'] = m['type']
elif 'imageType' in m:
m[u'type'] = m['imageType']
+ if 'displayName' not in m:
+ m[u'displayName'] = u''
return m
request.write(json.dumps([decorate(i.metadata) for i in images]))
@@ -382,16 +384,25 @@ class ImagesResource(resource.Resource):
def render_POST(self, request): # pylint: disable-msg=R0201
"""Update image attributes: public/private"""
+ # image_id required for all requests
image_id = get_argument(request, 'image_id', u'')
- operation = get_argument(request, 'operation', u'')
-
image_object = image.Image(image_id)
-
if not image_object.is_authorized(request.context):
+ logging.debug("not authorized for render_POST in images")
raise exception.NotAuthorized
- image_object.set_public(operation=='add')
-
+ operation = get_argument(request, 'operation', u'')
+ if operation:
+ # operation implies publicity toggle
+ logging.debug("handling publicity toggle")
+ image_object.set_public(operation=='add')
+ else:
+ # other attributes imply update
+ logging.debug("update user fields")
+ clean_args = {}
+ for arg in request.args.keys():
+ clean_args[arg] = request.args[arg][0]
+ image_object.update_user_editable_fields(clean_args)
return ''
def render_DELETE(self, request): # pylint: disable-msg=R0201
diff --git a/nova/objectstore/image.py b/nova/objectstore/image.py
index f3c02a425..def1b8167 100644
--- a/nova/objectstore/image.py
+++ b/nova/objectstore/image.py
@@ -82,6 +82,16 @@ class Image(object):
with open(os.path.join(self.path, 'info.json'), 'w') as f:
json.dump(md, f)
+ def update_user_editable_fields(self, args):
+ """args is from the request parameters, so requires extra cleaning"""
+ fields = {'display_name': 'displayName', 'description': 'description'}
+ info = self.metadata
+ for field in fields.keys():
+ if field in args:
+ info[fields[field]] = args[field]
+ with open(os.path.join(self.path, 'info.json'), 'w') as f:
+ json.dump(info, f)
+
@staticmethod
def all():
images = []
diff --git a/nova/quota.py b/nova/quota.py
index f0e51feeb..edbb83111 100644
--- a/nova/quota.py
+++ b/nova/quota.py
@@ -37,7 +37,7 @@ flags.DEFINE_integer('quota_gigabytes', 1000,
flags.DEFINE_integer('quota_floating_ips', 10,
'number of floating ips allowed per project')
-def _get_quota(context, project_id):
+def get_quota(context, project_id):
rval = {'instances': FLAGS.quota_instances,
'cores': FLAGS.quota_cores,
'volumes': FLAGS.quota_volumes,
@@ -57,7 +57,7 @@ def allowed_instances(context, num_instances, instance_type):
project_id = context.project.id
used_instances, used_cores = db.instance_data_get_for_project(context,
project_id)
- quota = _get_quota(context, project_id)
+ quota = get_quota(context, project_id)
allowed_instances = quota['instances'] - used_instances
allowed_cores = quota['cores'] - used_cores
type_cores = instance_types.INSTANCE_TYPES[instance_type]['vcpus']
@@ -72,9 +72,10 @@ def allowed_volumes(context, num_volumes, size):
project_id = context.project.id
used_volumes, used_gigabytes = db.volume_data_get_for_project(context,
project_id)
- quota = _get_quota(context, project_id)
+ quota = get_quota(context, project_id)
allowed_volumes = quota['volumes'] - used_volumes
allowed_gigabytes = quota['gigabytes'] - used_gigabytes
+ size = int(size)
num_gigabytes = num_volumes * size
allowed_volumes = min(allowed_volumes,
int(allowed_gigabytes // size))
@@ -85,7 +86,7 @@ def allowed_floating_ips(context, num_floating_ips):
"""Check quota and return min(num_floating_ips, allowed_floating_ips)"""
project_id = context.project.id
used_floating_ips = db.floating_ip_count_by_project(context, project_id)
- quota = _get_quota(context, project_id)
+ quota = get_quota(context, project_id)
allowed_floating_ips = quota['floating_ips'] - used_floating_ips
return min(num_floating_ips, allowed_floating_ips)
diff --git a/nova/rpc.py b/nova/rpc.py
index 84a9b5590..fe52ad35f 100644
--- a/nova/rpc.py
+++ b/nova/rpc.py
@@ -46,9 +46,9 @@ LOG.setLevel(logging.DEBUG)
class Connection(carrot_connection.BrokerConnection):
"""Connection instance object"""
@classmethod
- def instance(cls):
+ def instance(cls, new=False):
"""Returns the instance"""
- if not hasattr(cls, '_instance'):
+ if new or not hasattr(cls, '_instance'):
params = dict(hostname=FLAGS.rabbit_host,
port=FLAGS.rabbit_port,
userid=FLAGS.rabbit_userid,
@@ -60,7 +60,10 @@ class Connection(carrot_connection.BrokerConnection):
# NOTE(vish): magic is fun!
# pylint: disable-msg=W0142
- cls._instance = cls(**params)
+ if new:
+ return cls(**params)
+ else:
+ cls._instance = cls(**params)
return cls._instance
@classmethod
@@ -81,21 +84,6 @@ class Consumer(messaging.Consumer):
self.failed_connection = False
super(Consumer, self).__init__(*args, **kwargs)
- # TODO(termie): it would be nice to give these some way of automatically
- # cleaning up after themselves
- def attach_to_tornado(self, io_inst=None):
- """Attach a callback to tornado that fires 10 times a second"""
- from tornado import ioloop
- if io_inst is None:
- io_inst = ioloop.IOLoop.instance()
-
- injected = ioloop.PeriodicCallback(
- lambda: self.fetch(enable_callbacks=True), 100, io_loop=io_inst)
- injected.start()
- return injected
-
- attachToTornado = attach_to_tornado
-
def fetch(self, no_ack=None, auto_ack=None, enable_callbacks=False):
"""Wraps the parent fetch with some logic for failed connections"""
# TODO(vish): the logic for failed connections and logging should be
@@ -123,6 +111,7 @@ class Consumer(messaging.Consumer):
"""Attach a callback to twisted that fires 10 times a second"""
loop = task.LoopingCall(self.fetch, enable_callbacks=True)
loop.start(interval=0.1)
+ return loop
class Publisher(messaging.Publisher):
@@ -265,6 +254,41 @@ def call(topic, msg):
msg.update({'_msg_id': msg_id})
LOG.debug("MSG_ID is %s" % (msg_id))
+ class WaitMessage(object):
+
+ def __call__(self, data, message):
+ """Acks message and sets result."""
+ message.ack()
+ if data['failure']:
+ self.result = RemoteError(*data['failure'])
+ else:
+ self.result = data['result']
+
+ wait_msg = WaitMessage()
+ conn = Connection.instance(True)
+ consumer = DirectConsumer(connection=conn, msg_id=msg_id)
+ consumer.register_callback(wait_msg)
+
+ conn = Connection.instance()
+ publisher = TopicPublisher(connection=conn, topic=topic)
+ publisher.send(msg)
+ publisher.close()
+
+ try:
+ consumer.wait(limit=1)
+ except StopIteration:
+ pass
+ consumer.close()
+ return wait_msg.result
+
+
+def call_twisted(topic, msg):
+ """Sends a message on a topic and wait for a response"""
+ LOG.debug("Making asynchronous call...")
+ msg_id = uuid.uuid4().hex
+ msg.update({'_msg_id': msg_id})
+ LOG.debug("MSG_ID is %s" % (msg_id))
+
conn = Connection.instance()
d = defer.Deferred()
consumer = DirectConsumer(connection=conn, msg_id=msg_id)
@@ -278,7 +302,7 @@ def call(topic, msg):
return d.callback(data['result'])
consumer.register_callback(deferred_receive)
- injected = consumer.attach_to_tornado()
+ injected = consumer.attach_to_twisted()
# clean up after the injected listened and return x
d.addCallback(lambda x: injected.stop() and x or x)
diff --git a/nova/scheduler/driver.py b/nova/scheduler/driver.py
index 2e6a5a835..c89d25a47 100644
--- a/nova/scheduler/driver.py
+++ b/nova/scheduler/driver.py
@@ -42,7 +42,8 @@ class Scheduler(object):
def service_is_up(service):
"""Check whether a service is up based on last heartbeat."""
last_heartbeat = service['updated_at'] or service['created_at']
- elapsed = datetime.datetime.now() - last_heartbeat
+ # Timestamps in DB are UTC.
+ elapsed = datetime.datetime.utcnow() - last_heartbeat
return elapsed < datetime.timedelta(seconds=FLAGS.service_down_time)
def hosts_up(self, context, topic):
diff --git a/nova/service.py b/nova/service.py
index 870dd6ceb..115e0ff32 100644
--- a/nova/service.py
+++ b/nova/service.py
@@ -37,7 +37,11 @@ from nova import utils
FLAGS = flags.FLAGS
flags.DEFINE_integer('report_interval', 10,
- 'seconds between nodes reporting state to cloud',
+ 'seconds between nodes reporting state to datastore',
+ lower_bound=1)
+
+flags.DEFINE_integer('periodic_interval', 60,
+ 'seconds between running periodic tasks',
lower_bound=1)
@@ -48,10 +52,17 @@ class Service(object, service.Service):
self.host = host
self.binary = binary
self.topic = topic
- manager_class = utils.import_class(manager)
- self.manager = manager_class(host=host, *args, **kwargs)
- self.model_disconnected = False
+ self.manager_class_name = manager
super(Service, self).__init__(*args, **kwargs)
+ self.saved_args, self.saved_kwargs = args, kwargs
+
+
+ def startService(self): # pylint: disable-msg C0103
+ manager_class = utils.import_class(self.manager_class_name)
+ self.manager = manager_class(host=self.host, *self.saved_args,
+ **self.saved_kwargs)
+ self.manager.init_host()
+ self.model_disconnected = False
try:
service_ref = db.service_get_by_args(None,
self.host,
@@ -80,7 +91,8 @@ class Service(object, service.Service):
binary=None,
topic=None,
manager=None,
- report_interval=None):
+ report_interval=None,
+ periodic_interval=None):
"""Instantiates class and passes back application object.
Args:
@@ -89,6 +101,7 @@ class Service(object, service.Service):
topic, defaults to bin_name - "nova-" part
manager, defaults to FLAGS.<topic>_manager
report_interval, defaults to FLAGS.report_interval
+ periodic_interval, defaults to FLAGS.periodic_interval
"""
if not host:
host = FLAGS.host
@@ -100,6 +113,8 @@ class Service(object, service.Service):
manager = FLAGS.get('%s_manager' % topic, None)
if not report_interval:
report_interval = FLAGS.report_interval
+ if not periodic_interval:
+ periodic_interval = FLAGS.periodic_interval
logging.warn("Starting %s node", topic)
service_obj = cls(host, binary, topic, manager)
conn = rpc.Connection.instance()
@@ -112,11 +127,14 @@ class Service(object, service.Service):
topic='%s.%s' % (topic, host),
proxy=service_obj)
+ consumer_all.attach_to_twisted()
+ consumer_node.attach_to_twisted()
+
pulse = task.LoopingCall(service_obj.report_state)
pulse.start(interval=report_interval, now=False)
- consumer_all.attach_to_twisted()
- consumer_node.attach_to_twisted()
+ pulse = task.LoopingCall(service_obj.periodic_tasks)
+ pulse.start(interval=periodic_interval, now=False)
# This is the parent service that twistd will be looking for when it
# parses this file, return it so that we can get it into globals.
@@ -132,6 +150,11 @@ class Service(object, service.Service):
logging.warn("Service killed that has no database entry")
@defer.inlineCallbacks
+ def periodic_tasks(self, context=None):
+ """Tasks to be run at a periodic interval"""
+ yield self.manager.periodic_tasks(context)
+
+ @defer.inlineCallbacks
def report_state(self, context=None):
"""Update the state of this service in the datastore."""
try:
diff --git a/nova/test.py b/nova/test.py
index c392c8a84..1f4b33272 100644
--- a/nova/test.py
+++ b/nova/test.py
@@ -33,6 +33,7 @@ from twisted.trial import unittest
from nova import fakerabbit
from nova import flags
+from nova import rpc
FLAGS = flags.FLAGS
@@ -62,19 +63,29 @@ class TrialTestCase(unittest.TestCase):
self.mox = mox.Mox()
self.stubs = stubout.StubOutForTesting()
self.flag_overrides = {}
+ self.injected = []
+ self._monkeyPatchAttach()
def tearDown(self): # pylint: disable-msg=C0103
"""Runs after each test method to finalize/tear down test environment"""
- super(TrialTestCase, self).tearDown()
self.reset_flags()
self.mox.UnsetStubs()
self.stubs.UnsetAll()
self.stubs.SmartUnsetAll()
self.mox.VerifyAll()
+
+ rpc.Consumer.attach_to_twisted = self.originalAttach
+ for x in self.injected:
+ try:
+ x.stop()
+ except AssertionError:
+ pass
if FLAGS.fake_rabbit:
fakerabbit.reset_all()
+ super(TrialTestCase, self).tearDown()
+
def flags(self, **kw):
"""Override flag variables for a test"""
for k, v in kw.iteritems():
@@ -90,16 +101,51 @@ class TrialTestCase(unittest.TestCase):
for k, v in self.flag_overrides.iteritems():
setattr(FLAGS, k, v)
+ def run(self, result=None):
+ test_method = getattr(self, self._testMethodName)
+ setattr(self,
+ self._testMethodName,
+ self._maybeInlineCallbacks(test_method, result))
+ rv = super(TrialTestCase, self).run(result)
+ setattr(self, self._testMethodName, test_method)
+ return rv
+
+ def _maybeInlineCallbacks(self, func, result):
+ def _wrapped():
+ g = func()
+ if isinstance(g, defer.Deferred):
+ return g
+ if not hasattr(g, 'send'):
+ return defer.succeed(g)
+
+ inlined = defer.inlineCallbacks(func)
+ d = inlined()
+ return d
+ _wrapped.func_name = func.func_name
+ return _wrapped
+
+ def _monkeyPatchAttach(self):
+ self.originalAttach = rpc.Consumer.attach_to_twisted
+ def _wrapped(innerSelf):
+ rv = self.originalAttach(innerSelf)
+ self.injected.append(rv)
+ return rv
+
+ _wrapped.func_name = self.originalAttach.func_name
+ rpc.Consumer.attach_to_twisted = _wrapped
+
class BaseTestCase(TrialTestCase):
# TODO(jaypipes): Can this be moved into the TrialTestCase class?
- """Base test case class for all unit tests."""
+ """Base test case class for all unit tests.
+
+ DEPRECATED: This is being removed once Tornado is gone, use TrialTestCase.
+ """
def setUp(self): # pylint: disable-msg=C0103
"""Run before each test method to initialize test environment"""
super(BaseTestCase, self).setUp()
# TODO(termie): we could possibly keep a more global registry of
# the injected listeners... this is fine for now though
- self.injected = []
self.ioloop = ioloop.IOLoop.instance()
self._waiting = None
@@ -109,8 +155,6 @@ class BaseTestCase(TrialTestCase):
def tearDown(self):# pylint: disable-msg=C0103
"""Runs after each test method to finalize/tear down test environment"""
super(BaseTestCase, self).tearDown()
- for x in self.injected:
- x.stop()
if FLAGS.fake_rabbit:
fakerabbit.reset_all()
diff --git a/nova/tests/access_unittest.py b/nova/tests/access_unittest.py
index 59e1683db..4b40ffd0a 100644
--- a/nova/tests/access_unittest.py
+++ b/nova/tests/access_unittest.py
@@ -18,19 +18,20 @@
import unittest
import logging
+import webob
from nova import exception
from nova import flags
from nova import test
+from nova.api import ec2
from nova.auth import manager
-from nova.auth import rbac
FLAGS = flags.FLAGS
class Context(object):
pass
-class AccessTestCase(test.BaseTestCase):
+class AccessTestCase(test.TrialTestCase):
def setUp(self):
super(AccessTestCase, self).setUp()
um = manager.AuthManager()
@@ -72,9 +73,17 @@ class AccessTestCase(test.BaseTestCase):
try:
self.project.add_role(self.testsys, 'sysadmin')
except: pass
- self.context = Context()
- self.context.project = self.project
#user is set in each test
+ def noopWSGIApp(environ, start_response):
+ start_response('200 OK', [])
+ return ['']
+ self.mw = ec2.Authorizer(noopWSGIApp)
+ self.mw.action_roles = {'str': {
+ '_allow_all': ['all'],
+ '_allow_none': [],
+ '_allow_project_manager': ['projectmanager'],
+ '_allow_sys_and_net': ['sysadmin', 'netadmin'],
+ '_allow_sysadmin': ['sysadmin']}}
def tearDown(self):
um = manager.AuthManager()
@@ -87,76 +96,46 @@ class AccessTestCase(test.BaseTestCase):
um.delete_user('testsys')
super(AccessTestCase, self).tearDown()
+ def response_status(self, user, methodName):
+ context = Context()
+ context.project = self.project
+ context.user = user
+ environ = {'ec2.context' : context,
+ 'ec2.controller': 'some string',
+ 'ec2.action': methodName}
+ req = webob.Request.blank('/', environ)
+ resp = req.get_response(self.mw)
+ return resp.status_int
+
+ def shouldAllow(self, user, methodName):
+ self.assertEqual(200, self.response_status(user, methodName))
+
+ def shouldDeny(self, user, methodName):
+ self.assertEqual(401, self.response_status(user, methodName))
+
def test_001_allow_all(self):
- self.context.user = self.testadmin
- self.assertTrue(self._allow_all(self.context))
- self.context.user = self.testpmsys
- self.assertTrue(self._allow_all(self.context))
- self.context.user = self.testnet
- self.assertTrue(self._allow_all(self.context))
- self.context.user = self.testsys
- self.assertTrue(self._allow_all(self.context))
+ users = [self.testadmin, self.testpmsys, self.testnet, self.testsys]
+ for user in users:
+ self.shouldAllow(user, '_allow_all')
def test_002_allow_none(self):
- self.context.user = self.testadmin
- self.assertTrue(self._allow_none(self.context))
- self.context.user = self.testpmsys
- self.assertRaises(exception.NotAuthorized, self._allow_none, self.context)
- self.context.user = self.testnet
- self.assertRaises(exception.NotAuthorized, self._allow_none, self.context)
- self.context.user = self.testsys
- self.assertRaises(exception.NotAuthorized, self._allow_none, self.context)
+ self.shouldAllow(self.testadmin, '_allow_none')
+ users = [self.testpmsys, self.testnet, self.testsys]
+ for user in users:
+ self.shouldDeny(user, '_allow_none')
def test_003_allow_project_manager(self):
- self.context.user = self.testadmin
- self.assertTrue(self._allow_project_manager(self.context))
- self.context.user = self.testpmsys
- self.assertTrue(self._allow_project_manager(self.context))
- self.context.user = self.testnet
- self.assertRaises(exception.NotAuthorized, self._allow_project_manager, self.context)
- self.context.user = self.testsys
- self.assertRaises(exception.NotAuthorized, self._allow_project_manager, self.context)
+ for user in [self.testadmin, self.testpmsys]:
+ self.shouldAllow(user, '_allow_project_manager')
+ for user in [self.testnet, self.testsys]:
+ self.shouldDeny(user, '_allow_project_manager')
def test_004_allow_sys_and_net(self):
- self.context.user = self.testadmin
- self.assertTrue(self._allow_sys_and_net(self.context))
- self.context.user = self.testpmsys # doesn't have the per project sysadmin
- self.assertRaises(exception.NotAuthorized, self._allow_sys_and_net, self.context)
- self.context.user = self.testnet
- self.assertTrue(self._allow_sys_and_net(self.context))
- self.context.user = self.testsys
- self.assertTrue(self._allow_sys_and_net(self.context))
-
- def test_005_allow_sys_no_pm(self):
- self.context.user = self.testadmin
- self.assertTrue(self._allow_sys_no_pm(self.context))
- self.context.user = self.testpmsys
- self.assertRaises(exception.NotAuthorized, self._allow_sys_no_pm, self.context)
- self.context.user = self.testnet
- self.assertRaises(exception.NotAuthorized, self._allow_sys_no_pm, self.context)
- self.context.user = self.testsys
- self.assertTrue(self._allow_sys_no_pm(self.context))
-
- @rbac.allow('all')
- def _allow_all(self, context):
- return True
-
- @rbac.allow('none')
- def _allow_none(self, context):
- return True
-
- @rbac.allow('projectmanager')
- def _allow_project_manager(self, context):
- return True
-
- @rbac.allow('sysadmin', 'netadmin')
- def _allow_sys_and_net(self, context):
- return True
-
- @rbac.allow('sysadmin')
- @rbac.deny('projectmanager')
- def _allow_sys_no_pm(self, context):
- return True
+ for user in [self.testadmin, self.testnet, self.testsys]:
+ self.shouldAllow(user, '_allow_sys_and_net')
+ # denied because it doesn't have the per project sysadmin
+ for user in [self.testpmsys]:
+ self.shouldDeny(user, '_allow_sys_and_net')
if __name__ == "__main__":
# TODO: Implement use_fake as an option
diff --git a/nova/tests/api/__init__.py b/nova/tests/api/__init__.py
index 59c4adc3d..fc1ab9ae2 100644
--- a/nova/tests/api/__init__.py
+++ b/nova/tests/api/__init__.py
@@ -25,6 +25,7 @@ import stubout
import webob
import webob.dec
+import nova.exception
from nova import api
from nova.tests.api.test_helper import *
@@ -36,24 +37,46 @@ class Test(unittest.TestCase):
def tearDown(self): # pylint: disable-msg=C0103
self.stubs.UnsetAll()
+ def _request(self, url, subdomain, **kwargs):
+ environ_keys = {'HTTP_HOST': '%s.example.com' % subdomain}
+ environ_keys.update(kwargs)
+ req = webob.Request.blank(url, environ_keys)
+ return req.get_response(api.API())
+
def test_rackspace(self):
self.stubs.Set(api.rackspace, 'API', APIStub)
- result = webob.Request.blank('/v1.0/cloud').get_response(api.API())
+ result = self._request('/v1.0/cloud', 'rs')
self.assertEqual(result.body, "/cloud")
def test_ec2(self):
self.stubs.Set(api.ec2, 'API', APIStub)
- result = webob.Request.blank('/ec2/cloud').get_response(api.API())
+ result = self._request('/services/cloud', 'ec2')
self.assertEqual(result.body, "/cloud")
def test_not_found(self):
self.stubs.Set(api.ec2, 'API', APIStub)
self.stubs.Set(api.rackspace, 'API', APIStub)
- result = webob.Request.blank('/test/cloud').get_response(api.API())
+ result = self._request('/test/cloud', 'ec2')
self.assertNotEqual(result.body, "/cloud")
- def test_query_api_version(self):
- pass
+ def test_query_api_versions(self):
+ result = self._request('/', 'rs')
+ self.assertTrue('CURRENT' in result.body)
+
+ def test_metadata(self):
+ def go(url):
+ result = self._request(url, 'ec2',
+ REMOTE_ADDR='128.192.151.2')
+ # Each should get to the ORM layer and fail to find the IP
+ self.assertRaises(nova.exception.NotFound, go, '/latest/')
+ self.assertRaises(nova.exception.NotFound, go, '/2009-04-04/')
+ self.assertRaises(nova.exception.NotFound, go, '/1.0/')
+
+ def test_ec2_root(self):
+ result = self._request('/', 'ec2')
+ self.assertTrue('2007-12-15\n' in result.body)
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/nova/tests/api/rackspace/__init__.py b/nova/tests/api/rackspace/__init__.py
index e69de29bb..bfd0f87a7 100644
--- a/nova/tests/api/rackspace/__init__.py
+++ b/nova/tests/api/rackspace/__init__.py
@@ -0,0 +1,108 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+from nova.api.rackspace import limited
+from nova.api.rackspace import RateLimitingMiddleware
+from nova.tests.api.test_helper import *
+from webob import Request
+
+
+class RateLimitingMiddlewareTest(unittest.TestCase):
+
+ def test_get_action_name(self):
+ middleware = RateLimitingMiddleware(APIStub())
+ def verify(method, url, action_name):
+ req = Request.blank(url)
+ req.method = method
+ action = middleware.get_action_name(req)
+ self.assertEqual(action, action_name)
+ verify('PUT', '/servers/4', 'PUT')
+ verify('DELETE', '/servers/4', 'DELETE')
+ verify('POST', '/images/4', 'POST')
+ verify('POST', '/servers/4', 'POST servers')
+ verify('GET', '/foo?a=4&changes-since=never&b=5', 'GET changes-since')
+ verify('GET', '/foo?a=4&monkeys-since=never&b=5', None)
+ verify('GET', '/servers/4', None)
+ verify('HEAD', '/servers/4', None)
+
+ def exhaust(self, middleware, method, url, username, times):
+ req = Request.blank(url, dict(REQUEST_METHOD=method),
+ headers={'X-Auth-User': username})
+ for i in range(times):
+ resp = req.get_response(middleware)
+ self.assertEqual(resp.status_int, 200)
+ resp = req.get_response(middleware)
+ self.assertEqual(resp.status_int, 413)
+ self.assertTrue('Retry-After' in resp.headers)
+
+ def test_single_action(self):
+ middleware = RateLimitingMiddleware(APIStub())
+ self.exhaust(middleware, 'DELETE', '/servers/4', 'usr1', 100)
+ self.exhaust(middleware, 'DELETE', '/servers/4', 'usr2', 100)
+
+ def test_POST_servers_action_implies_POST_action(self):
+ middleware = RateLimitingMiddleware(APIStub())
+ self.exhaust(middleware, 'POST', '/servers/4', 'usr1', 10)
+ self.exhaust(middleware, 'POST', '/images/4', 'usr2', 10)
+ self.assertTrue(set(middleware.limiter._levels) ==
+ set(['usr1:POST', 'usr1:POST servers', 'usr2:POST']))
+
+ def test_POST_servers_action_correctly_ratelimited(self):
+ middleware = RateLimitingMiddleware(APIStub())
+ # Use up all of our "POST" allowance for the minute, 5 times
+ for i in range(5):
+ self.exhaust(middleware, 'POST', '/servers/4', 'usr1', 10)
+ # Reset the 'POST' action counter.
+ del middleware.limiter._levels['usr1:POST']
+ # All 50 daily "POST servers" actions should be all used up
+ self.exhaust(middleware, 'POST', '/servers/4', 'usr1', 0)
+
+ def test_proxy_ctor_works(self):
+ middleware = RateLimitingMiddleware(APIStub())
+ self.assertEqual(middleware.limiter.__class__.__name__, "Limiter")
+ middleware = RateLimitingMiddleware(APIStub(), service_host='foobar')
+ self.assertEqual(middleware.limiter.__class__.__name__, "WSGIAppProxy")
+
+
+class LimiterTest(unittest.TestCase):
+
+ def testLimiter(self):
+ items = range(2000)
+ req = Request.blank('/')
+ self.assertEqual(limited(items, req), items[ :1000])
+ req = Request.blank('/?offset=0')
+ self.assertEqual(limited(items, req), items[ :1000])
+ req = Request.blank('/?offset=3')
+ self.assertEqual(limited(items, req), items[3:1003])
+ req = Request.blank('/?offset=2005')
+ self.assertEqual(limited(items, req), [])
+ req = Request.blank('/?limit=10')
+ self.assertEqual(limited(items, req), items[ :10])
+ req = Request.blank('/?limit=0')
+ self.assertEqual(limited(items, req), items[ :1000])
+ req = Request.blank('/?limit=3000')
+ self.assertEqual(limited(items, req), items[ :1000])
+ req = Request.blank('/?offset=1&limit=3')
+ self.assertEqual(limited(items, req), items[1:4])
+ req = Request.blank('/?offset=3&limit=0')
+ self.assertEqual(limited(items, req), items[3:1003])
+ req = Request.blank('/?offset=3&limit=1500')
+ self.assertEqual(limited(items, req), items[3:1003])
+ req = Request.blank('/?offset=3000&limit=10')
+ self.assertEqual(limited(items, req), [])
diff --git a/nova/tests/api/rackspace/auth.py b/nova/tests/api/rackspace/auth.py
new file mode 100644
index 000000000..56677c2f4
--- /dev/null
+++ b/nova/tests/api/rackspace/auth.py
@@ -0,0 +1,108 @@
+import datetime
+import unittest
+
+import stubout
+import webob
+import webob.dec
+
+import nova.api
+import nova.api.rackspace.auth
+from nova import auth
+from nova.tests.api.rackspace import test_helper
+
+class Test(unittest.TestCase):
+ def setUp(self):
+ self.stubs = stubout.StubOutForTesting()
+ self.stubs.Set(nova.api.rackspace.auth.BasicApiAuthManager,
+ '__init__', test_helper.fake_auth_init)
+ test_helper.FakeAuthManager.auth_data = {}
+ test_helper.FakeAuthDatabase.data = {}
+ test_helper.stub_out_rate_limiting(self.stubs)
+ test_helper.stub_for_testing(self.stubs)
+
+ def tearDown(self):
+ self.stubs.UnsetAll()
+ test_helper.fake_data_store = {}
+
+ def test_authorize_user(self):
+ f = test_helper.FakeAuthManager()
+ f.add_user('derp', { 'uid': 1, 'name':'herp' } )
+
+ req = webob.Request.blank('/v1.0/')
+ req.headers['X-Auth-User'] = 'herp'
+ req.headers['X-Auth-Key'] = 'derp'
+ result = req.get_response(nova.api.API())
+ self.assertEqual(result.status, '204 No Content')
+ self.assertEqual(len(result.headers['X-Auth-Token']), 40)
+ self.assertEqual(result.headers['X-CDN-Management-Url'],
+ "")
+ self.assertEqual(result.headers['X-Storage-Url'], "")
+
+ def test_authorize_token(self):
+ f = test_helper.FakeAuthManager()
+ f.add_user('derp', { 'uid': 1, 'name':'herp' } )
+
+ req = webob.Request.blank('/v1.0/')
+ req.headers['X-Auth-User'] = 'herp'
+ req.headers['X-Auth-Key'] = 'derp'
+ result = req.get_response(nova.api.API())
+ self.assertEqual(result.status, '204 No Content')
+ self.assertEqual(len(result.headers['X-Auth-Token']), 40)
+ self.assertEqual(result.headers['X-Server-Management-Url'],
+ "https://foo/v1.0/")
+ self.assertEqual(result.headers['X-CDN-Management-Url'],
+ "")
+ self.assertEqual(result.headers['X-Storage-Url'], "")
+
+ token = result.headers['X-Auth-Token']
+ self.stubs.Set(nova.api.rackspace, 'APIRouter',
+ test_helper.FakeRouter)
+ req = webob.Request.blank('/v1.0/fake')
+ req.headers['X-Auth-Token'] = token
+ result = req.get_response(nova.api.API())
+ self.assertEqual(result.status, '200 OK')
+ self.assertEqual(result.headers['X-Test-Success'], 'True')
+
+ def test_token_expiry(self):
+ self.destroy_called = False
+ token_hash = 'bacon'
+
+ def destroy_token_mock(meh, context, token):
+ self.destroy_called = True
+
+ def bad_token(meh, context, token_hash):
+ return { 'token_hash':token_hash,
+ 'created_at':datetime.datetime(1990, 1, 1) }
+
+ self.stubs.Set(test_helper.FakeAuthDatabase, 'auth_destroy_token',
+ destroy_token_mock)
+
+ self.stubs.Set(test_helper.FakeAuthDatabase, 'auth_get_token',
+ bad_token)
+
+ req = webob.Request.blank('/v1.0/')
+ req.headers['X-Auth-Token'] = 'bacon'
+ result = req.get_response(nova.api.API())
+ self.assertEqual(result.status, '401 Unauthorized')
+ self.assertEqual(self.destroy_called, True)
+
+ def test_bad_user(self):
+ req = webob.Request.blank('/v1.0/')
+ req.headers['X-Auth-User'] = 'herp'
+ req.headers['X-Auth-Key'] = 'derp'
+ result = req.get_response(nova.api.API())
+ self.assertEqual(result.status, '401 Unauthorized')
+
+ def test_no_user(self):
+ req = webob.Request.blank('/v1.0/')
+ result = req.get_response(nova.api.API())
+ self.assertEqual(result.status, '401 Unauthorized')
+
+ def test_bad_token(self):
+ req = webob.Request.blank('/v1.0/')
+ req.headers['X-Auth-Token'] = 'baconbaconbacon'
+ result = req.get_response(nova.api.API())
+ self.assertEqual(result.status, '401 Unauthorized')
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/nova/tests/api/rackspace/flavors.py b/nova/tests/api/rackspace/flavors.py
index fb8ba94a5..d25a2e2be 100644
--- a/nova/tests/api/rackspace/flavors.py
+++ b/nova/tests/api/rackspace/flavors.py
@@ -16,19 +16,31 @@
# under the License.
import unittest
+import stubout
+import nova.api
from nova.api.rackspace import flavors
+from nova.tests.api.rackspace import test_helper
from nova.tests.api.test_helper import *
class FlavorsTest(unittest.TestCase):
def setUp(self):
self.stubs = stubout.StubOutForTesting()
+ test_helper.FakeAuthManager.auth_data = {}
+ test_helper.FakeAuthDatabase.data = {}
+ test_helper.stub_for_testing(self.stubs)
+ test_helper.stub_out_rate_limiting(self.stubs)
+ test_helper.stub_out_auth(self.stubs)
def tearDown(self):
self.stubs.UnsetAll()
def test_get_flavor_list(self):
- pass
+ req = webob.Request.blank('/v1.0/flavors')
+ res = req.get_response(nova.api.API())
def test_get_flavor_by_id(self):
pass
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/nova/tests/api/rackspace/images.py b/nova/tests/api/rackspace/images.py
index 560d8c898..4c9987e8b 100644
--- a/nova/tests/api/rackspace/images.py
+++ b/nova/tests/api/rackspace/images.py
@@ -15,6 +15,7 @@
# License for the specific language governing permissions and limitations
# under the License.
+import stubout
import unittest
from nova.api.rackspace import images
diff --git a/nova/tests/api/rackspace/servers.py b/nova/tests/api/rackspace/servers.py
index 6d628e78a..69ad2c1d3 100644
--- a/nova/tests/api/rackspace/servers.py
+++ b/nova/tests/api/rackspace/servers.py
@@ -15,44 +15,231 @@
# License for the specific language governing permissions and limitations
# under the License.
+import json
import unittest
+import stubout
+
+from nova import db
+from nova import flags
+import nova.api.rackspace
from nova.api.rackspace import servers
+import nova.db.api
+from nova.db.sqlalchemy.models import Instance
+import nova.rpc
from nova.tests.api.test_helper import *
+from nova.tests.api.rackspace import test_helper
+
+FLAGS = flags.FLAGS
+
+def return_server(context, id):
+ return stub_instance(id)
+
+def return_servers(context, user_id=1):
+ return [stub_instance(i, user_id) for i in xrange(5)]
+
+
+def stub_instance(id, user_id=1):
+ return Instance(
+ id=id, state=0, image_id=10, server_name='server%s'%id,
+ user_id=user_id
+ )
class ServersTest(unittest.TestCase):
def setUp(self):
self.stubs = stubout.StubOutForTesting()
+ test_helper.FakeAuthManager.auth_data = {}
+ test_helper.FakeAuthDatabase.data = {}
+ test_helper.stub_for_testing(self.stubs)
+ test_helper.stub_out_rate_limiting(self.stubs)
+ test_helper.stub_out_auth(self.stubs)
+ test_helper.stub_out_id_translator(self.stubs)
+ test_helper.stub_out_key_pair_funcs(self.stubs)
+ test_helper.stub_out_image_service(self.stubs)
+ self.stubs.Set(nova.db.api, 'instance_get_all', return_servers)
+ self.stubs.Set(nova.db.api, 'instance_get_by_ec2_id', return_server)
+ self.stubs.Set(nova.db.api, 'instance_get_all_by_user',
+ return_servers)
def tearDown(self):
self.stubs.UnsetAll()
+ def test_get_server_by_id(self):
+ req = webob.Request.blank('/v1.0/servers/1')
+ res = req.get_response(nova.api.API())
+ res_dict = json.loads(res.body)
+ self.assertEqual(res_dict['server']['id'], '1')
+ self.assertEqual(res_dict['server']['name'], 'server1')
+
def test_get_server_list(self):
- pass
+ req = webob.Request.blank('/v1.0/servers')
+ res = req.get_response(nova.api.API())
+ res_dict = json.loads(res.body)
+
+ i = 0
+ for s in res_dict['servers']:
+ self.assertEqual(s['id'], i)
+ self.assertEqual(s['name'], 'server%d'%i)
+ self.assertEqual(s.get('imageId', None), None)
+ i += 1
def test_create_instance(self):
- pass
+ def server_update(context, id, params):
+ pass
- def test_get_server_by_id(self):
- pass
+ def instance_create(context, inst):
+ class Foo(object):
+ ec2_id = 1
+ return Foo()
+
+ def fake_method(*args, **kwargs):
+ pass
+
+ def project_get_network(context, user_id):
+ return dict(id='1', host='localhost')
+
+ def queue_get_for(context, *args):
+ return 'network_topic'
+
+ self.stubs.Set(nova.db.api, 'project_get_network', project_get_network)
+ self.stubs.Set(nova.db.api, 'instance_create', instance_create)
+ self.stubs.Set(nova.rpc, 'cast', fake_method)
+ self.stubs.Set(nova.rpc, 'call', fake_method)
+ self.stubs.Set(nova.db.api, 'instance_update',
+ server_update)
+ self.stubs.Set(nova.db.api, 'queue_get_for', queue_get_for)
+ self.stubs.Set(nova.network.manager.FlatManager, 'allocate_fixed_ip',
+ fake_method)
+
+ test_helper.stub_out_id_translator(self.stubs)
+ body = dict(server=dict(
+ name='server_test', imageId=2, flavorId=2, metadata={},
+ personality = {}
+ ))
+ req = webob.Request.blank('/v1.0/servers')
+ req.method = 'POST'
+ req.body = json.dumps(body)
+
+ res = req.get_response(nova.api.API())
+
+ self.assertEqual(res.status_int, 200)
- def test_get_backup_schedule(self):
- pass
+ def test_update_no_body(self):
+ req = webob.Request.blank('/v1.0/servers/1')
+ req.method = 'PUT'
+ res = req.get_response(nova.api.API())
+ self.assertEqual(res.status_int, 422)
- def test_get_server_details(self):
- pass
+ def test_update_bad_params(self):
+ """ Confirm that update is filtering params """
+ inst_dict = dict(cat='leopard', name='server_test', adminPass='bacon')
+ self.body = json.dumps(dict(server=inst_dict))
- def test_get_server_ips(self):
- pass
+ def server_update(context, id, params):
+ self.update_called = True
+ filtered_dict = dict(name='server_test', admin_pass='bacon')
+ self.assertEqual(params, filtered_dict)
+
+ self.stubs.Set(nova.db.api, 'instance_update',
+ server_update)
+
+ req = webob.Request.blank('/v1.0/servers/1')
+ req.method = 'PUT'
+ req.body = self.body
+ req.get_response(nova.api.API())
+
+ def test_update_server(self):
+ inst_dict = dict(name='server_test', adminPass='bacon')
+ self.body = json.dumps(dict(server=inst_dict))
+
+ def server_update(context, id, params):
+ filtered_dict = dict(name='server_test', admin_pass='bacon')
+ self.assertEqual(params, filtered_dict)
+
+ self.stubs.Set(nova.db.api, 'instance_update',
+ server_update)
+
+ req = webob.Request.blank('/v1.0/servers/1')
+ req.method = 'PUT'
+ req.body = self.body
+ req.get_response(nova.api.API())
+
+ def test_create_backup_schedules(self):
+ req = webob.Request.blank('/v1.0/servers/1/backup_schedules')
+ req.method = 'POST'
+ res = req.get_response(nova.api.API())
+ self.assertEqual(res.status, '404 Not Found')
+
+ def test_delete_backup_schedules(self):
+ req = webob.Request.blank('/v1.0/servers/1/backup_schedules')
+ req.method = 'DELETE'
+ res = req.get_response(nova.api.API())
+ self.assertEqual(res.status, '404 Not Found')
+
+ def test_get_server_backup_schedules(self):
+ req = webob.Request.blank('/v1.0/servers/1/backup_schedules')
+ res = req.get_response(nova.api.API())
+ self.assertEqual(res.status, '404 Not Found')
+
+ def test_get_all_server_details(self):
+ req = webob.Request.blank('/v1.0/servers/detail')
+ res = req.get_response(nova.api.API())
+ res_dict = json.loads(res.body)
+
+ i = 0
+ for s in res_dict['servers']:
+ self.assertEqual(s['id'], i)
+ self.assertEqual(s['name'], 'server%d'%i)
+ self.assertEqual(s['imageId'], 10)
+ i += 1
def test_server_reboot(self):
- pass
+ body = dict(server=dict(
+ name='server_test', imageId=2, flavorId=2, metadata={},
+ personality = {}
+ ))
+ req = webob.Request.blank('/v1.0/servers/1/action')
+ req.method = 'POST'
+ req.content_type= 'application/json'
+ req.body = json.dumps(body)
+ res = req.get_response(nova.api.API())
def test_server_rebuild(self):
- pass
+ body = dict(server=dict(
+ name='server_test', imageId=2, flavorId=2, metadata={},
+ personality = {}
+ ))
+ req = webob.Request.blank('/v1.0/servers/1/action')
+ req.method = 'POST'
+ req.content_type= 'application/json'
+ req.body = json.dumps(body)
+ res = req.get_response(nova.api.API())
def test_server_resize(self):
- pass
+ body = dict(server=dict(
+ name='server_test', imageId=2, flavorId=2, metadata={},
+ personality = {}
+ ))
+ req = webob.Request.blank('/v1.0/servers/1/action')
+ req.method = 'POST'
+ req.content_type= 'application/json'
+ req.body = json.dumps(body)
+ res = req.get_response(nova.api.API())
def test_delete_server_instance(self):
- pass
+ req = webob.Request.blank('/v1.0/servers/1')
+ req.method = 'DELETE'
+
+ self.server_delete_called = False
+ def instance_destroy_mock(context, id):
+ self.server_delete_called = True
+
+ self.stubs.Set(nova.db.api, 'instance_destroy',
+ instance_destroy_mock)
+
+ res = req.get_response(nova.api.API())
+ self.assertEqual(res.status, '202 Accepted')
+ self.assertEqual(self.server_delete_called, True)
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/nova/tests/api/rackspace/sharedipgroups.py b/nova/tests/api/rackspace/sharedipgroups.py
index b4b281db7..1906b54f5 100644
--- a/nova/tests/api/rackspace/sharedipgroups.py
+++ b/nova/tests/api/rackspace/sharedipgroups.py
@@ -15,6 +15,7 @@
# License for the specific language governing permissions and limitations
# under the License.
+import stubout
import unittest
from nova.api.rackspace import sharedipgroups
diff --git a/nova/tests/api/rackspace/test_helper.py b/nova/tests/api/rackspace/test_helper.py
new file mode 100644
index 000000000..2cf154f63
--- /dev/null
+++ b/nova/tests/api/rackspace/test_helper.py
@@ -0,0 +1,134 @@
+import datetime
+import json
+
+import webob
+import webob.dec
+
+from nova import auth
+from nova import utils
+from nova import flags
+import nova.api.rackspace.auth
+import nova.api.rackspace._id_translator
+from nova.image import service
+from nova.wsgi import Router
+
+FLAGS = flags.FLAGS
+
+class Context(object):
+ pass
+
+class FakeRouter(Router):
+ def __init__(self):
+ pass
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ res = webob.Response()
+ res.status = '200'
+ res.headers['X-Test-Success'] = 'True'
+ return res
+
+def fake_auth_init(self):
+ self.db = FakeAuthDatabase()
+ self.context = Context()
+ self.auth = FakeAuthManager()
+ self.host = 'foo'
+
+@webob.dec.wsgify
+def fake_wsgi(self, req):
+ req.environ['nova.context'] = dict(user=dict(id=1))
+ if req.body:
+ req.environ['inst_dict'] = json.loads(req.body)
+ return self.application
+
+def stub_out_key_pair_funcs(stubs):
+ def key_pair(context, user_id):
+ return [dict(name='key', public_key='public_key')]
+ stubs.Set(nova.db.api, 'key_pair_get_all_by_user',
+ key_pair)
+
+def stub_out_image_service(stubs):
+ def fake_image_show(meh, id):
+ return dict(kernelId=1, ramdiskId=1)
+
+ stubs.Set(nova.image.service.LocalImageService, 'show', fake_image_show)
+
+def stub_out_id_translator(stubs):
+ class FakeTranslator(object):
+ def __init__(self, id_type, service_name):
+ pass
+
+ def to_rs_id(self, id):
+ return id
+
+ def from_rs_id(self, id):
+ return id
+
+ stubs.Set(nova.api.rackspace._id_translator,
+ 'RackspaceAPIIdTranslator', FakeTranslator)
+
+def stub_out_auth(stubs):
+ def fake_auth_init(self, app):
+ self.application = app
+
+ stubs.Set(nova.api.rackspace.AuthMiddleware,
+ '__init__', fake_auth_init)
+ stubs.Set(nova.api.rackspace.AuthMiddleware,
+ '__call__', fake_wsgi)
+
+def stub_out_rate_limiting(stubs):
+ def fake_rate_init(self, app):
+ super(nova.api.rackspace.RateLimitingMiddleware, self).__init__(app)
+ self.application = app
+
+ stubs.Set(nova.api.rackspace.RateLimitingMiddleware,
+ '__init__', fake_rate_init)
+
+ stubs.Set(nova.api.rackspace.RateLimitingMiddleware,
+ '__call__', fake_wsgi)
+
+def stub_for_testing(stubs):
+ def get_my_ip():
+ return '127.0.0.1'
+ stubs.Set(nova.utils, 'get_my_ip', get_my_ip)
+ FLAGS.FAKE_subdomain = 'rs'
+
+class FakeAuthDatabase(object):
+ data = {}
+
+ @staticmethod
+ def auth_get_token(context, token_hash):
+ return FakeAuthDatabase.data.get(token_hash, None)
+
+ @staticmethod
+ def auth_create_token(context, token):
+ token['created_at'] = datetime.datetime.now()
+ FakeAuthDatabase.data[token['token_hash']] = token
+
+ @staticmethod
+ def auth_destroy_token(context, token):
+ if FakeAuthDatabase.data.has_key(token['token_hash']):
+ del FakeAuthDatabase.data['token_hash']
+
+class FakeAuthManager(object):
+ auth_data = {}
+
+ def add_user(self, key, user):
+ FakeAuthManager.auth_data[key] = user
+
+ def get_user(self, uid):
+ for k, v in FakeAuthManager.auth_data.iteritems():
+ if v['uid'] == uid:
+ return v
+ return None
+
+ def get_user_from_access_key(self, key):
+ return FakeAuthManager.auth_data.get(key, None)
+
+class FakeRateLimiter(object):
+ def __init__(self, application):
+ self.application = application
+
+ @webob.dec.wsgify
+ def __call__(self, req):
+ return self.application
diff --git a/nova/tests/api/rackspace/testfaults.py b/nova/tests/api/rackspace/testfaults.py
new file mode 100644
index 000000000..b2931bc98
--- /dev/null
+++ b/nova/tests/api/rackspace/testfaults.py
@@ -0,0 +1,40 @@
+import unittest
+import webob
+import webob.dec
+import webob.exc
+
+from nova.api.rackspace import faults
+
+class TestFaults(unittest.TestCase):
+
+ def test_fault_parts(self):
+ req = webob.Request.blank('/.xml')
+ f = faults.Fault(webob.exc.HTTPBadRequest(explanation='scram'))
+ resp = req.get_response(f)
+
+ first_two_words = resp.body.strip().split()[:2]
+ self.assertEqual(first_two_words, ['<badRequest', 'code="400">'])
+ body_without_spaces = ''.join(resp.body.split())
+ self.assertTrue('<message>scram</message>' in body_without_spaces)
+
+ def test_retry_header(self):
+ req = webob.Request.blank('/.xml')
+ exc = webob.exc.HTTPRequestEntityTooLarge(explanation='sorry',
+ headers={'Retry-After': 4})
+ f = faults.Fault(exc)
+ resp = req.get_response(f)
+ first_two_words = resp.body.strip().split()[:2]
+ self.assertEqual(first_two_words, ['<overLimit', 'code="413">'])
+ body_sans_spaces = ''.join(resp.body.split())
+ self.assertTrue('<message>sorry</message>' in body_sans_spaces)
+ self.assertTrue('<retryAfter>4</retryAfter>' in body_sans_spaces)
+ self.assertEqual(resp.headers['Retry-After'], 4)
+
+ def test_raise(self):
+ @webob.dec.wsgify
+ def raiser(req):
+ raise faults.Fault(webob.exc.HTTPNotFound(explanation='whut?'))
+ req = webob.Request.blank('/.xml')
+ resp = req.get_response(raiser)
+ self.assertEqual(resp.status_int, 404)
+ self.assertTrue('whut?' in resp.body)
diff --git a/nova/tests/api/test_helper.py b/nova/tests/api/test_helper.py
index 8151a4af6..d0a2cc027 100644
--- a/nova/tests/api/test_helper.py
+++ b/nova/tests/api/test_helper.py
@@ -1,4 +1,5 @@
import webob.dec
+from nova import wsgi
class APIStub(object):
"""Class to verify request and mark it was called."""
diff --git a/nova/tests/api/wsgi_test.py b/nova/tests/api/wsgi_test.py
index 786dc1bce..9425b01d0 100644
--- a/nova/tests/api/wsgi_test.py
+++ b/nova/tests/api/wsgi_test.py
@@ -91,6 +91,57 @@ class Test(unittest.TestCase):
result = webob.Request.blank('/test/123').get_response(Router())
self.assertNotEqual(result.body, "123")
- def test_serializer(self):
- # TODO(eday): Placeholder for serializer testing.
- pass
+
+class SerializerTest(unittest.TestCase):
+
+ def match(self, url, accept, expect):
+ input_dict = dict(servers=dict(a=(2,3)))
+ expected_xml = '<servers><a>(2,3)</a></servers>'
+ expected_json = '{"servers":{"a":[2,3]}}'
+ req = webob.Request.blank(url, headers=dict(Accept=accept))
+ result = wsgi.Serializer(req.environ).to_content_type(input_dict)
+ result = result.replace('\n', '').replace(' ', '')
+ if expect == 'xml':
+ self.assertEqual(result, expected_xml)
+ elif expect == 'json':
+ self.assertEqual(result, expected_json)
+ else:
+ raise "Bad expect value"
+
+ def test_basic(self):
+ self.match('/servers/4.json', None, expect='json')
+ self.match('/servers/4', 'application/json', expect='json')
+ self.match('/servers/4', 'application/xml', expect='xml')
+ self.match('/servers/4.xml', None, expect='xml')
+
+ def test_defaults_to_json(self):
+ self.match('/servers/4', None, expect='json')
+ self.match('/servers/4', 'text/html', expect='json')
+
+ def test_suffix_takes_precedence_over_accept_header(self):
+ self.match('/servers/4.xml', 'application/json', expect='xml')
+ self.match('/servers/4.xml.', 'application/json', expect='json')
+
+ def test_deserialize(self):
+ xml = """
+ <a a1="1" a2="2">
+ <bs><b>1</b><b>2</b><b>3</b><b><c c1="1"/></b></bs>
+ <d><e>1</e></d>
+ <f>1</f>
+ </a>
+ """.strip()
+ as_dict = dict(a={
+ 'a1': '1',
+ 'a2': '2',
+ 'bs': ['1', '2', '3', {'c': dict(c1='1')}],
+ 'd': {'e': '1'},
+ 'f': '1'})
+ metadata = {'application/xml': dict(plurals={'bs': 'b', 'ts': 't'})}
+ serializer = wsgi.Serializer({}, metadata)
+ self.assertEqual(serializer.deserialize(xml), as_dict)
+
+ def test_deserialize_empty_xml(self):
+ xml = """<a></a>"""
+ as_dict = {"a": {}}
+ serializer = wsgi.Serializer({})
+ self.assertEqual(serializer.deserialize(xml), as_dict)
diff --git a/nova/tests/api_unittest.py b/nova/tests/api_unittest.py
index 462d1b295..c040cdad3 100644
--- a/nova/tests/api_unittest.py
+++ b/nova/tests/api_unittest.py
@@ -23,60 +23,17 @@ from boto.ec2 import regioninfo
import httplib
import random
import StringIO
-from tornado import httpserver
-from twisted.internet import defer
+import webob
from nova import flags
from nova import test
+from nova import api
+from nova.api.ec2 import cloud
from nova.auth import manager
-from nova.endpoint import api
-from nova.endpoint import cloud
FLAGS = flags.FLAGS
-
-
-# NOTE(termie): These are a bunch of helper methods and classes to short
-# circuit boto calls and feed them into our tornado handlers,
-# it's pretty damn circuitous so apologies if you have to fix
-# a bug in it
-# NOTE(jaypipes) The pylint disables here are for R0913 (too many args) which
-# isn't controllable since boto's HTTPRequest needs that many
-# args, and for the version-differentiated import of tornado's
-# httputil.
-# NOTE(jaypipes): The disable-msg=E1101 and E1103 below is because pylint is
-# unable to introspect the deferred's return value properly
-
-def boto_to_tornado(method, path, headers, data, # pylint: disable-msg=R0913
- host, connection=None):
- """ translate boto requests into tornado requests
-
- connection should be a FakeTornadoHttpConnection instance
- """
- try:
- headers = httpserver.HTTPHeaders()
- except AttributeError:
- from tornado import httputil # pylint: disable-msg=E0611
- headers = httputil.HTTPHeaders()
- for k, v in headers.iteritems():
- headers[k] = v
-
- req = httpserver.HTTPRequest(method=method,
- uri=path,
- headers=headers,
- body=data,
- host=host,
- remote_ip='127.0.0.1',
- connection=connection)
- return req
-
-
-def raw_to_httpresponse(response_string):
- """translate a raw tornado http response into an httplib.HTTPResponse"""
- sock = FakeHttplibSocket(response_string)
- resp = httplib.HTTPResponse(sock)
- resp.begin()
- return resp
+FLAGS.FAKE_subdomain = 'ec2'
class FakeHttplibSocket(object):
@@ -89,85 +46,35 @@ class FakeHttplibSocket(object):
return self._buffer
-class FakeTornadoStream(object):
- """a fake stream to satisfy tornado's assumptions, trivial"""
- def set_close_callback(self, _func):
- """Dummy callback for stream"""
- pass
-
-
-class FakeTornadoConnection(object):
- """A fake connection object for tornado to pass to its handlers
-
- web requests are expected to write to this as they get data and call
- finish when they are done with the request, we buffer the writes and
- kick off a callback when it is done so that we can feed the result back
- into boto.
- """
- def __init__(self, deferred):
- self._deferred = deferred
- self._buffer = StringIO.StringIO()
-
- def write(self, chunk):
- """Writes a chunk of data to the internal buffer"""
- self._buffer.write(chunk)
-
- def finish(self):
- """Finalizes the connection and returns the buffered data via the
- deferred callback.
- """
- data = self._buffer.getvalue()
- self._deferred.callback(data)
-
- xheaders = None
-
- @property
- def stream(self): # pylint: disable-msg=R0201
- """Required property for interfacing with tornado"""
- return FakeTornadoStream()
-
-
class FakeHttplibConnection(object):
"""A fake httplib.HTTPConnection for boto to use
requests made via this connection actually get translated and routed into
- our tornado app, we then wait for the response and turn it back into
+ our WSGI app, we then wait for the response and turn it back into
the httplib.HTTPResponse that boto expects.
"""
def __init__(self, app, host, is_secure=False):
self.app = app
self.host = host
- self.deferred = defer.Deferred()
def request(self, method, path, data, headers):
- """Creates a connection to a fake tornado and sets
- up a deferred request with the supplied data and
- headers"""
- conn = FakeTornadoConnection(self.deferred)
- request = boto_to_tornado(connection=conn,
- method=method,
- path=path,
- headers=headers,
- data=data,
- host=self.host)
- self.app(request)
- self.deferred.addCallback(raw_to_httpresponse)
+ req = webob.Request.blank(path)
+ req.method = method
+ req.body = data
+ req.headers = headers
+ req.headers['Accept'] = 'text/html'
+ req.host = self.host
+ # Call the WSGI app, get the HTTP response
+ resp = str(req.get_response(self.app))
+ # For some reason, the response doesn't have "HTTP/1.0 " prepended; I
+ # guess that's a function the web server usually provides.
+ resp = "HTTP/1.0 %s" % resp
+ sock = FakeHttplibSocket(resp)
+ self.http_response = httplib.HTTPResponse(sock)
+ self.http_response.begin()
def getresponse(self):
- """A bit of deferred magic for catching the response
- from the previously deferred request"""
- @defer.inlineCallbacks
- def _waiter():
- """Callback that simply yields the deferred's
- return value."""
- result = yield self.deferred
- defer.returnValue(result)
- d = _waiter()
- # NOTE(termie): defer.returnValue above should ensure that
- # this deferred has already been called by the time
- # we get here, we are going to cheat and return
- # the result of the callback
- return d.result # pylint: disable-msg=E1101
+ return self.http_response
def close(self):
"""Required for compatibility with boto/tornado"""
@@ -180,17 +87,16 @@ class ApiEc2TestCase(test.BaseTestCase):
super(ApiEc2TestCase, self).setUp()
self.manager = manager.AuthManager()
- self.cloud = cloud.CloudController()
self.host = '127.0.0.1'
- self.app = api.APIServerApplication({'Cloud': self.cloud})
+ self.app = api.API()
self.ec2 = boto.connect_ec2(
aws_access_key_id='fake',
aws_secret_access_key='fake',
is_secure=False,
region=regioninfo.RegionInfo(None, 'test', self.host),
- port=FLAGS.cc_port,
+ port=8773,
path='/services/Cloud')
self.mox.StubOutWithMock(self.ec2, 'new_http_connection')
@@ -198,7 +104,7 @@ class ApiEc2TestCase(test.BaseTestCase):
def expect_http(self, host=None, is_secure=False):
"""Returns a new EC2 connection"""
http = FakeHttplibConnection(
- self.app, '%s:%d' % (self.host, FLAGS.cc_port), False)
+ self.app, '%s:8773' % (self.host), False)
# pylint: disable-msg=E1103
self.ec2.new_http_connection(host, is_secure).AndReturn(http)
return http
@@ -224,7 +130,8 @@ class ApiEc2TestCase(test.BaseTestCase):
for x in range(random.randint(4, 8)))
user = self.manager.create_user('fake', 'fake', 'fake')
project = self.manager.create_project('fake', 'fake', 'fake')
- self.manager.generate_key_pair(user.id, keyname)
+ # NOTE(vish): create depends on pool, so call helper directly
+ cloud._gen_key(None, user.id, keyname)
rv = self.ec2.get_all_key_pairs()
results = [k for k in rv if k.name == keyname]
diff --git a/nova/tests/auth_unittest.py b/nova/tests/auth_unittest.py
index b54e68274..99f7ab599 100644
--- a/nova/tests/auth_unittest.py
+++ b/nova/tests/auth_unittest.py
@@ -17,8 +17,6 @@
# under the License.
import logging
-from M2Crypto import BIO
-from M2Crypto import RSA
from M2Crypto import X509
import unittest
@@ -26,29 +24,76 @@ from nova import crypto
from nova import flags
from nova import test
from nova.auth import manager
-from nova.endpoint import cloud
+from nova.api.ec2 import cloud
FLAGS = flags.FLAGS
-
-class AuthTestCase(test.BaseTestCase):
+class user_generator(object):
+ def __init__(self, manager, **user_state):
+ if 'name' not in user_state:
+ user_state['name'] = 'test1'
+ self.manager = manager
+ self.user = manager.create_user(**user_state)
+
+ def __enter__(self):
+ return self.user
+
+ def __exit__(self, value, type, trace):
+ self.manager.delete_user(self.user)
+
+class project_generator(object):
+ def __init__(self, manager, **project_state):
+ if 'name' not in project_state:
+ project_state['name'] = 'testproj'
+ if 'manager_user' not in project_state:
+ project_state['manager_user'] = 'test1'
+ self.manager = manager
+ self.project = manager.create_project(**project_state)
+
+ def __enter__(self):
+ return self.project
+
+ def __exit__(self, value, type, trace):
+ self.manager.delete_project(self.project)
+
+class user_and_project_generator(object):
+ def __init__(self, manager, user_state={}, project_state={}):
+ self.manager = manager
+ if 'name' not in user_state:
+ user_state['name'] = 'test1'
+ if 'name' not in project_state:
+ project_state['name'] = 'testproj'
+ if 'manager_user' not in project_state:
+ project_state['manager_user'] = 'test1'
+ self.user = manager.create_user(**user_state)
+ self.project = manager.create_project(**project_state)
+
+ def __enter__(self):
+ return (self.user, self.project)
+
+ def __exit__(self, value, type, trace):
+ self.manager.delete_user(self.user)
+ self.manager.delete_project(self.project)
+
+class AuthManagerTestCase(object):
def setUp(self):
- super(AuthTestCase, self).setUp()
+ FLAGS.auth_driver = self.auth_driver
+ super(AuthManagerTestCase, self).setUp()
self.flags(connection_type='fake')
self.manager = manager.AuthManager()
- def test_001_can_create_users(self):
- self.manager.create_user('test1', 'access', 'secret')
- self.manager.create_user('test2')
-
- def test_002_can_get_user(self):
- user = self.manager.get_user('test1')
+ def test_create_and_find_user(self):
+ with user_generator(self.manager):
+ self.assert_(self.manager.get_user('test1'))
- def test_003_can_retreive_properties(self):
- user = self.manager.get_user('test1')
- self.assertEqual('test1', user.id)
- self.assertEqual('access', user.access)
- self.assertEqual('secret', user.secret)
+ def test_create_and_find_with_properties(self):
+ with user_generator(self.manager, name="herbert", secret="classified",
+ access="private-party"):
+ u = self.manager.get_user('herbert')
+ self.assertEqual('herbert', u.id)
+ self.assertEqual('herbert', u.name)
+ self.assertEqual('classified', u.secret)
+ self.assertEqual('private-party', u.access)
def test_004_signature_is_valid(self):
#self.assertTrue(self.manager.authenticate( **boto.generate_url ... ? ? ? ))
@@ -65,156 +110,222 @@ class AuthTestCase(test.BaseTestCase):
'export S3_URL="http://127.0.0.1:3333/"\n' +
'export EC2_USER_ID="test1"\n')
- def test_006_test_key_storage(self):
- user = self.manager.get_user('test1')
- user.create_key_pair('public', 'key', 'fingerprint')
- key = user.get_key_pair('public')
- self.assertEqual('key', key.public_key)
- self.assertEqual('fingerprint', key.fingerprint)
-
- def test_007_test_key_generation(self):
- user = self.manager.get_user('test1')
- private_key, fingerprint = user.generate_key_pair('public2')
- key = RSA.load_key_string(private_key, callback=lambda: None)
- bio = BIO.MemoryBuffer()
- public_key = user.get_key_pair('public2').public_key
- key.save_pub_key_bio(bio)
- converted = crypto.ssl_pub_to_ssh_pub(bio.read())
- # assert key fields are equal
- self.assertEqual(public_key.split(" ")[1].strip(),
- converted.split(" ")[1].strip())
-
- def test_008_can_list_key_pairs(self):
- keys = self.manager.get_user('test1').get_key_pairs()
- self.assertTrue(filter(lambda k: k.name == 'public', keys))
- self.assertTrue(filter(lambda k: k.name == 'public2', keys))
-
- def test_009_can_delete_key_pair(self):
- self.manager.get_user('test1').delete_key_pair('public')
- keys = self.manager.get_user('test1').get_key_pairs()
- self.assertFalse(filter(lambda k: k.name == 'public', keys))
-
- def test_010_can_list_users(self):
- users = self.manager.get_users()
- logging.warn(users)
- self.assertTrue(filter(lambda u: u.id == 'test1', users))
-
- def test_101_can_add_user_role(self):
- self.assertFalse(self.manager.has_role('test1', 'itsec'))
- self.manager.add_role('test1', 'itsec')
- self.assertTrue(self.manager.has_role('test1', 'itsec'))
-
- def test_199_can_remove_user_role(self):
- self.assertTrue(self.manager.has_role('test1', 'itsec'))
- self.manager.remove_role('test1', 'itsec')
- self.assertFalse(self.manager.has_role('test1', 'itsec'))
-
- def test_201_can_create_project(self):
- project = self.manager.create_project('testproj', 'test1', 'A test project', ['test1'])
- self.assertTrue(filter(lambda p: p.name == 'testproj', self.manager.get_projects()))
- self.assertEqual(project.name, 'testproj')
- self.assertEqual(project.description, 'A test project')
- self.assertEqual(project.project_manager_id, 'test1')
- self.assertTrue(project.has_member('test1'))
-
- def test_202_user1_is_project_member(self):
- self.assertTrue(self.manager.get_user('test1').is_project_member('testproj'))
-
- def test_203_user2_is_not_project_member(self):
- self.assertFalse(self.manager.get_user('test2').is_project_member('testproj'))
-
- def test_204_user1_is_project_manager(self):
- self.assertTrue(self.manager.get_user('test1').is_project_manager('testproj'))
-
- def test_205_user2_is_not_project_manager(self):
- self.assertFalse(self.manager.get_user('test2').is_project_manager('testproj'))
-
- def test_206_can_add_user_to_project(self):
- self.manager.add_to_project('test2', 'testproj')
- self.assertTrue(self.manager.get_project('testproj').has_member('test2'))
-
- def test_207_can_remove_user_from_project(self):
- self.manager.remove_from_project('test2', 'testproj')
- self.assertFalse(self.manager.get_project('testproj').has_member('test2'))
-
- def test_208_can_remove_add_user_with_role(self):
- self.manager.add_to_project('test2', 'testproj')
- self.manager.add_role('test2', 'developer', 'testproj')
- self.manager.remove_from_project('test2', 'testproj')
- self.assertFalse(self.manager.has_role('test2', 'developer', 'testproj'))
- self.manager.add_to_project('test2', 'testproj')
- self.manager.remove_from_project('test2', 'testproj')
-
- def test_209_can_generate_x509(self):
- # MUST HAVE RUN CLOUD SETUP BY NOW
- self.cloud = cloud.CloudController()
- self.cloud.setup()
- _key, cert_str = self.manager._generate_x509_cert('test1', 'testproj')
- logging.debug(cert_str)
-
- # Need to verify that it's signed by the right intermediate CA
- full_chain = crypto.fetch_ca(project_id='testproj', chain=True)
- int_cert = crypto.fetch_ca(project_id='testproj', chain=False)
- cloud_cert = crypto.fetch_ca()
- logging.debug("CA chain:\n\n =====\n%s\n\n=====" % full_chain)
- signed_cert = X509.load_cert_string(cert_str)
- chain_cert = X509.load_cert_string(full_chain)
- int_cert = X509.load_cert_string(int_cert)
- cloud_cert = X509.load_cert_string(cloud_cert)
- self.assertTrue(signed_cert.verify(chain_cert.get_pubkey()))
- self.assertTrue(signed_cert.verify(int_cert.get_pubkey()))
-
- if not FLAGS.use_intermediate_ca:
- self.assertTrue(signed_cert.verify(cloud_cert.get_pubkey()))
- else:
- self.assertFalse(signed_cert.verify(cloud_cert.get_pubkey()))
-
- def test_210_can_add_project_role(self):
- project = self.manager.get_project('testproj')
- self.assertFalse(project.has_role('test1', 'sysadmin'))
- self.manager.add_role('test1', 'sysadmin')
- self.assertFalse(project.has_role('test1', 'sysadmin'))
- project.add_role('test1', 'sysadmin')
- self.assertTrue(project.has_role('test1', 'sysadmin'))
-
- def test_211_can_list_project_roles(self):
- project = self.manager.get_project('testproj')
- user = self.manager.get_user('test1')
- self.manager.add_role(user, 'netadmin', project)
- roles = self.manager.get_user_roles(user)
- self.assertTrue('sysadmin' in roles)
- self.assertFalse('netadmin' in roles)
- project_roles = self.manager.get_user_roles(user, project)
- self.assertTrue('sysadmin' in project_roles)
- self.assertTrue('netadmin' in project_roles)
- # has role should be false because global role is missing
- self.assertFalse(self.manager.has_role(user, 'netadmin', project))
-
-
- def test_212_can_remove_project_role(self):
- project = self.manager.get_project('testproj')
- self.assertTrue(project.has_role('test1', 'sysadmin'))
- project.remove_role('test1', 'sysadmin')
- self.assertFalse(project.has_role('test1', 'sysadmin'))
- self.manager.remove_role('test1', 'sysadmin')
- self.assertFalse(project.has_role('test1', 'sysadmin'))
-
- def test_214_can_retrieve_project_by_user(self):
- project = self.manager.create_project('testproj2', 'test2', 'Another test project', ['test2'])
- self.assert_(len(self.manager.get_projects()) > 1)
- self.assertEqual(len(self.manager.get_projects('test2')), 1)
-
- def test_299_can_delete_project(self):
- self.manager.delete_project('testproj')
- self.assertFalse(filter(lambda p: p.name == 'testproj', self.manager.get_projects()))
- self.manager.delete_project('testproj2')
-
- def test_999_can_delete_users(self):
+ def test_can_list_users(self):
+ with user_generator(self.manager):
+ with user_generator(self.manager, name="test2"):
+ users = self.manager.get_users()
+ self.assert_(filter(lambda u: u.id == 'test1', users))
+ self.assert_(filter(lambda u: u.id == 'test2', users))
+ self.assert_(not filter(lambda u: u.id == 'test3', users))
+
+ def test_can_add_and_remove_user_role(self):
+ with user_generator(self.manager):
+ self.assertFalse(self.manager.has_role('test1', 'itsec'))
+ self.manager.add_role('test1', 'itsec')
+ self.assertTrue(self.manager.has_role('test1', 'itsec'))
+ self.manager.remove_role('test1', 'itsec')
+ self.assertFalse(self.manager.has_role('test1', 'itsec'))
+
+ def test_can_create_and_get_project(self):
+ with user_and_project_generator(self.manager) as (u,p):
+ self.assert_(self.manager.get_user('test1'))
+ self.assert_(self.manager.get_user('test1'))
+ self.assert_(self.manager.get_project('testproj'))
+
+ def test_can_list_projects(self):
+ with user_and_project_generator(self.manager):
+ with project_generator(self.manager, name="testproj2"):
+ projects = self.manager.get_projects()
+ self.assert_(filter(lambda p: p.name == 'testproj', projects))
+ self.assert_(filter(lambda p: p.name == 'testproj2', projects))
+ self.assert_(not filter(lambda p: p.name == 'testproj3',
+ projects))
+
+ def test_can_create_and_get_project_with_attributes(self):
+ with user_generator(self.manager):
+ with project_generator(self.manager, description='A test project'):
+ project = self.manager.get_project('testproj')
+ self.assertEqual('A test project', project.description)
+
+ def test_can_create_project_with_manager(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.assertEqual('test1', project.project_manager_id)
+ self.assertTrue(self.manager.is_project_manager(user, project))
+
+ def test_create_project_assigns_manager_to_members(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.assertTrue(self.manager.is_project_member(user, project))
+
+ def test_no_extra_project_members(self):
+ with user_generator(self.manager, name='test2') as baduser:
+ with user_and_project_generator(self.manager) as (user, project):
+ self.assertFalse(self.manager.is_project_member(baduser,
+ project))
+
+ def test_no_extra_project_managers(self):
+ with user_generator(self.manager, name='test2') as baduser:
+ with user_and_project_generator(self.manager) as (user, project):
+ self.assertFalse(self.manager.is_project_manager(baduser,
+ project))
+
+ def test_can_add_user_to_project(self):
+ with user_generator(self.manager, name='test2') as user:
+ with user_and_project_generator(self.manager) as (_user, project):
+ self.manager.add_to_project(user, project)
+ project = self.manager.get_project('testproj')
+ self.assertTrue(self.manager.is_project_member(user, project))
+
+ def test_can_remove_user_from_project(self):
+ with user_generator(self.manager, name='test2') as user:
+ with user_and_project_generator(self.manager) as (_user, project):
+ self.manager.add_to_project(user, project)
+ project = self.manager.get_project('testproj')
+ self.assertTrue(self.manager.is_project_member(user, project))
+ self.manager.remove_from_project(user, project)
+ project = self.manager.get_project('testproj')
+ self.assertFalse(self.manager.is_project_member(user, project))
+
+ def test_can_add_remove_user_with_role(self):
+ with user_generator(self.manager, name='test2') as user:
+ with user_and_project_generator(self.manager) as (_user, project):
+ # NOTE(todd): after modifying users you must reload project
+ self.manager.add_to_project(user, project)
+ project = self.manager.get_project('testproj')
+ self.manager.add_role(user, 'developer', project)
+ self.assertTrue(self.manager.is_project_member(user, project))
+ self.manager.remove_from_project(user, project)
+ project = self.manager.get_project('testproj')
+ self.assertFalse(self.manager.has_role(user, 'developer',
+ project))
+ self.assertFalse(self.manager.is_project_member(user, project))
+
+ def test_can_generate_x509(self):
+ # NOTE(todd): this doesn't assert against the auth manager
+ # so it probably belongs in crypto_unittest
+ # but I'm leaving it where I found it.
+ with user_and_project_generator(self.manager) as (user, project):
+ # NOTE(todd): Should mention why we must setup controller first
+ # (somebody please clue me in)
+ cloud_controller = cloud.CloudController()
+ cloud_controller.setup()
+ _key, cert_str = self.manager._generate_x509_cert('test1',
+ 'testproj')
+ logging.debug(cert_str)
+
+ # Need to verify that it's signed by the right intermediate CA
+ full_chain = crypto.fetch_ca(project_id='testproj', chain=True)
+ int_cert = crypto.fetch_ca(project_id='testproj', chain=False)
+ cloud_cert = crypto.fetch_ca()
+ logging.debug("CA chain:\n\n =====\n%s\n\n=====" % full_chain)
+ signed_cert = X509.load_cert_string(cert_str)
+ chain_cert = X509.load_cert_string(full_chain)
+ int_cert = X509.load_cert_string(int_cert)
+ cloud_cert = X509.load_cert_string(cloud_cert)
+ self.assertTrue(signed_cert.verify(chain_cert.get_pubkey()))
+ self.assertTrue(signed_cert.verify(int_cert.get_pubkey()))
+ if not FLAGS.use_intermediate_ca:
+ self.assertTrue(signed_cert.verify(cloud_cert.get_pubkey()))
+ else:
+ self.assertFalse(signed_cert.verify(cloud_cert.get_pubkey()))
+
+ def test_adding_role_to_project_is_ignored_unless_added_to_user(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
+ self.manager.add_role(user, 'sysadmin', project)
+ # NOTE(todd): it will still show up in get_user_roles(u, project)
+ self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
+ self.manager.add_role(user, 'sysadmin')
+ self.assertTrue(self.manager.has_role(user, 'sysadmin', project))
+
+ def test_add_user_role_doesnt_infect_project_roles(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
+ self.manager.add_role(user, 'sysadmin')
+ self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
+
+ def test_can_list_user_roles(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.manager.add_role(user, 'sysadmin')
+ roles = self.manager.get_user_roles(user)
+ self.assertTrue('sysadmin' in roles)
+ self.assertFalse('netadmin' in roles)
+
+ def test_can_list_project_roles(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.manager.add_role(user, 'sysadmin')
+ self.manager.add_role(user, 'sysadmin', project)
+ self.manager.add_role(user, 'netadmin', project)
+ project_roles = self.manager.get_user_roles(user, project)
+ self.assertTrue('sysadmin' in project_roles)
+ self.assertTrue('netadmin' in project_roles)
+ # has role should be false user-level role is missing
+ self.assertFalse(self.manager.has_role(user, 'netadmin', project))
+
+ def test_can_remove_user_roles(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.manager.add_role(user, 'sysadmin')
+ self.assertTrue(self.manager.has_role(user, 'sysadmin'))
+ self.manager.remove_role(user, 'sysadmin')
+ self.assertFalse(self.manager.has_role(user, 'sysadmin'))
+
+ def test_removing_user_role_hides_it_from_project(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.manager.add_role(user, 'sysadmin')
+ self.manager.add_role(user, 'sysadmin', project)
+ self.assertTrue(self.manager.has_role(user, 'sysadmin', project))
+ self.manager.remove_role(user, 'sysadmin')
+ self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
+
+ def test_can_remove_project_role_but_keep_user_role(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.manager.add_role(user, 'sysadmin')
+ self.manager.add_role(user, 'sysadmin', project)
+ self.assertTrue(self.manager.has_role(user, 'sysadmin'))
+ self.manager.remove_role(user, 'sysadmin', project)
+ self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
+ self.assertTrue(self.manager.has_role(user, 'sysadmin'))
+
+ def test_can_retrieve_project_by_user(self):
+ with user_and_project_generator(self.manager) as (user, project):
+ self.assertEqual(1, len(self.manager.get_projects('test1')))
+
+ def test_can_modify_project(self):
+ with user_and_project_generator(self.manager):
+ with user_generator(self.manager, name='test2'):
+ self.manager.modify_project('testproj', 'test2', 'new desc')
+ project = self.manager.get_project('testproj')
+ self.assertEqual('test2', project.project_manager_id)
+ self.assertEqual('new desc', project.description)
+
+ def test_can_delete_project(self):
+ with user_generator(self.manager):
+ self.manager.create_project('testproj', 'test1')
+ self.assert_(self.manager.get_project('testproj'))
+ self.manager.delete_project('testproj')
+ projectlist = self.manager.get_projects()
+ self.assert_(not filter(lambda p: p.name == 'testproj',
+ projectlist))
+
+ def test_can_delete_user(self):
+ self.manager.create_user('test1')
+ self.assert_(self.manager.get_user('test1'))
self.manager.delete_user('test1')
- users = self.manager.get_users()
- self.assertFalse(filter(lambda u: u.id == 'test1', users))
- self.manager.delete_user('test2')
- self.assertEqual(self.manager.get_user('test2'), None)
+ userlist = self.manager.get_users()
+ self.assert_(not filter(lambda u: u.id == 'test1', userlist))
+
+ def test_can_modify_users(self):
+ with user_generator(self.manager):
+ self.manager.modify_user('test1', 'access', 'secret', True)
+ user = self.manager.get_user('test1')
+ self.assertEqual('access', user.access)
+ self.assertEqual('secret', user.secret)
+ self.assertTrue(user.is_admin())
+
+class AuthManagerLdapTestCase(AuthManagerTestCase, test.TrialTestCase):
+ auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver'
+
+class AuthManagerDbTestCase(AuthManagerTestCase, test.TrialTestCase):
+ auth_driver = 'nova.auth.dbdriver.DbDriver'
if __name__ == "__main__":
diff --git a/nova/tests/cloud_unittest.py b/nova/tests/cloud_unittest.py
index c36d5a34f..ae7dea1db 100644
--- a/nova/tests/cloud_unittest.py
+++ b/nova/tests/cloud_unittest.py
@@ -16,31 +16,45 @@
# License for the specific language governing permissions and limitations
# under the License.
+import json
import logging
+from M2Crypto import BIO
+from M2Crypto import RSA
+import os
import StringIO
+import tempfile
import time
-from tornado import ioloop
+
from twisted.internet import defer
import unittest
from xml.etree import ElementTree
+from nova import crypto
+from nova import db
from nova import flags
from nova import rpc
from nova import test
from nova import utils
from nova.auth import manager
from nova.compute import power_state
-from nova.endpoint import api
-from nova.endpoint import cloud
+from nova.api.ec2 import context
+from nova.api.ec2 import cloud
+from nova.objectstore import image
FLAGS = flags.FLAGS
-class CloudTestCase(test.BaseTestCase):
+# Temp dirs for working with image attributes through the cloud controller
+# (stole this from objectstore_unittest.py)
+OSS_TEMPDIR = tempfile.mkdtemp(prefix='test_oss-')
+IMAGES_PATH = os.path.join(OSS_TEMPDIR, 'images')
+os.makedirs(IMAGES_PATH)
+
+class CloudTestCase(test.TrialTestCase):
def setUp(self):
super(CloudTestCase, self).setUp()
- self.flags(connection_type='fake')
+ self.flags(connection_type='fake', images_path=IMAGES_PATH)
self.conn = rpc.Connection.instance()
logging.getLogger().setLevel(logging.DEBUG)
@@ -51,20 +65,24 @@ class CloudTestCase(test.BaseTestCase):
# set up a service
self.compute = utils.import_class(FLAGS.compute_manager)
self.compute_consumer = rpc.AdapterConsumer(connection=self.conn,
- topic=FLAGS.compute_topic,
- proxy=self.compute)
- self.injected.append(self.compute_consumer.attach_to_tornado(self.ioloop))
+ topic=FLAGS.compute_topic,
+ proxy=self.compute)
+ self.compute_consumer.attach_to_twisted()
- try:
- manager.AuthManager().create_user('admin', 'admin', 'admin')
- except: pass
- admin = manager.AuthManager().get_user('admin')
- project = manager.AuthManager().create_project('proj', 'admin', 'proj')
- self.context = api.APIRequestContext(handler=None,project=project,user=admin)
+ self.manager = manager.AuthManager()
+ self.user = self.manager.create_user('admin', 'admin', 'admin', True)
+ self.project = self.manager.create_project('proj', 'admin', 'proj')
+ self.context = context.APIRequestContext(user=self.user,
+ project=self.project)
def tearDown(self):
- manager.AuthManager().delete_project('proj')
- manager.AuthManager().delete_user('admin')
+ self.manager.delete_project(self.project)
+ self.manager.delete_user(self.user)
+ super(CloudTestCase, self).tearDown()
+
+ def _create_key(self, name):
+ # NOTE(vish): create depends on pool, so just call helper directly
+ return cloud._gen_key(self.context, self.context.user.id, name)
def test_console_output(self):
if FLAGS.connection_type == 'fake':
@@ -77,6 +95,33 @@ class CloudTestCase(test.BaseTestCase):
self.assert_(output)
rv = yield self.compute.terminate_instance(instance_id)
+
+ def test_key_generation(self):
+ result = self._create_key('test')
+ private_key = result['private_key']
+ key = RSA.load_key_string(private_key, callback=lambda: None)
+ bio = BIO.MemoryBuffer()
+ public_key = db.key_pair_get(self.context,
+ self.context.user.id,
+ 'test')['public_key']
+ key.save_pub_key_bio(bio)
+ converted = crypto.ssl_pub_to_ssh_pub(bio.read())
+ # assert key fields are equal
+ self.assertEqual(public_key.split(" ")[1].strip(),
+ converted.split(" ")[1].strip())
+
+ def test_describe_key_pairs(self):
+ self._create_key('test1')
+ self._create_key('test2')
+ result = self.cloud.describe_key_pairs(self.context)
+ keys = result["keypairsSet"]
+ self.assertTrue(filter(lambda k: k['keyName'] == 'test1', keys))
+ self.assertTrue(filter(lambda k: k['keyName'] == 'test2', keys))
+
+ def test_delete_key_pair(self):
+ self._create_key('test')
+ self.cloud.delete_key_pair(self.context, 'test')
+
def test_run_instances(self):
if FLAGS.connection_type == 'fake':
logging.debug("Can't test instances without a real virtual env.")
@@ -156,3 +201,67 @@ class CloudTestCase(test.BaseTestCase):
#for i in xrange(4):
# data = self.cloud.get_metadata(instance(i)['private_dns_name'])
# self.assert_(data['meta-data']['ami-id'] == 'ami-%s' % i)
+
+ @staticmethod
+ def _fake_set_image_description(ctxt, image_id, description):
+ from nova.objectstore import handler
+ class req:
+ pass
+ request = req()
+ request.context = ctxt
+ request.args = {'image_id': [image_id],
+ 'description': [description]}
+
+ resource = handler.ImagesResource()
+ resource.render_POST(request)
+
+ def test_user_editable_image_endpoint(self):
+ pathdir = os.path.join(FLAGS.images_path, 'ami-testing')
+ os.mkdir(pathdir)
+ info = {'isPublic': False}
+ with open(os.path.join(pathdir, 'info.json'), 'w') as f:
+ json.dump(info, f)
+ img = image.Image('ami-testing')
+ # self.cloud.set_image_description(self.context, 'ami-testing',
+ # 'Foo Img')
+ # NOTE(vish): Above won't work unless we start objectstore or create
+ # a fake version of api/ec2/images.py conn that can
+ # call methods directly instead of going through boto.
+ # for now, just cheat and call the method directly
+ self._fake_set_image_description(self.context, 'ami-testing',
+ 'Foo Img')
+ self.assertEqual('Foo Img', img.metadata['description'])
+ self._fake_set_image_description(self.context, 'ami-testing', '')
+ self.assertEqual('', img.metadata['description'])
+
+ def test_update_of_instance_display_fields(self):
+ inst = db.instance_create({}, {})
+ self.cloud.update_instance(self.context, inst['ec2_id'],
+ display_name='c00l 1m4g3')
+ inst = db.instance_get({}, inst['id'])
+ self.assertEqual('c00l 1m4g3', inst['display_name'])
+ db.instance_destroy({}, inst['id'])
+
+ def test_update_of_instance_wont_update_private_fields(self):
+ inst = db.instance_create({}, {})
+ self.cloud.update_instance(self.context, inst['id'],
+ mac_address='DE:AD:BE:EF')
+ inst = db.instance_get({}, inst['id'])
+ self.assertEqual(None, inst['mac_address'])
+ db.instance_destroy({}, inst['id'])
+
+ def test_update_of_volume_display_fields(self):
+ vol = db.volume_create({}, {})
+ self.cloud.update_volume(self.context, vol['id'],
+ display_name='c00l v0lum3')
+ vol = db.volume_get({}, vol['id'])
+ self.assertEqual('c00l v0lum3', vol['display_name'])
+ db.volume_destroy({}, vol['id'])
+
+ def test_update_of_volume_wont_update_private_fields(self):
+ vol = db.volume_create({}, {})
+ self.cloud.update_volume(self.context, vol['id'],
+ mountpoint='/not/here')
+ vol = db.volume_get({}, vol['id'])
+ self.assertEqual(None, vol['mountpoint'])
+ db.volume_destroy({}, vol['id'])
diff --git a/nova/tests/compute_unittest.py b/nova/tests/compute_unittest.py
index f5c0f1c09..1e2bb113b 100644
--- a/nova/tests/compute_unittest.py
+++ b/nova/tests/compute_unittest.py
@@ -30,7 +30,7 @@ from nova import flags
from nova import test
from nova import utils
from nova.auth import manager
-
+from nova.api import context
FLAGS = flags.FLAGS
@@ -96,7 +96,9 @@ class ComputeTestCase(test.TrialTestCase):
self.assertEqual(instance_ref['deleted_at'], None)
terminate = datetime.datetime.utcnow()
yield self.compute.terminate_instance(self.context, instance_id)
- instance_ref = db.instance_get({'deleted': True}, instance_id)
+ self.context = context.get_admin_context(user=self.user,
+ read_deleted=True)
+ instance_ref = db.instance_get(self.context, instance_id)
self.assert_(instance_ref['launched_at'] < terminate)
self.assert_(instance_ref['deleted_at'] > terminate)
diff --git a/nova/tests/fake_flags.py b/nova/tests/fake_flags.py
index 8f4754650..4bbef8832 100644
--- a/nova/tests/fake_flags.py
+++ b/nova/tests/fake_flags.py
@@ -24,7 +24,7 @@ flags.DECLARE('volume_driver', 'nova.volume.manager')
FLAGS.volume_driver = 'nova.volume.driver.FakeAOEDriver'
FLAGS.connection_type = 'fake'
FLAGS.fake_rabbit = True
-FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver'
+FLAGS.auth_driver = 'nova.auth.dbdriver.DbDriver'
flags.DECLARE('network_size', 'nova.network.manager')
flags.DECLARE('num_networks', 'nova.network.manager')
flags.DECLARE('fake_network', 'nova.network.manager')
diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py
index dc5277f02..59b0a36e4 100644
--- a/nova/tests/network_unittest.py
+++ b/nova/tests/network_unittest.py
@@ -28,7 +28,7 @@ from nova import flags
from nova import test
from nova import utils
from nova.auth import manager
-from nova.endpoint import api
+from nova.api.ec2 import context
FLAGS = flags.FLAGS
@@ -49,19 +49,19 @@ class NetworkTestCase(test.TrialTestCase):
self.user = self.manager.create_user('netuser', 'netuser', 'netuser')
self.projects = []
self.network = utils.import_object(FLAGS.network_manager)
- self.context = api.APIRequestContext(None, project=None, user=self.user)
+ self.context = context.APIRequestContext(project=None, user=self.user)
for i in range(5):
name = 'project%s' % i
self.projects.append(self.manager.create_project(name,
'netuser',
name))
# create the necessary network data for the project
- self.network.set_network_host(self.context, self.projects[i].id)
- instance_ref = db.instance_create(None,
- {'mac_address': utils.generate_mac()})
+ user_context = context.get_admin_context(user=self.user)
+
+ self.network.set_network_host(user_context, self.projects[i].id)
+ instance_ref = self._create_instance(0)
self.instance_id = instance_ref['id']
- instance_ref = db.instance_create(None,
- {'mac_address': utils.generate_mac()})
+ instance_ref = self._create_instance(1)
self.instance2_id = instance_ref['id']
def tearDown(self): # pylint: disable-msg=C0103
@@ -74,6 +74,15 @@ class NetworkTestCase(test.TrialTestCase):
self.manager.delete_project(project)
self.manager.delete_user(self.user)
+ def _create_instance(self, project_num, mac=None):
+ if not mac:
+ mac = utils.generate_mac()
+ project = self.projects[project_num]
+ self.context.project = project
+ return db.instance_create(self.context,
+ {'project_id': project.id,
+ 'mac_address': mac})
+
def _create_address(self, project_num, instance_id=None):
"""Create an address in given project num"""
if instance_id is None:
@@ -81,9 +90,15 @@ class NetworkTestCase(test.TrialTestCase):
self.context.project = self.projects[project_num]
return self.network.allocate_fixed_ip(self.context, instance_id)
+ def _deallocate_address(self, project_num, address):
+ self.context.project = self.projects[project_num]
+ self.network.deallocate_fixed_ip(self.context, address)
+
+
def test_public_network_association(self):
"""Makes sure that we can allocaate a public ip"""
# TODO(vish): better way of adding floating ips
+ self.context.project = self.projects[0]
pubnet = IPy.IP(flags.FLAGS.public_range)
address = str(pubnet[0])
try:
@@ -109,7 +124,7 @@ class NetworkTestCase(test.TrialTestCase):
address = self._create_address(0)
self.assertTrue(is_allocated_in_project(address, self.projects[0].id))
lease_ip(address)
- self.network.deallocate_fixed_ip(self.context, address)
+ self._deallocate_address(0, address)
# Doesn't go away until it's dhcp released
self.assertTrue(is_allocated_in_project(address, self.projects[0].id))
@@ -130,14 +145,14 @@ class NetworkTestCase(test.TrialTestCase):
lease_ip(address)
lease_ip(address2)
- self.network.deallocate_fixed_ip(self.context, address)
+ self._deallocate_address(0, address)
release_ip(address)
self.assertFalse(is_allocated_in_project(address, self.projects[0].id))
# First address release shouldn't affect the second
self.assertTrue(is_allocated_in_project(address2, self.projects[1].id))
- self.network.deallocate_fixed_ip(self.context, address2)
+ self._deallocate_address(1, address2)
release_ip(address2)
self.assertFalse(is_allocated_in_project(address2,
self.projects[1].id))
@@ -148,24 +163,19 @@ class NetworkTestCase(test.TrialTestCase):
lease_ip(first)
instance_ids = []
for i in range(1, 5):
- mac = utils.generate_mac()
- instance_ref = db.instance_create(None,
- {'mac_address': mac})
+ instance_ref = self._create_instance(i, mac=utils.generate_mac())
instance_ids.append(instance_ref['id'])
address = self._create_address(i, instance_ref['id'])
- mac = utils.generate_mac()
- instance_ref = db.instance_create(None,
- {'mac_address': mac})
+ instance_ref = self._create_instance(i, mac=utils.generate_mac())
instance_ids.append(instance_ref['id'])
address2 = self._create_address(i, instance_ref['id'])
- mac = utils.generate_mac()
- instance_ref = db.instance_create(None,
- {'mac_address': mac})
+ instance_ref = self._create_instance(i, mac=utils.generate_mac())
instance_ids.append(instance_ref['id'])
address3 = self._create_address(i, instance_ref['id'])
lease_ip(address)
lease_ip(address2)
lease_ip(address3)
+ self.context.project = self.projects[i]
self.assertFalse(is_allocated_in_project(address,
self.projects[0].id))
self.assertFalse(is_allocated_in_project(address2,
@@ -181,7 +191,7 @@ class NetworkTestCase(test.TrialTestCase):
for instance_id in instance_ids:
db.instance_destroy(None, instance_id)
release_ip(first)
- self.network.deallocate_fixed_ip(self.context, first)
+ self._deallocate_address(0, first)
def test_vpn_ip_and_port_looks_valid(self):
"""Ensure the vpn ip and port are reasonable"""
@@ -242,9 +252,7 @@ class NetworkTestCase(test.TrialTestCase):
addresses = []
instance_ids = []
for i in range(num_available_ips):
- mac = utils.generate_mac()
- instance_ref = db.instance_create(None,
- {'mac_address': mac})
+ instance_ref = self._create_instance(0)
instance_ids.append(instance_ref['id'])
address = self._create_address(0, instance_ref['id'])
addresses.append(address)
diff --git a/nova/tests/objectstore_unittest.py b/nova/tests/objectstore_unittest.py
index dece4b5d5..5a599ff3a 100644
--- a/nova/tests/objectstore_unittest.py
+++ b/nova/tests/objectstore_unittest.py
@@ -53,7 +53,7 @@ os.makedirs(os.path.join(OSS_TEMPDIR, 'images'))
os.makedirs(os.path.join(OSS_TEMPDIR, 'buckets'))
-class ObjectStoreTestCase(test.BaseTestCase):
+class ObjectStoreTestCase(test.TrialTestCase):
"""Test objectstore API directly."""
def setUp(self): # pylint: disable-msg=C0103
@@ -164,6 +164,12 @@ class ObjectStoreTestCase(test.BaseTestCase):
self.context.project = self.auth_manager.get_project('proj2')
self.assertFalse(my_img.is_authorized(self.context))
+ # change user-editable fields
+ my_img.update_user_editable_fields({'display_name': 'my cool image'})
+ self.assertEqual('my cool image', my_img.metadata['displayName'])
+ my_img.update_user_editable_fields({'display_name': ''})
+ self.assert_(not my_img.metadata['displayName'])
+
class TestHTTPChannel(http.HTTPChannel):
"""Dummy site required for twisted.web"""
diff --git a/nova/tests/quota_unittest.py b/nova/tests/quota_unittest.py
index cab9f663d..370ccd506 100644
--- a/nova/tests/quota_unittest.py
+++ b/nova/tests/quota_unittest.py
@@ -25,8 +25,8 @@ from nova import quota
from nova import test
from nova import utils
from nova.auth import manager
-from nova.endpoint import cloud
-from nova.endpoint import api
+from nova.api.ec2 import cloud
+from nova.api.ec2 import context
FLAGS = flags.FLAGS
@@ -48,9 +48,8 @@ class QuotaTestCase(test.TrialTestCase):
self.user = self.manager.create_user('admin', 'admin', 'admin', True)
self.project = self.manager.create_project('admin', 'admin', 'admin')
self.network = utils.import_object(FLAGS.network_manager)
- self.context = api.APIRequestContext(handler=None,
- project=self.project,
- user=self.user)
+ self.context = context.APIRequestContext(project=self.project,
+ user=self.user)
def tearDown(self): # pylint: disable-msg=C0103
manager.AuthManager().delete_project(self.project)
@@ -95,11 +94,11 @@ class QuotaTestCase(test.TrialTestCase):
for i in range(FLAGS.quota_instances):
instance_id = self._create_instance()
instance_ids.append(instance_id)
- self.assertFailure(self.cloud.run_instances(self.context,
- min_count=1,
- max_count=1,
- instance_type='m1.small'),
- cloud.QuotaError)
+ self.assertRaises(cloud.QuotaError, self.cloud.run_instances,
+ self.context,
+ min_count=1,
+ max_count=1,
+ instance_type='m1.small')
for instance_id in instance_ids:
db.instance_destroy(self.context, instance_id)
@@ -107,11 +106,11 @@ class QuotaTestCase(test.TrialTestCase):
instance_ids = []
instance_id = self._create_instance(cores=4)
instance_ids.append(instance_id)
- self.assertFailure(self.cloud.run_instances(self.context,
- min_count=1,
- max_count=1,
- instance_type='m1.small'),
- cloud.QuotaError)
+ self.assertRaises(cloud.QuotaError, self.cloud.run_instances,
+ self.context,
+ min_count=1,
+ max_count=1,
+ instance_type='m1.small')
for instance_id in instance_ids:
db.instance_destroy(self.context, instance_id)
@@ -120,10 +119,9 @@ class QuotaTestCase(test.TrialTestCase):
for i in range(FLAGS.quota_volumes):
volume_id = self._create_volume()
volume_ids.append(volume_id)
- self.assertRaises(cloud.QuotaError,
- self.cloud.create_volume,
- self.context,
- size=10)
+ self.assertRaises(cloud.QuotaError, self.cloud.create_volume,
+ self.context,
+ size=10)
for volume_id in volume_ids:
db.volume_destroy(self.context, volume_id)
@@ -151,5 +149,4 @@ class QuotaTestCase(test.TrialTestCase):
# make an rpc.call, the test just finishes with OK. It
# appears to be something in the magic inline callbacks
# that is breaking.
- self.assertFailure(self.cloud.allocate_address(self.context),
- cloud.QuotaError)
+ self.assertRaises(cloud.QuotaError, self.cloud.allocate_address, self.context)
diff --git a/nova/tests/rpc_unittest.py b/nova/tests/rpc_unittest.py
index e12a28fbc..9652841f2 100644
--- a/nova/tests/rpc_unittest.py
+++ b/nova/tests/rpc_unittest.py
@@ -30,7 +30,7 @@ from nova import test
FLAGS = flags.FLAGS
-class RpcTestCase(test.BaseTestCase):
+class RpcTestCase(test.TrialTestCase):
"""Test cases for rpc"""
def setUp(self): # pylint: disable-msg=C0103
super(RpcTestCase, self).setUp()
@@ -39,14 +39,13 @@ class RpcTestCase(test.BaseTestCase):
self.consumer = rpc.AdapterConsumer(connection=self.conn,
topic='test',
proxy=self.receiver)
-
- self.injected.append(self.consumer.attach_to_tornado(self.ioloop))
+ self.consumer.attach_to_twisted()
def test_call_succeed(self):
"""Get a value through rpc call"""
value = 42
- result = yield rpc.call('test', {"method": "echo",
- "args": {"value": value}})
+ result = yield rpc.call_twisted('test', {"method": "echo",
+ "args": {"value": value}})
self.assertEqual(value, result)
def test_call_exception(self):
@@ -57,12 +56,12 @@ class RpcTestCase(test.BaseTestCase):
to an int in the test.
"""
value = 42
- self.assertFailure(rpc.call('test', {"method": "fail",
- "args": {"value": value}}),
+ self.assertFailure(rpc.call_twisted('test', {"method": "fail",
+ "args": {"value": value}}),
rpc.RemoteError)
try:
- yield rpc.call('test', {"method": "fail",
- "args": {"value": value}})
+ yield rpc.call_twisted('test', {"method": "fail",
+ "args": {"value": value}})
self.fail("should have thrown rpc.RemoteError")
except rpc.RemoteError as exc:
self.assertEqual(int(exc.value), value)
diff --git a/nova/tests/scheduler_unittest.py b/nova/tests/scheduler_unittest.py
index fde30f81e..53a8be144 100644
--- a/nova/tests/scheduler_unittest.py
+++ b/nova/tests/scheduler_unittest.py
@@ -117,10 +117,12 @@ class SimpleDriverTestCase(test.TrialTestCase):
'nova-compute',
'compute',
FLAGS.compute_manager)
+ compute1.startService()
compute2 = service.Service('host2',
'nova-compute',
'compute',
FLAGS.compute_manager)
+ compute2.startService()
hosts = self.scheduler.driver.hosts_up(self.context, 'compute')
self.assertEqual(len(hosts), 2)
compute1.kill()
@@ -132,10 +134,12 @@ class SimpleDriverTestCase(test.TrialTestCase):
'nova-compute',
'compute',
FLAGS.compute_manager)
+ compute1.startService()
compute2 = service.Service('host2',
'nova-compute',
'compute',
FLAGS.compute_manager)
+ compute2.startService()
instance_id1 = self._create_instance()
compute1.run_instance(self.context, instance_id1)
instance_id2 = self._create_instance()
@@ -153,10 +157,12 @@ class SimpleDriverTestCase(test.TrialTestCase):
'nova-compute',
'compute',
FLAGS.compute_manager)
+ compute1.startService()
compute2 = service.Service('host2',
'nova-compute',
'compute',
FLAGS.compute_manager)
+ compute2.startService()
instance_ids1 = []
instance_ids2 = []
for index in xrange(FLAGS.max_cores):
@@ -184,10 +190,12 @@ class SimpleDriverTestCase(test.TrialTestCase):
'nova-volume',
'volume',
FLAGS.volume_manager)
+ volume1.startService()
volume2 = service.Service('host2',
'nova-volume',
'volume',
FLAGS.volume_manager)
+ volume2.startService()
volume_id1 = self._create_volume()
volume1.create_volume(self.context, volume_id1)
volume_id2 = self._create_volume()
@@ -205,10 +213,12 @@ class SimpleDriverTestCase(test.TrialTestCase):
'nova-volume',
'volume',
FLAGS.volume_manager)
+ volume1.startService()
volume2 = service.Service('host2',
'nova-volume',
'volume',
FLAGS.volume_manager)
+ volume2.startService()
volume_ids1 = []
volume_ids2 = []
for index in xrange(FLAGS.max_gigabytes):
diff --git a/nova/tests/service_unittest.py b/nova/tests/service_unittest.py
index 01da0eb8a..6afeec377 100644
--- a/nova/tests/service_unittest.py
+++ b/nova/tests/service_unittest.py
@@ -22,6 +22,8 @@ Unit Tests for remote procedure calls using queue
import mox
+from twisted.application.app import startApplication
+
from nova import exception
from nova import flags
from nova import rpc
@@ -65,15 +67,20 @@ class ServiceTestCase(test.BaseTestCase):
proxy=mox.IsA(service.Service)).AndReturn(
rpc.AdapterConsumer)
+ rpc.AdapterConsumer.attach_to_twisted()
+ rpc.AdapterConsumer.attach_to_twisted()
+
# Stub out looping call a bit needlessly since we don't have an easy
# way to cancel it (yet) when the tests finishes
service.task.LoopingCall(mox.IgnoreArg()).AndReturn(
service.task.LoopingCall)
service.task.LoopingCall.start(interval=mox.IgnoreArg(),
now=mox.IgnoreArg())
+ service.task.LoopingCall(mox.IgnoreArg()).AndReturn(
+ service.task.LoopingCall)
+ service.task.LoopingCall.start(interval=mox.IgnoreArg(),
+ now=mox.IgnoreArg())
- rpc.AdapterConsumer.attach_to_twisted()
- rpc.AdapterConsumer.attach_to_twisted()
service_create = {'host': host,
'binary': binary,
'topic': topic,
@@ -91,6 +98,7 @@ class ServiceTestCase(test.BaseTestCase):
self.mox.ReplayAll()
app = service.Service.create(host=host, binary=binary)
+ startApplication(app, False)
self.assert_(app)
# We're testing sort of weird behavior in how report_state decides
diff --git a/nova/utils.py b/nova/utils.py
index 8939043e6..d18dd9843 100644
--- a/nova/utils.py
+++ b/nova/utils.py
@@ -39,17 +39,6 @@ from nova.exception import ProcessExecutionError
FLAGS = flags.FLAGS
TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"
-class ProcessExecutionError(IOError):
- def __init__( self, stdout=None, stderr=None, exit_code=None, cmd=None,
- description=None):
- if description is None:
- description = "Unexpected error while running command."
- if exit_code is None:
- exit_code = '-'
- message = "%s\nCommand: %s\nExit code: %s\nStdout: %r\nStderr: %r" % (
- description, cmd, exit_code, stdout, stderr)
- IOError.__init__(self, message)
-
def import_class(import_str):
"""Returns a class from a string including module and class"""
mod_str, _sep, class_str = import_str.rpartition('.')
diff --git a/nova/virt/xenapi.py b/nova/virt/xenapi.py
index 5fdd2b9fc..04e830b64 100644
--- a/nova/virt/xenapi.py
+++ b/nova/virt/xenapi.py
@@ -42,10 +42,12 @@ from twisted.internet import defer
from twisted.internet import reactor
from twisted.internet import task
+from nova import db
from nova import flags
from nova import process
from nova import utils
from nova.auth.manager import AuthManager
+from nova.compute import instance_types
from nova.compute import power_state
from nova.virt import images
@@ -103,8 +105,8 @@ class XenAPIConnection(object):
self._conn.login_with_password(user, pw)
def list_instances(self):
- result = [self._conn.xenapi.VM.get_name_label(vm) \
- for vm in self._conn.xenapi.VM.get_all()]
+ return [self._conn.xenapi.VM.get_name_label(vm) \
+ for vm in self._conn.xenapi.VM.get_all()]
@defer.inlineCallbacks
def spawn(self, instance):
@@ -113,32 +115,24 @@ class XenAPIConnection(object):
raise Exception('Attempted to create non-unique name %s' %
instance.name)
- if 'bridge_name' in instance.datamodel:
- network_ref = \
- yield self._find_network_with_bridge(
- instance.datamodel['bridge_name'])
- else:
- network_ref = None
-
- if 'mac_address' in instance.datamodel:
- mac_address = instance.datamodel['mac_address']
- else:
- mac_address = ''
+ network = db.project_get_network(None, instance.project_id)
+ network_ref = \
+ yield self._find_network_with_bridge(network.bridge)
- user = AuthManager().get_user(instance.datamodel['user_id'])
- project = AuthManager().get_project(instance.datamodel['project_id'])
+ user = AuthManager().get_user(instance.user_id)
+ project = AuthManager().get_project(instance.project_id)
vdi_uuid = yield self._fetch_image(
- instance.datamodel['image_id'], user, project, True)
+ instance.image_id, user, project, True)
kernel = yield self._fetch_image(
- instance.datamodel['kernel_id'], user, project, False)
+ instance.kernel_id, user, project, False)
ramdisk = yield self._fetch_image(
- instance.datamodel['ramdisk_id'], user, project, False)
+ instance.ramdisk_id, user, project, False)
vdi_ref = yield self._call_xenapi('VDI.get_by_uuid', vdi_uuid)
vm_ref = yield self._create_vm(instance, kernel, ramdisk)
yield self._create_vbd(vm_ref, vdi_ref, 0, True)
if network_ref:
- yield self._create_vif(vm_ref, network_ref, mac_address)
+ yield self._create_vif(vm_ref, network_ref, instance.mac_address)
logging.debug('Starting VM %s...', vm_ref)
yield self._call_xenapi('VM.start', vm_ref, False, False)
logging.info('Spawning VM %s created %s.', instance.name, vm_ref)
@@ -148,8 +142,9 @@ class XenAPIConnection(object):
"""Create a VM record. Returns a Deferred that gives the new
VM reference."""
- mem = str(long(instance.datamodel['memory_kb']) * 1024)
- vcpus = str(instance.datamodel['vcpus'])
+ instance_type = instance_types.INSTANCE_TYPES[instance.instance_type]
+ mem = str(long(instance_type['memory_mb']) * 1024 * 1024)
+ vcpus = str(instance_type['vcpus'])
rec = {
'name_label': instance.name,
'name_description': '',
diff --git a/nova/volume/manager.py b/nova/volume/manager.py
index 034763512..8508f27b2 100644
--- a/nova/volume/manager.py
+++ b/nova/volume/manager.py
@@ -77,7 +77,7 @@ class AOEManager(manager.Manager):
size = volume_ref['size']
logging.debug("volume %s: creating lv of size %sG", volume_id, size)
- yield self.driver.create_volume(volume_ref['str_id'], size)
+ yield self.driver.create_volume(volume_ref['ec2_id'], size)
logging.debug("volume %s: allocating shelf & blade", volume_id)
self._ensure_blades(context)
@@ -87,7 +87,7 @@ class AOEManager(manager.Manager):
logging.debug("volume %s: exporting shelf %s & blade %s", volume_id,
shelf_id, blade_id)
- yield self.driver.create_export(volume_ref['str_id'],
+ yield self.driver.create_export(volume_ref['ec2_id'],
shelf_id,
blade_id)
@@ -111,10 +111,10 @@ class AOEManager(manager.Manager):
raise exception.Error("Volume is not local to this node")
shelf_id, blade_id = self.db.volume_get_shelf_and_blade(context,
volume_id)
- yield self.driver.remove_export(volume_ref['str_id'],
+ yield self.driver.remove_export(volume_ref['ec2_id'],
shelf_id,
blade_id)
- yield self.driver.delete_volume(volume_ref['str_id'])
+ yield self.driver.delete_volume(volume_ref['ec2_id'])
self.db.volume_destroy(context, volume_id)
defer.returnValue(True)
@@ -125,7 +125,7 @@ class AOEManager(manager.Manager):
Returns path to device.
"""
volume_ref = self.db.volume_get(context, volume_id)
- yield self.driver.discover_volume(volume_ref['str_id'])
+ yield self.driver.discover_volume(volume_ref['ec2_id'])
shelf_id, blade_id = self.db.volume_get_shelf_and_blade(context,
volume_id)
defer.returnValue("/dev/etherd/e%s.%s" % (shelf_id, blade_id))
diff --git a/nova/wsgi.py b/nova/wsgi.py
index 8a4e2a9f4..b91d91121 100644
--- a/nova/wsgi.py
+++ b/nova/wsgi.py
@@ -21,14 +21,17 @@
Utility methods for working with WSGI servers
"""
+import json
import logging
import sys
+from xml.dom import minidom
import eventlet
import eventlet.wsgi
eventlet.patcher.monkey_patch(all=False, socket=True)
import routes
import routes.middleware
+import webob
import webob.dec
import webob.exc
@@ -227,10 +230,19 @@ class Controller(object):
serializer = Serializer(request.environ, _metadata)
return serializer.to_content_type(data)
+ def _deserialize(self, data, request):
+ """
+ Deserialize the request body to the response type requested in request.
+ Uses self._serialization_metadata if it exists, which is a dict mapping
+ MIME types to information needed to serialize to that type.
+ """
+ _metadata = getattr(type(self), "_serialization_metadata", {})
+ serializer = Serializer(request.environ, _metadata)
+ return serializer.deserialize(data)
class Serializer(object):
"""
- Serializes a dictionary to a Content Type specified by a WSGI environment.
+ Serializes and deserializes dictionaries to certain MIME types.
"""
def __init__(self, environ, metadata=None):
@@ -239,31 +251,77 @@ class Serializer(object):
'metadata' is an optional dict mapping MIME types to information
needed to serialize a dictionary to that type.
"""
- self.environ = environ
self.metadata = metadata or {}
- self._methods = {
- 'application/json': self._to_json,
- 'application/xml': self._to_xml}
+ req = webob.Request(environ)
+ suffix = req.path_info.split('.')[-1].lower()
+ if suffix == 'json':
+ self.handler = self._to_json
+ elif suffix == 'xml':
+ self.handler = self._to_xml
+ elif 'application/json' in req.accept:
+ self.handler = self._to_json
+ elif 'application/xml' in req.accept:
+ self.handler = self._to_xml
+ else:
+ self.handler = self._to_json # default
def to_content_type(self, data):
"""
- Serialize a dictionary into a string. The format of the string
- will be decided based on the Content Type requested in self.environ:
- by Accept: header, or by URL suffix.
+ Serialize a dictionary into a string.
+
+ The format of the string will be decided based on the Content Type
+ requested in self.environ: by Accept: header, or by URL suffix.
"""
- mimetype = 'application/xml'
- # TODO(gundlach): determine mimetype from request
- return self._methods.get(mimetype, repr)(data)
+ return self.handler(data)
+
+ def deserialize(self, datastring):
+ """
+ Deserialize a string to a dictionary.
+
+ The string must be in the format of a supported MIME type.
+ """
+ datastring = datastring.strip()
+ try:
+ is_xml = (datastring[0] == '<')
+ if not is_xml:
+ return json.loads(datastring)
+ return self._from_xml(datastring)
+ except:
+ return None
+
+ def _from_xml(self, datastring):
+ xmldata = self.metadata.get('application/xml', {})
+ plurals = set(xmldata.get('plurals', {}))
+ node = minidom.parseString(datastring).childNodes[0]
+ return {node.nodeName: self._from_xml_node(node, plurals)}
+
+ def _from_xml_node(self, node, listnames):
+ """
+ Convert a minidom node to a simple Python type.
+
+ listnames is a collection of names of XML nodes whose subnodes should
+ be considered list items.
+ """
+ if len(node.childNodes) == 1 and node.childNodes[0].nodeType == 3:
+ return node.childNodes[0].nodeValue
+ elif node.nodeName in listnames:
+ return [self._from_xml_node(n, listnames) for n in node.childNodes]
+ else:
+ result = dict()
+ for attr in node.attributes.keys():
+ result[attr] = node.attributes[attr].nodeValue
+ for child in node.childNodes:
+ if child.nodeType != node.TEXT_NODE:
+ result[child.nodeName] = self._from_xml_node(child, listnames)
+ return result
def _to_json(self, data):
- import json
return json.dumps(data)
def _to_xml(self, data):
metadata = self.metadata.get('application/xml', {})
# We expect data to contain a single key which is the XML root.
root_key = data.keys()[0]
- from xml.dom import minidom
doc = minidom.Document()
node = self._to_xml_node(doc, metadata, root_key, data[root_key])
return node.toprettyxml(indent=' ')
diff --git a/pylintrc b/pylintrc
index 6702ca895..024802835 100644
--- a/pylintrc
+++ b/pylintrc
@@ -1,7 +1,8 @@
[Messages Control]
# W0511: TODOs in code comments are fine.
# W0142: *args and **kwargs are fine.
-disable-msg=W0511,W0142
+# W0622: Redefining id is fine.
+disable-msg=W0511,W0142,W0622
[Basic]
# Variable names can be 1 to 31 characters long, with lowercase and underscores
diff --git a/setup.py b/setup.py
index 1767b00f4..d420d3559 100644
--- a/setup.py
+++ b/setup.py
@@ -54,5 +54,5 @@ setup(name='nova',
'bin/nova-manage',
'bin/nova-network',
'bin/nova-objectstore',
- 'bin/nova-api-new',
+ 'bin/nova-scheduler',
'bin/nova-volume'])
diff --git a/tools/install_venv.py b/tools/install_venv.py
index 5d2369a96..32c372352 100644
--- a/tools/install_venv.py
+++ b/tools/install_venv.py
@@ -88,6 +88,10 @@ def create_virtualenv(venv=VENV):
def install_dependencies(venv=VENV):
print 'Installing dependencies with pip (this can take a while)...'
+ # Install greenlet by hand - just listing it in the requires file does not
+ # get it in stalled in the right order
+ run_command(['tools/with_venv.sh', 'pip', 'install', '-E', venv, 'greenlet'],
+ redirect_output=False)
run_command(['tools/with_venv.sh', 'pip', 'install', '-E', venv, '-r', PIP_REQUIRES],
redirect_output=False)
run_command(['tools/with_venv.sh', 'pip', 'install', '-E', venv, TWISTED_NOVA],
diff --git a/tools/pip-requires b/tools/pip-requires
index dd69708ce..1e2707be7 100644
--- a/tools/pip-requires
+++ b/tools/pip-requires
@@ -7,15 +7,16 @@ amqplib==0.6.1
anyjson==0.2.4
boto==2.0b1
carrot==0.10.5
-eventlet==0.9.10
+eventlet==0.9.12
lockfile==0.8
python-daemon==1.5.5
python-gflags==1.3
redis==2.0.0
routes==1.12.3
tornado==1.0
-webob==0.9.8
+WebOb==0.9.8
wsgiref==0.1.2
zope.interface==3.6.1
mox==0.5.0
-f http://pymox.googlecode.com/files/mox-0.5.0.tar.gz
+greenlet==0.3.1
diff --git a/tools/setup_iptables.sh b/tools/setup_iptables.sh
new file mode 100755
index 000000000..673353eb4
--- /dev/null
+++ b/tools/setup_iptables.sh
@@ -0,0 +1,158 @@
+#!/usr/bin/env bash
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+# NOTE(vish): This script sets up some reasonable defaults for iptables and
+# creates nova-specific chains. If you use this script you should
+# run nova-network and nova-compute with --use_nova_chains=True
+
+# NOTE(vish): If you run nova-api on a different port, make sure to change
+# the port here
+API_PORT=${API_PORT:-"8773"}
+if [ -n "$1" ]; then
+ CMD=$1
+else
+ CMD="all"
+fi
+
+if [ -n "$2" ]; then
+ IP=$2
+else
+ # NOTE(vish): This will just get the first ip in the list, so if you
+ # have more than one eth device set up, this will fail, and
+ # you should explicitly pass in the ip of the instance
+ IP=`ifconfig | grep -m 1 'inet addr:'| cut -d: -f2 | awk '{print $1}'`
+fi
+
+if [ -n "$3" ]; then
+ PRIVATE_RANGE=$3
+else
+ PRIVATE_RANGE="10.0.0.0/12"
+fi
+
+
+if [ -n "$4" ]; then
+ # NOTE(vish): Management IP is the ip over which to allow ssh traffic. It
+ # will also allow traffic to nova-api
+ MGMT_IP=$4
+else
+ MGMT_IP="$IP"
+fi
+if [ "$CMD" == "clear" ]; then
+ iptables -P INPUT ACCEPT
+ iptables -P FORWARD ACCEPT
+ iptables -P OUTPUT ACCEPT
+ iptables -F
+ iptables -t nat -F
+ iptables -F nova_input
+ iptables -F nova_output
+ iptables -F nova_forward
+ iptables -t nat -F nova_input
+ iptables -t nat -F nova_output
+ iptables -t nat -F nova_forward
+ iptables -t nat -X
+ iptables -X
+fi
+
+if [ "$CMD" == "base" ] || [ "$CMD" == "all" ]; then
+ iptables -P INPUT DROP
+ iptables -A INPUT -m state --state INVALID -j DROP
+ iptables -A INPUT -m state --state RELATED,ESTABLISHED -j ACCEPT
+ iptables -A INPUT -m tcp -p tcp -d $MGMT_IP --dport 22 -j ACCEPT
+ iptables -A INPUT -m udp -p udp --dport 123 -j ACCEPT
+ iptables -N nova_input
+ iptables -A INPUT -j nova_input
+ iptables -A INPUT -p icmp -j ACCEPT
+ iptables -A INPUT -p tcp -j REJECT --reject-with tcp-reset
+ iptables -A INPUT -j REJECT --reject-with icmp-port-unreachable
+
+ iptables -P FORWARD DROP
+ iptables -A FORWARD -m state --state INVALID -j DROP
+ iptables -A FORWARD -m state --state RELATED,ESTABLISHED -j ACCEPT
+ iptables -A FORWARD -p tcp -m tcp --tcp-flags SYN,RST SYN -j TCPMSS --clamp-mss-to-pmtu
+ iptables -N nova_forward
+ iptables -A FORWARD -j nova_forward
+
+ # NOTE(vish): DROP on output is too restrictive for now. We need to add
+ # in a bunch of more specific output rules to use it.
+ # iptables -P OUTPUT DROP
+ iptables -A OUTPUT -m state --state INVALID -j DROP
+ iptables -A OUTPUT -m state --state RELATED,ESTABLISHED -j ACCEPT
+ iptables -N nova_output
+ iptables -A OUTPUT -j nova_output
+
+ iptables -t nat -N nova_prerouting
+ iptables -t nat -A PREROUTING -j nova_prerouting
+
+ iptables -t nat -N nova_postrouting
+ iptables -t nat -A POSTROUTING -j nova_postrouting
+
+ iptables -t nat -N nova_output
+ iptables -t nat -A OUTPUT -j nova_output
+fi
+
+if [ "$CMD" == "ganglia" ] || [ "$CMD" == "all" ]; then
+ iptables -A nova_input -m tcp -p tcp -d $IP --dport 8649 -j ACCEPT
+ iptables -A nova_input -m udp -p udp -d $IP --dport 8649 -j ACCEPT
+fi
+
+if [ "$CMD" == "web" ] || [ "$CMD" == "all" ]; then
+ # NOTE(vish): This opens up ports for web access, allowing web-based
+ # dashboards to work.
+ iptables -A nova_input -m tcp -p tcp -d $IP --dport 80 -j ACCEPT
+ iptables -A nova_input -m tcp -p tcp -d $IP --dport 443 -j ACCEPT
+fi
+
+if [ "$CMD" == "objectstore" ] || [ "$CMD" == "all" ]; then
+ iptables -A nova_input -m tcp -p tcp -d $IP --dport 3333 -j ACCEPT
+fi
+
+if [ "$CMD" == "api" ] || [ "$CMD" == "all" ]; then
+ iptables -A nova_input -m tcp -p tcp -d $IP --dport $API_PORT -j ACCEPT
+ if [ "$IP" != "$MGMT_IP" ]; then
+ iptables -A nova_input -m tcp -p tcp -d $MGMT_IP --dport $API_PORT -j ACCEPT
+ fi
+fi
+
+if [ "$CMD" == "redis" ] || [ "$CMD" == "all" ]; then
+ iptables -A nova_input -m tcp -p tcp -d $IP --dport 6379 -j ACCEPT
+fi
+
+if [ "$CMD" == "mysql" ] || [ "$CMD" == "all" ]; then
+ iptables -A nova_input -m tcp -p tcp -d $IP --dport 3306 -j ACCEPT
+fi
+
+if [ "$CMD" == "rabbitmq" ] || [ "$CMD" == "all" ]; then
+ iptables -A nova_input -m tcp -p tcp -d $IP --dport 4369 -j ACCEPT
+ iptables -A nova_input -m tcp -p tcp -d $IP --dport 5672 -j ACCEPT
+ iptables -A nova_input -m tcp -p tcp -d $IP --dport 53284 -j ACCEPT
+fi
+
+if [ "$CMD" == "dnsmasq" ] || [ "$CMD" == "all" ]; then
+ # NOTE(vish): this could theoretically be setup per network
+ # for each host, but it seems like overkill
+ iptables -A nova_input -m tcp -p tcp -s $PRIVATE_RANGE --dport 53 -j ACCEPT
+ iptables -A nova_input -m udp -p udp -s $PRIVATE_RANGE --dport 53 -j ACCEPT
+ iptables -A nova_input -m udp -p udp --dport 67 -j ACCEPT
+fi
+
+if [ "$CMD" == "ldap" ] || [ "$CMD" == "all" ]; then
+ iptables -A nova_input -m tcp -p tcp -d $IP --dport 389 -j ACCEPT
+fi
+
+