diff options
85 files changed, 5310 insertions, 1977 deletions
diff --git a/bin/nova-api b/bin/nova-api index ede09d38c..a5027700b 100755 --- a/bin/nova-api +++ b/bin/nova-api @@ -1,31 +1,28 @@ #!/usr/bin/env python +# pylint: disable-msg=C0103 # vim: tabstop=4 shiftwidth=4 softtabstop=4 # Copyright 2010 United States Government as represented by the # Administrator of the National Aeronautics and Space Administration. # All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ -Tornado daemon for the main API endpoint. +Nova API daemon. """ -import logging import os import sys -from tornado import httpserver -from tornado import ioloop # If ../nova/__init__.py exists, add ../ to Python search path, so that # it will override what happens to be installed in /usr/(local/)lib/python... @@ -36,28 +33,16 @@ if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')): sys.path.insert(0, possible_topdir) from nova import flags -from nova import server from nova import utils -from nova.endpoint import admin -from nova.endpoint import api -from nova.endpoint import cloud +from nova import server FLAGS = flags.FLAGS +flags.DEFINE_integer('api_port', 8773, 'API port') - -def main(_argv): - """Load the controllers and start the tornado I/O loop.""" - controllers = { - 'Cloud': cloud.CloudController(), - 'Admin': admin.AdminController()} - _app = api.APIServerApplication(controllers) - - io_inst = ioloop.IOLoop.instance() - http_server = httpserver.HTTPServer(_app) - http_server.listen(FLAGS.cc_port) - logging.debug('Started HTTP server on %s', FLAGS.cc_port) - io_inst.start() - +def main(_args): + from nova import api + from nova import wsgi + wsgi.run_server(api.API(), FLAGS.api_port) if __name__ == '__main__': utils.default_flagfile() diff --git a/bin/nova-api-new b/bin/nova-api-new deleted file mode 100755 index 8625c487f..000000000 --- a/bin/nova-api-new +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env python -# pylint: disable-msg=C0103 -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 United States Government as represented by the -# Administrator of the National Aeronautics and Space Administration. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Nova API daemon. -""" - -import os -import sys - -# If ../nova/__init__.py exists, add ../ to Python search path, so that -# it will override what happens to be installed in /usr/(local/)lib/python... -possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]), - os.pardir, - os.pardir)) -if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')): - sys.path.insert(0, possible_topdir) - -from nova import api -from nova import flags -from nova import utils -from nova import wsgi - -FLAGS = flags.FLAGS -flags.DEFINE_integer('api_port', 8773, 'API port') - -if __name__ == '__main__': - utils.default_flagfile() - wsgi.run_server(api.API(), FLAGS.api_port) diff --git a/bin/nova-manage b/bin/nova-manage index 325245ac4..5b72c170f 100755 --- a/bin/nova-manage +++ b/bin/nova-manage @@ -50,9 +50,9 @@ """ CLI interface for nova management. - Connects to the running ADMIN api in the api daemon. """ +import logging import os import sys import time @@ -68,11 +68,12 @@ if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')): sys.path.insert(0, possible_topdir) from nova import db +from nova import exception from nova import flags +from nova import quota from nova import utils from nova.auth import manager from nova.cloudpipe import pipelib -from nova.endpoint import cloud FLAGS = flags.FLAGS @@ -83,7 +84,7 @@ class VpnCommands(object): def __init__(self): self.manager = manager.AuthManager() - self.pipe = pipelib.CloudPipe(cloud.CloudController()) + self.pipe = pipelib.CloudPipe() def list(self): """Print a listing of the VPNs for all projects.""" @@ -114,7 +115,7 @@ class VpnCommands(object): def _vpn_for(self, project_id): """Get the VPN instance for a project ID.""" - for instance in db.instance_get_all(): + for instance in db.instance_get_all(None): if (instance['image_id'] == FLAGS.vpn_image_id and not instance['state_description'] in ['shutting_down', 'shutdown'] @@ -135,15 +136,48 @@ class VpnCommands(object): class ShellCommands(object): - def run(self): - "Runs a Python interactive interpreter. Tries to use IPython, if it's available." - try: - import IPython - # Explicitly pass an empty list as arguments, because otherwise IPython - # would use sys.argv from this script. - shell = IPython.Shell.IPShell(argv=[]) - shell.mainloop() - except ImportError: + def bpython(self): + """Runs a bpython shell. + + Falls back to Ipython/python shell if unavailable""" + self.run('bpython') + + def ipython(self): + """Runs an Ipython shell. + + Falls back to Python shell if unavailable""" + self.run('ipython') + + def python(self): + """Runs a python shell. + + Falls back to Python shell if unavailable""" + self.run('python') + + def run(self, shell=None): + """Runs a Python interactive interpreter. + + args: [shell=bpython]""" + if not shell: + shell = 'bpython' + + if shell == 'bpython': + try: + import bpython + bpython.embed() + except ImportError: + shell = 'ipython' + if shell == 'ipython': + try: + import IPython + # Explicitly pass an empty list as arguments, because otherwise IPython + # would use sys.argv from this script. + shell = IPython.Shell.IPShell(argv=[]) + shell.mainloop() + except ImportError: + shell = 'python' + + if shell == 'python': import code try: # Try activating rlcompleter, because it's handy. import readline @@ -156,6 +190,11 @@ class ShellCommands(object): readline.parse_and_bind("tab:complete") code.interact() + def script(self, path): + """Runs the script from the specifed path with flags set properly. + arguments: path""" + exec(compile(open(path).read(), path, 'exec'), locals(), globals()) + class RoleCommands(object): """Class for managing roles.""" @@ -186,6 +225,13 @@ class RoleCommands(object): class UserCommands(object): """Class for managing users.""" + @staticmethod + def _print_export(user): + """Print export variables to use with API.""" + print 'export EC2_ACCESS_KEY=%s' % user.access + print 'export EC2_SECRET_KEY=%s' % user.secret + + def __init__(self): self.manager = manager.AuthManager() @@ -193,13 +239,13 @@ class UserCommands(object): """creates a new admin and prints exports arguments: name [access] [secret]""" user = self.manager.create_user(name, access, secret, True) - print_export(user) + self._print_export(user) def create(self, name, access=None, secret=None): """creates a new user and prints exports arguments: name [access] [secret]""" user = self.manager.create_user(name, access, secret, False) - print_export(user) + self._print_export(user) def delete(self, name): """deletes an existing user @@ -211,7 +257,7 @@ class UserCommands(object): arguments: name""" user = self.manager.get_user(name) if user: - print_export(user) + self._print_export(user) else: print "User %s doesn't exist" % name @@ -221,12 +267,18 @@ class UserCommands(object): for user in self.manager.get_users(): print user.name - -def print_export(user): - """Print export variables to use with API.""" - print 'export EC2_ACCESS_KEY=%s' % user.access - print 'export EC2_SECRET_KEY=%s' % user.secret - + def modify(self, name, access_key, secret_key, is_admin): + """update a users keys & admin flag + arguments: accesskey secretkey admin + leave any field blank to ignore it, admin should be 'T', 'F', or blank + """ + if not is_admin: + is_admin = None + elif is_admin.upper()[0] == 'T': + is_admin = True + else: + is_admin = False + self.manager.modify_user(name, access_key, secret_key, is_admin) class ProjectCommands(object): """Class for managing projects.""" @@ -252,7 +304,7 @@ class ProjectCommands(object): def environment(self, project_id, user_id, filename='novarc'): """Exports environment variables to an sourcable file arguments: project_id user_id [filename='novarc]""" - rc = self.manager.get_environment_rc(project_id, user_id) + rc = self.manager.get_environment_rc(user_id, project_id) with open(filename, 'w') as f: f.write(rc) @@ -262,6 +314,19 @@ class ProjectCommands(object): for project in self.manager.get_projects(): print project.name + def quota(self, project_id, key=None, value=None): + """Set or display quotas for project + arguments: project_id [key] [value]""" + if key: + quo = {'project_id': project_id, key: value} + try: + db.quota_update(None, project_id, quo) + except exception.NotFound: + db.quota_create(None, quo) + project_quota = quota.get_quota(None, project_id) + for key, value in project_quota.iteritems(): + print '%s: %s' % (key, value) + def remove(self, project, user): """Removes user from project arguments: project user""" @@ -274,6 +339,7 @@ class ProjectCommands(object): with open(filename, 'w') as f: f.write(zip_file) + class FloatingIpCommands(object): """Class for managing floating ip.""" @@ -301,11 +367,12 @@ class FloatingIpCommands(object): for floating_ip in floating_ips: instance = None if floating_ip['fixed_ip']: - instance = floating_ip['fixed_ip']['instance']['str_id'] + instance = floating_ip['fixed_ip']['instance']['ec2_id'] print "%s\t%s\t%s" % (floating_ip['host'], floating_ip['address'], instance) + CATEGORIES = [ ('user', UserCommands), ('project', ProjectCommands), @@ -351,6 +418,10 @@ def main(): """Parse options and call the appropriate class/method.""" utils.default_flagfile('/etc/nova/nova-manage.conf') argv = FLAGS(sys.argv) + + if FLAGS.verbose: + logging.getLogger().setLevel(logging.DEBUG) + script_name = argv.pop(0) if len(argv) < 1: print script_name + " category action [<args>]" diff --git a/doc/source/auth.rst b/doc/source/auth.rst index 70aca704a..3fcb309cd 100644 --- a/doc/source/auth.rst +++ b/doc/source/auth.rst @@ -172,14 +172,6 @@ Further Challenges -The :mod:`rbac` Module --------------------------- - -.. automodule:: nova.auth.rbac - :members: - :undoc-members: - :show-inheritance: - The :mod:`signer` Module ------------------------ diff --git a/nova/adminclient.py b/nova/adminclient.py index 0ca32b1e5..fc9fcfde0 100644 --- a/nova/adminclient.py +++ b/nova/adminclient.py @@ -20,11 +20,17 @@ Nova User API client library. """ import base64 - import boto +import httplib from boto.ec2.regioninfo import RegionInfo +DEFAULT_CLC_URL='http://127.0.0.1:8773' +DEFAULT_REGION='nova' +DEFAULT_ACCESS_KEY='admin' +DEFAULT_SECRET_KEY='admin' + + class UserInfo(object): """ Information about a Nova user, as parsed through SAX @@ -68,13 +74,13 @@ class UserRole(object): def __init__(self, connection=None): self.connection = connection self.role = None - + def __repr__(self): return 'UserRole:%s' % self.role def startElement(self, name, attrs, connection): return None - + def endElement(self, name, value, connection): if name == 'role': self.role = value @@ -128,20 +134,20 @@ class ProjectMember(object): def __init__(self, connection=None): self.connection = connection self.memberId = None - + def __repr__(self): return 'ProjectMember:%s' % self.memberId def startElement(self, name, attrs, connection): return None - + def endElement(self, name, value, connection): if name == 'member': self.memberId = value else: setattr(self, name, str(value)) - + class HostInfo(object): """ Information about a Nova Host, as parsed through SAX: @@ -171,35 +177,56 @@ class HostInfo(object): class NovaAdminClient(object): - def __init__(self, clc_ip='127.0.0.1', region='nova', access_key='admin', - secret_key='admin', **kwargs): - self.clc_ip = clc_ip + def __init__(self, clc_url=DEFAULT_CLC_URL, region=DEFAULT_REGION, + access_key=DEFAULT_ACCESS_KEY, secret_key=DEFAULT_SECRET_KEY, + **kwargs): + parts = self.split_clc_url(clc_url) + + self.clc_url = clc_url self.region = region self.access = access_key self.secret = secret_key self.apiconn = boto.connect_ec2(aws_access_key_id=access_key, aws_secret_access_key=secret_key, - is_secure=False, - region=RegionInfo(None, region, clc_ip), - port=8773, + is_secure=parts['is_secure'], + region=RegionInfo(None, + region, + parts['ip']), + port=parts['port'], path='/services/Admin', **kwargs) self.apiconn.APIVersion = 'nova' - def connection_for(self, username, project, **kwargs): + def connection_for(self, username, project, clc_url=None, region=None, + **kwargs): """ Returns a boto ec2 connection for the given username. """ + if not clc_url: + clc_url = self.clc_url + if not region: + region = self.region + parts = self.split_clc_url(clc_url) user = self.get_user(username) access_key = '%s:%s' % (user.accesskey, project) - return boto.connect_ec2( - aws_access_key_id=access_key, - aws_secret_access_key=user.secretkey, - is_secure=False, - region=RegionInfo(None, self.region, self.clc_ip), - port=8773, - path='/services/Cloud', - **kwargs) + return boto.connect_ec2(aws_access_key_id=access_key, + aws_secret_access_key=user.secretkey, + is_secure=parts['is_secure'], + region=RegionInfo(None, + self.region, + parts['ip']), + port=parts['port'], + path='/services/Cloud', + **kwargs) + + def split_clc_url(self, clc_url): + """ + Splits a cloud controller endpoint url. + """ + parts = httplib.urlsplit(clc_url) + is_secure = parts.scheme == 'https' + ip, port = parts.netloc.split(':') + return {'ip': ip, 'port': int(port), 'is_secure': is_secure} def get_users(self): """ grabs the list of all users """ @@ -289,7 +316,7 @@ class NovaAdminClient(object): if project.projectname != None: return project - + def create_project(self, projectname, manager_user, description=None, member_users=None): """ @@ -322,7 +349,7 @@ class NovaAdminClient(object): Adds a user to a project. """ return self.modify_project_member(user, project, operation='add') - + def remove_project_member(self, user, project): """ Removes a user from a project. diff --git a/nova/api/__init__.py b/nova/api/__init__.py index b9b9e3988..744abd621 100644 --- a/nova/api/__init__.py +++ b/nova/api/__init__.py @@ -21,17 +21,91 @@ Root WSGI middleware for all API controllers. """ import routes +import webob.dec +from nova import flags from nova import wsgi +from nova.api import cloudpipe from nova.api import ec2 from nova.api import rackspace +from nova.api.ec2 import metadatarequesthandler + + +flags.DEFINE_string('rsapi_subdomain', 'rs', + 'subdomain running the RS API') +flags.DEFINE_string('ec2api_subdomain', 'ec2', + 'subdomain running the EC2 API') +flags.DEFINE_string('FAKE_subdomain', None, + 'set to rs or ec2 to fake the subdomain of the host for testing') +FLAGS = flags.FLAGS class API(wsgi.Router): """Routes top-level requests to the appropriate controller.""" def __init__(self): + rsdomain = {'sub_domain': [FLAGS.rsapi_subdomain]} + ec2domain = {'sub_domain': [FLAGS.ec2api_subdomain]} + # If someone wants to pretend they're hitting the RS subdomain + # on their local box, they can set FAKE_subdomain to 'rs', which + # removes subdomain restrictions from the RS routes below. + if FLAGS.FAKE_subdomain == 'rs': + rsdomain = {} + elif FLAGS.FAKE_subdomain == 'ec2': + ec2domain = {} mapper = routes.Mapper() - mapper.connect("/v1.0/{path_info:.*}", controller=rackspace.API()) - mapper.connect("/ec2/{path_info:.*}", controller=ec2.API()) + mapper.sub_domains = True + mapper.connect("/", controller=self.rsapi_versions, + conditions=rsdomain) + mapper.connect("/v1.0/{path_info:.*}", controller=rackspace.API(), + conditions=rsdomain) + + mapper.connect("/", controller=self.ec2api_versions, + conditions=ec2domain) + mapper.connect("/services/{path_info:.*}", controller=ec2.API(), + conditions=ec2domain) + mapper.connect("/cloudpipe/{path_info:.*}", controller=cloudpipe.API()) + mrh = metadatarequesthandler.MetadataRequestHandler() + for s in ['/latest', + '/2009-04-04', + '/2008-09-01', + '/2008-02-01', + '/2007-12-15', + '/2007-10-10', + '/2007-08-29', + '/2007-03-01', + '/2007-01-19', + '/1.0']: + mapper.connect('%s/{path_info:.*}' % s, controller=mrh, + conditions=ec2domain) super(API, self).__init__(mapper) + + @webob.dec.wsgify + def rsapi_versions(self, req): + """Respond to a request for all OpenStack API versions.""" + response = { + "versions": [ + dict(status="CURRENT", id="v1.0")]} + metadata = { + "application/xml": { + "attributes": dict(version=["status", "id"])}} + return wsgi.Serializer(req.environ, metadata).to_content_type(response) + + @webob.dec.wsgify + def ec2api_versions(self, req): + """Respond to a request for all EC2 versions.""" + # available api versions + versions = [ + '1.0', + '2007-01-19', + '2007-03-01', + '2007-08-29', + '2007-10-10', + '2007-12-15', + '2008-02-01', + '2008-09-01', + '2009-04-04', + ] + return ''.join('%s\n' % v for v in versions) + + diff --git a/nova/api/cloud.py b/nova/api/cloud.py new file mode 100644 index 000000000..345677d4f --- /dev/null +++ b/nova/api/cloud.py @@ -0,0 +1,42 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Methods for API calls to control instances via AMQP. +""" + + +from nova import db +from nova import flags +from nova import rpc + +FLAGS = flags.FLAGS + + +def reboot(instance_id, context=None): + """Reboot the given instance. + + #TODO(gundlach) not actually sure what context is used for by ec2 here + -- I think we can just remove it and use None all the time. + """ + instance_ref = db.instance_get_by_ec2_id(None, instance_id) + host = instance_ref['host'] + rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host), + {"method": "reboot_instance", + "args": {"context": None, + "instance_id": instance_ref['id']}}) diff --git a/nova/api/cloudpipe/__init__.py b/nova/api/cloudpipe/__init__.py new file mode 100644 index 000000000..6d40990a8 --- /dev/null +++ b/nova/api/cloudpipe/__init__.py @@ -0,0 +1,69 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +REST API Request Handlers for CloudPipe +""" + +import logging +import urllib +import webob +import webob.dec +import webob.exc + +from nova import crypto +from nova import wsgi +from nova.auth import manager +from nova.api.ec2 import cloud + + +_log = logging.getLogger("api") +_log.setLevel(logging.DEBUG) + + +class API(wsgi.Application): + + def __init__(self): + self.controller = cloud.CloudController() + + @webob.dec.wsgify + def __call__(self, req): + if req.method == 'POST': + return self.sign_csr(req) + _log.debug("Cloudpipe path is %s" % req.path_info) + if req.path_info.endswith("/getca/"): + return self.send_root_ca(req) + return webob.exc.HTTPNotFound() + + def get_project_id_from_ip(self, ip): + # TODO(eday): This was removed with the ORM branch, fix! + instance = self.controller.get_instance_by_ip(ip) + return instance['project_id'] + + def send_root_ca(self, req): + _log.debug("Getting root ca") + project_id = self.get_project_id_from_ip(req.remote_addr) + res = webob.Response() + res.headers["Content-Type"] = "text/plain" + res.body = crypto.fetch_ca(project_id) + return res + + def sign_csr(self, req): + project_id = self.get_project_id_from_ip(req.remote_addr) + cert = self.str_params['cert'] + return crypto.sign_csr(urllib.unquote(cert), project_id) diff --git a/nova/api/context.py b/nova/api/context.py new file mode 100644 index 000000000..b66cfe468 --- /dev/null +++ b/nova/api/context.py @@ -0,0 +1,46 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +APIRequestContext +""" + +import random + + +class APIRequestContext(object): + def __init__(self, user, project): + self.user = user + self.project = project + self.request_id = ''.join( + [random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-') + for x in xrange(20)] + ) + if user: + self.is_admin = user.is_admin() + else: + self.is_admin = False + self.read_deleted = False + + +def get_admin_context(user=None, read_deleted=False): + context_ref = APIRequestContext(user=user, project=None) + context_ref.is_admin = True + context_ref.read_deleted = read_deleted + return context_ref + diff --git a/nova/api/ec2/__init__.py b/nova/api/ec2/__init__.py index 6eec0abf7..6b538a7f1 100644 --- a/nova/api/ec2/__init__.py +++ b/nova/api/ec2/__init__.py @@ -1,6 +1,7 @@ # vim: tabstop=4 shiftwidth=4 softtabstop=4 -# Copyright 2010 OpenStack LLC. +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -15,28 +16,225 @@ # License for the specific language governing permissions and limitations # under the License. -""" -WSGI middleware for EC2 API controllers. -""" +"""Starting point for routing EC2 requests""" +import logging import routes +import webob import webob.dec +import webob.exc +from nova import exception +from nova import flags from nova import wsgi +from nova.api import context +from nova.api.ec2 import apirequest +from nova.api.ec2 import admin +from nova.api.ec2 import cloud +from nova.auth import manager -class API(wsgi.Router): - """Routes EC2 requests to the appropriate controller.""" +FLAGS = flags.FLAGS +_log = logging.getLogger("api") +_log.setLevel(logging.DEBUG) + + +class API(wsgi.Middleware): + + """Routing for all EC2 API requests.""" def __init__(self): - mapper = routes.Mapper() - mapper.connect(None, "{all:.*}", controller=self.dummy) - super(API, self).__init__(mapper) + self.application = Authenticate(Router(Authorizer(Executor()))) + + +class Authenticate(wsgi.Middleware): + + """Authenticate an EC2 request and add 'ec2.context' to WSGI environ.""" + + @webob.dec.wsgify + def __call__(self, req): + # Read request signature and access id. + try: + signature = req.params['Signature'] + access = req.params['AWSAccessKeyId'] + except: + raise webob.exc.HTTPBadRequest() + + # Make a copy of args for authentication and signature verification. + auth_params = dict(req.params) + auth_params.pop('Signature') # not part of authentication args + + # Authenticate the request. + try: + (user, project) = manager.AuthManager().authenticate( + access, + signature, + auth_params, + req.method, + req.host, + req.path) + except exception.Error, ex: + logging.debug("Authentication Failure: %s" % ex) + raise webob.exc.HTTPForbidden() + + # Authenticated! + req.environ['ec2.context'] = context.APIRequestContext(user, project) + return self.application + + +class Router(wsgi.Middleware): + + """Add ec2.'controller', .'action', and .'action_args' to WSGI environ.""" + + def __init__(self, application): + super(Router, self).__init__(application) + self.map = routes.Mapper() + self.map.connect("/{controller_name}/") + self.controllers = dict(Cloud=cloud.CloudController(), + Admin=admin.AdminController()) - @staticmethod @webob.dec.wsgify - def dummy(req): - """Temporary dummy controller.""" - msg = "dummy response -- please hook up __init__() to cloud.py instead" - return repr({'dummy': msg, - 'kwargs': repr(req.environ['wsgiorg.routing_args'][1])}) + def __call__(self, req): + # Obtain the appropriate controller and action for this request. + try: + match = self.map.match(req.path_info) + controller_name = match['controller_name'] + controller = self.controllers[controller_name] + except: + raise webob.exc.HTTPNotFound() + non_args = ['Action', 'Signature', 'AWSAccessKeyId', 'SignatureMethod', + 'SignatureVersion', 'Version', 'Timestamp'] + args = dict(req.params) + try: + action = req.params['Action'] # raise KeyError if omitted + for non_arg in non_args: + args.pop(non_arg) # remove, but raise KeyError if omitted + except: + raise webob.exc.HTTPBadRequest() + + _log.debug('action: %s' % action) + for key, value in args.items(): + _log.debug('arg: %s\t\tval: %s' % (key, value)) + + # Success! + req.environ['ec2.controller'] = controller + req.environ['ec2.action'] = action + req.environ['ec2.action_args'] = args + return self.application + + +class Authorizer(wsgi.Middleware): + + """Authorize an EC2 API request. + + Return a 401 if ec2.controller and ec2.action in WSGI environ may not be + executed in ec2.context. + """ + + def __init__(self, application): + super(Authorizer, self).__init__(application) + self.action_roles = { + 'CloudController': { + 'DescribeAvailabilityzones': ['all'], + 'DescribeRegions': ['all'], + 'DescribeSnapshots': ['all'], + 'DescribeKeyPairs': ['all'], + 'CreateKeyPair': ['all'], + 'DeleteKeyPair': ['all'], + 'DescribeSecurityGroups': ['all'], + 'CreateSecurityGroup': ['netadmin'], + 'DeleteSecurityGroup': ['netadmin'], + 'GetConsoleOutput': ['projectmanager', 'sysadmin'], + 'DescribeVolumes': ['projectmanager', 'sysadmin'], + 'CreateVolume': ['projectmanager', 'sysadmin'], + 'AttachVolume': ['projectmanager', 'sysadmin'], + 'DetachVolume': ['projectmanager', 'sysadmin'], + 'DescribeInstances': ['all'], + 'DescribeAddresses': ['all'], + 'AllocateAddress': ['netadmin'], + 'ReleaseAddress': ['netadmin'], + 'AssociateAddress': ['netadmin'], + 'DisassociateAddress': ['netadmin'], + 'RunInstances': ['projectmanager', 'sysadmin'], + 'TerminateInstances': ['projectmanager', 'sysadmin'], + 'RebootInstances': ['projectmanager', 'sysadmin'], + 'UpdateInstance': ['projectmanager', 'sysadmin'], + 'DeleteVolume': ['projectmanager', 'sysadmin'], + 'DescribeImages': ['all'], + 'DeregisterImage': ['projectmanager', 'sysadmin'], + 'RegisterImage': ['projectmanager', 'sysadmin'], + 'DescribeImageAttribute': ['all'], + 'ModifyImageAttribute': ['projectmanager', 'sysadmin'], + 'UpdateImage': ['projectmanager', 'sysadmin'], + }, + 'AdminController': { + # All actions have the same permission: ['none'] (the default) + # superusers will be allowed to run them + # all others will get HTTPUnauthorized. + }, + } + + @webob.dec.wsgify + def __call__(self, req): + context = req.environ['ec2.context'] + controller_name = req.environ['ec2.controller'].__class__.__name__ + action = req.environ['ec2.action'] + allowed_roles = self.action_roles[controller_name].get(action, ['none']) + if self._matches_any_role(context, allowed_roles): + return self.application + else: + raise webob.exc.HTTPUnauthorized() + + def _matches_any_role(self, context, roles): + """Return True if any role in roles is allowed in context.""" + if context.user.is_superuser(): + return True + if 'all' in roles: + return True + if 'none' in roles: + return False + return any(context.project.has_role(context.user.id, role) + for role in roles) + + +class Executor(wsgi.Application): + + """Execute an EC2 API request. + + Executes 'ec2.action' upon 'ec2.controller', passing 'ec2.context' and + 'ec2.action_args' (all variables in WSGI environ.) Returns an XML + response, or a 400 upon failure. + """ + + @webob.dec.wsgify + def __call__(self, req): + context = req.environ['ec2.context'] + controller = req.environ['ec2.controller'] + action = req.environ['ec2.action'] + args = req.environ['ec2.action_args'] + + api_request = apirequest.APIRequest(controller, action) + try: + result = api_request.send(context, **args) + req.headers['Content-Type'] = 'text/xml' + return result + except exception.ApiError as ex: + + if ex.code: + return self._error(req, ex.code, ex.message) + else: + return self._error(req, type(ex).__name__, ex.message) + # TODO(vish): do something more useful with unknown exceptions + except Exception as ex: + return self._error(req, type(ex).__name__, str(ex)) + + def _error(self, req, code, message): + resp = webob.Response() + resp.status = 400 + resp.headers['Content-Type'] = 'text/xml' + resp.body = ('<?xml version="1.0"?>\n' + '<Response><Errors><Error><Code>%s</Code>' + '<Message>%s</Message></Error></Errors>' + '<RequestID>?</RequestID></Response>') % (code, message) + return resp + diff --git a/nova/endpoint/admin.py b/nova/api/ec2/admin.py index c6dcb5320..36feae451 100644 --- a/nova/endpoint/admin.py +++ b/nova/api/ec2/admin.py @@ -58,46 +58,27 @@ def host_dict(host): return {} -def admin_only(target): - """Decorator for admin-only API calls""" - def wrapper(*args, **kwargs): - """Internal wrapper method for admin-only API calls""" - context = args[1] - if context.user.is_admin(): - return target(*args, **kwargs) - else: - return {} - - return wrapper - - class AdminController(object): """ API Controller for users, hosts, nodes, and workers. - Trivial admin_only wrapper will be replaced with RBAC, - allowing project managers to administer project users. """ def __str__(self): return 'AdminController' - @admin_only def describe_user(self, _context, name, **_kwargs): """Returns user data, including access and secret keys.""" return user_dict(manager.AuthManager().get_user(name)) - @admin_only def describe_users(self, _context, **_kwargs): """Returns all users - should be changed to deal with a list.""" return {'userSet': [user_dict(u) for u in manager.AuthManager().get_users()] } - @admin_only def register_user(self, _context, name, **_kwargs): """Creates a new user, and returns generated credentials.""" return user_dict(manager.AuthManager().create_user(name)) - @admin_only def deregister_user(self, _context, name, **_kwargs): """Deletes a single user (NOT undoable.) Should throw an exception if the user has instances, @@ -107,13 +88,11 @@ class AdminController(object): return True - @admin_only def describe_roles(self, context, project_roles=True, **kwargs): """Returns a list of allowed roles.""" roles = manager.AuthManager().get_roles(project_roles) return { 'roles': [{'role': r} for r in roles]} - @admin_only def describe_user_roles(self, context, user, project=None, **kwargs): """Returns a list of roles for the given user. Omitting project will return any global roles that the user has. @@ -122,7 +101,6 @@ class AdminController(object): roles = manager.AuthManager().get_user_roles(user, project=project) return { 'roles': [{'role': r} for r in roles]} - @admin_only def modify_user_role(self, context, user, role, project=None, operation='add', **kwargs): """Add or remove a role for a user and project.""" @@ -135,7 +113,6 @@ class AdminController(object): return True - @admin_only def generate_x509_for_user(self, _context, name, project=None, **kwargs): """Generates and returns an x509 certificate for a single user. Is usually called from a client that will wrap this with @@ -147,19 +124,16 @@ class AdminController(object): user = manager.AuthManager().get_user(name) return user_dict(user, base64.b64encode(project.get_credentials(user))) - @admin_only def describe_project(self, context, name, **kwargs): """Returns project data, including member ids.""" return project_dict(manager.AuthManager().get_project(name)) - @admin_only def describe_projects(self, context, user=None, **kwargs): """Returns all projects - should be changed to deal with a list.""" return {'projectSet': [project_dict(u) for u in manager.AuthManager().get_projects(user=user)]} - @admin_only def register_project(self, context, name, manager_user, description=None, member_users=None, **kwargs): """Creates a new project""" @@ -170,20 +144,17 @@ class AdminController(object): description=None, member_users=None)) - @admin_only def deregister_project(self, context, name): """Permanently deletes a project.""" manager.AuthManager().delete_project(name) return True - @admin_only def describe_project_members(self, context, name, **kwargs): project = manager.AuthManager().get_project(name) result = { 'members': [{'member': m} for m in project.member_ids]} return result - @admin_only def modify_project_member(self, context, user, project, operation, **kwargs): """Add or remove a user from a project.""" if operation =='add': @@ -196,7 +167,6 @@ class AdminController(object): # FIXME(vish): these host commands don't work yet, perhaps some of the # required data can be retrieved from service objects? - @admin_only def describe_hosts(self, _context, **_kwargs): """Returns status info for all nodes. Includes: * Disk Space @@ -208,7 +178,6 @@ class AdminController(object): """ return {'hostSet': [host_dict(h) for h in db.host_get_all()]} - @admin_only def describe_host(self, _context, name, **_kwargs): """Returns status info for single node.""" return host_dict(db.host_get(name)) diff --git a/nova/api/ec2/apirequest.py b/nova/api/ec2/apirequest.py new file mode 100644 index 000000000..a87c21fb3 --- /dev/null +++ b/nova/api/ec2/apirequest.py @@ -0,0 +1,131 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +APIRequest class +""" + +import logging +import re +# TODO(termie): replace minidom with etree +from xml.dom import minidom + +_log = logging.getLogger("api") +_log.setLevel(logging.DEBUG) + + +_c2u = re.compile('(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))') + + +def _camelcase_to_underscore(str): + return _c2u.sub(r'_\1', str).lower().strip('_') + + +def _underscore_to_camelcase(str): + return ''.join([x[:1].upper() + x[1:] for x in str.split('_')]) + + +def _underscore_to_xmlcase(str): + res = _underscore_to_camelcase(str) + return res[:1].lower() + res[1:] + + +class APIRequest(object): + def __init__(self, controller, action): + self.controller = controller + self.action = action + + def send(self, context, **kwargs): + try: + method = getattr(self.controller, + _camelcase_to_underscore(self.action)) + except AttributeError: + _error = ('Unsupported API request: controller = %s,' + 'action = %s') % (self.controller, self.action) + _log.warning(_error) + # TODO: Raise custom exception, trap in apiserver, + # and reraise as 400 error. + raise Exception(_error) + + args = {} + for key, value in kwargs.items(): + parts = key.split(".") + key = _camelcase_to_underscore(parts[0]) + if len(parts) > 1: + d = args.get(key, {}) + d[parts[1]] = value + value = d + args[key] = value + + for key in args.keys(): + if isinstance(args[key], dict): + if args[key] != {} and args[key].keys()[0].isdigit(): + s = args[key].items() + s.sort() + args[key] = [v for k, v in s] + + result = method(context, **args) + return self._render_response(result, context.request_id) + + def _render_response(self, response_data, request_id): + xml = minidom.Document() + + response_el = xml.createElement(self.action + 'Response') + response_el.setAttribute('xmlns', + 'http://ec2.amazonaws.com/doc/2009-11-30/') + request_id_el = xml.createElement('requestId') + request_id_el.appendChild(xml.createTextNode(request_id)) + response_el.appendChild(request_id_el) + if(response_data == True): + self._render_dict(xml, response_el, {'return': 'true'}) + else: + self._render_dict(xml, response_el, response_data) + + xml.appendChild(response_el) + + response = xml.toxml() + xml.unlink() + _log.debug(response) + return response + + def _render_dict(self, xml, el, data): + try: + for key in data.keys(): + val = data[key] + el.appendChild(self._render_data(xml, key, val)) + except: + _log.debug(data) + raise + + def _render_data(self, xml, el_name, data): + el_name = _underscore_to_xmlcase(el_name) + data_el = xml.createElement(el_name) + + if isinstance(data, list): + for item in data: + data_el.appendChild(self._render_data(xml, 'item', item)) + elif isinstance(data, dict): + self._render_dict(xml, data_el, data) + elif hasattr(data, '__dict__'): + self._render_dict(xml, data_el, data.__dict__) + elif isinstance(data, bool): + data_el.appendChild(xml.createTextNode(str(data).lower())) + elif data != None: + data_el.appendChild(xml.createTextNode(str(data))) + + return data_el diff --git a/nova/endpoint/cloud.py b/nova/api/ec2/cloud.py index 339a67b0e..1f01731ae 100644 --- a/nova/endpoint/cloud.py +++ b/nova/api/ec2/cloud.py @@ -28,18 +28,16 @@ import logging import os import time -from twisted.internet import defer - +from nova import crypto from nova import db from nova import exception from nova import flags from nova import quota from nova import rpc from nova import utils -from nova.auth import rbac -from nova.auth import manager from nova.compute.instance_types import INSTANCE_TYPES -from nova.endpoint import images +from nova.api import cloud +from nova.api.ec2 import images FLAGS = flags.FLAGS @@ -51,13 +49,26 @@ class QuotaError(exception.ApiError): pass -def _gen_key(user_id, key_name): - """ Tuck this into AuthManager """ +def _gen_key(context, user_id, key_name): + """Generate a key + + This is a module level method because it is slow and we need to defer + it into a process pool.""" + # NOTE(vish): generating key pair is slow so check for legal + # creation before creating key_pair try: - mgr = manager.AuthManager() - private_key, fingerprint = mgr.generate_key_pair(user_id, key_name) - except Exception as ex: - return {'exception': ex} + db.key_pair_get(context, user_id, key_name) + raise exception.Duplicate("The key_pair %s already exists" + % key_name) + except exception.NotFound: + pass + private_key, public_key, fingerprint = crypto.generate_key_pair() + key = {} + key['user_id'] = user_id + key['name'] = key_name + key['public_key'] = public_key + key['fingerprint'] = fingerprint + db.key_pair_create(context, key) return {'private_key': private_key, 'fingerprint': fingerprint} @@ -91,14 +102,15 @@ class CloudController(object): def _get_mpi_data(self, project_id): result = {} - for instance in db.instance_get_by_project(None, project_id): + for instance in db.instance_get_all_by_project(None, project_id): if instance['fixed_ip']: - line = '%s slots=%d' % (instance['fixed_ip']['str_id'], + line = '%s slots=%d' % (instance['fixed_ip']['address'], INSTANCE_TYPES[instance['instance_type']]['vcpus']) - if instance['key_name'] in result: - result[instance['key_name']].append(line) + key = str(instance['key_name']) + if key in result: + result[key].append(line) else: - result[instance['key_name']] = [line] + result[key] = [line] return result def get_metadata(self, address): @@ -132,7 +144,7 @@ class CloudController(object): }, 'hostname': hostname, 'instance-action': 'none', - 'instance-id': instance_ref['str_id'], + 'instance-id': instance_ref['ec2_id'], 'instance-type': instance_ref['instance_type'], 'local-hostname': hostname, 'local-ipv4': address, @@ -155,18 +167,24 @@ class CloudController(object): data['product-codes'] = [] return data - @rbac.allow('all') def describe_availability_zones(self, context, **kwargs): return {'availabilityZoneInfo': [{'zoneName': 'nova', 'zoneState': 'available'}]} - @rbac.allow('all') def describe_regions(self, context, region_name=None, **kwargs): - # TODO(vish): region_name is an array. Support filtering - return {'regionInfo': [{'regionName': 'nova', - 'regionUrl': FLAGS.ec2_url}]} + if FLAGS.region_list: + regions = [] + for region in FLAGS.region_list: + name, _sep, url = region.partition('=') + regions.append({'regionName': name, + 'regionEndpoint': url}) + else: + regions = [{'regionName': 'nova', + 'regionEndpoint': FLAGS.ec2_url}] + if region_name: + regions = [r for r in regions if r['regionName'] in region_name] + return {'regionInfo': regions } - @rbac.allow('all') def describe_snapshots(self, context, snapshot_id=None, @@ -182,64 +200,53 @@ class CloudController(object): 'volumeSize': 0, 'description': 'fixme'}]} - @rbac.allow('all') def describe_key_pairs(self, context, key_name=None, **kwargs): - key_pairs = context.user.get_key_pairs() + key_pairs = db.key_pair_get_all_by_user(context, context.user.id) if not key_name is None: - key_pairs = [x for x in key_pairs if x.name in key_name] + key_pairs = [x for x in key_pairs if x['name'] in key_name] result = [] for key_pair in key_pairs: # filter out the vpn keys suffix = FLAGS.vpn_key_suffix - if context.user.is_admin() or not key_pair.name.endswith(suffix): + if context.user.is_admin() or not key_pair['name'].endswith(suffix): result.append({ - 'keyName': key_pair.name, - 'keyFingerprint': key_pair.fingerprint, + 'keyName': key_pair['name'], + 'keyFingerprint': key_pair['fingerprint'], }) return {'keypairsSet': result} - @rbac.allow('all') def create_key_pair(self, context, key_name, **kwargs): - dcall = defer.Deferred() - pool = context.handler.application.settings.get('pool') - def _complete(kwargs): - if 'exception' in kwargs: - dcall.errback(kwargs['exception']) - return - dcall.callback({'keyName': key_name, - 'keyFingerprint': kwargs['fingerprint'], - 'keyMaterial': kwargs['private_key']}) - pool.apply_async(_gen_key, [context.user.id, key_name], - callback=_complete) - return dcall - - @rbac.allow('all') + data = _gen_key(None, context.user.id, key_name) + return {'keyName': key_name, + 'keyFingerprint': data['fingerprint'], + 'keyMaterial': data['private_key']} + # TODO(vish): when context is no longer an object, pass it here + def delete_key_pair(self, context, key_name, **kwargs): - context.user.delete_key_pair(key_name) - # aws returns true even if the key doens't exist + try: + db.key_pair_destroy(context, context.user.id, key_name) + except exception.NotFound: + # aws returns true even if the key doesn't exist + pass return True - @rbac.allow('all') def describe_security_groups(self, context, group_names, **kwargs): groups = {'securityGroupSet': []} # Stubbed for now to unblock other things. return groups - @rbac.allow('netadmin') def create_security_group(self, context, group_name, **kwargs): return True - @rbac.allow('netadmin') def delete_security_group(self, context, group_name, **kwargs): return True - @rbac.allow('projectmanager', 'sysadmin') def get_console_output(self, context, instance_id, **kwargs): # instance_id is passed in as a list of instances - instance_ref = db.instance_get_by_str(context, instance_id[0]) + instance_ref = db.instance_get_by_ec2_id(context, instance_id[0]) d = rpc.call('%s.%s' % (FLAGS.compute_topic, instance_ref['host']), { "method" : "get_console_output", @@ -251,12 +258,11 @@ class CloudController(object): "output": base64.b64encode(output)}) return d - @rbac.allow('projectmanager', 'sysadmin') def describe_volumes(self, context, **kwargs): if context.user.is_admin(): volumes = db.volume_get_all(context) else: - volumes = db.volume_get_by_project(context, context.project.id) + volumes = db.volume_get_all_by_project(context, context.project.id) volumes = [self._format_volume(context, v) for v in volumes] @@ -264,7 +270,7 @@ class CloudController(object): def _format_volume(self, context, volume): v = {} - v['volumeId'] = volume['str_id'] + v['volumeId'] = volume['ec2_id'] v['status'] = volume['status'] v['size'] = volume['size'] v['availabilityZone'] = volume['availability_zone'] @@ -282,15 +288,16 @@ class CloudController(object): 'device': volume['mountpoint'], 'instanceId': volume['instance_id'], 'status': 'attached', - 'volume_id': volume['str_id']}] + 'volume_id': volume['ec2_id']}] else: v['attachmentSet'] = [{}] + + v['display_name'] = volume['display_name'] + v['display_description'] = volume['display_description'] return v - @rbac.allow('projectmanager', 'sysadmin') def create_volume(self, context, size, **kwargs): # check quota - size = int(size) if quota.allowed_volumes(context, 1, size) < 1: logging.warn("Quota exceeeded for %s, tried to create %sG volume", context.project.id, size) @@ -304,6 +311,8 @@ class CloudController(object): vol['availability_zone'] = FLAGS.storage_availability_zone vol['status'] = "creating" vol['attach_status'] = "detached" + vol['display_name'] = kwargs.get('display_name') + vol['display_description'] = kwargs.get('display_description') volume_ref = db.volume_create(context, vol) rpc.cast(FLAGS.scheduler_topic, @@ -315,15 +324,14 @@ class CloudController(object): return {'volumeSet': [self._format_volume(context, volume_ref)]} - @rbac.allow('projectmanager', 'sysadmin') def attach_volume(self, context, volume_id, instance_id, device, **kwargs): - volume_ref = db.volume_get_by_str(context, volume_id) + volume_ref = db.volume_get_by_ec2_id(context, volume_id) # TODO(vish): abstract status checking? if volume_ref['status'] != "available": raise exception.ApiError("Volume status must be available") if volume_ref['attach_status'] == "attached": raise exception.ApiError("Volume is already attached") - instance_ref = db.instance_get_by_str(context, instance_id) + instance_ref = db.instance_get_by_ec2_id(context, instance_id) host = instance_ref['host'] rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host), {"method": "attach_volume", @@ -331,16 +339,15 @@ class CloudController(object): "volume_id": volume_ref['id'], "instance_id": instance_ref['id'], "mountpoint": device}}) - return defer.succeed({'attachTime': volume_ref['attach_time'], - 'device': volume_ref['mountpoint'], - 'instanceId': instance_ref['id'], - 'requestId': context.request_id, - 'status': volume_ref['attach_status'], - 'volumeId': volume_ref['id']}) - - @rbac.allow('projectmanager', 'sysadmin') + return {'attachTime': volume_ref['attach_time'], + 'device': volume_ref['mountpoint'], + 'instanceId': instance_ref['id'], + 'requestId': context.request_id, + 'status': volume_ref['attach_status'], + 'volumeId': volume_ref['id']} + def detach_volume(self, context, volume_id, **kwargs): - volume_ref = db.volume_get_by_str(context, volume_id) + volume_ref = db.volume_get_by_ec2_id(context, volume_id) instance_ref = db.volume_get_instance(context, volume_ref['id']) if not instance_ref: raise exception.ApiError("Volume isn't attached to anything!") @@ -358,12 +365,12 @@ class CloudController(object): # If the instance doesn't exist anymore, # then we need to call detach blind db.volume_detached(context) - return defer.succeed({'attachTime': volume_ref['attach_time'], - 'device': volume_ref['mountpoint'], - 'instanceId': instance_ref['str_id'], - 'requestId': context.request_id, - 'status': volume_ref['attach_status'], - 'volumeId': volume_ref['id']}) + return {'attachTime': volume_ref['attach_time'], + 'device': volume_ref['mountpoint'], + 'instanceId': instance_ref['ec2_id'], + 'requestId': context.request_id, + 'status': volume_ref['attach_status'], + 'volumeId': volume_ref['id']} def _convert_to_set(self, lst, label): if lst == None or lst == []: @@ -372,9 +379,18 @@ class CloudController(object): lst = [lst] return [{label: x} for x in lst] - @rbac.allow('all') + def update_volume(self, context, volume_id, **kwargs): + updatable_fields = ['display_name', 'display_description'] + changes = {} + for field in updatable_fields: + if field in kwargs: + changes[field] = kwargs[field] + if changes: + db.volume_update(context, volume_id, kwargs) + return True + def describe_instances(self, context, **kwargs): - return defer.succeed(self._format_describe_instances(context)) + return self._format_describe_instances(context) def _format_describe_instances(self, context): return { 'reservationSet': self._format_instances(context) } @@ -387,20 +403,20 @@ class CloudController(object): def _format_instances(self, context, reservation_id=None): reservations = {} if reservation_id: - instances = db.instance_get_by_reservation(context, - reservation_id) + instances = db.instance_get_all_by_reservation(context, + reservation_id) else: if context.user.is_admin(): instances = db.instance_get_all(context) else: - instances = db.instance_get_by_project(context, - context.project.id) + instances = db.instance_get_all_by_project(context, + context.project.id) for instance in instances: if not context.user.is_admin(): if instance['image_id'] == FLAGS.vpn_image_id: continue i = {} - i['instanceId'] = instance['str_id'] + i['instanceId'] = instance['ec2_id'] i['imageId'] = instance['image_id'] i['instanceState'] = { 'code': instance['state'], @@ -409,10 +425,10 @@ class CloudController(object): fixed_addr = None floating_addr = None if instance['fixed_ip']: - fixed_addr = instance['fixed_ip']['str_id'] + fixed_addr = instance['fixed_ip']['address'] if instance['fixed_ip']['floating_ips']: fixed = instance['fixed_ip'] - floating_addr = fixed['floating_ips'][0]['str_id'] + floating_addr = fixed['floating_ips'][0]['address'] i['privateDnsName'] = fixed_addr i['publicDnsName'] = floating_addr i['dnsName'] = i['publicDnsName'] or i['privateDnsName'] @@ -425,6 +441,8 @@ class CloudController(object): i['instanceType'] = instance['instance_type'] i['launchTime'] = instance['created_at'] i['amiLaunchIndex'] = instance['launch_index'] + i['displayName'] = instance['display_name'] + i['displayDescription'] = instance['display_description'] if not reservations.has_key(instance['reservation_id']): r = {} r['reservationId'] = instance['reservation_id'] @@ -436,7 +454,6 @@ class CloudController(object): return list(reservations.values()) - @rbac.allow('all') def describe_addresses(self, context, **kwargs): return self.format_addresses(context) @@ -445,14 +462,14 @@ class CloudController(object): if context.user.is_admin(): iterator = db.floating_ip_get_all(context) else: - iterator = db.floating_ip_get_by_project(context, - context.project.id) + iterator = db.floating_ip_get_all_by_project(context, + context.project.id) for floating_ip_ref in iterator: - address = floating_ip_ref['str_id'] + address = floating_ip_ref['address'] instance_id = None if (floating_ip_ref['fixed_ip'] and floating_ip_ref['fixed_ip']['instance']): - instance_id = floating_ip_ref['fixed_ip']['instance']['str_id'] + instance_id = floating_ip_ref['fixed_ip']['instance']['ec2_id'] address_rv = {'public_ip': address, 'instance_id': instance_id} if context.user.is_admin(): @@ -462,8 +479,6 @@ class CloudController(object): addresses.append(address_rv) return {'addressesSet': addresses} - @rbac.allow('netadmin') - @defer.inlineCallbacks def allocate_address(self, context, **kwargs): # check quota if quota.allowed_floating_ips(context, 1) < 1: @@ -471,64 +486,56 @@ class CloudController(object): context.project.id) raise QuotaError("Address quota exceeded. You cannot " "allocate any more addresses") - network_topic = yield self._get_network_topic(context) - public_ip = yield rpc.call(network_topic, + network_topic = self._get_network_topic(context) + public_ip = rpc.call(network_topic, {"method": "allocate_floating_ip", "args": {"context": None, "project_id": context.project.id}}) - defer.returnValue({'addressSet': [{'publicIp': public_ip}]}) + return {'addressSet': [{'publicIp': public_ip}]} - @rbac.allow('netadmin') - @defer.inlineCallbacks def release_address(self, context, public_ip, **kwargs): # NOTE(vish): Should we make sure this works? floating_ip_ref = db.floating_ip_get_by_address(context, public_ip) - network_topic = yield self._get_network_topic(context) + network_topic = self._get_network_topic(context) rpc.cast(network_topic, {"method": "deallocate_floating_ip", "args": {"context": None, - "floating_address": floating_ip_ref['str_id']}}) - defer.returnValue({'releaseResponse': ["Address released."]}) + "floating_address": floating_ip_ref['address']}}) + return {'releaseResponse': ["Address released."]} - @rbac.allow('netadmin') - @defer.inlineCallbacks def associate_address(self, context, instance_id, public_ip, **kwargs): - instance_ref = db.instance_get_by_str(context, instance_id) - fixed_ip_ref = db.fixed_ip_get_by_instance(context, instance_ref['id']) + instance_ref = db.instance_get_by_ec2_id(context, instance_id) + fixed_address = db.instance_get_fixed_address(context, + instance_ref['id']) floating_ip_ref = db.floating_ip_get_by_address(context, public_ip) - network_topic = yield self._get_network_topic(context) + network_topic = self._get_network_topic(context) rpc.cast(network_topic, {"method": "associate_floating_ip", "args": {"context": None, - "floating_address": floating_ip_ref['str_id'], - "fixed_address": fixed_ip_ref['str_id']}}) - defer.returnValue({'associateResponse': ["Address associated."]}) + "floating_address": floating_ip_ref['address'], + "fixed_address": fixed_address}}) + return {'associateResponse': ["Address associated."]} - @rbac.allow('netadmin') - @defer.inlineCallbacks def disassociate_address(self, context, public_ip, **kwargs): floating_ip_ref = db.floating_ip_get_by_address(context, public_ip) - network_topic = yield self._get_network_topic(context) + network_topic = self._get_network_topic(context) rpc.cast(network_topic, {"method": "disassociate_floating_ip", "args": {"context": None, - "floating_address": floating_ip_ref['str_id']}}) - defer.returnValue({'disassociateResponse': ["Address disassociated."]}) + "floating_address": floating_ip_ref['address']}}) + return {'disassociateResponse': ["Address disassociated."]} - @defer.inlineCallbacks def _get_network_topic(self, context): """Retrieves the network host for a project""" network_ref = db.project_get_network(context, context.project.id) host = network_ref['host'] if not host: - host = yield rpc.call(FLAGS.network_topic, + host = rpc.call(FLAGS.network_topic, {"method": "set_network_host", "args": {"context": None, "project_id": context.project.id}}) - defer.returnValue(db.queue_get_for(context, FLAGS.network_topic, host)) + return db.queue_get_for(context, FLAGS.network_topic, host) - @rbac.allow('projectmanager', 'sysadmin') - @defer.inlineCallbacks def run_instances(self, context, **kwargs): instance_type = kwargs.get('instance_type', 'm1.small') if instance_type not in INSTANCE_TYPES: @@ -571,11 +578,10 @@ class CloudController(object): launch_time = time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime()) key_data = None if kwargs.has_key('key_name'): - key_pair = context.user.get_key_pair(kwargs['key_name']) - if not key_pair: - raise exception.ApiError('Key Pair %s not found' % - kwargs['key_name']) - key_data = key_pair.public_key + key_pair_ref = db.key_pair_get(context, + context.user.id, + kwargs['key_name']) + key_data = key_pair_ref['public_key'] # TODO: Get the real security group of launch in here security_group = "default" @@ -594,6 +600,8 @@ class CloudController(object): base_options['user_data'] = kwargs.get('user_data', '') base_options['security_group'] = security_group base_options['instance_type'] = instance_type + base_options['display_name'] = kwargs.get('display_name') + base_options['display_description'] = kwargs.get('display_description') type_data = INSTANCE_TYPES[instance_type] base_options['memory_mb'] = type_data['memory_mb'] @@ -607,7 +615,7 @@ class CloudController(object): inst = {} inst['mac_address'] = utils.generate_mac() inst['launch_index'] = num - inst['hostname'] = instance_ref['str_id'] + inst['hostname'] = instance_ref['ec2_id'] db.instance_update(context, inst_id, inst) address = self.network_manager.allocate_fixed_ip(context, inst_id, @@ -615,7 +623,7 @@ class CloudController(object): # TODO(vish): This probably should be done in the scheduler # network is setup when host is assigned - network_topic = yield self._get_network_topic(context) + network_topic = self._get_network_topic(context) rpc.call(network_topic, {"method": "setup_fixed_ip", "args": {"context": None, @@ -628,18 +636,15 @@ class CloudController(object): "instance_id": inst_id}}) logging.debug("Casting to scheduler for %s/%s's instance %s" % (context.project.name, context.user.name, inst_id)) - defer.returnValue(self._format_run_instances(context, - reservation_id)) + return self._format_run_instances(context, reservation_id) - @rbac.allow('projectmanager', 'sysadmin') - @defer.inlineCallbacks def terminate_instances(self, context, instance_id, **kwargs): logging.debug("Going to start terminating instances") for id_str in instance_id: logging.debug("Going to try and terminate %s" % id_str) try: - instance_ref = db.instance_get_by_str(context, id_str) + instance_ref = db.instance_get_by_ec2_id(context, id_str) except exception.NotFound: logging.warning("Instance %s was not found during terminate" % id_str) @@ -657,11 +662,11 @@ class CloudController(object): # NOTE(vish): Right now we don't really care if the ip is # disassociated. We may need to worry about # checking this later. Perhaps in the scheduler? - network_topic = yield self._get_network_topic(context) + network_topic = self._get_network_topic(context) rpc.cast(network_topic, {"method": "disassociate_floating_ip", "args": {"context": None, - "address": address}}) + "floating_address": address}}) address = db.instance_get_fixed_address(context, instance_ref['id']) @@ -680,24 +685,29 @@ class CloudController(object): "instance_id": instance_ref['id']}}) else: db.instance_destroy(context, instance_ref['id']) - defer.returnValue(True) + return True - @rbac.allow('projectmanager', 'sysadmin') def reboot_instances(self, context, instance_id, **kwargs): """instance_id is a list of instance ids""" for id_str in instance_id: - instance_ref = db.instance_get_by_str(context, id_str) - host = instance_ref['host'] - rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host), - {"method": "reboot_instance", - "args": {"context": None, - "instance_id": instance_ref['id']}}) - return defer.succeed(True) + cloud.reboot(id_str, context=context) + return True + + def update_instance(self, context, instance_id, **kwargs): + updatable_fields = ['display_name', 'display_description'] + changes = {} + for field in updatable_fields: + if field in kwargs: + changes[field] = kwargs[field] + if changes: + db_context = {} + inst = db.instance_get_by_ec2_id(db_context, instance_id) + db.instance_update(db_context, inst['id'], kwargs) + return True - @rbac.allow('projectmanager', 'sysadmin') def delete_volume(self, context, volume_id, **kwargs): # TODO: return error if not authorized - volume_ref = db.volume_get_by_str(context, volume_id) + volume_ref = db.volume_get_by_ec2_id(context, volume_id) if volume_ref['status'] != "available": raise exception.ApiError("Volume status must be available") now = datetime.datetime.utcnow() @@ -707,31 +717,26 @@ class CloudController(object): {"method": "delete_volume", "args": {"context": None, "volume_id": volume_ref['id']}}) - return defer.succeed(True) + return True - @rbac.allow('all') def describe_images(self, context, image_id=None, **kwargs): # The objectstore does its own authorization for describe imageSet = images.list(context, image_id) - return defer.succeed({'imagesSet': imageSet}) + return {'imagesSet': imageSet} - @rbac.allow('projectmanager', 'sysadmin') def deregister_image(self, context, image_id, **kwargs): # FIXME: should the objectstore be doing these authorization checks? images.deregister(context, image_id) - return defer.succeed({'imageId': image_id}) + return {'imageId': image_id} - @rbac.allow('projectmanager', 'sysadmin') def register_image(self, context, image_location=None, **kwargs): # FIXME: should the objectstore be doing these authorization checks? if image_location is None and kwargs.has_key('name'): image_location = kwargs['name'] image_id = images.register(context, image_location) logging.debug("Registered %s as %s" % (image_location, image_id)) + return {'imageId': image_id} - return defer.succeed({'imageId': image_id}) - - @rbac.allow('all') def describe_image_attribute(self, context, image_id, attribute, **kwargs): if attribute != 'launchPermission': raise exception.ApiError('attribute not supported: %s' % attribute) @@ -742,9 +747,8 @@ class CloudController(object): result = {'image_id': image_id, 'launchPermission': []} if image['isPublic']: result['launchPermission'].append({'group': 'all'}) - return defer.succeed(result) + return result - @rbac.allow('projectmanager', 'sysadmin') def modify_image_attribute(self, context, image_id, attribute, operation_type, **kwargs): # TODO(devcamcar): Support users and groups other than 'all'. if attribute != 'launchPermission': @@ -755,5 +759,8 @@ class CloudController(object): raise exception.ApiError('only group "all" is supported') if not operation_type in ['add', 'remove']: raise exception.ApiError('operation_type must be add or remove') - result = images.modify(context, image_id, operation_type) - return defer.succeed(result) + return images.modify(context, image_id, operation_type) + + def update_image(self, context, image_id, **kwargs): + result = images.update(context, image_id, dict(kwargs)) + return result diff --git a/nova/endpoint/images.py b/nova/api/ec2/images.py index 4579cd81a..cb54cdda2 100644 --- a/nova/endpoint/images.py +++ b/nova/api/ec2/images.py @@ -43,6 +43,14 @@ def modify(context, image_id, operation): return True +def update(context, image_id, attributes): + """update an image's attributes / info.json""" + attributes.update({"image_id": image_id}) + conn(context).make_request( + method='POST', + bucket='_images', + query_args=qs(attributes)) + return True def register(context, image_location): """ rpc call to register a new image based from a manifest """ diff --git a/nova/api/ec2/metadatarequesthandler.py b/nova/api/ec2/metadatarequesthandler.py new file mode 100644 index 000000000..08a8040ca --- /dev/null +++ b/nova/api/ec2/metadatarequesthandler.py @@ -0,0 +1,73 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Metadata request handler.""" + +import logging + +import webob.dec +import webob.exc + +from nova.api.ec2 import cloud + + +class MetadataRequestHandler(object): + + """Serve metadata from the EC2 API.""" + + def print_data(self, data): + if isinstance(data, dict): + output = '' + for key in data: + if key == '_name': + continue + output += key + if isinstance(data[key], dict): + if '_name' in data[key]: + output += '=' + str(data[key]['_name']) + else: + output += '/' + output += '\n' + return output[:-1] # cut off last \n + elif isinstance(data, list): + return '\n'.join(data) + else: + return str(data) + + def lookup(self, path, data): + items = path.split('/') + for item in items: + if item: + if not isinstance(data, dict): + return data + if not item in data: + return None + data = data[item] + return data + + @webob.dec.wsgify + def __call__(self, req): + cc = cloud.CloudController() + meta_data = cc.get_metadata(req.remote_addr) + if meta_data is None: + logging.error('Failed to get metadata for ip: %s' % req.remote_addr) + raise webob.exc.HTTPNotFound() + data = self.lookup(req.path_info, meta_data) + if data is None: + raise webob.exc.HTTPNotFound() + return self.print_data(data) diff --git a/nova/api/rackspace/__init__.py b/nova/api/rackspace/__init__.py index b4d666d63..89a4693ad 100644 --- a/nova/api/rackspace/__init__.py +++ b/nova/api/rackspace/__init__.py @@ -26,44 +26,122 @@ import time import routes import webob.dec import webob.exc +import webob from nova import flags +from nova import utils from nova import wsgi +from nova.api.rackspace import faults +from nova.api.rackspace import backup_schedules from nova.api.rackspace import flavors from nova.api.rackspace import images +from nova.api.rackspace import ratelimiting from nova.api.rackspace import servers from nova.api.rackspace import sharedipgroups from nova.auth import manager +FLAGS = flags.FLAGS +flags.DEFINE_string('nova_api_auth', + 'nova.api.rackspace.auth.BasicApiAuthManager', + 'The auth mechanism to use for the Rackspace API implemenation') + class API(wsgi.Middleware): """WSGI entry point for all Rackspace API requests.""" def __init__(self): - app = AuthMiddleware(APIRouter()) + app = AuthMiddleware(RateLimitingMiddleware(APIRouter())) super(API, self).__init__(app) - class AuthMiddleware(wsgi.Middleware): """Authorize the rackspace API request or return an HTTP Forbidden.""" - #TODO(gundlach): isn't this the old Nova API's auth? Should it be replaced - #with correct RS API auth? + def __init__(self, application): + self.auth_driver = utils.import_class(FLAGS.nova_api_auth)() + super(AuthMiddleware, self).__init__(application) + + @webob.dec.wsgify + def __call__(self, req): + if not req.headers.has_key("X-Auth-Token"): + return self.auth_driver.authenticate(req) + + user = self.auth_driver.authorize_token(req.headers["X-Auth-Token"]) + + if not user: + return faults.Fault(webob.exc.HTTPUnauthorized()) + + if not req.environ.has_key('nova.context'): + req.environ['nova.context'] = {} + req.environ['nova.context']['user'] = user + return self.application + +class RateLimitingMiddleware(wsgi.Middleware): + """Rate limit incoming requests according to the OpenStack rate limits.""" + + def __init__(self, application, service_host=None): + """Create a rate limiting middleware that wraps the given application. + + By default, rate counters are stored in memory. If service_host is + specified, the middleware instead relies on the ratelimiting.WSGIApp + at the given host+port to keep rate counters. + """ + super(RateLimitingMiddleware, self).__init__(application) + if not service_host: + #TODO(gundlach): These limits were based on limitations of Cloud + #Servers. We should revisit them in Nova. + self.limiter = ratelimiting.Limiter(limits={ + 'DELETE': (100, ratelimiting.PER_MINUTE), + 'PUT': (10, ratelimiting.PER_MINUTE), + 'POST': (10, ratelimiting.PER_MINUTE), + 'POST servers': (50, ratelimiting.PER_DAY), + 'GET changes-since': (3, ratelimiting.PER_MINUTE), + }) + else: + self.limiter = ratelimiting.WSGIAppProxy(service_host) @webob.dec.wsgify def __call__(self, req): - context = {} - if "HTTP_X_AUTH_TOKEN" in req.environ: - context['user'] = manager.AuthManager().get_user_from_access_key( - req.environ['HTTP_X_AUTH_TOKEN']) - if context['user']: - context['project'] = manager.AuthManager().get_project( - context['user'].name) - if "user" not in context: - return webob.exc.HTTPForbidden() - req.environ['nova.context'] = context + """Rate limit the request. + + If the request should be rate limited, return a 413 status with a + Retry-After header giving the time when the request would succeed. + """ + username = req.headers['X-Auth-User'] + action_name = self.get_action_name(req) + if not action_name: # not rate limited + return self.application + delay = self.get_delay(action_name, username) + if delay: + # TODO(gundlach): Get the retry-after format correct. + exc = webob.exc.HTTPRequestEntityTooLarge( + explanation='Too many requests.', + headers={'Retry-After': time.time() + delay}) + raise faults.Fault(exc) return self.application + def get_delay(self, action_name, username): + """Return the delay for the given action and username, or None if + the action would not be rate limited. + """ + if action_name == 'POST servers': + # "POST servers" is a POST, so it counts against "POST" too. + # Attempt the "POST" first, lest we are rate limited by "POST" but + # use up a precious "POST servers" call. + delay = self.limiter.perform("POST", username=username) + if delay: + return delay + return self.limiter.perform(action_name, username=username) + + def get_action_name(self, req): + """Return the action name for this request.""" + if req.method == 'GET' and 'changes-since' in req.GET: + return 'GET changes-since' + if req.method == 'POST' and req.path_info.startswith('/servers'): + return 'POST servers' + if req.method in ['PUT', 'POST', 'DELETE']: + return req.method + return None + class APIRouter(wsgi.Router): """ @@ -73,11 +151,40 @@ class APIRouter(wsgi.Router): def __init__(self): mapper = routes.Mapper() - mapper.resource("server", "servers", controller=servers.Controller()) + mapper.resource("server", "servers", controller=servers.Controller(), + collection={ 'detail': 'GET'}, + member={'action':'POST'}) + + mapper.resource("backup_schedule", "backup_schedules", + controller=backup_schedules.Controller(), + parent_resource=dict(member_name='server', + collection_name = 'servers')) + mapper.resource("image", "images", controller=images.Controller(), collection={'detail': 'GET'}) mapper.resource("flavor", "flavors", controller=flavors.Controller(), collection={'detail': 'GET'}) mapper.resource("sharedipgroup", "sharedipgroups", controller=sharedipgroups.Controller()) + super(APIRouter, self).__init__(mapper) + + +def limited(items, req): + """Return a slice of items according to requested offset and limit. + + items - a sliceable + req - wobob.Request possibly containing offset and limit GET variables. + offset is where to start in the list, and limit is the maximum number + of items to return. + + If limit is not specified, 0, or > 1000, defaults to 1000. + """ + offset = int(req.GET.get('offset', 0)) + limit = int(req.GET.get('limit', 0)) + if not limit: + limit = 1000 + limit = min(1000, limit) + range_end = offset + limit + return items[offset:range_end] + diff --git a/nova/api/rackspace/_id_translator.py b/nova/api/rackspace/_id_translator.py index aec5fb6a5..333aa8434 100644 --- a/nova/api/rackspace/_id_translator.py +++ b/nova/api/rackspace/_id_translator.py @@ -37,6 +37,6 @@ class RackspaceAPIIdTranslator(object): # every int id be used.) return int(self._store.hget(self._fwd_key, str(opaque_id))) - def from_rs_id(self, strategy_name, rs_id): + def from_rs_id(self, rs_id): """Convert a Rackspace id to a strategy-specific one.""" return self._store.hget(self._rev_key, rs_id) diff --git a/nova/api/rackspace/auth.py b/nova/api/rackspace/auth.py new file mode 100644 index 000000000..c45156ebd --- /dev/null +++ b/nova/api/rackspace/auth.py @@ -0,0 +1,101 @@ +import datetime +import hashlib +import json +import time + +import webob.exc +import webob.dec + +from nova import auth +from nova import db +from nova import flags +from nova import manager +from nova import utils +from nova.api.rackspace import faults + +FLAGS = flags.FLAGS + +class Context(object): + pass + +class BasicApiAuthManager(object): + """ Implements a somewhat rudimentary version of Rackspace Auth""" + + def __init__(self, host=None, db_driver=None): + if not host: + host = FLAGS.host + self.host = host + if not db_driver: + db_driver = FLAGS.db_driver + self.db = utils.import_object(db_driver) + self.auth = auth.manager.AuthManager() + self.context = Context() + super(BasicApiAuthManager, self).__init__() + + def authenticate(self, req): + # Unless the request is explicitly made against /<version>/ don't + # honor it + path_info = req.path_info + if len(path_info) > 1: + return faults.Fault(webob.exc.HTTPUnauthorized()) + + try: + username, key = req.headers['X-Auth-User'], \ + req.headers['X-Auth-Key'] + except KeyError: + return faults.Fault(webob.exc.HTTPUnauthorized()) + + username, key = req.headers['X-Auth-User'], req.headers['X-Auth-Key'] + token, user = self._authorize_user(username, key) + if user and token: + res = webob.Response() + res.headers['X-Auth-Token'] = token['token_hash'] + res.headers['X-Server-Management-Url'] = \ + token['server_management_url'] + res.headers['X-Storage-Url'] = token['storage_url'] + res.headers['X-CDN-Management-Url'] = token['cdn_management_url'] + res.content_type = 'text/plain' + res.status = '204' + return res + else: + return faults.Fault(webob.exc.HTTPUnauthorized()) + + def authorize_token(self, token_hash): + """ retrieves user information from the datastore given a token + + If the token has expired, returns None + If the token is not found, returns None + Otherwise returns the token + + This method will also remove the token if the timestamp is older than + 2 days ago. + """ + token = self.db.auth_get_token(self.context, token_hash) + if token: + delta = datetime.datetime.now() - token['created_at'] + if delta.days >= 2: + self.db.auth_destroy_token(self.context, token) + else: + user = self.auth.get_user(token['user_id']) + return { 'id':user['uid'] } + return None + + def _authorize_user(self, username, key): + """ Generates a new token and assigns it to a user """ + user = self.auth.get_user_from_access_key(key) + if user and user['name'] == username: + token_hash = hashlib.sha1('%s%s%f' % (username, key, + time.time())).hexdigest() + token = {} + token['token_hash'] = token_hash + token['cdn_management_url'] = '' + token['server_management_url'] = self._get_server_mgmt_url() + token['storage_url'] = '' + token['user_id'] = user['uid'] + self.db.auth_create_token(self.context, token) + return token, user + return None, None + + def _get_server_mgmt_url(self): + return 'https://%s/v1.0/' % self.host + diff --git a/nova/api/rackspace/backup_schedules.py b/nova/api/rackspace/backup_schedules.py new file mode 100644 index 000000000..cb83023bc --- /dev/null +++ b/nova/api/rackspace/backup_schedules.py @@ -0,0 +1,39 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import time +from webob import exc + +from nova import wsgi +from nova.api.rackspace import _id_translator +from nova.api.rackspace import faults +import nova.image.service + +class Controller(wsgi.Controller): + def __init__(self): + pass + + def index(self, req, server_id): + return faults.Fault(exc.HTTPNotFound()) + + def create(self, req, server_id): + """ No actual update method required, since the existing API allows + both create and update through a POST """ + return faults.Fault(exc.HTTPNotFound()) + + def delete(self, req, server_id): + return faults.Fault(exc.HTTPNotFound()) diff --git a/nova/api/rackspace/base.py b/nova/api/rackspace/context.py index dd2c6543c..77394615b 100644 --- a/nova/api/rackspace/base.py +++ b/nova/api/rackspace/context.py @@ -15,16 +15,19 @@ # License for the specific language governing permissions and limitations # under the License. -from nova import wsgi +""" +APIRequestContext +""" +import random -class Controller(wsgi.Controller): - """TODO(eday): Base controller for all rackspace controllers. What is this - for? Is this just Rackspace specific? """ - - @classmethod - def render(cls, instance): - if isinstance(instance, list): - return {cls.entity_name: cls.render(instance)} - else: - return {"TODO": "TODO"} +class Project(object): + def __init__(self, user_id): + self.id = user_id + +class APIRequestContext(object): + """ This is an adapter class to get around all of the assumptions made in + the FlatNetworking """ + def __init__(self, user_id): + self.user_id = user_id + self.project = Project(user_id) diff --git a/nova/api/rackspace/faults.py b/nova/api/rackspace/faults.py new file mode 100644 index 000000000..32e5c866f --- /dev/null +++ b/nova/api/rackspace/faults.py @@ -0,0 +1,62 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import webob.dec +import webob.exc + +from nova import wsgi + + +class Fault(webob.exc.HTTPException): + + """An RS API fault response.""" + + _fault_names = { + 400: "badRequest", + 401: "unauthorized", + 403: "resizeNotAllowed", + 404: "itemNotFound", + 405: "badMethod", + 409: "inProgress", + 413: "overLimit", + 415: "badMediaType", + 501: "notImplemented", + 503: "serviceUnavailable"} + + def __init__(self, exception): + """Create a Fault for the given webob.exc.exception.""" + self.wrapped_exc = exception + + @webob.dec.wsgify + def __call__(self, req): + """Generate a WSGI response based on the exception passed to ctor.""" + # Replace the body with fault details. + code = self.wrapped_exc.status_int + fault_name = self._fault_names.get(code, "cloudServersFault") + fault_data = { + fault_name: { + 'code': code, + 'message': self.wrapped_exc.explanation}} + if code == 413: + retry = self.wrapped_exc.headers['Retry-After'] + fault_data[fault_name]['retryAfter'] = retry + # 'code' is an attribute on the fault tag itself + metadata = {'application/xml': {'attributes': {fault_name: 'code'}}} + serializer = wsgi.Serializer(req.environ, metadata) + self.wrapped_exc.body = serializer.to_content_type(fault_data) + return self.wrapped_exc diff --git a/nova/api/rackspace/flavors.py b/nova/api/rackspace/flavors.py index 60b35c939..916449854 100644 --- a/nova/api/rackspace/flavors.py +++ b/nova/api/rackspace/flavors.py @@ -15,11 +15,14 @@ # License for the specific language governing permissions and limitations # under the License. -from nova.api.rackspace import base -from nova.compute import instance_types from webob import exc -class Controller(base.Controller): +from nova.api.rackspace import faults +from nova.compute import instance_types +from nova import wsgi +import nova.api.rackspace + +class Controller(wsgi.Controller): """Flavor controller for the Rackspace API.""" _serialization_metadata = { @@ -38,6 +41,7 @@ class Controller(base.Controller): def detail(self, req): """Return all flavors in detail.""" items = [self.show(req, id)['flavor'] for id in self._all_ids()] + items = nova.api.rackspace.limited(items, req) return dict(flavors=items) def show(self, req, id): @@ -47,7 +51,7 @@ class Controller(base.Controller): item = dict(ram=val['memory_mb'], disk=val['local_gb'], id=val['flavorid'], name=name) return dict(flavor=item) - raise exc.HTTPNotFound() + raise faults.Fault(exc.HTTPNotFound()) def _all_ids(self): """Return the list of all flavorids.""" diff --git a/nova/api/rackspace/images.py b/nova/api/rackspace/images.py index 2f3e928b9..4a7dd489c 100644 --- a/nova/api/rackspace/images.py +++ b/nova/api/rackspace/images.py @@ -15,12 +15,15 @@ # License for the specific language governing permissions and limitations # under the License. -import nova.image.service -from nova.api.rackspace import base -from nova.api.rackspace import _id_translator from webob import exc -class Controller(base.Controller): +from nova import wsgi +from nova.api.rackspace import _id_translator +import nova.api.rackspace +import nova.image.service +from nova.api.rackspace import faults + +class Controller(wsgi.Controller): _serialization_metadata = { 'application/xml': { @@ -44,6 +47,7 @@ class Controller(base.Controller): def detail(self, req): """Return all public images in detail.""" data = self._service.index() + data = nova.api.rackspace.limited(data, req) for img in data: img['id'] = self._id_translator.to_rs_id(img['id']) return dict(images=data) @@ -57,14 +61,14 @@ class Controller(base.Controller): def delete(self, req, id): # Only public images are supported for now. - raise exc.HTTPNotFound() + raise faults.Fault(exc.HTTPNotFound()) def create(self, req): # Only public images are supported for now, so a request to # make a backup of a server cannot be supproted. - raise exc.HTTPNotFound() + raise faults.Fault(exc.HTTPNotFound()) def update(self, req, id): # Users may not modify public images, and that's all that # we support for now. - raise exc.HTTPNotFound() + raise faults.Fault(exc.HTTPNotFound()) diff --git a/nova/api/rackspace/ratelimiting/__init__.py b/nova/api/rackspace/ratelimiting/__init__.py new file mode 100644 index 000000000..f843bac0f --- /dev/null +++ b/nova/api/rackspace/ratelimiting/__init__.py @@ -0,0 +1,122 @@ +"""Rate limiting of arbitrary actions.""" + +import httplib +import time +import urllib +import webob.dec +import webob.exc + + +# Convenience constants for the limits dictionary passed to Limiter(). +PER_SECOND = 1 +PER_MINUTE = 60 +PER_HOUR = 60 * 60 +PER_DAY = 60 * 60 * 24 + +class Limiter(object): + + """Class providing rate limiting of arbitrary actions.""" + + def __init__(self, limits): + """Create a rate limiter. + + limits: a dict mapping from action name to a tuple. The tuple contains + the number of times the action may be performed, and the time period + (in seconds) during which the number must not be exceeded for this + action. Example: dict(reboot=(10, ratelimiting.PER_MINUTE)) would + allow 10 'reboot' actions per minute. + """ + self.limits = limits + self._levels = {} + + def perform(self, action_name, username='nobody'): + """Attempt to perform an action by the given username. + + action_name: the string name of the action to perform. This must + be a key in the limits dict passed to the ctor. + + username: an optional string name of the user performing the action. + Each user has her own set of rate limiting counters. Defaults to + 'nobody' (so that if you never specify a username when calling + perform(), a single set of counters will be used.) + + Return None if the action may proceed. If the action may not proceed + because it has been rate limited, return the float number of seconds + until the action would succeed. + """ + # Think of rate limiting as a bucket leaking water at 1cc/second. The + # bucket can hold as many ccs as there are seconds in the rate + # limiting period (e.g. 3600 for per-hour ratelimits), and if you can + # perform N actions in that time, each action fills the bucket by + # 1/Nth of its volume. You may only perform an action if the bucket + # would not overflow. + now = time.time() + key = '%s:%s' % (username, action_name) + last_time_performed, water_level = self._levels.get(key, (now, 0)) + # The bucket leaks 1cc/second. + water_level -= (now - last_time_performed) + if water_level < 0: + water_level = 0 + num_allowed_per_period, period_in_secs = self.limits[action_name] + # Fill the bucket by 1/Nth its capacity, and hope it doesn't overflow. + capacity = period_in_secs + new_level = water_level + (capacity * 1.0 / num_allowed_per_period) + if new_level > capacity: + # Delay this many seconds. + return new_level - capacity + self._levels[key] = (now, new_level) + return None + + +# If one instance of this WSGIApps is unable to handle your load, put a +# sharding app in front that shards by username to one of many backends. + +class WSGIApp(object): + + """Application that tracks rate limits in memory. Send requests to it of + this form: + + POST /limiter/<username>/<urlencoded action> + + and receive a 200 OK, or a 403 Forbidden with an X-Wait-Seconds header + containing the number of seconds to wait before the action would succeed. + """ + + def __init__(self, limiter): + """Create the WSGI application using the given Limiter instance.""" + self.limiter = limiter + + @webob.dec.wsgify + def __call__(self, req): + parts = req.path_info.split('/') + # format: /limiter/<username>/<urlencoded action> + if req.method != 'POST': + raise webob.exc.HTTPMethodNotAllowed() + if len(parts) != 4 or parts[1] != 'limiter': + raise webob.exc.HTTPNotFound() + username = parts[2] + action_name = urllib.unquote(parts[3]) + delay = self.limiter.perform(action_name, username) + if delay: + return webob.exc.HTTPForbidden( + headers={'X-Wait-Seconds': "%.2f" % delay}) + else: + return '' # 200 OK + + +class WSGIAppProxy(object): + + """Limiter lookalike that proxies to a ratelimiting.WSGIApp.""" + + def __init__(self, service_host): + """Creates a proxy pointing to a ratelimiting.WSGIApp at the given + host.""" + self.service_host = service_host + + def perform(self, action, username='nobody'): + conn = httplib.HTTPConnection(self.service_host) + conn.request('POST', '/limiter/%s/%s' % (username, action)) + resp = conn.getresponse() + if resp.status == 200: + return None # no delay + return float(resp.getheader('X-Wait-Seconds')) diff --git a/nova/api/rackspace/ratelimiting/tests.py b/nova/api/rackspace/ratelimiting/tests.py new file mode 100644 index 000000000..4c9510917 --- /dev/null +++ b/nova/api/rackspace/ratelimiting/tests.py @@ -0,0 +1,237 @@ +import httplib +import StringIO +import time +import unittest +import webob + +import nova.api.rackspace.ratelimiting as ratelimiting + +class LimiterTest(unittest.TestCase): + + def setUp(self): + self.limits = { + 'a': (5, ratelimiting.PER_SECOND), + 'b': (5, ratelimiting.PER_MINUTE), + 'c': (5, ratelimiting.PER_HOUR), + 'd': (1, ratelimiting.PER_SECOND), + 'e': (100, ratelimiting.PER_SECOND)} + self.rl = ratelimiting.Limiter(self.limits) + + def exhaust(self, action, times_until_exhausted, **kwargs): + for i in range(times_until_exhausted): + when = self.rl.perform(action, **kwargs) + self.assertEqual(when, None) + num, period = self.limits[action] + delay = period * 1.0 / num + # Verify that we are now thoroughly delayed + for i in range(10): + when = self.rl.perform(action, **kwargs) + self.assertAlmostEqual(when, delay, 2) + + def test_second(self): + self.exhaust('a', 5) + time.sleep(0.2) + self.exhaust('a', 1) + time.sleep(1) + self.exhaust('a', 5) + + def test_minute(self): + self.exhaust('b', 5) + + def test_one_per_period(self): + def allow_once_and_deny_once(): + when = self.rl.perform('d') + self.assertEqual(when, None) + when = self.rl.perform('d') + self.assertAlmostEqual(when, 1, 2) + return when + time.sleep(allow_once_and_deny_once()) + time.sleep(allow_once_and_deny_once()) + allow_once_and_deny_once() + + def test_we_can_go_indefinitely_if_we_spread_out_requests(self): + for i in range(200): + when = self.rl.perform('e') + self.assertEqual(when, None) + time.sleep(0.01) + + def test_users_get_separate_buckets(self): + self.exhaust('c', 5, username='alice') + self.exhaust('c', 5, username='bob') + self.exhaust('c', 5, username='chuck') + self.exhaust('c', 0, username='chuck') + self.exhaust('c', 0, username='bob') + self.exhaust('c', 0, username='alice') + + +class FakeLimiter(object): + """Fake Limiter class that you can tell how to behave.""" + def __init__(self, test): + self._action = self._username = self._delay = None + self.test = test + def mock(self, action, username, delay): + self._action = action + self._username = username + self._delay = delay + def perform(self, action, username): + self.test.assertEqual(action, self._action) + self.test.assertEqual(username, self._username) + return self._delay + + +class WSGIAppTest(unittest.TestCase): + + def setUp(self): + self.limiter = FakeLimiter(self) + self.app = ratelimiting.WSGIApp(self.limiter) + + def test_invalid_methods(self): + requests = [] + for method in ['GET', 'PUT', 'DELETE']: + req = webob.Request.blank('/limits/michael/breakdance', + dict(REQUEST_METHOD=method)) + requests.append(req) + for req in requests: + self.assertEqual(req.get_response(self.app).status_int, 405) + + def test_invalid_urls(self): + requests = [] + for prefix in ['limit', '', 'limiter2', 'limiter/limits', 'limiter/1']: + req = webob.Request.blank('/%s/michael/breakdance' % prefix, + dict(REQUEST_METHOD='POST')) + requests.append(req) + for req in requests: + self.assertEqual(req.get_response(self.app).status_int, 404) + + def verify(self, url, username, action, delay=None): + """Make sure that POSTing to the given url causes the given username + to perform the given action. Make the internal rate limiter return + delay and make sure that the WSGI app returns the correct response. + """ + req = webob.Request.blank(url, dict(REQUEST_METHOD='POST')) + self.limiter.mock(action, username, delay) + resp = req.get_response(self.app) + if not delay: + self.assertEqual(resp.status_int, 200) + else: + self.assertEqual(resp.status_int, 403) + self.assertEqual(resp.headers['X-Wait-Seconds'], "%.2f" % delay) + + def test_good_urls(self): + self.verify('/limiter/michael/hoot', 'michael', 'hoot') + + def test_escaping(self): + self.verify('/limiter/michael/jump%20up', 'michael', 'jump up') + + def test_response_to_delays(self): + self.verify('/limiter/michael/hoot', 'michael', 'hoot', 1) + self.verify('/limiter/michael/hoot', 'michael', 'hoot', 1.56) + self.verify('/limiter/michael/hoot', 'michael', 'hoot', 1000) + + +class FakeHttplibSocket(object): + """a fake socket implementation for httplib.HTTPResponse, trivial""" + + def __init__(self, response_string): + self._buffer = StringIO.StringIO(response_string) + + def makefile(self, _mode, _other): + """Returns the socket's internal buffer""" + return self._buffer + + +class FakeHttplibConnection(object): + """A fake httplib.HTTPConnection + + Requests made via this connection actually get translated and routed into + our WSGI app, we then wait for the response and turn it back into + an httplib.HTTPResponse. + """ + def __init__(self, app, host, is_secure=False): + self.app = app + self.host = host + + def request(self, method, path, data='', headers={}): + req = webob.Request.blank(path) + req.method = method + req.body = data + req.headers = headers + req.host = self.host + # Call the WSGI app, get the HTTP response + resp = str(req.get_response(self.app)) + # For some reason, the response doesn't have "HTTP/1.0 " prepended; I + # guess that's a function the web server usually provides. + resp = "HTTP/1.0 %s" % resp + sock = FakeHttplibSocket(resp) + self.http_response = httplib.HTTPResponse(sock) + self.http_response.begin() + + def getresponse(self): + return self.http_response + + +def wire_HTTPConnection_to_WSGI(host, app): + """Monkeypatches HTTPConnection so that if you try to connect to host, you + are instead routed straight to the given WSGI app. + + After calling this method, when any code calls + + httplib.HTTPConnection(host) + + the connection object will be a fake. Its requests will be sent directly + to the given WSGI app rather than through a socket. + + Code connecting to hosts other than host will not be affected. + + This method may be called multiple times to map different hosts to + different apps. + """ + class HTTPConnectionDecorator(object): + """Wraps the real HTTPConnection class so that when you instantiate + the class you might instead get a fake instance.""" + def __init__(self, wrapped): + self.wrapped = wrapped + def __call__(self, connection_host, *args, **kwargs): + if connection_host == host: + return FakeHttplibConnection(app, host) + else: + return self.wrapped(connection_host, *args, **kwargs) + httplib.HTTPConnection = HTTPConnectionDecorator(httplib.HTTPConnection) + + +class WSGIAppProxyTest(unittest.TestCase): + + def setUp(self): + """Our WSGIAppProxy is going to call across an HTTPConnection to a + WSGIApp running a limiter. The proxy will send input, and the proxy + should receive that same input, pass it to the limiter who gives a + result, and send the expected result back. + + The HTTPConnection isn't real -- it's monkeypatched to point straight + at the WSGIApp. And the limiter isn't real -- it's a fake that + behaves the way we tell it to. + """ + self.limiter = FakeLimiter(self) + app = ratelimiting.WSGIApp(self.limiter) + wire_HTTPConnection_to_WSGI('100.100.100.100:80', app) + self.proxy = ratelimiting.WSGIAppProxy('100.100.100.100:80') + + def test_200(self): + self.limiter.mock('conquer', 'caesar', None) + when = self.proxy.perform('conquer', 'caesar') + self.assertEqual(when, None) + + def test_403(self): + self.limiter.mock('grumble', 'proletariat', 1.5) + when = self.proxy.perform('grumble', 'proletariat') + self.assertEqual(when, 1.5) + + def test_failure(self): + def shouldRaise(): + self.limiter.mock('murder', 'brutus', None) + self.proxy.perform('stab', 'brutus') + self.assertRaises(AssertionError, shouldRaise) + + +if __name__ == '__main__': + unittest.main() diff --git a/nova/api/rackspace/servers.py b/nova/api/rackspace/servers.py index 1815f7523..11efd8aef 100644 --- a/nova/api/rackspace/servers.py +++ b/nova/api/rackspace/servers.py @@ -14,67 +14,286 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. + import time -from nova import db +import webob +from webob import exc + from nova import flags from nova import rpc from nova import utils -from nova.api.rackspace import base +from nova import wsgi +from nova.api import cloud +from nova.api.rackspace import _id_translator +from nova.api.rackspace import context +from nova.api.rackspace import faults +from nova.compute import instance_types +from nova.compute import power_state +import nova.api.rackspace +import nova.image.service FLAGS = flags.FLAGS -class Controller(base.Controller): - entity_name = 'servers' +flags.DEFINE_string('rs_network_manager', 'nova.network.manager.FlatManager', + 'Networking for rackspace') + +def _instance_id_translator(): + """ Helper method for initializing an id translator for Rackspace instance + ids """ + return _id_translator.RackspaceAPIIdTranslator( "instance", 'nova') + +def _image_service(): + """ Helper method for initializing the image id translator """ + service = nova.image.service.ImageService.load() + return (service, _id_translator.RackspaceAPIIdTranslator( + "image", service.__class__.__name__)) + +def _filter_params(inst_dict): + """ Extracts all updatable parameters for a server update request """ + keys = dict(name='name', admin_pass='adminPass') + new_attrs = {} + for k, v in keys.items(): + if inst_dict.has_key(v): + new_attrs[k] = inst_dict[v] + return new_attrs + +def _entity_list(entities): + """ Coerces a list of servers into proper dictionary format """ + return dict(servers=entities) + +def _entity_detail(inst): + """ Maps everything to Rackspace-like attributes for return""" + power_mapping = { + power_state.NOSTATE: 'build', + power_state.RUNNING: 'active', + power_state.BLOCKED: 'active', + power_state.PAUSED: 'suspended', + power_state.SHUTDOWN: 'active', + power_state.SHUTOFF: 'active', + power_state.CRASHED: 'error' + } + inst_dict = {} + + mapped_keys = dict(status='state', imageId='image_id', + flavorId='instance_type', name='server_name', id='id') + + for k, v in mapped_keys.iteritems(): + inst_dict[k] = inst[v] + + inst_dict['status'] = power_mapping[inst_dict['status']] + inst_dict['addresses'] = dict(public=[], private=[]) + inst_dict['metadata'] = {} + inst_dict['hostId'] = '' + + return dict(server=inst_dict) + +def _entity_inst(inst): + """ Filters all model attributes save for id and name """ + return dict(server=dict(id=inst['id'], name=inst['server_name'])) - def index(self, **kwargs): - instances = [] - for inst in db.instance_get_all(None): - instances.append(instance_details(inst)) +class Controller(wsgi.Controller): + """ The Server API controller for the Openstack API """ - def show(self, **kwargs): - instance_id = kwargs['id'] - return db.instance_get(None, instance_id) + _serialization_metadata = { + 'application/xml': { + "attributes": { + "server": [ "id", "imageId", "name", "flavorId", "hostId", + "status", "progress", "progress" ] + } + } + } - def delete(self, **kwargs): - instance_id = kwargs['id'] - instance = db.instance_get(None, instance_id) - if not instance: - raise ServerNotFound("The requested server was not found") - instance.destroy() - return True + def __init__(self, db_driver=None): + if not db_driver: + db_driver = FLAGS.db_driver + self.db_driver = utils.import_object(db_driver) + super(Controller, self).__init__() + + def index(self, req): + """ Returns a list of server names and ids for a given user """ + return self._items(req, entity_maker=_entity_inst) + + def detail(self, req): + """ Returns a list of server details for a given user """ + return self._items(req, entity_maker=_entity_detail) + + def _items(self, req, entity_maker): + """Returns a list of servers for a given user. + + entity_maker - either _entity_detail or _entity_inst + """ + user_id = req.environ['nova.context']['user']['id'] + instance_list = self.db_driver.instance_get_all_by_user(None, user_id) + limited_list = nova.api.rackspace.limited(instance_list, req) + res = [entity_maker(inst)['server'] for inst in limited_list] + return _entity_list(res) + + def show(self, req, id): + """ Returns server details by server id """ + inst_id_trans = _instance_id_translator() + inst_id = inst_id_trans.from_rs_id(id) + + user_id = req.environ['nova.context']['user']['id'] + inst = self.db_driver.instance_get_by_ec2_id(None, inst_id) + if inst: + if inst.user_id == user_id: + return _entity_detail(inst) + raise faults.Fault(exc.HTTPNotFound()) + + def delete(self, req, id): + """ Destroys a server """ + inst_id_trans = _instance_id_translator() + inst_id = inst_id_trans.from_rs_id(id) + + user_id = req.environ['nova.context']['user']['id'] + instance = self.db_driver.instance_get_by_ec2_id(None, inst_id) + if instance and instance['user_id'] == user_id: + self.db_driver.instance_destroy(None, id) + return faults.Fault(exc.HTTPAccepted()) + return faults.Fault(exc.HTTPNotFound()) + + def create(self, req): + """ Creates a new server for a given user """ + + env = self._deserialize(req.body, req) + if not env: + return faults.Fault(exc.HTTPUnprocessableEntity()) + + try: + inst = self._build_server_instance(req, env) + except Exception, e: + return faults.Fault(exc.HTTPUnprocessableEntity()) - def create(self, **kwargs): - inst = self.build_server_instance(kwargs['server']) rpc.cast( FLAGS.compute_topic, { "method": "run_instance", "args": {"instance_id": inst['id']}}) + return _entity_inst(inst) + + def update(self, req, id): + """ Updates the server name or password """ + inst_id_trans = _instance_id_translator() + inst_id = inst_id_trans.from_rs_id(id) + user_id = req.environ['nova.context']['user']['id'] - def update(self, **kwargs): - instance_id = kwargs['id'] - instance = db.instance_get(None, instance_id) - if not instance: - raise ServerNotFound("The requested server was not found") - instance.update(kwargs['server']) - instance.save() + inst_dict = self._deserialize(req.body, req) + + if not inst_dict: + return faults.Fault(exc.HTTPUnprocessableEntity()) - def build_server_instance(self, env): + instance = self.db_driver.instance_get_by_ec2_id(None, inst_id) + if not instance or instance.user_id != user_id: + return faults.Fault(exc.HTTPNotFound()) + + self.db_driver.instance_update(None, id, + _filter_params(inst_dict['server'])) + return faults.Fault(exc.HTTPNoContent()) + + def action(self, req, id): + """ multi-purpose method used to reboot, rebuild, and + resize a server """ + input_dict = self._deserialize(req.body, req) + try: + reboot_type = input_dict['reboot']['type'] + except Exception: + raise faults.Fault(webob.exc.HTTPNotImplemented()) + opaque_id = _instance_id_translator().from_rs_id(id) + cloud.reboot(opaque_id) + + def _build_server_instance(self, req, env): """Build instance data structure and save it to the data store.""" - reservation = utils.generate_uid('r') ltime = time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime()) inst = {} - inst['name'] = env['server']['name'] - inst['image_id'] = env['server']['imageId'] - inst['instance_type'] = env['server']['flavorId'] - inst['user_id'] = env['user']['id'] - inst['project_id'] = env['project']['id'] - inst['reservation_id'] = reservation + + inst_id_trans = _instance_id_translator() + + user_id = req.environ['nova.context']['user']['id'] + + flavor_id = env['server']['flavorId'] + + instance_type, flavor = [(k, v) for k, v in + instance_types.INSTANCE_TYPES.iteritems() + if v['flavorid'] == flavor_id][0] + + image_id = env['server']['imageId'] + + img_service, image_id_trans = _image_service() + + opaque_image_id = image_id_trans.to_rs_id(image_id) + image = img_service.show(opaque_image_id) + + if not image: + raise Exception, "Image not found" + + inst['server_name'] = env['server']['name'] + inst['image_id'] = opaque_image_id + inst['user_id'] = user_id inst['launch_time'] = ltime inst['mac_address'] = utils.generate_mac() - inst_id = db.instance_create(None, inst)['id'] - address = self.network_manager.allocate_fixed_ip(None, inst_id) - # key_data, key_name, ami_launch_index - # TODO(todd): key data or root password - inst.save() + inst['project_id'] = user_id + + inst['state_description'] = 'scheduling' + inst['kernel_id'] = image.get('kernelId', FLAGS.default_kernel) + inst['ramdisk_id'] = image.get('ramdiskId', FLAGS.default_ramdisk) + inst['reservation_id'] = utils.generate_uid('r') + + inst['display_name'] = env['server']['name'] + inst['display_description'] = env['server']['name'] + + #TODO(dietz) this may be ill advised + key_pair_ref = self.db_driver.key_pair_get_all_by_user( + None, user_id)[0] + + inst['key_data'] = key_pair_ref['public_key'] + inst['key_name'] = key_pair_ref['name'] + + #TODO(dietz) stolen from ec2 api, see TODO there + inst['security_group'] = 'default' + + # Flavor related attributes + inst['instance_type'] = instance_type + inst['memory_mb'] = flavor['memory_mb'] + inst['vcpus'] = flavor['vcpus'] + inst['local_gb'] = flavor['local_gb'] + + ref = self.db_driver.instance_create(None, inst) + inst['id'] = inst_id_trans.to_rs_id(ref.ec2_id) + + # TODO(dietz): this isn't explicitly necessary, but the networking + # calls depend on an object with a project_id property, and therefore + # should be cleaned up later + api_context = context.APIRequestContext(user_id) + + inst['mac_address'] = utils.generate_mac() + + #TODO(dietz) is this necessary? + inst['launch_index'] = 0 + + inst['hostname'] = ref.ec2_id + self.db_driver.instance_update(None, inst['id'], inst) + + network_manager = utils.import_object(FLAGS.rs_network_manager) + address = network_manager.allocate_fixed_ip(api_context, + inst['id']) + + # TODO(vish): This probably should be done in the scheduler + # network is setup when host is assigned + network_topic = self._get_network_topic(user_id) + rpc.call(network_topic, + {"method": "setup_fixed_ip", + "args": {"context": None, + "address": address}}) return inst + + def _get_network_topic(self, user_id): + """Retrieves the network host for a project""" + network_ref = self.db_driver.project_get_network(None, + user_id) + host = network_ref['host'] + if not host: + host = rpc.call(FLAGS.network_topic, + {"method": "set_network_host", + "args": {"context": None, + "project_id": user_id}}) + return self.db_driver.queue_get_for(None, FLAGS.network_topic, host) diff --git a/nova/api/rackspace/sharedipgroups.py b/nova/api/rackspace/sharedipgroups.py index 986f11434..4d2d0ede1 100644 --- a/nova/api/rackspace/sharedipgroups.py +++ b/nova/api/rackspace/sharedipgroups.py @@ -15,4 +15,6 @@ # License for the specific language governing permissions and limitations # under the License. -class Controller(object): pass +from nova import wsgi + +class Controller(wsgi.Controller): pass diff --git a/nova/auth/dbdriver.py b/nova/auth/dbdriver.py new file mode 100644 index 000000000..09d15018b --- /dev/null +++ b/nova/auth/dbdriver.py @@ -0,0 +1,236 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Auth driver using the DB as its backend. +""" + +import logging +import sys + +from nova import exception +from nova import db + + +class DbDriver(object): + """DB Auth driver + + Defines enter and exit and therefore supports the with/as syntax. + """ + + def __init__(self): + """Imports the LDAP module""" + pass + db + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + + def get_user(self, uid): + """Retrieve user by id""" + return self._db_user_to_auth_user(db.user_get({}, uid)) + + def get_user_from_access_key(self, access): + """Retrieve user by access key""" + return self._db_user_to_auth_user(db.user_get_by_access_key({}, access)) + + def get_project(self, pid): + """Retrieve project by id""" + return self._db_project_to_auth_projectuser(db.project_get({}, pid)) + + def get_users(self): + """Retrieve list of users""" + return [self._db_user_to_auth_user(user) for user in db.user_get_all({})] + + def get_projects(self, uid=None): + """Retrieve list of projects""" + if uid: + result = db.project_get_by_user({}, uid) + else: + result = db.project_get_all({}) + return [self._db_project_to_auth_projectuser(proj) for proj in result] + + def create_user(self, name, access_key, secret_key, is_admin): + """Create a user""" + values = { 'id' : name, + 'access_key' : access_key, + 'secret_key' : secret_key, + 'is_admin' : is_admin + } + try: + user_ref = db.user_create({}, values) + return self._db_user_to_auth_user(user_ref) + except exception.Duplicate, e: + raise exception.Duplicate('User %s already exists' % name) + + def _db_user_to_auth_user(self, user_ref): + return { 'id' : user_ref['id'], + 'name' : user_ref['id'], + 'access' : user_ref['access_key'], + 'secret' : user_ref['secret_key'], + 'admin' : user_ref['is_admin'] } + + def _db_project_to_auth_projectuser(self, project_ref): + return { 'id' : project_ref['id'], + 'name' : project_ref['name'], + 'project_manager_id' : project_ref['project_manager'], + 'description' : project_ref['description'], + 'member_ids' : [member['id'] for member in project_ref['members']] } + + def create_project(self, name, manager_uid, + description=None, member_uids=None): + """Create a project""" + manager = db.user_get({}, manager_uid) + if not manager: + raise exception.NotFound("Project can't be created because " + "manager %s doesn't exist" % manager_uid) + + # description is a required attribute + if description is None: + description = name + + # First, we ensure that all the given users exist before we go + # on to create the project. This way we won't have to destroy + # the project again because a user turns out to be invalid. + members = set([manager]) + if member_uids != None: + for member_uid in member_uids: + member = db.user_get({}, member_uid) + if not member: + raise exception.NotFound("Project can't be created " + "because user %s doesn't exist" + % member_uid) + members.add(member) + + values = { 'id' : name, + 'name' : name, + 'project_manager' : manager['id'], + 'description': description } + + try: + project = db.project_create({}, values) + except exception.Duplicate: + raise exception.Duplicate("Project can't be created because " + "project %s already exists" % name) + + for member in members: + db.project_add_member({}, project['id'], member['id']) + + # This looks silly, but ensures that the members element has been + # correctly populated + project_ref = db.project_get({}, project['id']) + return self._db_project_to_auth_projectuser(project_ref) + + def modify_project(self, project_id, manager_uid=None, description=None): + """Modify an existing project""" + if not manager_uid and not description: + return + values = {} + if manager_uid: + manager = db.user_get({}, manager_uid) + if not manager: + raise exception.NotFound("Project can't be modified because " + "manager %s doesn't exist" % + manager_uid) + values['project_manager'] = manager['id'] + if description: + values['description'] = description + + db.project_update({}, project_id, values) + + def add_to_project(self, uid, project_id): + """Add user to project""" + user, project = self._validate_user_and_project(uid, project_id) + db.project_add_member({}, project['id'], user['id']) + + def remove_from_project(self, uid, project_id): + """Remove user from project""" + user, project = self._validate_user_and_project(uid, project_id) + db.project_remove_member({}, project['id'], user['id']) + + def is_in_project(self, uid, project_id): + """Check if user is in project""" + user, project = self._validate_user_and_project(uid, project_id) + return user in project.members + + def has_role(self, uid, role, project_id=None): + """Check if user has role + + If project is specified, it checks for local role, otherwise it + checks for global role + """ + + return role in self.get_user_roles(uid, project_id) + + def add_role(self, uid, role, project_id=None): + """Add role for user (or user and project)""" + if not project_id: + db.user_add_role({}, uid, role) + return + db.user_add_project_role({}, uid, project_id, role) + + def remove_role(self, uid, role, project_id=None): + """Remove role for user (or user and project)""" + if not project_id: + db.user_remove_role({}, uid, role) + return + db.user_remove_project_role({}, uid, project_id, role) + + def get_user_roles(self, uid, project_id=None): + """Retrieve list of roles for user (or user and project)""" + if project_id is None: + roles = db.user_get_roles({}, uid) + return roles + else: + roles = db.user_get_roles_for_project({}, uid, project_id) + return roles + + def delete_user(self, id): + """Delete a user""" + user = db.user_get({}, id) + db.user_delete({}, user['id']) + + def delete_project(self, project_id): + """Delete a project""" + db.project_delete({}, project_id) + + def modify_user(self, uid, access_key=None, secret_key=None, admin=None): + """Modify an existing user""" + if not access_key and not secret_key and admin is None: + return + values = {} + if access_key: + values['access_key'] = access_key + if secret_key: + values['secret_key'] = secret_key + if admin is not None: + values['is_admin'] = admin + db.user_update({}, uid, values) + + def _validate_user_and_project(self, user_id, project_id): + user = db.user_get({}, user_id) + if not user: + raise exception.NotFound('User "%s" not found' % user_id) + project = db.project_get({}, project_id) + if not project: + raise exception.NotFound('Project "%s" not found' % project_id) + return user, project + diff --git a/nova/auth/fakeldap.py b/nova/auth/fakeldap.py index bfc3433c5..2791dfde6 100644 --- a/nova/auth/fakeldap.py +++ b/nova/auth/fakeldap.py @@ -33,6 +33,7 @@ SCOPE_ONELEVEL = 1 # not implemented SCOPE_SUBTREE = 2 MOD_ADD = 0 MOD_DELETE = 1 +MOD_REPLACE = 2 class NO_SUCH_OBJECT(Exception): # pylint: disable-msg=C0103 @@ -175,7 +176,7 @@ class FakeLDAP(object): Args: dn -- a dn attrs -- a list of tuples in the following form: - ([MOD_ADD | MOD_DELETE], attribute, value) + ([MOD_ADD | MOD_DELETE | MOD_REPACE], attribute, value) """ redis = datastore.Redis.instance() @@ -185,6 +186,8 @@ class FakeLDAP(object): values = _from_json(redis.hget(key, k)) if cmd == MOD_ADD: values.append(v) + elif cmd == MOD_REPLACE: + values = [v] else: values.remove(v) values = redis.hset(key, k, _to_json(values)) diff --git a/nova/auth/ldapdriver.py b/nova/auth/ldapdriver.py index 74ba011b5..640ea169e 100644 --- a/nova/auth/ldapdriver.py +++ b/nova/auth/ldapdriver.py @@ -99,13 +99,6 @@ class LdapDriver(object): dn = FLAGS.ldap_user_subtree return self.__to_user(self.__find_object(dn, query)) - def get_key_pair(self, uid, key_name): - """Retrieve key pair by uid and key name""" - dn = 'cn=%s,%s' % (key_name, - self.__uid_to_dn(uid)) - attr = self.__find_object(dn, '(objectclass=novaKeyPair)') - return self.__to_key_pair(uid, attr) - def get_project(self, pid): """Retrieve project by id""" dn = 'cn=%s,%s' % (pid, @@ -119,12 +112,6 @@ class LdapDriver(object): '(objectclass=novaUser)') return [self.__to_user(attr) for attr in attrs] - def get_key_pairs(self, uid): - """Retrieve list of key pairs""" - attrs = self.__find_objects(self.__uid_to_dn(uid), - '(objectclass=novaKeyPair)') - return [self.__to_key_pair(uid, attr) for attr in attrs] - def get_projects(self, uid=None): """Retrieve list of projects""" pattern = '(objectclass=novaProject)' @@ -154,21 +141,6 @@ class LdapDriver(object): self.conn.add_s(self.__uid_to_dn(name), attr) return self.__to_user(dict(attr)) - def create_key_pair(self, uid, key_name, public_key, fingerprint): - """Create a key pair""" - # TODO(vish): possibly refactor this to store keys in their own ou - # and put dn reference in the user object - attr = [ - ('objectclass', ['novaKeyPair']), - ('cn', [key_name]), - ('sshPublicKey', [public_key]), - ('keyFingerprint', [fingerprint]), - ] - self.conn.add_s('cn=%s,%s' % (key_name, - self.__uid_to_dn(uid)), - attr) - return self.__to_key_pair(uid, dict(attr)) - def create_project(self, name, manager_uid, description=None, member_uids=None): """Create a project""" @@ -202,6 +174,24 @@ class LdapDriver(object): self.conn.add_s('cn=%s,%s' % (name, FLAGS.ldap_project_subtree), attr) return self.__to_project(dict(attr)) + def modify_project(self, project_id, manager_uid=None, description=None): + """Modify an existing project""" + if not manager_uid and not description: + return + attr = [] + if manager_uid: + if not self.__user_exists(manager_uid): + raise exception.NotFound("Project can't be modified because " + "manager %s doesn't exist" % + manager_uid) + manager_dn = self.__uid_to_dn(manager_uid) + attr.append((self.ldap.MOD_REPLACE, 'projectManager', manager_dn)) + if description: + attr.append((self.ldap.MOD_REPLACE, 'description', description)) + self.conn.modify_s('cn=%s,%s' % (project_id, + FLAGS.ldap_project_subtree), + attr) + def add_to_project(self, uid, project_id): """Add user to project""" dn = 'cn=%s,%s' % (project_id, FLAGS.ldap_project_subtree) @@ -265,18 +255,8 @@ class LdapDriver(object): """Delete a user""" if not self.__user_exists(uid): raise exception.NotFound("User %s doesn't exist" % uid) - self.__delete_key_pairs(uid) self.__remove_from_all(uid) - self.conn.delete_s('uid=%s,%s' % (uid, - FLAGS.ldap_user_subtree)) - - def delete_key_pair(self, uid, key_name): - """Delete a key pair""" - if not self.__key_pair_exists(uid, key_name): - raise exception.NotFound("Key Pair %s doesn't exist for user %s" % - (key_name, uid)) - self.conn.delete_s('cn=%s,uid=%s,%s' % (key_name, uid, - FLAGS.ldap_user_subtree)) + self.conn.delete_s(self.__uid_to_dn(uid)) def delete_project(self, project_id): """Delete a project""" @@ -284,14 +264,23 @@ class LdapDriver(object): self.__delete_roles(project_dn) self.__delete_group(project_dn) + def modify_user(self, uid, access_key=None, secret_key=None, admin=None): + """Modify an existing project""" + if not access_key and not secret_key and admin is None: + return + attr = [] + if access_key: + attr.append((self.ldap.MOD_REPLACE, 'accessKey', access_key)) + if secret_key: + attr.append((self.ldap.MOD_REPLACE, 'secretKey', secret_key)) + if admin is not None: + attr.append((self.ldap.MOD_REPLACE, 'isAdmin', str(admin).upper())) + self.conn.modify_s(self.__uid_to_dn(uid), attr) + def __user_exists(self, uid): """Check if user exists""" return self.get_user(uid) != None - def __key_pair_exists(self, uid, key_name): - """Check if key pair exists""" - return self.get_key_pair(uid, key_name) != None - def __project_exists(self, project_id): """Check if project exists""" return self.get_project(project_id) != None @@ -341,13 +330,6 @@ class LdapDriver(object): """Check if group exists""" return self.__find_object(dn, '(objectclass=groupOfNames)') != None - def __delete_key_pairs(self, uid): - """Delete all key pairs for user""" - keys = self.get_key_pairs(uid) - if keys != None: - for key in keys: - self.delete_key_pair(uid, key['name']) - @staticmethod def __role_to_dn(role, project_id=None): """Convert role to corresponding dn""" @@ -472,18 +454,6 @@ class LdapDriver(object): 'secret': attr['secretKey'][0], 'admin': (attr['isAdmin'][0] == 'TRUE')} - @staticmethod - def __to_key_pair(owner, attr): - """Convert ldap attributes to KeyPair object""" - if attr == None: - return None - return { - 'id': attr['cn'][0], - 'name': attr['cn'][0], - 'owner_id': owner, - 'public_key': attr['sshPublicKey'][0], - 'fingerprint': attr['keyFingerprint'][0]} - def __to_project(self, attr): """Convert ldap attributes to Project object""" if attr == None: diff --git a/nova/auth/manager.py b/nova/auth/manager.py index d5fbec7c5..ce8a294df 100644 --- a/nova/auth/manager.py +++ b/nova/auth/manager.py @@ -44,7 +44,7 @@ flags.DEFINE_list('allowed_roles', # NOTE(vish): a user with one of these roles will be a superuser and # have access to all api commands flags.DEFINE_list('superuser_roles', ['cloudadmin'], - 'Roles that ignore rbac checking completely') + 'Roles that ignore authorization checking completely') # NOTE(vish): a user with one of these roles will have it for every # project, even if he or she is not a member of the project @@ -69,7 +69,7 @@ flags.DEFINE_string('credential_cert_subject', '/C=US/ST=California/L=MountainView/O=AnsoLabs/' 'OU=NovaDev/CN=%s-%s', 'Subject for certificate for users') -flags.DEFINE_string('auth_driver', 'nova.auth.ldapdriver.FakeLdapDriver', +flags.DEFINE_string('auth_driver', 'nova.auth.dbdriver.DbDriver', 'Driver that auth manager uses') @@ -128,24 +128,6 @@ class User(AuthBase): def is_project_manager(self, project): return AuthManager().is_project_manager(self, project) - def generate_key_pair(self, name): - return AuthManager().generate_key_pair(self.id, name) - - def create_key_pair(self, name, public_key, fingerprint): - return AuthManager().create_key_pair(self.id, - name, - public_key, - fingerprint) - - def get_key_pair(self, name): - return AuthManager().get_key_pair(self.id, name) - - def delete_key_pair(self, name): - return AuthManager().delete_key_pair(self.id, name) - - def get_key_pairs(self): - return AuthManager().get_key_pairs(self.id) - def __repr__(self): return "User('%s', '%s', '%s', '%s', %s)" % (self.id, self.name, @@ -154,29 +136,6 @@ class User(AuthBase): self.admin) -class KeyPair(AuthBase): - """Represents an ssh key returned from the datastore - - Even though this object is named KeyPair, only the public key and - fingerprint is stored. The user's private key is not saved. - """ - - def __init__(self, id, name, owner_id, public_key, fingerprint): - AuthBase.__init__(self) - self.id = id - self.name = name - self.owner_id = owner_id - self.public_key = public_key - self.fingerprint = fingerprint - - def __repr__(self): - return "KeyPair('%s', '%s', '%s', '%s', '%s')" % (self.id, - self.name, - self.owner_id, - self.public_key, - self.fingerprint) - - class Project(AuthBase): """Represents a Project returned from the datastore""" @@ -307,7 +266,7 @@ class AuthManager(object): # NOTE(vish): if we stop using project name as id we need better # logic to find a default project for user - if project_id is '': + if project_id == '': project_id = user.name project = self.get_project(project_id) @@ -345,7 +304,7 @@ class AuthManager(object): return "%s:%s" % (user.access, Project.safe_id(project)) def is_superuser(self, user): - """Checks for superuser status, allowing user to bypass rbac + """Checks for superuser status, allowing user to bypass authorization @type user: User or uid @param user: User to check. @@ -533,6 +492,26 @@ class AuthManager(object): raise return project + def modify_project(self, project, manager_user=None, description=None): + """Modify a project + + @type name: Project or project_id + @param project: The project to modify. + + @type manager_user: User or uid + @param manager_user: This user will be the new project manager. + + @type description: str + @param project: This will be the new description of the project. + + """ + if manager_user: + manager_user = User.safe_id(manager_user) + with self.driver() as drv: + drv.modify_project(Project.safe_id(project), + manager_user, + description) + def add_to_project(self, user, project): """Add user to project""" with self.driver() as drv: @@ -643,67 +622,19 @@ class AuthManager(object): return User(**user_dict) def delete_user(self, user): - """Deletes a user""" - with self.driver() as drv: - drv.delete_user(User.safe_id(user)) - - def generate_key_pair(self, user, key_name): - """Generates a key pair for a user - - Generates a public and private key, stores the public key using the - key_name, and returns the private key and fingerprint. - - @type user: User or uid - @param user: User for which to create key pair. - - @type key_name: str - @param key_name: Name to use for the generated KeyPair. + """Deletes a user - @rtype: tuple (private_key, fingerprint) - @return: A tuple containing the private_key and fingerprint. - """ - # NOTE(vish): generating key pair is slow so check for legal - # creation before creating keypair + Additionally deletes all users key_pairs""" uid = User.safe_id(user) + db.key_pair_destroy_all_by_user(None, uid) with self.driver() as drv: - if not drv.get_user(uid): - raise exception.NotFound("User %s doesn't exist" % user) - if drv.get_key_pair(uid, key_name): - raise exception.Duplicate("The keypair %s already exists" - % key_name) - private_key, public_key, fingerprint = crypto.generate_key_pair() - self.create_key_pair(uid, key_name, public_key, fingerprint) - return private_key, fingerprint - - def create_key_pair(self, user, key_name, public_key, fingerprint): - """Creates a key pair for user""" - with self.driver() as drv: - kp_dict = drv.create_key_pair(User.safe_id(user), - key_name, - public_key, - fingerprint) - if kp_dict: - return KeyPair(**kp_dict) - - def get_key_pair(self, user, key_name): - """Retrieves a key pair for user""" - with self.driver() as drv: - kp_dict = drv.get_key_pair(User.safe_id(user), key_name) - if kp_dict: - return KeyPair(**kp_dict) + drv.delete_user(uid) - def get_key_pairs(self, user): - """Retrieves all key pairs for user""" - with self.driver() as drv: - kp_list = drv.get_key_pairs(User.safe_id(user)) - if not kp_list: - return [] - return [KeyPair(**kp_dict) for kp_dict in kp_list] - - def delete_key_pair(self, user, key_name): - """Deletes a key pair for user""" + def modify_user(self, user, access_key=None, secret_key=None, admin=None): + """Modify credentials for a user""" + uid = User.safe_id(user) with self.driver() as drv: - drv.delete_key_pair(User.safe_id(user), key_name) + drv.modify_user(uid, access_key, secret_key, admin) def get_credentials(self, user, project=None): """Get credential zip for user in project""" diff --git a/nova/auth/rbac.py b/nova/auth/rbac.py deleted file mode 100644 index d157f44b3..000000000 --- a/nova/auth/rbac.py +++ /dev/null @@ -1,69 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 United States Government as represented by the -# Administrator of the National Aeronautics and Space Administration. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -"""Role-based access control decorators to use fpr wrapping other -methods with.""" - -from nova import exception - - -def allow(*roles): - """Allow the given roles access the wrapped function.""" - - def wrap(func): # pylint: disable-msg=C0111 - - def wrapped_func(self, context, *args, - **kwargs): # pylint: disable-msg=C0111 - if context.user.is_superuser(): - return func(self, context, *args, **kwargs) - for role in roles: - if __matches_role(context, role): - return func(self, context, *args, **kwargs) - raise exception.NotAuthorized() - - return wrapped_func - - return wrap - - -def deny(*roles): - """Deny the given roles access the wrapped function.""" - - def wrap(func): # pylint: disable-msg=C0111 - - def wrapped_func(self, context, *args, - **kwargs): # pylint: disable-msg=C0111 - if context.user.is_superuser(): - return func(self, context, *args, **kwargs) - for role in roles: - if __matches_role(context, role): - raise exception.NotAuthorized() - return func(self, context, *args, **kwargs) - - return wrapped_func - - return wrap - - -def __matches_role(context, role): - """Check if a role is allowed.""" - if role == 'all': - return True - if role == 'none': - return False - return context.project.has_role(context.user.id, role) diff --git a/nova/cloudpipe/api.py b/nova/cloudpipe/api.py deleted file mode 100644 index 56aa89834..000000000 --- a/nova/cloudpipe/api.py +++ /dev/null @@ -1,59 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 United States Government as represented by the -# Administrator of the National Aeronautics and Space Administration. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -""" -Tornado REST API Request Handlers for CloudPipe -""" - -import logging -import urllib - -import tornado.web - -from nova import crypto -from nova.auth import manager - - -_log = logging.getLogger("api") -_log.setLevel(logging.DEBUG) - - -class CloudPipeRequestHandler(tornado.web.RequestHandler): - def get(self, path): - path = self.request.path - _log.debug( "Cloudpipe path is %s" % path) - if path.endswith("/getca/"): - self.send_root_ca() - self.finish() - - def get_project_id_from_ip(self, ip): - cc = self.application.controllers['Cloud'] - instance = cc.get_instance_by_ip(ip) - instance['project_id'] - - def send_root_ca(self): - _log.debug( "Getting root ca") - project_id = self.get_project_id_from_ip(self.request.remote_ip) - self.set_header("Content-Type", "text/plain") - self.write(crypto.fetch_ca(project_id)) - - def post(self, *args, **kwargs): - project_id = self.get_project_id_from_ip(self.request.remote_ip) - cert = self.get_argument('cert', '') - self.write(crypto.sign_csr(urllib.unquote(cert), project_id)) - self.finish() diff --git a/nova/cloudpipe/pipelib.py b/nova/cloudpipe/pipelib.py index 2867bcb21..706a175d9 100644 --- a/nova/cloudpipe/pipelib.py +++ b/nova/cloudpipe/pipelib.py @@ -32,7 +32,9 @@ from nova import exception from nova import flags from nova import utils from nova.auth import manager -from nova.endpoint import api +# TODO(eday): Eventually changes these to something not ec2-specific +from nova.api.ec2 import cloud +from nova.api.ec2 import context FLAGS = flags.FLAGS @@ -42,8 +44,8 @@ flags.DEFINE_string('boot_script_template', class CloudPipe(object): - def __init__(self, cloud_controller): - self.controller = cloud_controller + def __init__(self): + self.controller = cloud.CloudController() self.manager = manager.AuthManager() def launch_vpn_instance(self, project_id): @@ -58,9 +60,9 @@ class CloudPipe(object): z.write(FLAGS.boot_script_template,'autorun.sh') z.close() - key_name = self.setup_keypair(project.project_manager_id, project_id) + key_name = self.setup_key_pair(project.project_manager_id, project_id) zippy = open(zippath, "r") - context = api.APIRequestContext(handler=None, user=project.project_manager, project=project) + context = context.APIRequestContext(user=project.project_manager, project=project) reservation = self.controller.run_instances(context, # run instances expects encoded userdata, it is decoded in the get_metadata_call @@ -74,7 +76,7 @@ class CloudPipe(object): security_groups=["vpn-secgroup"]) zippy.close() - def setup_keypair(self, user_id, project_id): + def setup_key_pair(self, user_id, project_id): key_name = '%s%s' % (project_id, FLAGS.vpn_key_suffix) try: private_key, fingerprint = self.manager.generate_key_pair(user_id, key_name) diff --git a/nova/compute/manager.py b/nova/compute/manager.py index cb6434694..8b98a2b0e 100644 --- a/nova/compute/manager.py +++ b/nova/compute/manager.py @@ -67,7 +67,7 @@ class ComputeManager(manager.Manager): def run_instance(self, context, instance_id, **_kwargs): """Launch a new instance with specified options.""" instance_ref = self.db.instance_get(context, instance_id) - if instance_ref['str_id'] in self.driver.list_instances(): + if instance_ref['ec2_id'] in self.driver.list_instances(): raise exception.Error("Instance has already been created") logging.debug("instance %s: starting...", instance_id) project_id = instance_ref['project_id'] @@ -129,7 +129,7 @@ class ComputeManager(manager.Manager): raise exception.Error( 'trying to reboot a non-running' 'instance: %s (state: %s excepted: %s)' % - (instance_ref['str_id'], + (instance_ref['ec2_id'], instance_ref['state'], power_state.RUNNING)) @@ -158,7 +158,7 @@ class ComputeManager(manager.Manager): instance_ref = self.db.instance_get(context, instance_id) dev_path = yield self.volume_manager.setup_compute_volume(context, volume_id) - yield self.driver.attach_volume(instance_ref['str_id'], + yield self.driver.attach_volume(instance_ref['ec2_id'], dev_path, mountpoint) self.db.volume_attached(context, volume_id, instance_id, mountpoint) @@ -173,7 +173,7 @@ class ComputeManager(manager.Manager): volume_id) instance_ref = self.db.instance_get(context, instance_id) volume_ref = self.db.volume_get(context, volume_id) - self.driver.detach_volume(instance_ref['str_id'], - volume_ref['mountpoint']) + yield self.driver.detach_volume(instance_ref['ec2_id'], + volume_ref['mountpoint']) self.db.volume_detached(context, volume_id) defer.returnValue(True) diff --git a/nova/crypto.py b/nova/crypto.py index b05548ea1..1c6fe57ad 100644 --- a/nova/crypto.py +++ b/nova/crypto.py @@ -18,7 +18,7 @@ """ Wrappers around standard crypto, including root and intermediate CAs, -SSH keypairs and x509 certificates. +SSH key_pairs and x509 certificates. """ import base64 diff --git a/nova/db/api.py b/nova/db/api.py index 9f6ff99c3..a6d1f405a 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -161,20 +161,20 @@ def floating_ip_get_all(context): def floating_ip_get_all_by_host(context, host): - """Get all floating ips.""" + """Get all floating ips by host.""" return IMPL.floating_ip_get_all_by_host(context, host) +def floating_ip_get_all_by_project(context, project_id): + """Get all floating ips by project.""" + return IMPL.floating_ip_get_all_by_project(context, project_id) + + def floating_ip_get_by_address(context, address): """Get a floating ip by address or raise if it doesn't exist.""" return IMPL.floating_ip_get_by_address(context, address) -def floating_ip_get_instance(context, address): - """Get an instance for a floating ip by address.""" - return IMPL.floating_ip_get_instance(context, address) - - #################### @@ -204,6 +204,11 @@ def fixed_ip_disassociate(context, address): return IMPL.fixed_ip_disassociate(context, address) +def fixed_ip_disassociate_all_by_timeout(context, host, time): + """Disassociate old fixed ips from host""" + return IMPL.fixed_ip_disassociate_all_by_timeout(context, host, time) + + def fixed_ip_get_by_address(context, address): """Get a fixed ip by address or raise if it does not exist.""" return IMPL.fixed_ip_get_by_address(context, address) @@ -251,15 +256,18 @@ def instance_get_all(context): """Get all instances.""" return IMPL.instance_get_all(context) +def instance_get_all_by_user(context, user_id): + """Get all instances.""" + return IMPL.instance_get_all(context, user_id) -def instance_get_by_project(context, project_id): +def instance_get_all_by_project(context, project_id): """Get all instance belonging to a project.""" - return IMPL.instance_get_by_project(context, project_id) + return IMPL.instance_get_all_by_project(context, project_id) -def instance_get_by_reservation(context, reservation_id): +def instance_get_all_by_reservation(context, reservation_id): """Get all instance belonging to a reservation.""" - return IMPL.instance_get_by_reservation(context, reservation_id) + return IMPL.instance_get_all_by_reservation(context, reservation_id) def instance_get_fixed_address(context, instance_id): @@ -272,9 +280,9 @@ def instance_get_floating_address(context, instance_id): return IMPL.instance_get_floating_address(context, instance_id) -def instance_get_by_str(context, str_id): - """Get an instance by string id.""" - return IMPL.instance_get_by_str(context, str_id) +def instance_get_by_ec2_id(context, ec2_id): + """Get an instance by ec2 id.""" + return IMPL.instance_get_by_ec2_id(context, ec2_id) def instance_is_vpn(context, instance_id): @@ -296,6 +304,34 @@ def instance_update(context, instance_id, values): return IMPL.instance_update(context, instance_id, values) +################### + + +def key_pair_create(context, values): + """Create a key_pair from the values dictionary.""" + return IMPL.key_pair_create(context, values) + + +def key_pair_destroy(context, user_id, name): + """Destroy the key_pair or raise if it does not exist.""" + return IMPL.key_pair_destroy(context, user_id, name) + + +def key_pair_destroy_all_by_user(context, user_id): + """Destroy all key_pairs by user.""" + return IMPL.key_pair_destroy_all_by_user(context, user_id) + + +def key_pair_get(context, user_id, name): + """Get a key_pair or raise if it does not exist.""" + return IMPL.key_pair_get(context, user_id, name) + + +def key_pair_get_all_by_user(context, user_id): + """Get all key_pairs by user.""" + return IMPL.key_pair_get_all_by_user(context, user_id) + + #################### @@ -365,9 +401,12 @@ def network_index_count(context): return IMPL.network_index_count(context) -def network_index_create(context, values): - """Create a network index from the values dict""" - return IMPL.network_index_create(context, values) +def network_index_create_safe(context, values): + """Create a network index from the values dict + + The index is not returned. If the create violates the unique + constraints because the index already exists, no exception is raised.""" + return IMPL.network_index_create_safe(context, values) def network_set_cidr(context, network_id, cidr): @@ -420,6 +459,21 @@ def export_device_create(context, values): ################### +def auth_destroy_token(context, token): + """Destroy an auth token""" + return IMPL.auth_destroy_token(context, token) + +def auth_get_token(context, token_hash): + """Retrieves a token given the hash representing it""" + return IMPL.auth_get_token(context, token_hash) + +def auth_create_token(context, token): + """Creates a new token""" + return IMPL.auth_create_token(context, token_hash, token) + + +################### + def quota_create(context, values): """Create a quota from the values dictionary.""" @@ -489,14 +543,14 @@ def volume_get_instance(context, volume_id): return IMPL.volume_get_instance(context, volume_id) -def volume_get_by_project(context, project_id): +def volume_get_all_by_project(context, project_id): """Get all volumes belonging to a project.""" - return IMPL.volume_get_by_project(context, project_id) + return IMPL.volume_get_all_by_project(context, project_id) -def volume_get_by_str(context, str_id): - """Get a volume by string id.""" - return IMPL.volume_get_by_str(context, str_id) +def volume_get_by_ec2_id(context, ec2_id): + """Get a volume by ec2 id.""" + return IMPL.volume_get_by_ec2_id(context, ec2_id) def volume_get_shelf_and_blade(context, volume_id): @@ -511,3 +565,122 @@ def volume_update(context, volume_id, values): """ return IMPL.volume_update(context, volume_id, values) + + +################### + + +def user_get(context, id): + """Get user by id""" + return IMPL.user_get(context, id) + + +def user_get_by_uid(context, uid): + """Get user by uid""" + return IMPL.user_get_by_uid(context, uid) + + +def user_get_by_access_key(context, access_key): + """Get user by access key""" + return IMPL.user_get_by_access_key(context, access_key) + + +def user_create(context, values): + """Create a new user""" + return IMPL.user_create(context, values) + + +def user_delete(context, id): + """Delete a user""" + return IMPL.user_delete(context, id) + + +def user_get_all(context): + """Create a new user""" + return IMPL.user_get_all(context) + + +def user_add_role(context, user_id, role): + """Add another global role for user""" + return IMPL.user_add_role(context, user_id, role) + + +def user_remove_role(context, user_id, role): + """Remove global role from user""" + return IMPL.user_remove_role(context, user_id, role) + + +def user_get_roles(context, user_id): + """Get global roles for user""" + return IMPL.user_get_roles(context, user_id) + + +def user_add_project_role(context, user_id, project_id, role): + """Add project role for user""" + return IMPL.user_add_project_role(context, user_id, project_id, role) + + +def user_remove_project_role(context, user_id, project_id, role): + """Remove project role from user""" + return IMPL.user_remove_project_role(context, user_id, project_id, role) + + +def user_get_roles_for_project(context, user_id, project_id): + """Return list of roles a user holds on project""" + return IMPL.user_get_roles_for_project(context, user_id, project_id) + + +def user_update(context, user_id, values): + """Update user""" + return IMPL.user_update(context, user_id, values) + + +def project_get(context, id): + """Get project by id""" + return IMPL.project_get(context, id) + + +def project_create(context, values): + """Create a new project""" + return IMPL.project_create(context, values) + + +def project_add_member(context, project_id, user_id): + """Add user to project""" + return IMPL.project_add_member(context, project_id, user_id) + + +def project_get_all(context): + """Get all projects""" + return IMPL.project_get_all(context) + + +def project_get_by_user(context, user_id): + """Get all projects of which the given user is a member""" + return IMPL.project_get_by_user(context, user_id) + + +def project_remove_member(context, project_id, user_id): + """Remove the given user from the given project""" + return IMPL.project_remove_member(context, project_id, user_id) + + +def project_update(context, project_id, values): + """Update Remove the given user from the given project""" + return IMPL.project_update(context, project_id, values) + + +def project_delete(context, project_id): + """Delete project""" + return IMPL.project_delete(context, project_id) + + +################### + + +def host_get_networks(context, host): + """Return all networks for which the given host is the designated + network host + """ + return IMPL.host_get_networks(context, host) + diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index d612fe669..e0c6a34b8 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -19,62 +19,140 @@ Implementation of SQLAlchemy backend """ +import warnings + from nova import db from nova import exception from nova import flags +from nova import utils from nova.db.sqlalchemy import models from nova.db.sqlalchemy.session import get_session from sqlalchemy import or_ -from sqlalchemy.orm import joinedload_all -from sqlalchemy.sql import func +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import joinedload, joinedload_all +from sqlalchemy.sql import exists, func FLAGS = flags.FLAGS -# NOTE(vish): disabling docstring pylint because the docstrings are -# in the interface definition -# pylint: disable-msg=C0111 -def _deleted(context): - """Calculates whether to include deleted objects based on context. +def is_admin_context(context): + """Indicates if the request context is an administrator.""" + if not context: + warnings.warn('Use of empty request context is deprecated', + DeprecationWarning) + return True + return context.is_admin + + +def is_user_context(context): + """Indicates if the request context is a normal user.""" + if not context: + return False + if not context.user or not context.project: + return False + return True + + +def authorize_project_context(context, project_id): + """Ensures that the request context has permission to access the + given project. + """ + if is_user_context(context): + if not context.project: + raise exception.NotAuthorized() + elif context.project.id != project_id: + raise exception.NotAuthorized() + - Currently just looks for a flag called deleted in the context dict. +def authorize_user_context(context, user_id): + """Ensures that the request context has permission to access the + given user. """ - if not hasattr(context, 'get'): + if is_user_context(context): + if not context.user: + raise exception.NotAuthorized() + elif context.user.id != user_id: + raise exception.NotAuthorized() + + +def can_read_deleted(context): + """Indicates if the context has access to deleted objects.""" + if not context: return False - return context.get('deleted', False) + return context.read_deleted -################### +def require_admin_context(f): + """Decorator used to indicate that the method requires an + administrator context. + """ + def wrapper(*args, **kwargs): + if not is_admin_context(args[0]): + raise exception.NotAuthorized() + return f(*args, **kwargs) + return wrapper + + +def require_context(f): + """Decorator used to indicate that the method requires either + an administrator or normal user context. + """ + def wrapper(*args, **kwargs): + if not is_admin_context(args[0]) and not is_user_context(args[0]): + raise exception.NotAuthorized() + return f(*args, **kwargs) + return wrapper +################### + +@require_admin_context def service_destroy(context, service_id): session = get_session() with session.begin(): - service_ref = models.Service.find(service_id, session=session) + service_ref = service_get(context, service_id, session=session) service_ref.delete(session=session) -def service_get(_context, service_id): - return models.Service.find(service_id) +@require_admin_context +def service_get(context, service_id, session=None): + if not session: + session = get_session() + + result = session.query(models.Service + ).filter_by(id=service_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + + if not result: + raise exception.NotFound('No service for id %s' % service_id) + + return result + +@require_admin_context def service_get_all_by_topic(context, topic): session = get_session() return session.query(models.Service ).filter_by(deleted=False + ).filter_by(disabled=False ).filter_by(topic=topic ).all() -def _service_get_all_topic_subquery(_context, session, topic, subq, label): +@require_admin_context +def _service_get_all_topic_subquery(context, session, topic, subq, label): sort_value = getattr(subq.c, label) return session.query(models.Service, func.coalesce(sort_value, 0) ).filter_by(topic=topic ).filter_by(deleted=False + ).filter_by(disabled=False ).outerjoin((subq, models.Service.host == subq.c.host) ).order_by(sort_value ).all() +@require_admin_context def service_get_all_compute_sorted(context): session = get_session() with session.begin(): @@ -99,6 +177,7 @@ def service_get_all_compute_sorted(context): label) +@require_admin_context def service_get_all_network_sorted(context): session = get_session() with session.begin(): @@ -116,6 +195,7 @@ def service_get_all_network_sorted(context): label) +@require_admin_context def service_get_all_volume_sorted(context): session = get_session() with session.begin(): @@ -133,11 +213,22 @@ def service_get_all_volume_sorted(context): label) -def service_get_by_args(_context, host, binary): - return models.Service.find_by_args(host, binary) +@require_admin_context +def service_get_by_args(context, host, binary): + session = get_session() + result = session.query(models.Service + ).filter_by(host=host + ).filter_by(binary=binary + ).filter_by(deleted=can_read_deleted(context) + ).first() + if not result: + raise exception.NotFound('No service for %s, %s' % (host, binary)) + + return result -def service_create(_context, values): +@require_admin_context +def service_create(context, values): service_ref = models.Service() for (key, value) in values.iteritems(): service_ref[key] = value @@ -145,10 +236,11 @@ def service_create(_context, values): return service_ref -def service_update(_context, service_id, values): +@require_admin_context +def service_update(context, service_id, values): session = get_session() with session.begin(): - service_ref = models.Service.find(service_id, session=session) + service_ref = service_get(context, service_id, session=session) for (key, value) in values.iteritems(): service_ref[key] = value service_ref.save(session=session) @@ -157,12 +249,15 @@ def service_update(_context, service_id, values): ################### -def floating_ip_allocate_address(_context, host, project_id): +@require_context +def floating_ip_allocate_address(context, host, project_id): + authorize_project_context(context, project_id) session = get_session() with session.begin(): floating_ip_ref = session.query(models.FloatingIp ).filter_by(host=host ).filter_by(fixed_ip_id=None + ).filter_by(project_id=None ).filter_by(deleted=False ).with_lockmode('update' ).first() @@ -175,7 +270,8 @@ def floating_ip_allocate_address(_context, host, project_id): return floating_ip_ref['address'] -def floating_ip_create(_context, values): +@require_context +def floating_ip_create(context, values): floating_ip_ref = models.FloatingIp() for (key, value) in values.iteritems(): floating_ip_ref[key] = value @@ -183,7 +279,9 @@ def floating_ip_create(_context, values): return floating_ip_ref['address'] -def floating_ip_count_by_project(_context, project_id): +@require_context +def floating_ip_count_by_project(context, project_id): + authorize_project_context(context, project_id) session = get_session() return session.query(models.FloatingIp ).filter_by(project_id=project_id @@ -191,39 +289,53 @@ def floating_ip_count_by_project(_context, project_id): ).count() -def floating_ip_fixed_ip_associate(_context, floating_address, fixed_address): +@require_context +def floating_ip_fixed_ip_associate(context, floating_address, fixed_address): session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(floating_address, - session=session) - fixed_ip_ref = models.FixedIp.find_by_str(fixed_address, - session=session) + # TODO(devcamcar): How to ensure floating_id belongs to user? + floating_ip_ref = floating_ip_get_by_address(context, + floating_address, + session=session) + fixed_ip_ref = fixed_ip_get_by_address(context, + fixed_address, + session=session) floating_ip_ref.fixed_ip = fixed_ip_ref floating_ip_ref.save(session=session) -def floating_ip_deallocate(_context, address): +@require_context +def floating_ip_deallocate(context, address): session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) + # TODO(devcamcar): How to ensure floating id belongs to user? + floating_ip_ref = floating_ip_get_by_address(context, + address, + session=session) floating_ip_ref['project_id'] = None floating_ip_ref.save(session=session) -def floating_ip_destroy(_context, address): +@require_context +def floating_ip_destroy(context, address): session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) + # TODO(devcamcar): Ensure address belongs to user. + floating_ip_ref = get_floating_ip_by_address(context, + address, + session=session) floating_ip_ref.delete(session=session) -def floating_ip_disassociate(_context, address): +@require_context +def floating_ip_disassociate(context, address): session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) + # TODO(devcamcar): Ensure address belongs to user. + # Does get_floating_ip_by_address handle this? + floating_ip_ref = floating_ip_get_by_address(context, + address, + session=session) fixed_ip_ref = floating_ip_ref.fixed_ip if fixed_ip_ref: fixed_ip_address = fixed_ip_ref['address'] @@ -234,7 +346,8 @@ def floating_ip_disassociate(_context, address): return fixed_ip_address -def floating_ip_get_all(_context): +@require_admin_context +def floating_ip_get_all(context): session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -242,7 +355,8 @@ def floating_ip_get_all(_context): ).all() -def floating_ip_get_all_by_host(_context, host): +@require_admin_context +def floating_ip_get_all_by_host(context, host): session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -250,24 +364,42 @@ def floating_ip_get_all_by_host(_context, host): ).filter_by(deleted=False ).all() -def floating_ip_get_by_address(_context, address): - return models.FloatingIp.find_by_str(address) - -def floating_ip_get_instance(_context, address): +@require_context +def floating_ip_get_all_by_project(context, project_id): + authorize_project_context(context, project_id) session = get_session() - with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) - return floating_ip_ref.fixed_ip.instance + return session.query(models.FloatingIp + ).options(joinedload_all('fixed_ip.instance') + ).filter_by(project_id=project_id + ).filter_by(deleted=False + ).all() + + +@require_context +def floating_ip_get_by_address(context, address, session=None): + # TODO(devcamcar): Ensure the address belongs to user. + if not session: + session = get_session() + + result = session.query(models.FloatingIp + ).filter_by(address=address + ).filter_by(deleted=can_read_deleted(context) + ).first() + if not result: + raise exception.NotFound('No fixed ip for address %s' % address) + + return result ################### -def fixed_ip_associate(_context, address, instance_id): +@require_context +def fixed_ip_associate(context, address, instance_id): session = get_session() with session.begin(): + instance = instance_get(context, instance_id, session=session) fixed_ip_ref = session.query(models.FixedIp ).filter_by(address=address ).filter_by(deleted=False @@ -278,12 +410,12 @@ def fixed_ip_associate(_context, address, instance_id): # then this has concurrency issues if not fixed_ip_ref: raise db.NoMoreAddresses() - fixed_ip_ref.instance = models.Instance.find(instance_id, - session=session) + fixed_ip_ref.instance = instance session.add(fixed_ip_ref) -def fixed_ip_associate_pool(_context, network_id, instance_id): +@require_admin_context +def fixed_ip_associate_pool(context, network_id, instance_id): session = get_session() with session.begin(): network_or_none = or_(models.FixedIp.network_id == network_id, @@ -300,14 +432,17 @@ def fixed_ip_associate_pool(_context, network_id, instance_id): if not fixed_ip_ref: raise db.NoMoreAddresses() if not fixed_ip_ref.network: - fixed_ip_ref.network = models.Network.find(network_id, - session=session) - fixed_ip_ref.instance = models.Instance.find(instance_id, - session=session) + fixed_ip_ref.network = network_get(context, + network_id, + session=session) + fixed_ip_ref.instance = instance_get(context, + instance_id, + session=session) session.add(fixed_ip_ref) return fixed_ip_ref['address'] +@require_context def fixed_ip_create(_context, values): fixed_ip_ref = models.FixedIp() for (key, value) in values.iteritems(): @@ -316,34 +451,72 @@ def fixed_ip_create(_context, values): return fixed_ip_ref['address'] -def fixed_ip_disassociate(_context, address): +@require_context +def fixed_ip_disassociate(context, address): session = get_session() with session.begin(): - fixed_ip_ref = models.FixedIp.find_by_str(address, session=session) + fixed_ip_ref = fixed_ip_get_by_address(context, + address, + session=session) fixed_ip_ref.instance = None fixed_ip_ref.save(session=session) -def fixed_ip_get_by_address(_context, address): - return models.FixedIp.find_by_str(address) +@require_admin_context +def fixed_ip_disassociate_all_by_timeout(_context, host, time): + session = get_session() + # NOTE(vish): The nested select is because sqlite doesn't support + # JOINs in UPDATEs. + result = session.execute('UPDATE fixed_ips SET instance_id = NULL, ' + 'leased = 0 ' + 'WHERE network_id IN (SELECT id FROM networks ' + 'WHERE host = :host) ' + 'AND updated_at < :time ' + 'AND instance_id IS NOT NULL ' + 'AND allocated = 0', + {'host': host, + 'time': time.isoformat()}) + return result.rowcount + + +@require_context +def fixed_ip_get_by_address(context, address, session=None): + if not session: + session = get_session() + result = session.query(models.FixedIp + ).filter_by(address=address + ).filter_by(deleted=can_read_deleted(context) + ).options(joinedload('network') + ).options(joinedload('instance') + ).first() + if not result: + raise exception.NotFound('No floating ip for address %s' % address) + if is_user_context(context): + authorize_project_context(context, result.instance.project_id) -def fixed_ip_get_instance(_context, address): - session = get_session() - with session.begin(): - return models.FixedIp.find_by_str(address, session=session).instance + return result -def fixed_ip_get_network(_context, address): - session = get_session() - with session.begin(): - return models.FixedIp.find_by_str(address, session=session).network +@require_context +def fixed_ip_get_instance(context, address): + fixed_ip_ref = fixed_ip_get_by_address(context, address) + return fixed_ip_ref.instance + + +@require_admin_context +def fixed_ip_get_network(context, address): + fixed_ip_ref = fixed_ip_get_by_address(context, address) + return fixed_ip_ref.network -def fixed_ip_update(_context, address, values): +@require_context +def fixed_ip_update(context, address, values): session = get_session() with session.begin(): - fixed_ip_ref = models.FixedIp.find_by_str(address, session=session) + fixed_ip_ref = fixed_ip_get_by_address(context, + address, + session=session) for (key, value) in values.iteritems(): fixed_ip_ref[key] = value fixed_ip_ref.save(session=session) @@ -352,15 +525,24 @@ def fixed_ip_update(_context, address, values): ################### -def instance_create(_context, values): +@require_context +def instance_create(context, values): instance_ref = models.Instance() for (key, value) in values.iteritems(): instance_ref[key] = value - instance_ref.save() + + session = get_session() + with session.begin(): + while instance_ref.ec2_id == None: + ec2_id = utils.generate_uid(instance_ref.__prefix__) + if not instance_ec2_id_exists(context, ec2_id, session=session): + instance_ref.ec2_id = ec2_id + instance_ref.save(session=session) return instance_ref -def instance_data_get_for_project(_context, project_id): +@require_admin_context +def instance_data_get_for_project(context, project_id): session = get_session() result = session.query(func.count(models.Instance.id), func.sum(models.Instance.vcpus) @@ -371,60 +553,130 @@ def instance_data_get_for_project(_context, project_id): return (result[0] or 0, result[1] or 0) -def instance_destroy(_context, instance_id): +@require_context +def instance_destroy(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) instance_ref.delete(session=session) -def instance_get(context, instance_id): - return models.Instance.find(instance_id, deleted=_deleted(context)) +@require_context +def instance_get(context, instance_id, session=None): + if not session: + session = get_session() + result = None + if is_admin_context(context): + result = session.query(models.Instance + ).filter_by(id=instance_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Instance + ).filter_by(project_id=context.project.id + ).filter_by(id=instance_id + ).filter_by(deleted=False + ).first() + if not result: + raise exception.NotFound('No instance for id %s' % instance_id) + return result + + +@require_admin_context def instance_get_all(context): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() -def instance_get_by_project(context, project_id): +@require_admin_context +def instance_get_all_by_user(context, user_id): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(project_id=project_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=can_read_deleted(context) + ).filter_by(user_id=user_id ).all() -def instance_get_by_reservation(_context, reservation_id): +@require_context +def instance_get_all_by_project(context, project_id): + authorize_project_context(context, project_id) + session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(reservation_id=reservation_id - ).filter_by(deleted=False + ).filter_by(project_id=project_id + ).filter_by(deleted=can_read_deleted(context) ).all() -def instance_get_by_str(context, str_id): - return models.Instance.find_by_str(str_id, deleted=_deleted(context)) +@require_context +def instance_get_all_by_reservation(context, reservation_id): + session = get_session() + + if is_admin_context(context): + return session.query(models.Instance + ).options(joinedload_all('fixed_ip.floating_ips') + ).filter_by(reservation_id=reservation_id + ).filter_by(deleted=can_read_deleted(context) + ).all() + elif is_user_context(context): + return session.query(models.Instance + ).options(joinedload_all('fixed_ip.floating_ips') + ).filter_by(project_id=context.project.id + ).filter_by(reservation_id=reservation_id + ).filter_by(deleted=False + ).all() + + +@require_context +def instance_get_by_ec2_id(context, ec2_id): + session = get_session() + + if is_admin_context(context): + result = session.query(models.Instance + ).filter_by(ec2_id=ec2_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Instance + ).filter_by(project_id=context.project.id + ).filter_by(ec2_id=ec2_id + ).filter_by(deleted=False + ).first() + if not result: + raise exception.NotFound('Instance %s not found' % (ec2_id)) + + return result + +@require_context +def instance_ec2_id_exists(context, ec2_id, session=None): + if not session: + session = get_session() + return session.query(exists().where(models.Instance.id==ec2_id)).one()[0] -def instance_get_fixed_address(_context, instance_id): + +@require_context +def instance_get_fixed_address(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) if not instance_ref.fixed_ip: return None return instance_ref.fixed_ip['address'] -def instance_get_floating_address(_context, instance_id): +@require_context +def instance_get_floating_address(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) if not instance_ref.fixed_ip: return None if not instance_ref.fixed_ip.floating_ips: @@ -433,12 +685,14 @@ def instance_get_floating_address(_context, instance_id): return instance_ref.fixed_ip.floating_ips[0]['address'] +@require_admin_context def instance_is_vpn(context, instance_id): # TODO(vish): Move this into image code somewhere instance_ref = instance_get(context, instance_id) return instance_ref['image_id'] == FLAGS.vpn_image_id +@require_admin_context def instance_set_state(context, instance_id, state, description=None): # TODO(devcamcar): Move this out of models and into driver from nova.compute import power_state @@ -450,10 +704,11 @@ def instance_set_state(context, instance_id, state, description=None): 'state_description': description}) -def instance_update(_context, instance_id, values): +@require_context +def instance_update(context, instance_id, values): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) for (key, value) in values.iteritems(): instance_ref[key] = value instance_ref.save(session=session) @@ -462,11 +717,75 @@ def instance_update(_context, instance_id, values): ################### -def network_count(_context): - return models.Network.count() +@require_context +def key_pair_create(context, values): + key_pair_ref = models.KeyPair() + for (key, value) in values.iteritems(): + key_pair_ref[key] = value + key_pair_ref.save() + return key_pair_ref + + +@require_context +def key_pair_destroy(context, user_id, name): + authorize_user_context(context, user_id) + session = get_session() + with session.begin(): + key_pair_ref = key_pair_get(context, user_id, name, session=session) + key_pair_ref.delete(session=session) + + +@require_context +def key_pair_destroy_all_by_user(context, user_id): + authorize_user_context(context, user_id) + session = get_session() + with session.begin(): + # TODO(vish): do we have to use sql here? + session.execute('update key_pairs set deleted=1 where user_id=:id', + {'id': user_id}) + + +@require_context +def key_pair_get(context, user_id, name, session=None): + authorize_user_context(context, user_id) + + if not session: + session = get_session() + + result = session.query(models.KeyPair + ).filter_by(user_id=user_id + ).filter_by(name=name + ).filter_by(deleted=can_read_deleted(context) + ).first() + if not result: + raise exception.NotFound('no keypair for user %s, name %s' % + (user_id, name)) + return result + + +@require_context +def key_pair_get_all_by_user(context, user_id): + authorize_user_context(context, user_id) + session = get_session() + return session.query(models.KeyPair + ).filter_by(user_id=user_id + ).filter_by(deleted=False + ).all() + + +################### -def network_count_allocated_ips(_context, network_id): +@require_admin_context +def network_count(context): + session = get_session() + return session.query(models.Network + ).filter_by(deleted=can_read_deleted(context) + ).count() + + +@require_admin_context +def network_count_allocated_ips(context, network_id): session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -475,7 +794,8 @@ def network_count_allocated_ips(_context, network_id): ).count() -def network_count_available_ips(_context, network_id): +@require_admin_context +def network_count_available_ips(context, network_id): session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -485,7 +805,8 @@ def network_count_available_ips(_context, network_id): ).count() -def network_count_reserved_ips(_context, network_id): +@require_admin_context +def network_count_reserved_ips(context, network_id): session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -494,7 +815,8 @@ def network_count_reserved_ips(_context, network_id): ).count() -def network_create(_context, values): +@require_admin_context +def network_create(context, values): network_ref = models.Network() for (key, value) in values.iteritems(): network_ref[key] = value @@ -502,7 +824,8 @@ def network_create(_context, values): return network_ref -def network_destroy(_context, network_id): +@require_admin_context +def network_destroy(context, network_id): session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -520,34 +843,59 @@ def network_destroy(_context, network_id): {'id': network_id}) -def network_get(_context, network_id): - return models.Network.find(network_id) +@require_context +def network_get(context, network_id, session=None): + if not session: + session = get_session() + result = None + + if is_admin_context(context): + result = session.query(models.Network + ).filter_by(id=network_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Network + ).filter_by(project_id=context.project.id + ).filter_by(id=network_id + ).filter_by(deleted=False + ).first() + if not result: + raise exception.NotFound('No network for id %s' % network_id) + + return result # NOTE(vish): pylint complains because of the long method name, but # it fits with the names of the rest of the methods # pylint: disable-msg=C0103 -def network_get_associated_fixed_ips(_context, network_id): +@require_admin_context +def network_get_associated_fixed_ips(context, network_id): session = get_session() return session.query(models.FixedIp + ).options(joinedload_all('instance') ).filter_by(network_id=network_id ).filter(models.FixedIp.instance_id != None ).filter_by(deleted=False ).all() -def network_get_by_bridge(_context, bridge): +@require_admin_context +def network_get_by_bridge(context, bridge): session = get_session() - rv = session.query(models.Network + result = session.query(models.Network ).filter_by(bridge=bridge ).filter_by(deleted=False ).first() - if not rv: + + if not result: raise exception.NotFound('No network for bridge %s' % bridge) - return rv + + return result -def network_get_index(_context, network_id): +@require_admin_context +def network_get_index(context, network_id): session = get_session() with session.begin(): network_index = session.query(models.NetworkIndex @@ -555,48 +903,63 @@ def network_get_index(_context, network_id): ).filter_by(deleted=False ).with_lockmode('update' ).first() + if not network_index: raise db.NoMoreNetworks() - network_index['network'] = models.Network.find(network_id, - session=session) + + network_index['network'] = network_get(context, + network_id, + session=session) session.add(network_index) + return network_index['index'] -def network_index_count(_context): - return models.NetworkIndex.count() +@require_admin_context +def network_index_count(context): + session = get_session() + return session.query(models.NetworkIndex + ).filter_by(deleted=can_read_deleted(context) + ).count() -def network_index_create(_context, values): +@require_admin_context +def network_index_create_safe(context, values): network_index_ref = models.NetworkIndex() for (key, value) in values.iteritems(): network_index_ref[key] = value - network_index_ref.save() + try: + network_index_ref.save() + except IntegrityError: + pass -def network_set_host(_context, network_id, host_id): +@require_admin_context +def network_set_host(context, network_id, host_id): session = get_session() with session.begin(): - network = session.query(models.Network - ).filter_by(id=network_id - ).filter_by(deleted=False - ).with_lockmode('update' - ).first() - if not network: - raise exception.NotFound("Couldn't find network with %s" % - network_id) + network_ref = session.query(models.Network + ).filter_by(id=network_id + ).filter_by(deleted=False + ).with_lockmode('update' + ).first() + if not network_ref: + raise exception.NotFound('No network for id %s' % network_id) + # NOTE(vish): if with_lockmode isn't supported, as in sqlite, # then this has concurrency issues - if not network['host']: - network['host'] = host_id - session.add(network) - return network['host'] + if not network_ref['host']: + network_ref['host'] = host_id + session.add(network_ref) + return network_ref['host'] -def network_update(_context, network_id, values): + +@require_context +def network_update(context, network_id, values): session = get_session() with session.begin(): - network_ref = models.Network.find(network_id, session=session) + network_ref = network_get(context, network_id, session=session) for (key, value) in values.iteritems(): network_ref[key] = value network_ref.save(session=session) @@ -605,15 +968,18 @@ def network_update(_context, network_id, values): ################### -def project_get_network(_context, project_id): +@require_context +def project_get_network(context, project_id): session = get_session() - rv = session.query(models.Network + result= session.query(models.Network ).filter_by(project_id=project_id ).filter_by(deleted=False ).first() - if not rv: + + if not result: raise exception.NotFound('No network for project: %s' % project_id) - return rv + + return result ################### @@ -623,14 +989,20 @@ def queue_get_for(_context, topic, physical_node_id): # FIXME(ja): this should be servername? return "%s.%s" % (topic, physical_node_id) + ################### -def export_device_count(_context): - return models.ExportDevice.count() +@require_admin_context +def export_device_count(context): + session = get_session() + return session.query(models.ExportDevice + ).filter_by(deleted=can_read_deleted(context) + ).count() -def export_device_create(_context, values): +@require_admin_context +def export_device_create(context, values): export_device_ref = models.ExportDevice() for (key, value) in values.iteritems(): export_device_ref[key] = value @@ -641,7 +1013,46 @@ def export_device_create(_context, values): ################### -def quota_create(_context, values): +def auth_destroy_token(_context, token): + session = get_session() + session.delete(token) + +def auth_get_token(_context, token_hash): + session = get_session() + tk = session.query(models.AuthToken + ).filter_by(token_hash=token_hash) + if not tk: + raise exception.NotFound('Token %s does not exist' % token_hash) + return tk + +def auth_create_token(_context, token): + tk = models.AuthToken() + for k,v in token.iteritems(): + tk[k] = v + tk.save() + return tk + + +################### + + +@require_admin_context +def quota_get(context, project_id, session=None): + if not session: + session = get_session() + + result = session.query(models.Quota + ).filter_by(project_id=project_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + if not result: + raise exception.NotFound('No quota for project_id %s' % project_id) + + return result + + +@require_admin_context +def quota_create(context, values): quota_ref = models.Quota() for (key, value) in values.iteritems(): quota_ref[key] = value @@ -649,30 +1060,29 @@ def quota_create(_context, values): return quota_ref -def quota_get(_context, project_id): - return models.Quota.find_by_str(project_id) - - -def quota_update(_context, project_id, values): +@require_admin_context +def quota_update(context, project_id, values): session = get_session() with session.begin(): - quota_ref = models.Quota.find_by_str(project_id, session=session) + quota_ref = quota_get(context, project_id, session=session) for (key, value) in values.iteritems(): quota_ref[key] = value quota_ref.save(session=session) -def quota_destroy(_context, project_id): +@require_admin_context +def quota_destroy(context, project_id): session = get_session() with session.begin(): - quota_ref = models.Quota.find_by_str(project_id, session=session) + quota_ref = quota_get(context, project_id, session=session) quota_ref.delete(session=session) ################### -def volume_allocate_shelf_and_blade(_context, volume_id): +@require_admin_context +def volume_allocate_shelf_and_blade(context, volume_id): session = get_session() with session.begin(): export_device = session.query(models.ExportDevice @@ -689,27 +1099,36 @@ def volume_allocate_shelf_and_blade(_context, volume_id): return (export_device.shelf_id, export_device.blade_id) -def volume_attached(_context, volume_id, instance_id, mountpoint): +@require_admin_context +def volume_attached(context, volume_id, instance_id, mountpoint): session = get_session() with session.begin(): - volume_ref = models.Volume.find(volume_id, session=session) + volume_ref = volume_get(context, volume_id, session=session) volume_ref['status'] = 'in-use' volume_ref['mountpoint'] = mountpoint volume_ref['attach_status'] = 'attached' - volume_ref.instance = models.Instance.find(instance_id, - session=session) + volume_ref.instance = instance_get(context, instance_id, session=session) volume_ref.save(session=session) -def volume_create(_context, values): +@require_context +def volume_create(context, values): volume_ref = models.Volume() for (key, value) in values.iteritems(): volume_ref[key] = value - volume_ref.save() + + session = get_session() + with session.begin(): + while volume_ref.ec2_id == None: + ec2_id = utils.generate_uid(volume_ref.__prefix__) + if not volume_ec2_id_exists(context, ec2_id, session=session): + volume_ref.ec2_id = ec2_id + volume_ref.save(session=session) return volume_ref -def volume_data_get_for_project(_context, project_id): +@require_admin_context +def volume_data_get_for_project(context, project_id): session = get_session() result = session.query(func.count(models.Volume.id), func.sum(models.Volume.size) @@ -720,7 +1139,8 @@ def volume_data_get_for_project(_context, project_id): return (result[0] or 0, result[1] or 0) -def volume_destroy(_context, volume_id): +@require_admin_context +def volume_destroy(context, volume_id): session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -731,10 +1151,11 @@ def volume_destroy(_context, volume_id): {'id': volume_id}) -def volume_detached(_context, volume_id): +@require_admin_context +def volume_detached(context, volume_id): session = get_session() with session.begin(): - volume_ref = models.Volume.find(volume_id, session=session) + volume_ref = volume_get(context, volume_id, session=session) volume_ref['status'] = 'available' volume_ref['mountpoint'] = None volume_ref['attach_status'] = 'detached' @@ -742,46 +1163,334 @@ def volume_detached(_context, volume_id): volume_ref.save(session=session) -def volume_get(context, volume_id): - return models.Volume.find(volume_id, deleted=_deleted(context)) +@require_context +def volume_get(context, volume_id, session=None): + if not session: + session = get_session() + result = None + + if is_admin_context(context): + result = session.query(models.Volume + ).filter_by(id=volume_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Volume + ).filter_by(project_id=context.project.id + ).filter_by(id=volume_id + ).filter_by(deleted=False + ).first() + if not result: + raise exception.NotFound('No volume for id %s' % volume_id) + + return result +@require_admin_context def volume_get_all(context): - return models.Volume.all(deleted=_deleted(context)) + return session.query(models.Volume + ).filter_by(deleted=can_read_deleted(context) + ).all() +@require_context +def volume_get_all_by_project(context, project_id): + authorize_project_context(context, project_id) -def volume_get_by_project(context, project_id): session = get_session() return session.query(models.Volume ).filter_by(project_id=project_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() -def volume_get_by_str(context, str_id): - return models.Volume.find_by_str(str_id, deleted=_deleted(context)) +@require_context +def volume_get_by_ec2_id(context, ec2_id): + session = get_session() + result = None + + if is_admin_context(context): + result = session.query(models.Volume + ).filter_by(ec2_id=ec2_id + ).filter_by(deleted=can_read_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Volume + ).filter_by(project_id=context.project.id + ).filter_by(ec2_id=ec2_id + ).filter_by(deleted=False + ).first() + else: + raise exception.NotAuthorized() + + if not result: + raise exception.NotFound('Volume %s not found' % ec2_id) + + return result + + +@require_context +def volume_ec2_id_exists(context, ec2_id, session=None): + if not session: + session = get_session() + return session.query(exists( + ).where(models.Volume.id==ec2_id) + ).one()[0] -def volume_get_instance(_context, volume_id): + +@require_admin_context +def volume_get_instance(context, volume_id): session = get_session() - with session.begin(): - return models.Volume.find(volume_id, session=session).instance + result = session.query(models.Volume + ).filter_by(id=volume_id + ).filter_by(deleted=can_read_deleted(context) + ).options(joinedload('instance') + ).first() + if not result: + raise exception.NotFound('Volume %s not found' % ec2_id) + + return result.instance -def volume_get_shelf_and_blade(_context, volume_id): +@require_admin_context +def volume_get_shelf_and_blade(context, volume_id): session = get_session() - export_device = session.query(models.ExportDevice - ).filter_by(volume_id=volume_id - ).first() - if not export_device: - raise exception.NotFound() - return (export_device.shelf_id, export_device.blade_id) + result = session.query(models.ExportDevice + ).filter_by(volume_id=volume_id + ).first() + if not result: + raise exception.NotFound('No export device found for volume %s' % + volume_id) + return (result.shelf_id, result.blade_id) -def volume_update(_context, volume_id, values): + +@require_context +def volume_update(context, volume_id, values): session = get_session() with session.begin(): - volume_ref = models.Volume.find(volume_id, session=session) + volume_ref = volume_get(context, volume_id, session=session) for (key, value) in values.iteritems(): volume_ref[key] = value volume_ref.save(session=session) + + +################### + + +@require_admin_context +def user_get(context, id, session=None): + if not session: + session = get_session() + + result = session.query(models.User + ).filter_by(id=id + ).filter_by(deleted=can_read_deleted(context) + ).first() + + if not result: + raise exception.NotFound('No user for id %s' % id) + + return result + + +@require_admin_context +def user_get_by_access_key(context, access_key, session=None): + if not session: + session = get_session() + + result = session.query(models.User + ).filter_by(access_key=access_key + ).filter_by(deleted=can_read_deleted(context) + ).first() + + if not result: + raise exception.NotFound('No user for id %s' % id) + + return result + + +@require_admin_context +def user_create(_context, values): + user_ref = models.User() + for (key, value) in values.iteritems(): + user_ref[key] = value + user_ref.save() + return user_ref + + +@require_admin_context +def user_delete(context, id): + session = get_session() + with session.begin(): + session.execute('delete from user_project_association where user_id=:id', + {'id': id}) + session.execute('delete from user_role_association where user_id=:id', + {'id': id}) + session.execute('delete from user_project_role_association where user_id=:id', + {'id': id}) + user_ref = user_get(context, id, session=session) + session.delete(user_ref) + + +def user_get_all(context): + session = get_session() + return session.query(models.User + ).filter_by(deleted=can_read_deleted(context) + ).all() + + +def project_create(_context, values): + project_ref = models.Project() + for (key, value) in values.iteritems(): + project_ref[key] = value + project_ref.save() + return project_ref + + +def project_add_member(context, project_id, user_id): + session = get_session() + with session.begin(): + project_ref = project_get(context, project_id, session=session) + user_ref = user_get(context, user_id, session=session) + + project_ref.members += [user_ref] + project_ref.save(session=session) + + +def project_get(context, id, session=None): + if not session: + session = get_session() + + result = session.query(models.Project + ).filter_by(deleted=False + ).filter_by(id=id + ).options(joinedload_all('members') + ).first() + + if not result: + raise exception.NotFound("No project with id %s" % id) + + return result + + +def project_get_all(context): + session = get_session() + return session.query(models.Project + ).filter_by(deleted=can_read_deleted(context) + ).options(joinedload_all('members') + ).all() + + +def project_get_by_user(context, user_id): + session = get_session() + user = session.query(models.User + ).filter_by(deleted=can_read_deleted(context) + ).options(joinedload_all('projects') + ).first() + return user.projects + + +def project_remove_member(context, project_id, user_id): + session = get_session() + project = project_get(context, project_id, session=session) + user = user_get(context, user_id, session=session) + + if user in project.members: + project.members.remove(user) + project.save(session=session) + + +def user_update(context, user_id, values): + session = get_session() + with session.begin(): + user_ref = user_get(context, user_id, session=session) + for (key, value) in values.iteritems(): + user_ref[key] = value + user_ref.save(session=session) + + +def project_update(context, project_id, values): + session = get_session() + with session.begin(): + project_ref = project_get(context, project_id, session=session) + for (key, value) in values.iteritems(): + project_ref[key] = value + project_ref.save(session=session) + + +def project_delete(context, id): + session = get_session() + with session.begin(): + session.execute('delete from user_project_association where project_id=:id', + {'id': id}) + session.execute('delete from user_project_role_association where project_id=:id', + {'id': id}) + project_ref = project_get(context, id, session=session) + session.delete(project_ref) + + +def user_get_roles(context, user_id): + session = get_session() + with session.begin(): + user_ref = user_get(context, user_id, session=session) + return [role.role for role in user_ref['roles']] + + +def user_get_roles_for_project(context, user_id, project_id): + session = get_session() + with session.begin(): + res = session.query(models.UserProjectRoleAssociation + ).filter_by(user_id=user_id + ).filter_by(project_id=project_id + ).all() + return [association.role for association in res] + +def user_remove_project_role(context, user_id, project_id, role): + session = get_session() + with session.begin(): + session.execute('delete from user_project_role_association where ' + \ + 'user_id=:user_id and project_id=:project_id and ' + \ + 'role=:role', { 'user_id' : user_id, + 'project_id' : project_id, + 'role' : role }) + + +def user_remove_role(context, user_id, role): + session = get_session() + with session.begin(): + res = session.query(models.UserRoleAssociation + ).filter_by(user_id=user_id + ).filter_by(role=role + ).all() + for role in res: + session.delete(role) + + +def user_add_role(context, user_id, role): + session = get_session() + with session.begin(): + user_ref = user_get(context, user_id, session=session) + models.UserRoleAssociation(user=user_ref, role=role).save(session=session) + + +def user_add_project_role(context, user_id, project_id, role): + session = get_session() + with session.begin(): + user_ref = user_get(context, user_id, session=session) + project_ref = project_get(context, project_id, session=session) + models.UserProjectRoleAssociation(user_id=user_ref['id'], + project_id=project_ref['id'], + role=role).save(session=session) + + +################### + + +def host_get_networks(context, host): + session = get_session() + with session.begin(): + return session.query(models.Network + ).filter_by(deleted=False + ).filter_by(host=host + ).all() diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 41013f41b..673c8e94f 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -27,7 +27,9 @@ import datetime from sqlalchemy.orm import relationship, backref, exc, object_mapper from sqlalchemy import Column, Integer, String from sqlalchemy import ForeignKey, DateTime, Boolean, Text +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.schema import ForeignKeyConstraint from nova.db.sqlalchemy.session import get_session @@ -50,44 +52,6 @@ class NovaBase(object): deleted_at = Column(DateTime) deleted = Column(Boolean, default=False) - @classmethod - def all(cls, session=None, deleted=False): - """Get all objects of this type""" - if not session: - session = get_session() - return session.query(cls - ).filter_by(deleted=deleted - ).all() - - @classmethod - def count(cls, session=None, deleted=False): - """Count objects of this type""" - if not session: - session = get_session() - return session.query(cls - ).filter_by(deleted=deleted - ).count() - - @classmethod - def find(cls, obj_id, session=None, deleted=False): - """Find object by id""" - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(id=obj_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for id %s" % obj_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - """Find object by str_id""" - int_id = int(str_id.rpartition('-')[2]) - return cls.find(int_id, session=session, deleted=deleted) - @property def str_id(self): """Get string id of object (generally prefix + '-' + id)""" @@ -98,7 +62,13 @@ class NovaBase(object): if not session: session = get_session() session.add(self) - session.flush() + try: + session.flush() + except IntegrityError, e: + if str(e).endswith('is not unique'): + raise exception.Duplicate(str(e)) + else: + raise def delete(self, session=None): """Delete this object""" @@ -126,6 +96,7 @@ class NovaBase(object): # __tablename__ = 'images' # __prefix__ = 'ami' # id = Column(Integer, primary_key=True) +# ec2_id = Column(String(12), unique=True) # user_id = Column(String(255)) # project_id = Column(String(255)) # image_type = Column(String(255)) @@ -173,21 +144,7 @@ class Service(BASE, NovaBase): binary = Column(String(255)) topic = Column(String(255)) report_count = Column(Integer, nullable=False, default=0) - - @classmethod - def find_by_args(cls, host, binary, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(host=host - ).filter_by(binary=binary - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for %s, %s" % (host, - binary)) - raise new_exc.__class__, new_exc, sys.exc_info()[2] + disabled = Column(Boolean, default=False) class Instance(BASE, NovaBase): @@ -195,6 +152,9 @@ class Instance(BASE, NovaBase): __tablename__ = 'instances' __prefix__ = 'i' id = Column(Integer, primary_key=True) + ec2_id = Column(String(10), unique=True) + + admin_pass = Column(String(255)) user_id = Column(String(255)) project_id = Column(String(255)) @@ -209,11 +169,14 @@ class Instance(BASE, NovaBase): @property def name(self): - return self.str_id + return self.ec2_id image_id = Column(String(255)) kernel_id = Column(String(255)) ramdisk_id = Column(String(255)) + + server_name = Column(String(255)) + # image_id = Column(Integer, ForeignKey('images.id'), nullable=True) # kernel_id = Column(Integer, ForeignKey('images.id'), nullable=True) # ramdisk_id = Column(Integer, ForeignKey('images.id'), nullable=True) @@ -233,7 +196,6 @@ class Instance(BASE, NovaBase): vcpus = Column(Integer) local_gb = Column(Integer) - hostname = Column(String(255)) host = Column(String(255)) # , ForeignKey('hosts.id')) @@ -247,6 +209,10 @@ class Instance(BASE, NovaBase): scheduled_at = Column(DateTime) launched_at = Column(DateTime) terminated_at = Column(DateTime) + + display_name = Column(String(255)) + display_description = Column(String(255)) + # TODO(vish): see Ewan's email about state improvements, probably # should be in a driver base class or some such # vmstate_state = running, halted, suspended, paused @@ -264,6 +230,7 @@ class Volume(BASE, NovaBase): __tablename__ = 'volumes' __prefix__ = 'vol' id = Column(Integer, primary_key=True) + ec2_id = Column(String(12), unique=True) user_id = Column(String(255)) project_id = Column(String(255)) @@ -272,7 +239,11 @@ class Volume(BASE, NovaBase): size = Column(Integer) availability_zone = Column(String(255)) # TODO(vish): foreign key? instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True) - instance = relationship(Instance, backref=backref('volumes')) + instance = relationship(Instance, + backref=backref('volumes'), + foreign_keys=instance_id, + primaryjoin='and_(Volume.instance_id==Instance.id,' + 'Volume.deleted==False)') mountpoint = Column(String(255)) attach_time = Column(String(255)) # TODO(vish): datetime status = Column(String(255)) # TODO(vish): enum? @@ -282,6 +253,10 @@ class Volume(BASE, NovaBase): launched_at = Column(DateTime) terminated_at = Column(DateTime) + display_name = Column(String(255)) + display_description = Column(String(255)) + + class Quota(BASE, NovaBase): """Represents quota overrides for a project""" __tablename__ = 'quotas' @@ -299,18 +274,6 @@ class Quota(BASE, NovaBase): def str_id(self): return self.project_id - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(project_id=str_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for project_id %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] class ExportDevice(BASE, NovaBase): """Represates a shelf and blade that a volume can be exported on""" @@ -319,8 +282,27 @@ class ExportDevice(BASE, NovaBase): shelf_id = Column(Integer) blade_id = Column(Integer) volume_id = Column(Integer, ForeignKey('volumes.id'), nullable=True) - volume = relationship(Volume, backref=backref('export_device', - uselist=False)) + volume = relationship(Volume, + backref=backref('export_device', uselist=False), + foreign_keys=volume_id, + primaryjoin='and_(ExportDevice.volume_id==Volume.id,' + 'ExportDevice.deleted==False)') + + +class KeyPair(BASE, NovaBase): + """Represents a public key pair for ssh""" + __tablename__ = 'key_pairs' + id = Column(Integer, primary_key=True) + name = Column(String(255)) + + user_id = Column(String(255)) + + fingerprint = Column(String(255)) + public_key = Column(Text) + + @property + def str_id(self): + return '%s.%s' % (self.user_id, self.name) class Network(BASE, NovaBase): @@ -355,10 +337,26 @@ class NetworkIndex(BASE, NovaBase): """ __tablename__ = 'network_indexes' id = Column(Integer, primary_key=True) - index = Column(Integer) + index = Column(Integer, unique=True) network_id = Column(Integer, ForeignKey('networks.id'), nullable=True) - network = relationship(Network, backref=backref('network_index', - uselist=False)) + network = relationship(Network, + backref=backref('network_index', uselist=False), + foreign_keys=network_id, + primaryjoin='and_(NetworkIndex.network_id==Network.id,' + 'NetworkIndex.deleted==False)') + + +class AuthToken(BASE, NovaBase): + """Represents an authorization token for all API transactions. Fields + are a string representing the actual token and a user id for mapping + to the actual user""" + __tablename__ = 'auth_tokens' + token_hash = Column(String(255), primary_key=True) + user_id = Column(Integer) + server_manageent_url = Column(String(255)) + storage_url = Column(String(255)) + cdn_management_url = Column(String(255)) + # TODO(vish): can these both come from the same baseclass? @@ -370,8 +368,11 @@ class FixedIp(BASE, NovaBase): network_id = Column(Integer, ForeignKey('networks.id'), nullable=True) network = relationship(Network, backref=backref('fixed_ips')) instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True) - instance = relationship(Instance, backref=backref('fixed_ip', - uselist=False)) + instance = relationship(Instance, + backref=backref('fixed_ip', uselist=False), + foreign_keys=instance_id, + primaryjoin='and_(FixedIp.instance_id==Instance.id,' + 'FixedIp.deleted==False)') allocated = Column(Boolean, default=False) leased = Column(Boolean, default=False) reserved = Column(Boolean, default=False) @@ -380,18 +381,66 @@ class FixedIp(BASE, NovaBase): def str_id(self): return self.address - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(address=str_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for address %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] + +class User(BASE, NovaBase): + """Represents a user""" + __tablename__ = 'users' + id = Column(String(255), primary_key=True) + + name = Column(String(255)) + access_key = Column(String(255)) + secret_key = Column(String(255)) + + is_admin = Column(Boolean) + + +class Project(BASE, NovaBase): + """Represents a project""" + __tablename__ = 'projects' + id = Column(String(255), primary_key=True) + name = Column(String(255)) + description = Column(String(255)) + + project_manager = Column(String(255), ForeignKey(User.id)) + + members = relationship(User, + secondary='user_project_association', + backref='projects') + + +class UserProjectRoleAssociation(BASE, NovaBase): + __tablename__ = 'user_project_role_association' + user_id = Column(String(255), primary_key=True) + user = relationship(User, + primaryjoin=user_id==User.id, + foreign_keys=[User.id], + uselist=False) + + project_id = Column(String(255), primary_key=True) + project = relationship(Project, + primaryjoin=project_id==Project.id, + foreign_keys=[Project.id], + uselist=False) + + role = Column(String(255), primary_key=True) + ForeignKeyConstraint(['user_id', + 'project_id'], + ['user_project_association.user_id', + 'user_project_association.project_id']) + + +class UserRoleAssociation(BASE, NovaBase): + __tablename__ = 'user_role_association' + user_id = Column(String(255), ForeignKey('users.id'), primary_key=True) + user = relationship(User, backref='roles') + role = Column(String(255), primary_key=True) + + +class UserProjectAssociation(BASE, NovaBase): + __tablename__ = 'user_project_association' + user_id = Column(String(255), ForeignKey(User.id), primary_key=True) + project_id = Column(String(255), ForeignKey(Project.id), primary_key=True) + class FloatingIp(BASE, NovaBase): @@ -400,34 +449,21 @@ class FloatingIp(BASE, NovaBase): id = Column(Integer, primary_key=True) address = Column(String(255)) fixed_ip_id = Column(Integer, ForeignKey('fixed_ips.id'), nullable=True) - fixed_ip = relationship(FixedIp, backref=backref('floating_ips')) - + fixed_ip = relationship(FixedIp, + backref=backref('floating_ips'), + foreign_keys=fixed_ip_id, + primaryjoin='and_(FloatingIp.fixed_ip_id==FixedIp.id,' + 'FloatingIp.deleted==False)') project_id = Column(String(255)) host = Column(String(255)) # , ForeignKey('hosts.id')) - @property - def str_id(self): - return self.address - - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(address=str_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for address %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - def register_models(): """Register Models and create metadata""" from sqlalchemy import create_engine models = (Service, Instance, Volume, ExportDevice, - FixedIp, FloatingIp, Network, NetworkIndex) # , Image, Host) + FixedIp, FloatingIp, Network, NetworkIndex, + AuthToken, UserProjectAssociation, User, Project) # , Image, Host) engine = create_engine(FLAGS.sql_connection, echo=False) for model in models: model.metadata.create_all(engine) diff --git a/nova/endpoint/__init__.py b/nova/endpoint/__init__.py deleted file mode 100644 index e69de29bb..000000000 --- a/nova/endpoint/__init__.py +++ /dev/null diff --git a/nova/endpoint/api.py b/nova/endpoint/api.py deleted file mode 100755 index 12eedfe67..000000000 --- a/nova/endpoint/api.py +++ /dev/null @@ -1,347 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 United States Government as represented by the -# Administrator of the National Aeronautics and Space Administration. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -""" -Tornado REST API Request Handlers for Nova functions -Most calls are proxied into the responsible controller. -""" - -import logging -import multiprocessing -import random -import re -import urllib -# TODO(termie): replace minidom with etree -from xml.dom import minidom - -import tornado.web -from twisted.internet import defer - -from nova import crypto -from nova import exception -from nova import flags -from nova import utils -from nova.auth import manager -import nova.cloudpipe.api -from nova.endpoint import cloud - - -FLAGS = flags.FLAGS -flags.DEFINE_integer('cc_port', 8773, 'cloud controller port') - - -_log = logging.getLogger("api") -_log.setLevel(logging.DEBUG) - - -_c2u = re.compile('(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))') - - -def _camelcase_to_underscore(str): - return _c2u.sub(r'_\1', str).lower().strip('_') - - -def _underscore_to_camelcase(str): - return ''.join([x[:1].upper() + x[1:] for x in str.split('_')]) - - -def _underscore_to_xmlcase(str): - res = _underscore_to_camelcase(str) - return res[:1].lower() + res[1:] - - -class APIRequestContext(object): - def __init__(self, handler, user, project): - self.handler = handler - self.user = user - self.project = project - self.request_id = ''.join( - [random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-') - for x in xrange(20)] - ) - - -class APIRequest(object): - def __init__(self, controller, action): - self.controller = controller - self.action = action - - def send(self, context, **kwargs): - - try: - method = getattr(self.controller, - _camelcase_to_underscore(self.action)) - except AttributeError: - _error = ('Unsupported API request: controller = %s,' - 'action = %s') % (self.controller, self.action) - _log.warning(_error) - # TODO: Raise custom exception, trap in apiserver, - # and reraise as 400 error. - raise Exception(_error) - - args = {} - for key, value in kwargs.items(): - parts = key.split(".") - key = _camelcase_to_underscore(parts[0]) - if len(parts) > 1: - d = args.get(key, {}) - d[parts[1]] = value[0] - value = d - else: - value = value[0] - args[key] = value - - for key in args.keys(): - if isinstance(args[key], dict): - if args[key] != {} and args[key].keys()[0].isdigit(): - s = args[key].items() - s.sort() - args[key] = [v for k, v in s] - - d = defer.maybeDeferred(method, context, **args) - d.addCallback(self._render_response, context.request_id) - return d - - def _render_response(self, response_data, request_id): - xml = minidom.Document() - - response_el = xml.createElement(self.action + 'Response') - response_el.setAttribute('xmlns', - 'http://ec2.amazonaws.com/doc/2009-11-30/') - request_id_el = xml.createElement('requestId') - request_id_el.appendChild(xml.createTextNode(request_id)) - response_el.appendChild(request_id_el) - if(response_data == True): - self._render_dict(xml, response_el, {'return': 'true'}) - else: - self._render_dict(xml, response_el, response_data) - - xml.appendChild(response_el) - - response = xml.toxml() - xml.unlink() - _log.debug(response) - return response - - def _render_dict(self, xml, el, data): - try: - for key in data.keys(): - val = data[key] - el.appendChild(self._render_data(xml, key, val)) - except: - _log.debug(data) - raise - - def _render_data(self, xml, el_name, data): - el_name = _underscore_to_xmlcase(el_name) - data_el = xml.createElement(el_name) - - if isinstance(data, list): - for item in data: - data_el.appendChild(self._render_data(xml, 'item', item)) - elif isinstance(data, dict): - self._render_dict(xml, data_el, data) - elif hasattr(data, '__dict__'): - self._render_dict(xml, data_el, data.__dict__) - elif isinstance(data, bool): - data_el.appendChild(xml.createTextNode(str(data).lower())) - elif data != None: - data_el.appendChild(xml.createTextNode(str(data))) - - return data_el - - -class RootRequestHandler(tornado.web.RequestHandler): - def get(self): - # available api versions - versions = [ - '1.0', - '2007-01-19', - '2007-03-01', - '2007-08-29', - '2007-10-10', - '2007-12-15', - '2008-02-01', - '2008-09-01', - '2009-04-04', - ] - for version in versions: - self.write('%s\n' % version) - self.finish() - - -class MetadataRequestHandler(tornado.web.RequestHandler): - def print_data(self, data): - if isinstance(data, dict): - output = '' - for key in data: - if key == '_name': - continue - output += key - if isinstance(data[key], dict): - if '_name' in data[key]: - output += '=' + str(data[key]['_name']) - else: - output += '/' - output += '\n' - self.write(output[:-1]) # cut off last \n - elif isinstance(data, list): - self.write('\n'.join(data)) - else: - self.write(str(data)) - - def lookup(self, path, data): - items = path.split('/') - for item in items: - if item: - if not isinstance(data, dict): - return data - if not item in data: - return None - data = data[item] - return data - - def get(self, path): - cc = self.application.controllers['Cloud'] - meta_data = cc.get_metadata(self.request.remote_ip) - if meta_data is None: - _log.error('Failed to get metadata for ip: %s' % - self.request.remote_ip) - raise tornado.web.HTTPError(404) - data = self.lookup(path, meta_data) - if data is None: - raise tornado.web.HTTPError(404) - self.print_data(data) - self.finish() - - -class APIRequestHandler(tornado.web.RequestHandler): - def get(self, controller_name): - self.execute(controller_name) - - @tornado.web.asynchronous - def execute(self, controller_name): - # Obtain the appropriate controller for this request. - try: - controller = self.application.controllers[controller_name] - except KeyError: - self._error('unhandled', 'no controller named %s' % controller_name) - return - - args = self.request.arguments - - # Read request signature. - try: - signature = args.pop('Signature')[0] - except: - raise tornado.web.HTTPError(400) - - # Make a copy of args for authentication and signature verification. - auth_params = {} - for key, value in args.items(): - auth_params[key] = value[0] - - # Get requested action and remove authentication args for final request. - try: - action = args.pop('Action')[0] - access = args.pop('AWSAccessKeyId')[0] - args.pop('SignatureMethod') - args.pop('SignatureVersion') - args.pop('Version') - args.pop('Timestamp') - except: - raise tornado.web.HTTPError(400) - - # Authenticate the request. - try: - (user, project) = manager.AuthManager().authenticate( - access, - signature, - auth_params, - self.request.method, - self.request.host, - self.request.path - ) - - except exception.Error, ex: - logging.debug("Authentication Failure: %s" % ex) - raise tornado.web.HTTPError(403) - - _log.debug('action: %s' % action) - - for key, value in args.items(): - _log.debug('arg: %s\t\tval: %s' % (key, value)) - - request = APIRequest(controller, action) - context = APIRequestContext(self, user, project) - d = request.send(context, **args) - # d.addCallback(utils.debug) - - # TODO: Wrap response in AWS XML format - d.addCallbacks(self._write_callback, self._error_callback) - - def _write_callback(self, data): - self.set_header('Content-Type', 'text/xml') - self.write(data) - self.finish() - - def _error_callback(self, failure): - try: - failure.raiseException() - except exception.ApiError as ex: - if ex.code: - self._error(ex.code, ex.message) - else: - self._error(type(ex).__name__, ex.message) - # TODO(vish): do something more useful with unknown exceptions - except Exception as ex: - self._error(type(ex).__name__, str(ex)) - raise - - def post(self, controller_name): - self.execute(controller_name) - - def _error(self, code, message): - self._status_code = 400 - self.set_header('Content-Type', 'text/xml') - self.write('<?xml version="1.0"?>\n') - self.write('<Response><Errors><Error><Code>%s</Code>' - '<Message>%s</Message></Error></Errors>' - '<RequestID>?</RequestID></Response>' % (code, message)) - self.finish() - - -class APIServerApplication(tornado.web.Application): - def __init__(self, controllers): - tornado.web.Application.__init__(self, [ - (r'/', RootRequestHandler), - (r'/cloudpipe/(.*)', nova.cloudpipe.api.CloudPipeRequestHandler), - (r'/cloudpipe', nova.cloudpipe.api.CloudPipeRequestHandler), - (r'/services/([A-Za-z0-9]+)/', APIRequestHandler), - (r'/latest/([-A-Za-z0-9/]*)', MetadataRequestHandler), - (r'/2009-04-04/([-A-Za-z0-9/]*)', MetadataRequestHandler), - (r'/2008-09-01/([-A-Za-z0-9/]*)', MetadataRequestHandler), - (r'/2008-02-01/([-A-Za-z0-9/]*)', MetadataRequestHandler), - (r'/2007-12-15/([-A-Za-z0-9/]*)', MetadataRequestHandler), - (r'/2007-10-10/([-A-Za-z0-9/]*)', MetadataRequestHandler), - (r'/2007-08-29/([-A-Za-z0-9/]*)', MetadataRequestHandler), - (r'/2007-03-01/([-A-Za-z0-9/]*)', MetadataRequestHandler), - (r'/2007-01-19/([-A-Za-z0-9/]*)', MetadataRequestHandler), - (r'/1.0/([-A-Za-z0-9/]*)', MetadataRequestHandler), - ], pool=multiprocessing.Pool(4)) - self.controllers = controllers diff --git a/nova/flags.py b/nova/flags.py index ed0baee65..c32cdd7a4 100644 --- a/nova/flags.py +++ b/nova/flags.py @@ -167,6 +167,9 @@ def DECLARE(name, module_string, flag_values=FLAGS): # Define any app-specific flags in their own files, docs at: # http://code.google.com/p/python-gflags/source/browse/trunk/gflags.py#39 +DEFINE_list('region_list', + [], + 'list of region=url pairs separated by commas') DEFINE_string('connection_type', 'libvirt', 'libvirt, xenapi or fake') DEFINE_integer('s3_port', 3333, 's3 port') DEFINE_string('s3_host', '127.0.0.1', 's3 host') @@ -185,6 +188,8 @@ DEFINE_string('rabbit_userid', 'guest', 'rabbit userid') DEFINE_string('rabbit_password', 'guest', 'rabbit password') DEFINE_string('rabbit_virtual_host', '/', 'rabbit virtual host') DEFINE_string('control_exchange', 'nova', 'the main exchange to connect to') +DEFINE_string('cc_host', '127.0.0.1', 'ip of api server') +DEFINE_integer('cc_port', 8773, 'cloud controller port') DEFINE_string('ec2_url', 'http://127.0.0.1:8773/services/Cloud', 'Url to ec2 api server') diff --git a/nova/manager.py b/nova/manager.py index e9aa50c56..56ba7d3f6 100644 --- a/nova/manager.py +++ b/nova/manager.py @@ -22,6 +22,7 @@ Base class for managers of different parts of the system from nova import utils from nova import flags +from twisted.internet import defer FLAGS = flags.FLAGS flags.DEFINE_string('db_driver', 'nova.db.api', @@ -37,3 +38,15 @@ class Manager(object): if not db_driver: db_driver = FLAGS.db_driver self.db = utils.import_object(db_driver) # pylint: disable-msg=C0103 + + @defer.inlineCallbacks + def periodic_tasks(self, context=None): + """Tasks to be run at a periodic interval""" + yield + + def init_host(self): + """Do any initialization that needs to be run if this is a standalone service. + + Child classes should override this method. + """ + pass diff --git a/nova/network/linux_net.py b/nova/network/linux_net.py index 41aeb5da7..37f9c8253 100644 --- a/nova/network/linux_net.py +++ b/nova/network/linux_net.py @@ -28,6 +28,11 @@ from nova import flags from nova import utils +def _bin_file(script): + """Return the absolute path to scipt in the bin directory""" + return os.path.abspath(os.path.join(__file__, "../../../bin", script)) + + FLAGS = flags.FLAGS flags.DEFINE_string('dhcpbridge_flagfile', '/etc/nova/nova-dhcpbridge.conf', @@ -36,13 +41,36 @@ flags.DEFINE_string('dhcpbridge_flagfile', flags.DEFINE_string('networks_path', utils.abspath('../networks'), 'Location to keep network config files') flags.DEFINE_string('public_interface', 'vlan1', - 'Interface for public IP addresses') + 'Interface for public IP addresses') flags.DEFINE_string('bridge_dev', 'eth0', 'network device for bridges') - +flags.DEFINE_string('dhcpbridge', _bin_file('nova-dhcpbridge'), + 'location of nova-dhcpbridge') +flags.DEFINE_string('routing_source_ip', '127.0.0.1', + 'Public IP of network host') +flags.DEFINE_bool('use_nova_chains', False, + 'use the nova_ routing chains instead of default') DEFAULT_PORTS = [("tcp", 80), ("tcp", 22), ("udp", 1194), ("tcp", 443)] +def init_host(): + """Basic networking setup goes here""" + # NOTE(devcamcar): Cloud public DNAT entries, CloudPipe port + # forwarding entries and a default DNAT entry. + _confirm_rule("PREROUTING", "-t nat -s 0.0.0.0/0 " + "-d 169.254.169.254/32 -p tcp -m tcp --dport 80 -j DNAT " + "--to-destination %s:%s" % (FLAGS.cc_host, FLAGS.cc_port)) + + # NOTE(devcamcar): Cloud public SNAT entries and the default + # SNAT rule for outbound traffic. + _confirm_rule("POSTROUTING", "-t nat -s %s " + "-j SNAT --to-source %s" + % (FLAGS.private_range, FLAGS.routing_source_ip)) + + _confirm_rule("POSTROUTING", "-t nat -s %s -j MASQUERADE" % + FLAGS.private_range) + _confirm_rule("POSTROUTING", "-t nat -s %(range)s -d %(range)s -j ACCEPT" % + {'range': FLAGS.private_range}) def bind_floating_ip(floating_ip): """Bind ip to public interface""" @@ -58,37 +86,37 @@ def unbind_floating_ip(floating_ip): def ensure_vlan_forward(public_ip, port, private_ip): """Sets up forwarding rules for vlan""" - _confirm_rule("FORWARD -d %s -p udp --dport 1194 -j ACCEPT" % private_ip) - _confirm_rule( - "PREROUTING -t nat -d %s -p udp --dport %s -j DNAT --to %s:1194" + _confirm_rule("FORWARD", "-d %s -p udp --dport 1194 -j ACCEPT" % + private_ip) + _confirm_rule("PREROUTING", + "-t nat -d %s -p udp --dport %s -j DNAT --to %s:1194" % (public_ip, port, private_ip)) def ensure_floating_forward(floating_ip, fixed_ip): """Ensure floating ip forwarding rule""" - _confirm_rule("PREROUTING -t nat -d %s -j DNAT --to %s" + _confirm_rule("PREROUTING", "-t nat -d %s -j DNAT --to %s" % (floating_ip, fixed_ip)) - _confirm_rule("POSTROUTING -t nat -s %s -j SNAT --to %s" + _confirm_rule("POSTROUTING", "-t nat -s %s -j SNAT --to %s" % (fixed_ip, floating_ip)) # TODO(joshua): Get these from the secgroup datastore entries - _confirm_rule("FORWARD -d %s -p icmp -j ACCEPT" + _confirm_rule("FORWARD", "-d %s -p icmp -j ACCEPT" % (fixed_ip)) for (protocol, port) in DEFAULT_PORTS: - _confirm_rule( - "FORWARD -d %s -p %s --dport %s -j ACCEPT" + _confirm_rule("FORWARD","-d %s -p %s --dport %s -j ACCEPT" % (fixed_ip, protocol, port)) def remove_floating_forward(floating_ip, fixed_ip): """Remove forwarding for floating ip""" - _remove_rule("PREROUTING -t nat -d %s -j DNAT --to %s" + _remove_rule("PREROUTING", "-t nat -d %s -j DNAT --to %s" % (floating_ip, fixed_ip)) - _remove_rule("POSTROUTING -t nat -s %s -j SNAT --to %s" + _remove_rule("POSTROUTING", "-t nat -s %s -j SNAT --to %s" % (fixed_ip, floating_ip)) - _remove_rule("FORWARD -d %s -p icmp -j ACCEPT" + _remove_rule("FORWARD", "-d %s -p icmp -j ACCEPT" % (fixed_ip)) for (protocol, port) in DEFAULT_PORTS: - _remove_rule("FORWARD -d %s -p %s --dport %s -j ACCEPT" + _remove_rule("FORWARD", "-d %s -p %s --dport %s -j ACCEPT" % (fixed_ip, protocol, port)) @@ -118,22 +146,24 @@ def ensure_bridge(bridge, interface, net_attrs=None): # _execute("sudo brctl setageing %s 10" % bridge) _execute("sudo brctl stp %s off" % bridge) _execute("sudo brctl addif %s %s" % (bridge, interface)) - if net_attrs: - _execute("sudo ifconfig %s %s broadcast %s netmask %s up" % \ - (bridge, - net_attrs['gateway'], - net_attrs['broadcast'], - net_attrs['netmask'])) - _confirm_rule("FORWARD --in-interface %s -j ACCEPT" % bridge) - else: - _execute("sudo ifconfig %s up" % bridge) + if net_attrs: + _execute("sudo ifconfig %s %s broadcast %s netmask %s up" % \ + (bridge, + net_attrs['gateway'], + net_attrs['broadcast'], + net_attrs['netmask'])) + else: + _execute("sudo ifconfig %s up" % bridge) + _confirm_rule("FORWARD", "--in-interface %s -j ACCEPT" % bridge) + _confirm_rule("FORWARD", "--out-interface %s -j ACCEPT" % bridge) def get_dhcp_hosts(context, network_id): """Get a string containing a network's hosts config in dnsmasq format""" hosts = [] - for fixed_ip in db.network_get_associated_fixed_ips(context, network_id): - hosts.append(_host_dhcp(fixed_ip['str_id'])) + for fixed_ip_ref in db.network_get_associated_fixed_ips(context, + network_id): + hosts.append(_host_dhcp(fixed_ip_ref)) return '\n'.join(hosts) @@ -149,9 +179,14 @@ def update_dhcp(context, network_id): signal causing it to reload, otherwise spawn a new instance """ network_ref = db.network_get(context, network_id) - with open(_dhcp_file(network_ref['vlan'], 'conf'), 'w') as f: + + conffile = _dhcp_file(network_ref['vlan'], 'conf') + with open(conffile, 'w') as f: f.write(get_dhcp_hosts(context, network_id)) + # Make sure dnsmasq can actually read it (it setuid()s to "nobody") + os.chmod(conffile, 0644) + pid = _dnsmasq_pid_for(network_ref['vlan']) # if dnsmasq is already running, then tell it to reload @@ -159,7 +194,7 @@ def update_dhcp(context, network_id): # TODO(ja): use "/proc/%d/cmdline" % (pid) to determine if pid refers # correct dnsmasq process try: - os.kill(pid, signal.SIGHUP) + _execute('sudo kill -HUP %d' % pid) return except Exception as exc: # pylint: disable-msg=W0703 logging.debug("Hupping dnsmasq threw %s", exc) @@ -171,12 +206,12 @@ def update_dhcp(context, network_id): _execute(command, addl_env=env) -def _host_dhcp(address): +def _host_dhcp(fixed_ip_ref): """Return a host string for an address""" - instance_ref = db.fixed_ip_get_instance(None, address) + instance_ref = fixed_ip_ref['instance'] return "%s,%s.novalocal,%s" % (instance_ref['mac_address'], instance_ref['hostname'], - address) + fixed_ip_ref['address']) def _execute(cmd, *args, **kwargs): @@ -194,15 +229,19 @@ def _device_exists(device): return not err -def _confirm_rule(cmd): +def _confirm_rule(chain, cmd): """Delete and re-add iptables rule""" - _execute("sudo iptables --delete %s" % (cmd), check_exit_code=False) - _execute("sudo iptables -I %s" % (cmd)) + if FLAGS.use_nova_chains: + chain = "nova_%s" % chain.lower() + _execute("sudo iptables --delete %s %s" % (chain, cmd), check_exit_code=False) + _execute("sudo iptables -I %s %s" % (chain, cmd)) -def _remove_rule(cmd): +def _remove_rule(chain, cmd): """Remove iptables rule""" - _execute("sudo iptables --delete %s" % (cmd)) + if FLAGS.use_nova_chains: + chain = "%S" % chain.lower() + _execute("sudo iptables --delete %s %s" % (chain, cmd)) def _dnsmasq_cmd(net): @@ -216,7 +255,7 @@ def _dnsmasq_cmd(net): ' --except-interface=lo', ' --dhcp-range=%s,static,120s' % net['dhcp_start'], ' --dhcp-hostsfile=%s' % _dhcp_file(net['vlan'], 'conf'), - ' --dhcp-script=%s' % _bin_file('nova-dhcpbridge'), + ' --dhcp-script=%s' % FLAGS.dhcpbridge, ' --leasefile-ro'] return ''.join(cmd) @@ -227,7 +266,7 @@ def _stop_dnsmasq(network): if pid: try: - os.kill(pid, signal.SIGTERM) + _execute('sudo kill -TERM %d' % pid) except Exception as exc: # pylint: disable-msg=W0703 logging.debug("Killing dnsmasq threw %s", exc) @@ -235,12 +274,10 @@ def _stop_dnsmasq(network): def _dhcp_file(vlan, kind): """Return path to a pid, leases or conf file for a vlan""" - return os.path.abspath("%s/nova-%s.%s" % (FLAGS.networks_path, vlan, kind)) - + if not os.path.exists(FLAGS.networks_path): + os.makedirs(FLAGS.networks_path) -def _bin_file(script): - """Return the absolute path to scipt in the bin directory""" - return os.path.abspath(os.path.join(__file__, "../../../bin", script)) + return os.path.abspath("%s/nova-%s.%s" % (FLAGS.networks_path, vlan, kind)) def _dnsmasq_pid_for(vlan): diff --git a/nova/network/manager.py b/nova/network/manager.py index 191c1d364..9c1846dd9 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -20,10 +20,12 @@ Network Hosts are responsible for allocating ips and setting up network """ +import datetime import logging import math import IPy +from twisted.internet import defer from nova import db from nova import exception @@ -62,7 +64,9 @@ flags.DEFINE_integer('cnt_vpn_clients', 5, flags.DEFINE_string('network_driver', 'nova.network.linux_net', 'Driver to use for network creation') flags.DEFINE_bool('update_dhcp_on_disassociate', False, - 'Whether to update dhcp when fixed_ip is disassocated') + 'Whether to update dhcp when fixed_ip is disassociated') +flags.DEFINE_integer('fixed_ip_disassociate_timeout', 600, + 'Seconds after which a deallocated ip is disassociated') class AddressAlreadyAllocated(exception.Error): @@ -81,6 +85,12 @@ class NetworkManager(manager.Manager): self.driver = utils.import_object(network_driver) super(NetworkManager, self).__init__(*args, **kwargs) + def init_host(self): + # Set up networking for the projects for which we're already + # the designated network host. + for network in self.db.host_get_networks(None, self.host): + self._on_set_network_host(None, network['id']) + def set_network_host(self, context, project_id): """Safely sets the host of the projects network""" logging.debug("setting network host") @@ -88,7 +98,7 @@ class NetworkManager(manager.Manager): # TODO(vish): can we minimize db access by just getting the # id here instead of the ref? network_id = network_ref['id'] - host = self.db.network_set_host(context, + host = self.db.network_set_host(None, network_id, self.host) self._on_set_network_host(context, network_id) @@ -218,6 +228,27 @@ class FlatManager(NetworkManager): class VlanManager(NetworkManager): """Vlan network with dhcp""" + + @defer.inlineCallbacks + def periodic_tasks(self, context=None): + """Tasks to be run at a periodic interval""" + yield super(VlanManager, self).periodic_tasks(context) + now = datetime.datetime.utcnow() + timeout = FLAGS.fixed_ip_disassociate_timeout + time = now - datetime.timedelta(seconds=timeout) + num = self.db.fixed_ip_disassociate_all_by_timeout(context, + self.host, + time) + if num: + logging.debug("Dissassociated %s stale fixed ip(s)", num) + + def init_host(self): + """Do any initialization that needs to be run if this is a + standalone service. + """ + super(VlanManager, self).init_host() + self.driver.init_host() + def allocate_fixed_ip(self, context, instance_id, *args, **kwargs): """Gets a fixed ip from the pool""" network_ref = self.db.project_get_network(context, context.project.id) @@ -225,7 +256,7 @@ class VlanManager(NetworkManager): address = network_ref['vpn_private_address'] self.db.fixed_ip_associate(context, address, instance_id) else: - address = self.db.fixed_ip_associate_pool(context, + address = self.db.fixed_ip_associate_pool(None, network_ref['id'], instance_id) self.db.fixed_ip_update(context, address, {'allocated': True}) @@ -235,14 +266,6 @@ class VlanManager(NetworkManager): """Returns a fixed ip to the pool""" self.db.fixed_ip_update(context, address, {'allocated': False}) fixed_ip_ref = self.db.fixed_ip_get_by_address(context, address) - if not fixed_ip_ref['leased']: - self.db.fixed_ip_disassociate(context, address) - # NOTE(vish): dhcp server isn't updated until next setup, this - # means there will stale entries in the conf file - # the code below will update the file if necessary - if FLAGS.update_dhcp_on_disassociate: - network_ref = self.db.fixed_ip_get_network(context, address) - self.driver.update_dhcp(context, network_ref['id']) def setup_fixed_ip(self, context, address): @@ -259,10 +282,7 @@ class VlanManager(NetworkManager): """Called by dhcp-bridge when ip is leased""" logging.debug("Leasing IP %s", address) fixed_ip_ref = self.db.fixed_ip_get_by_address(context, address) - if not fixed_ip_ref['allocated']: - logging.warn("IP %s leased that was already deallocated", address) - return - instance_ref = self.db.fixed_ip_get_instance(context, address) + instance_ref = fixed_ip_ref['instance'] if not instance_ref: raise exception.Error("IP %s leased that isn't associated" % address) @@ -270,24 +290,27 @@ class VlanManager(NetworkManager): raise exception.Error("IP %s leased to bad mac %s vs %s" % (address, instance_ref['mac_address'], mac)) self.db.fixed_ip_update(context, - fixed_ip_ref['str_id'], + fixed_ip_ref['address'], {'leased': True}) + if not fixed_ip_ref['allocated']: + logging.warn("IP %s leased that was already deallocated", address) def release_fixed_ip(self, context, mac, address): """Called by dhcp-bridge when ip is released""" logging.debug("Releasing IP %s", address) fixed_ip_ref = self.db.fixed_ip_get_by_address(context, address) - if not fixed_ip_ref['leased']: - logging.warn("IP %s released that was not leased", address) - return - instance_ref = self.db.fixed_ip_get_instance(context, address) + instance_ref = fixed_ip_ref['instance'] if not instance_ref: raise exception.Error("IP %s released that isn't associated" % address) if instance_ref['mac_address'] != mac: raise exception.Error("IP %s released from bad mac %s vs %s" % (address, instance_ref['mac_address'], mac)) - self.db.fixed_ip_update(context, address, {'leased': False}) + if not fixed_ip_ref['leased']: + logging.warn("IP %s released that was not leased", address) + self.db.fixed_ip_update(context, + fixed_ip_ref['str_id'], + {'leased': False}) if not fixed_ip_ref['allocated']: self.db.fixed_ip_disassociate(context, address) # NOTE(vish): dhcp server isn't updated until next setup, this @@ -343,7 +366,7 @@ class VlanManager(NetworkManager): This could use a manage command instead of keying off of a flag""" if not self.db.network_index_count(context): for index in range(FLAGS.num_networks): - self.db.network_index_create(context, {'index': index}) + self.db.network_index_create_safe(context, {'index': index}) def _on_set_network_host(self, context, network_id): """Called when this host becomes the host for a project""" @@ -351,6 +374,7 @@ class VlanManager(NetworkManager): self.driver.ensure_vlan_bridge(network_ref['vlan'], network_ref['bridge'], network_ref) + self.driver.update_dhcp(context, network_id) @property def _bottom_reserved_ips(self): diff --git a/nova/objectstore/__init__.py b/nova/objectstore/__init__.py index b8890ac03..ecad9be7c 100644 --- a/nova/objectstore/__init__.py +++ b/nova/objectstore/__init__.py @@ -22,7 +22,7 @@ .. automodule:: nova.objectstore :platform: Unix - :synopsis: Currently a trivial file-based system, getting extended w/ mongo. + :synopsis: Currently a trivial file-based system, getting extended w/ swift. .. moduleauthor:: Jesse Andrews <jesse@ansolabs.com> .. moduleauthor:: Devin Carlen <devin.carlen@gmail.com> .. moduleauthor:: Vishvananda Ishaya <vishvananda@yahoo.com> diff --git a/nova/objectstore/handler.py b/nova/objectstore/handler.py index 5c3dc286b..dfee64aca 100644 --- a/nova/objectstore/handler.py +++ b/nova/objectstore/handler.py @@ -55,7 +55,7 @@ from twisted.web import static from nova import exception from nova import flags from nova.auth import manager -from nova.endpoint import api +from nova.api.ec2 import context from nova.objectstore import bucket from nova.objectstore import image @@ -131,7 +131,7 @@ def get_context(request): request.uri, headers=request.getAllHeaders(), check_type='s3') - return api.APIRequestContext(None, user, project) + return context.APIRequestContext(user, project) except exception.Error as ex: logging.debug("Authentication Failure: %s", ex) raise exception.NotAuthorized @@ -352,6 +352,8 @@ class ImagesResource(resource.Resource): m[u'imageType'] = m['type'] elif 'imageType' in m: m[u'type'] = m['imageType'] + if 'displayName' not in m: + m[u'displayName'] = u'' return m request.write(json.dumps([decorate(i.metadata) for i in images])) @@ -382,16 +384,25 @@ class ImagesResource(resource.Resource): def render_POST(self, request): # pylint: disable-msg=R0201 """Update image attributes: public/private""" + # image_id required for all requests image_id = get_argument(request, 'image_id', u'') - operation = get_argument(request, 'operation', u'') - image_object = image.Image(image_id) - if not image_object.is_authorized(request.context): + logging.debug("not authorized for render_POST in images") raise exception.NotAuthorized - image_object.set_public(operation=='add') - + operation = get_argument(request, 'operation', u'') + if operation: + # operation implies publicity toggle + logging.debug("handling publicity toggle") + image_object.set_public(operation=='add') + else: + # other attributes imply update + logging.debug("update user fields") + clean_args = {} + for arg in request.args.keys(): + clean_args[arg] = request.args[arg][0] + image_object.update_user_editable_fields(clean_args) return '' def render_DELETE(self, request): # pylint: disable-msg=R0201 diff --git a/nova/objectstore/image.py b/nova/objectstore/image.py index f3c02a425..def1b8167 100644 --- a/nova/objectstore/image.py +++ b/nova/objectstore/image.py @@ -82,6 +82,16 @@ class Image(object): with open(os.path.join(self.path, 'info.json'), 'w') as f: json.dump(md, f) + def update_user_editable_fields(self, args): + """args is from the request parameters, so requires extra cleaning""" + fields = {'display_name': 'displayName', 'description': 'description'} + info = self.metadata + for field in fields.keys(): + if field in args: + info[fields[field]] = args[field] + with open(os.path.join(self.path, 'info.json'), 'w') as f: + json.dump(info, f) + @staticmethod def all(): images = [] diff --git a/nova/quota.py b/nova/quota.py index f0e51feeb..edbb83111 100644 --- a/nova/quota.py +++ b/nova/quota.py @@ -37,7 +37,7 @@ flags.DEFINE_integer('quota_gigabytes', 1000, flags.DEFINE_integer('quota_floating_ips', 10, 'number of floating ips allowed per project') -def _get_quota(context, project_id): +def get_quota(context, project_id): rval = {'instances': FLAGS.quota_instances, 'cores': FLAGS.quota_cores, 'volumes': FLAGS.quota_volumes, @@ -57,7 +57,7 @@ def allowed_instances(context, num_instances, instance_type): project_id = context.project.id used_instances, used_cores = db.instance_data_get_for_project(context, project_id) - quota = _get_quota(context, project_id) + quota = get_quota(context, project_id) allowed_instances = quota['instances'] - used_instances allowed_cores = quota['cores'] - used_cores type_cores = instance_types.INSTANCE_TYPES[instance_type]['vcpus'] @@ -72,9 +72,10 @@ def allowed_volumes(context, num_volumes, size): project_id = context.project.id used_volumes, used_gigabytes = db.volume_data_get_for_project(context, project_id) - quota = _get_quota(context, project_id) + quota = get_quota(context, project_id) allowed_volumes = quota['volumes'] - used_volumes allowed_gigabytes = quota['gigabytes'] - used_gigabytes + size = int(size) num_gigabytes = num_volumes * size allowed_volumes = min(allowed_volumes, int(allowed_gigabytes // size)) @@ -85,7 +86,7 @@ def allowed_floating_ips(context, num_floating_ips): """Check quota and return min(num_floating_ips, allowed_floating_ips)""" project_id = context.project.id used_floating_ips = db.floating_ip_count_by_project(context, project_id) - quota = _get_quota(context, project_id) + quota = get_quota(context, project_id) allowed_floating_ips = quota['floating_ips'] - used_floating_ips return min(num_floating_ips, allowed_floating_ips) diff --git a/nova/rpc.py b/nova/rpc.py index 84a9b5590..fe52ad35f 100644 --- a/nova/rpc.py +++ b/nova/rpc.py @@ -46,9 +46,9 @@ LOG.setLevel(logging.DEBUG) class Connection(carrot_connection.BrokerConnection): """Connection instance object""" @classmethod - def instance(cls): + def instance(cls, new=False): """Returns the instance""" - if not hasattr(cls, '_instance'): + if new or not hasattr(cls, '_instance'): params = dict(hostname=FLAGS.rabbit_host, port=FLAGS.rabbit_port, userid=FLAGS.rabbit_userid, @@ -60,7 +60,10 @@ class Connection(carrot_connection.BrokerConnection): # NOTE(vish): magic is fun! # pylint: disable-msg=W0142 - cls._instance = cls(**params) + if new: + return cls(**params) + else: + cls._instance = cls(**params) return cls._instance @classmethod @@ -81,21 +84,6 @@ class Consumer(messaging.Consumer): self.failed_connection = False super(Consumer, self).__init__(*args, **kwargs) - # TODO(termie): it would be nice to give these some way of automatically - # cleaning up after themselves - def attach_to_tornado(self, io_inst=None): - """Attach a callback to tornado that fires 10 times a second""" - from tornado import ioloop - if io_inst is None: - io_inst = ioloop.IOLoop.instance() - - injected = ioloop.PeriodicCallback( - lambda: self.fetch(enable_callbacks=True), 100, io_loop=io_inst) - injected.start() - return injected - - attachToTornado = attach_to_tornado - def fetch(self, no_ack=None, auto_ack=None, enable_callbacks=False): """Wraps the parent fetch with some logic for failed connections""" # TODO(vish): the logic for failed connections and logging should be @@ -123,6 +111,7 @@ class Consumer(messaging.Consumer): """Attach a callback to twisted that fires 10 times a second""" loop = task.LoopingCall(self.fetch, enable_callbacks=True) loop.start(interval=0.1) + return loop class Publisher(messaging.Publisher): @@ -265,6 +254,41 @@ def call(topic, msg): msg.update({'_msg_id': msg_id}) LOG.debug("MSG_ID is %s" % (msg_id)) + class WaitMessage(object): + + def __call__(self, data, message): + """Acks message and sets result.""" + message.ack() + if data['failure']: + self.result = RemoteError(*data['failure']) + else: + self.result = data['result'] + + wait_msg = WaitMessage() + conn = Connection.instance(True) + consumer = DirectConsumer(connection=conn, msg_id=msg_id) + consumer.register_callback(wait_msg) + + conn = Connection.instance() + publisher = TopicPublisher(connection=conn, topic=topic) + publisher.send(msg) + publisher.close() + + try: + consumer.wait(limit=1) + except StopIteration: + pass + consumer.close() + return wait_msg.result + + +def call_twisted(topic, msg): + """Sends a message on a topic and wait for a response""" + LOG.debug("Making asynchronous call...") + msg_id = uuid.uuid4().hex + msg.update({'_msg_id': msg_id}) + LOG.debug("MSG_ID is %s" % (msg_id)) + conn = Connection.instance() d = defer.Deferred() consumer = DirectConsumer(connection=conn, msg_id=msg_id) @@ -278,7 +302,7 @@ def call(topic, msg): return d.callback(data['result']) consumer.register_callback(deferred_receive) - injected = consumer.attach_to_tornado() + injected = consumer.attach_to_twisted() # clean up after the injected listened and return x d.addCallback(lambda x: injected.stop() and x or x) diff --git a/nova/scheduler/driver.py b/nova/scheduler/driver.py index 2e6a5a835..c89d25a47 100644 --- a/nova/scheduler/driver.py +++ b/nova/scheduler/driver.py @@ -42,7 +42,8 @@ class Scheduler(object): def service_is_up(service): """Check whether a service is up based on last heartbeat.""" last_heartbeat = service['updated_at'] or service['created_at'] - elapsed = datetime.datetime.now() - last_heartbeat + # Timestamps in DB are UTC. + elapsed = datetime.datetime.utcnow() - last_heartbeat return elapsed < datetime.timedelta(seconds=FLAGS.service_down_time) def hosts_up(self, context, topic): diff --git a/nova/service.py b/nova/service.py index 870dd6ceb..115e0ff32 100644 --- a/nova/service.py +++ b/nova/service.py @@ -37,7 +37,11 @@ from nova import utils FLAGS = flags.FLAGS flags.DEFINE_integer('report_interval', 10, - 'seconds between nodes reporting state to cloud', + 'seconds between nodes reporting state to datastore', + lower_bound=1) + +flags.DEFINE_integer('periodic_interval', 60, + 'seconds between running periodic tasks', lower_bound=1) @@ -48,10 +52,17 @@ class Service(object, service.Service): self.host = host self.binary = binary self.topic = topic - manager_class = utils.import_class(manager) - self.manager = manager_class(host=host, *args, **kwargs) - self.model_disconnected = False + self.manager_class_name = manager super(Service, self).__init__(*args, **kwargs) + self.saved_args, self.saved_kwargs = args, kwargs + + + def startService(self): # pylint: disable-msg C0103 + manager_class = utils.import_class(self.manager_class_name) + self.manager = manager_class(host=self.host, *self.saved_args, + **self.saved_kwargs) + self.manager.init_host() + self.model_disconnected = False try: service_ref = db.service_get_by_args(None, self.host, @@ -80,7 +91,8 @@ class Service(object, service.Service): binary=None, topic=None, manager=None, - report_interval=None): + report_interval=None, + periodic_interval=None): """Instantiates class and passes back application object. Args: @@ -89,6 +101,7 @@ class Service(object, service.Service): topic, defaults to bin_name - "nova-" part manager, defaults to FLAGS.<topic>_manager report_interval, defaults to FLAGS.report_interval + periodic_interval, defaults to FLAGS.periodic_interval """ if not host: host = FLAGS.host @@ -100,6 +113,8 @@ class Service(object, service.Service): manager = FLAGS.get('%s_manager' % topic, None) if not report_interval: report_interval = FLAGS.report_interval + if not periodic_interval: + periodic_interval = FLAGS.periodic_interval logging.warn("Starting %s node", topic) service_obj = cls(host, binary, topic, manager) conn = rpc.Connection.instance() @@ -112,11 +127,14 @@ class Service(object, service.Service): topic='%s.%s' % (topic, host), proxy=service_obj) + consumer_all.attach_to_twisted() + consumer_node.attach_to_twisted() + pulse = task.LoopingCall(service_obj.report_state) pulse.start(interval=report_interval, now=False) - consumer_all.attach_to_twisted() - consumer_node.attach_to_twisted() + pulse = task.LoopingCall(service_obj.periodic_tasks) + pulse.start(interval=periodic_interval, now=False) # This is the parent service that twistd will be looking for when it # parses this file, return it so that we can get it into globals. @@ -132,6 +150,11 @@ class Service(object, service.Service): logging.warn("Service killed that has no database entry") @defer.inlineCallbacks + def periodic_tasks(self, context=None): + """Tasks to be run at a periodic interval""" + yield self.manager.periodic_tasks(context) + + @defer.inlineCallbacks def report_state(self, context=None): """Update the state of this service in the datastore.""" try: diff --git a/nova/test.py b/nova/test.py index c392c8a84..1f4b33272 100644 --- a/nova/test.py +++ b/nova/test.py @@ -33,6 +33,7 @@ from twisted.trial import unittest from nova import fakerabbit from nova import flags +from nova import rpc FLAGS = flags.FLAGS @@ -62,19 +63,29 @@ class TrialTestCase(unittest.TestCase): self.mox = mox.Mox() self.stubs = stubout.StubOutForTesting() self.flag_overrides = {} + self.injected = [] + self._monkeyPatchAttach() def tearDown(self): # pylint: disable-msg=C0103 """Runs after each test method to finalize/tear down test environment""" - super(TrialTestCase, self).tearDown() self.reset_flags() self.mox.UnsetStubs() self.stubs.UnsetAll() self.stubs.SmartUnsetAll() self.mox.VerifyAll() + + rpc.Consumer.attach_to_twisted = self.originalAttach + for x in self.injected: + try: + x.stop() + except AssertionError: + pass if FLAGS.fake_rabbit: fakerabbit.reset_all() + super(TrialTestCase, self).tearDown() + def flags(self, **kw): """Override flag variables for a test""" for k, v in kw.iteritems(): @@ -90,16 +101,51 @@ class TrialTestCase(unittest.TestCase): for k, v in self.flag_overrides.iteritems(): setattr(FLAGS, k, v) + def run(self, result=None): + test_method = getattr(self, self._testMethodName) + setattr(self, + self._testMethodName, + self._maybeInlineCallbacks(test_method, result)) + rv = super(TrialTestCase, self).run(result) + setattr(self, self._testMethodName, test_method) + return rv + + def _maybeInlineCallbacks(self, func, result): + def _wrapped(): + g = func() + if isinstance(g, defer.Deferred): + return g + if not hasattr(g, 'send'): + return defer.succeed(g) + + inlined = defer.inlineCallbacks(func) + d = inlined() + return d + _wrapped.func_name = func.func_name + return _wrapped + + def _monkeyPatchAttach(self): + self.originalAttach = rpc.Consumer.attach_to_twisted + def _wrapped(innerSelf): + rv = self.originalAttach(innerSelf) + self.injected.append(rv) + return rv + + _wrapped.func_name = self.originalAttach.func_name + rpc.Consumer.attach_to_twisted = _wrapped + class BaseTestCase(TrialTestCase): # TODO(jaypipes): Can this be moved into the TrialTestCase class? - """Base test case class for all unit tests.""" + """Base test case class for all unit tests. + + DEPRECATED: This is being removed once Tornado is gone, use TrialTestCase. + """ def setUp(self): # pylint: disable-msg=C0103 """Run before each test method to initialize test environment""" super(BaseTestCase, self).setUp() # TODO(termie): we could possibly keep a more global registry of # the injected listeners... this is fine for now though - self.injected = [] self.ioloop = ioloop.IOLoop.instance() self._waiting = None @@ -109,8 +155,6 @@ class BaseTestCase(TrialTestCase): def tearDown(self):# pylint: disable-msg=C0103 """Runs after each test method to finalize/tear down test environment""" super(BaseTestCase, self).tearDown() - for x in self.injected: - x.stop() if FLAGS.fake_rabbit: fakerabbit.reset_all() diff --git a/nova/tests/access_unittest.py b/nova/tests/access_unittest.py index 59e1683db..4b40ffd0a 100644 --- a/nova/tests/access_unittest.py +++ b/nova/tests/access_unittest.py @@ -18,19 +18,20 @@ import unittest import logging +import webob from nova import exception from nova import flags from nova import test +from nova.api import ec2 from nova.auth import manager -from nova.auth import rbac FLAGS = flags.FLAGS class Context(object): pass -class AccessTestCase(test.BaseTestCase): +class AccessTestCase(test.TrialTestCase): def setUp(self): super(AccessTestCase, self).setUp() um = manager.AuthManager() @@ -72,9 +73,17 @@ class AccessTestCase(test.BaseTestCase): try: self.project.add_role(self.testsys, 'sysadmin') except: pass - self.context = Context() - self.context.project = self.project #user is set in each test + def noopWSGIApp(environ, start_response): + start_response('200 OK', []) + return [''] + self.mw = ec2.Authorizer(noopWSGIApp) + self.mw.action_roles = {'str': { + '_allow_all': ['all'], + '_allow_none': [], + '_allow_project_manager': ['projectmanager'], + '_allow_sys_and_net': ['sysadmin', 'netadmin'], + '_allow_sysadmin': ['sysadmin']}} def tearDown(self): um = manager.AuthManager() @@ -87,76 +96,46 @@ class AccessTestCase(test.BaseTestCase): um.delete_user('testsys') super(AccessTestCase, self).tearDown() + def response_status(self, user, methodName): + context = Context() + context.project = self.project + context.user = user + environ = {'ec2.context' : context, + 'ec2.controller': 'some string', + 'ec2.action': methodName} + req = webob.Request.blank('/', environ) + resp = req.get_response(self.mw) + return resp.status_int + + def shouldAllow(self, user, methodName): + self.assertEqual(200, self.response_status(user, methodName)) + + def shouldDeny(self, user, methodName): + self.assertEqual(401, self.response_status(user, methodName)) + def test_001_allow_all(self): - self.context.user = self.testadmin - self.assertTrue(self._allow_all(self.context)) - self.context.user = self.testpmsys - self.assertTrue(self._allow_all(self.context)) - self.context.user = self.testnet - self.assertTrue(self._allow_all(self.context)) - self.context.user = self.testsys - self.assertTrue(self._allow_all(self.context)) + users = [self.testadmin, self.testpmsys, self.testnet, self.testsys] + for user in users: + self.shouldAllow(user, '_allow_all') def test_002_allow_none(self): - self.context.user = self.testadmin - self.assertTrue(self._allow_none(self.context)) - self.context.user = self.testpmsys - self.assertRaises(exception.NotAuthorized, self._allow_none, self.context) - self.context.user = self.testnet - self.assertRaises(exception.NotAuthorized, self._allow_none, self.context) - self.context.user = self.testsys - self.assertRaises(exception.NotAuthorized, self._allow_none, self.context) + self.shouldAllow(self.testadmin, '_allow_none') + users = [self.testpmsys, self.testnet, self.testsys] + for user in users: + self.shouldDeny(user, '_allow_none') def test_003_allow_project_manager(self): - self.context.user = self.testadmin - self.assertTrue(self._allow_project_manager(self.context)) - self.context.user = self.testpmsys - self.assertTrue(self._allow_project_manager(self.context)) - self.context.user = self.testnet - self.assertRaises(exception.NotAuthorized, self._allow_project_manager, self.context) - self.context.user = self.testsys - self.assertRaises(exception.NotAuthorized, self._allow_project_manager, self.context) + for user in [self.testadmin, self.testpmsys]: + self.shouldAllow(user, '_allow_project_manager') + for user in [self.testnet, self.testsys]: + self.shouldDeny(user, '_allow_project_manager') def test_004_allow_sys_and_net(self): - self.context.user = self.testadmin - self.assertTrue(self._allow_sys_and_net(self.context)) - self.context.user = self.testpmsys # doesn't have the per project sysadmin - self.assertRaises(exception.NotAuthorized, self._allow_sys_and_net, self.context) - self.context.user = self.testnet - self.assertTrue(self._allow_sys_and_net(self.context)) - self.context.user = self.testsys - self.assertTrue(self._allow_sys_and_net(self.context)) - - def test_005_allow_sys_no_pm(self): - self.context.user = self.testadmin - self.assertTrue(self._allow_sys_no_pm(self.context)) - self.context.user = self.testpmsys - self.assertRaises(exception.NotAuthorized, self._allow_sys_no_pm, self.context) - self.context.user = self.testnet - self.assertRaises(exception.NotAuthorized, self._allow_sys_no_pm, self.context) - self.context.user = self.testsys - self.assertTrue(self._allow_sys_no_pm(self.context)) - - @rbac.allow('all') - def _allow_all(self, context): - return True - - @rbac.allow('none') - def _allow_none(self, context): - return True - - @rbac.allow('projectmanager') - def _allow_project_manager(self, context): - return True - - @rbac.allow('sysadmin', 'netadmin') - def _allow_sys_and_net(self, context): - return True - - @rbac.allow('sysadmin') - @rbac.deny('projectmanager') - def _allow_sys_no_pm(self, context): - return True + for user in [self.testadmin, self.testnet, self.testsys]: + self.shouldAllow(user, '_allow_sys_and_net') + # denied because it doesn't have the per project sysadmin + for user in [self.testpmsys]: + self.shouldDeny(user, '_allow_sys_and_net') if __name__ == "__main__": # TODO: Implement use_fake as an option diff --git a/nova/tests/api/__init__.py b/nova/tests/api/__init__.py index 59c4adc3d..fc1ab9ae2 100644 --- a/nova/tests/api/__init__.py +++ b/nova/tests/api/__init__.py @@ -25,6 +25,7 @@ import stubout import webob import webob.dec +import nova.exception from nova import api from nova.tests.api.test_helper import * @@ -36,24 +37,46 @@ class Test(unittest.TestCase): def tearDown(self): # pylint: disable-msg=C0103 self.stubs.UnsetAll() + def _request(self, url, subdomain, **kwargs): + environ_keys = {'HTTP_HOST': '%s.example.com' % subdomain} + environ_keys.update(kwargs) + req = webob.Request.blank(url, environ_keys) + return req.get_response(api.API()) + def test_rackspace(self): self.stubs.Set(api.rackspace, 'API', APIStub) - result = webob.Request.blank('/v1.0/cloud').get_response(api.API()) + result = self._request('/v1.0/cloud', 'rs') self.assertEqual(result.body, "/cloud") def test_ec2(self): self.stubs.Set(api.ec2, 'API', APIStub) - result = webob.Request.blank('/ec2/cloud').get_response(api.API()) + result = self._request('/services/cloud', 'ec2') self.assertEqual(result.body, "/cloud") def test_not_found(self): self.stubs.Set(api.ec2, 'API', APIStub) self.stubs.Set(api.rackspace, 'API', APIStub) - result = webob.Request.blank('/test/cloud').get_response(api.API()) + result = self._request('/test/cloud', 'ec2') self.assertNotEqual(result.body, "/cloud") - def test_query_api_version(self): - pass + def test_query_api_versions(self): + result = self._request('/', 'rs') + self.assertTrue('CURRENT' in result.body) + + def test_metadata(self): + def go(url): + result = self._request(url, 'ec2', + REMOTE_ADDR='128.192.151.2') + # Each should get to the ORM layer and fail to find the IP + self.assertRaises(nova.exception.NotFound, go, '/latest/') + self.assertRaises(nova.exception.NotFound, go, '/2009-04-04/') + self.assertRaises(nova.exception.NotFound, go, '/1.0/') + + def test_ec2_root(self): + result = self._request('/', 'ec2') + self.assertTrue('2007-12-15\n' in result.body) + + if __name__ == '__main__': unittest.main() diff --git a/nova/tests/api/rackspace/__init__.py b/nova/tests/api/rackspace/__init__.py index e69de29bb..bfd0f87a7 100644 --- a/nova/tests/api/rackspace/__init__.py +++ b/nova/tests/api/rackspace/__init__.py @@ -0,0 +1,108 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +from nova.api.rackspace import limited +from nova.api.rackspace import RateLimitingMiddleware +from nova.tests.api.test_helper import * +from webob import Request + + +class RateLimitingMiddlewareTest(unittest.TestCase): + + def test_get_action_name(self): + middleware = RateLimitingMiddleware(APIStub()) + def verify(method, url, action_name): + req = Request.blank(url) + req.method = method + action = middleware.get_action_name(req) + self.assertEqual(action, action_name) + verify('PUT', '/servers/4', 'PUT') + verify('DELETE', '/servers/4', 'DELETE') + verify('POST', '/images/4', 'POST') + verify('POST', '/servers/4', 'POST servers') + verify('GET', '/foo?a=4&changes-since=never&b=5', 'GET changes-since') + verify('GET', '/foo?a=4&monkeys-since=never&b=5', None) + verify('GET', '/servers/4', None) + verify('HEAD', '/servers/4', None) + + def exhaust(self, middleware, method, url, username, times): + req = Request.blank(url, dict(REQUEST_METHOD=method), + headers={'X-Auth-User': username}) + for i in range(times): + resp = req.get_response(middleware) + self.assertEqual(resp.status_int, 200) + resp = req.get_response(middleware) + self.assertEqual(resp.status_int, 413) + self.assertTrue('Retry-After' in resp.headers) + + def test_single_action(self): + middleware = RateLimitingMiddleware(APIStub()) + self.exhaust(middleware, 'DELETE', '/servers/4', 'usr1', 100) + self.exhaust(middleware, 'DELETE', '/servers/4', 'usr2', 100) + + def test_POST_servers_action_implies_POST_action(self): + middleware = RateLimitingMiddleware(APIStub()) + self.exhaust(middleware, 'POST', '/servers/4', 'usr1', 10) + self.exhaust(middleware, 'POST', '/images/4', 'usr2', 10) + self.assertTrue(set(middleware.limiter._levels) == + set(['usr1:POST', 'usr1:POST servers', 'usr2:POST'])) + + def test_POST_servers_action_correctly_ratelimited(self): + middleware = RateLimitingMiddleware(APIStub()) + # Use up all of our "POST" allowance for the minute, 5 times + for i in range(5): + self.exhaust(middleware, 'POST', '/servers/4', 'usr1', 10) + # Reset the 'POST' action counter. + del middleware.limiter._levels['usr1:POST'] + # All 50 daily "POST servers" actions should be all used up + self.exhaust(middleware, 'POST', '/servers/4', 'usr1', 0) + + def test_proxy_ctor_works(self): + middleware = RateLimitingMiddleware(APIStub()) + self.assertEqual(middleware.limiter.__class__.__name__, "Limiter") + middleware = RateLimitingMiddleware(APIStub(), service_host='foobar') + self.assertEqual(middleware.limiter.__class__.__name__, "WSGIAppProxy") + + +class LimiterTest(unittest.TestCase): + + def testLimiter(self): + items = range(2000) + req = Request.blank('/') + self.assertEqual(limited(items, req), items[ :1000]) + req = Request.blank('/?offset=0') + self.assertEqual(limited(items, req), items[ :1000]) + req = Request.blank('/?offset=3') + self.assertEqual(limited(items, req), items[3:1003]) + req = Request.blank('/?offset=2005') + self.assertEqual(limited(items, req), []) + req = Request.blank('/?limit=10') + self.assertEqual(limited(items, req), items[ :10]) + req = Request.blank('/?limit=0') + self.assertEqual(limited(items, req), items[ :1000]) + req = Request.blank('/?limit=3000') + self.assertEqual(limited(items, req), items[ :1000]) + req = Request.blank('/?offset=1&limit=3') + self.assertEqual(limited(items, req), items[1:4]) + req = Request.blank('/?offset=3&limit=0') + self.assertEqual(limited(items, req), items[3:1003]) + req = Request.blank('/?offset=3&limit=1500') + self.assertEqual(limited(items, req), items[3:1003]) + req = Request.blank('/?offset=3000&limit=10') + self.assertEqual(limited(items, req), []) diff --git a/nova/tests/api/rackspace/auth.py b/nova/tests/api/rackspace/auth.py new file mode 100644 index 000000000..56677c2f4 --- /dev/null +++ b/nova/tests/api/rackspace/auth.py @@ -0,0 +1,108 @@ +import datetime +import unittest + +import stubout +import webob +import webob.dec + +import nova.api +import nova.api.rackspace.auth +from nova import auth +from nova.tests.api.rackspace import test_helper + +class Test(unittest.TestCase): + def setUp(self): + self.stubs = stubout.StubOutForTesting() + self.stubs.Set(nova.api.rackspace.auth.BasicApiAuthManager, + '__init__', test_helper.fake_auth_init) + test_helper.FakeAuthManager.auth_data = {} + test_helper.FakeAuthDatabase.data = {} + test_helper.stub_out_rate_limiting(self.stubs) + test_helper.stub_for_testing(self.stubs) + + def tearDown(self): + self.stubs.UnsetAll() + test_helper.fake_data_store = {} + + def test_authorize_user(self): + f = test_helper.FakeAuthManager() + f.add_user('derp', { 'uid': 1, 'name':'herp' } ) + + req = webob.Request.blank('/v1.0/') + req.headers['X-Auth-User'] = 'herp' + req.headers['X-Auth-Key'] = 'derp' + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '204 No Content') + self.assertEqual(len(result.headers['X-Auth-Token']), 40) + self.assertEqual(result.headers['X-CDN-Management-Url'], + "") + self.assertEqual(result.headers['X-Storage-Url'], "") + + def test_authorize_token(self): + f = test_helper.FakeAuthManager() + f.add_user('derp', { 'uid': 1, 'name':'herp' } ) + + req = webob.Request.blank('/v1.0/') + req.headers['X-Auth-User'] = 'herp' + req.headers['X-Auth-Key'] = 'derp' + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '204 No Content') + self.assertEqual(len(result.headers['X-Auth-Token']), 40) + self.assertEqual(result.headers['X-Server-Management-Url'], + "https://foo/v1.0/") + self.assertEqual(result.headers['X-CDN-Management-Url'], + "") + self.assertEqual(result.headers['X-Storage-Url'], "") + + token = result.headers['X-Auth-Token'] + self.stubs.Set(nova.api.rackspace, 'APIRouter', + test_helper.FakeRouter) + req = webob.Request.blank('/v1.0/fake') + req.headers['X-Auth-Token'] = token + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '200 OK') + self.assertEqual(result.headers['X-Test-Success'], 'True') + + def test_token_expiry(self): + self.destroy_called = False + token_hash = 'bacon' + + def destroy_token_mock(meh, context, token): + self.destroy_called = True + + def bad_token(meh, context, token_hash): + return { 'token_hash':token_hash, + 'created_at':datetime.datetime(1990, 1, 1) } + + self.stubs.Set(test_helper.FakeAuthDatabase, 'auth_destroy_token', + destroy_token_mock) + + self.stubs.Set(test_helper.FakeAuthDatabase, 'auth_get_token', + bad_token) + + req = webob.Request.blank('/v1.0/') + req.headers['X-Auth-Token'] = 'bacon' + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '401 Unauthorized') + self.assertEqual(self.destroy_called, True) + + def test_bad_user(self): + req = webob.Request.blank('/v1.0/') + req.headers['X-Auth-User'] = 'herp' + req.headers['X-Auth-Key'] = 'derp' + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '401 Unauthorized') + + def test_no_user(self): + req = webob.Request.blank('/v1.0/') + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '401 Unauthorized') + + def test_bad_token(self): + req = webob.Request.blank('/v1.0/') + req.headers['X-Auth-Token'] = 'baconbaconbacon' + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '401 Unauthorized') + +if __name__ == '__main__': + unittest.main() diff --git a/nova/tests/api/rackspace/flavors.py b/nova/tests/api/rackspace/flavors.py index fb8ba94a5..d25a2e2be 100644 --- a/nova/tests/api/rackspace/flavors.py +++ b/nova/tests/api/rackspace/flavors.py @@ -16,19 +16,31 @@ # under the License. import unittest +import stubout +import nova.api from nova.api.rackspace import flavors +from nova.tests.api.rackspace import test_helper from nova.tests.api.test_helper import * class FlavorsTest(unittest.TestCase): def setUp(self): self.stubs = stubout.StubOutForTesting() + test_helper.FakeAuthManager.auth_data = {} + test_helper.FakeAuthDatabase.data = {} + test_helper.stub_for_testing(self.stubs) + test_helper.stub_out_rate_limiting(self.stubs) + test_helper.stub_out_auth(self.stubs) def tearDown(self): self.stubs.UnsetAll() def test_get_flavor_list(self): - pass + req = webob.Request.blank('/v1.0/flavors') + res = req.get_response(nova.api.API()) def test_get_flavor_by_id(self): pass + +if __name__ == '__main__': + unittest.main() diff --git a/nova/tests/api/rackspace/images.py b/nova/tests/api/rackspace/images.py index 560d8c898..4c9987e8b 100644 --- a/nova/tests/api/rackspace/images.py +++ b/nova/tests/api/rackspace/images.py @@ -15,6 +15,7 @@ # License for the specific language governing permissions and limitations # under the License. +import stubout import unittest from nova.api.rackspace import images diff --git a/nova/tests/api/rackspace/servers.py b/nova/tests/api/rackspace/servers.py index 6d628e78a..69ad2c1d3 100644 --- a/nova/tests/api/rackspace/servers.py +++ b/nova/tests/api/rackspace/servers.py @@ -15,44 +15,231 @@ # License for the specific language governing permissions and limitations # under the License. +import json import unittest +import stubout + +from nova import db +from nova import flags +import nova.api.rackspace from nova.api.rackspace import servers +import nova.db.api +from nova.db.sqlalchemy.models import Instance +import nova.rpc from nova.tests.api.test_helper import * +from nova.tests.api.rackspace import test_helper + +FLAGS = flags.FLAGS + +def return_server(context, id): + return stub_instance(id) + +def return_servers(context, user_id=1): + return [stub_instance(i, user_id) for i in xrange(5)] + + +def stub_instance(id, user_id=1): + return Instance( + id=id, state=0, image_id=10, server_name='server%s'%id, + user_id=user_id + ) class ServersTest(unittest.TestCase): def setUp(self): self.stubs = stubout.StubOutForTesting() + test_helper.FakeAuthManager.auth_data = {} + test_helper.FakeAuthDatabase.data = {} + test_helper.stub_for_testing(self.stubs) + test_helper.stub_out_rate_limiting(self.stubs) + test_helper.stub_out_auth(self.stubs) + test_helper.stub_out_id_translator(self.stubs) + test_helper.stub_out_key_pair_funcs(self.stubs) + test_helper.stub_out_image_service(self.stubs) + self.stubs.Set(nova.db.api, 'instance_get_all', return_servers) + self.stubs.Set(nova.db.api, 'instance_get_by_ec2_id', return_server) + self.stubs.Set(nova.db.api, 'instance_get_all_by_user', + return_servers) def tearDown(self): self.stubs.UnsetAll() + def test_get_server_by_id(self): + req = webob.Request.blank('/v1.0/servers/1') + res = req.get_response(nova.api.API()) + res_dict = json.loads(res.body) + self.assertEqual(res_dict['server']['id'], '1') + self.assertEqual(res_dict['server']['name'], 'server1') + def test_get_server_list(self): - pass + req = webob.Request.blank('/v1.0/servers') + res = req.get_response(nova.api.API()) + res_dict = json.loads(res.body) + + i = 0 + for s in res_dict['servers']: + self.assertEqual(s['id'], i) + self.assertEqual(s['name'], 'server%d'%i) + self.assertEqual(s.get('imageId', None), None) + i += 1 def test_create_instance(self): - pass + def server_update(context, id, params): + pass - def test_get_server_by_id(self): - pass + def instance_create(context, inst): + class Foo(object): + ec2_id = 1 + return Foo() + + def fake_method(*args, **kwargs): + pass + + def project_get_network(context, user_id): + return dict(id='1', host='localhost') + + def queue_get_for(context, *args): + return 'network_topic' + + self.stubs.Set(nova.db.api, 'project_get_network', project_get_network) + self.stubs.Set(nova.db.api, 'instance_create', instance_create) + self.stubs.Set(nova.rpc, 'cast', fake_method) + self.stubs.Set(nova.rpc, 'call', fake_method) + self.stubs.Set(nova.db.api, 'instance_update', + server_update) + self.stubs.Set(nova.db.api, 'queue_get_for', queue_get_for) + self.stubs.Set(nova.network.manager.FlatManager, 'allocate_fixed_ip', + fake_method) + + test_helper.stub_out_id_translator(self.stubs) + body = dict(server=dict( + name='server_test', imageId=2, flavorId=2, metadata={}, + personality = {} + )) + req = webob.Request.blank('/v1.0/servers') + req.method = 'POST' + req.body = json.dumps(body) + + res = req.get_response(nova.api.API()) + + self.assertEqual(res.status_int, 200) - def test_get_backup_schedule(self): - pass + def test_update_no_body(self): + req = webob.Request.blank('/v1.0/servers/1') + req.method = 'PUT' + res = req.get_response(nova.api.API()) + self.assertEqual(res.status_int, 422) - def test_get_server_details(self): - pass + def test_update_bad_params(self): + """ Confirm that update is filtering params """ + inst_dict = dict(cat='leopard', name='server_test', adminPass='bacon') + self.body = json.dumps(dict(server=inst_dict)) - def test_get_server_ips(self): - pass + def server_update(context, id, params): + self.update_called = True + filtered_dict = dict(name='server_test', admin_pass='bacon') + self.assertEqual(params, filtered_dict) + + self.stubs.Set(nova.db.api, 'instance_update', + server_update) + + req = webob.Request.blank('/v1.0/servers/1') + req.method = 'PUT' + req.body = self.body + req.get_response(nova.api.API()) + + def test_update_server(self): + inst_dict = dict(name='server_test', adminPass='bacon') + self.body = json.dumps(dict(server=inst_dict)) + + def server_update(context, id, params): + filtered_dict = dict(name='server_test', admin_pass='bacon') + self.assertEqual(params, filtered_dict) + + self.stubs.Set(nova.db.api, 'instance_update', + server_update) + + req = webob.Request.blank('/v1.0/servers/1') + req.method = 'PUT' + req.body = self.body + req.get_response(nova.api.API()) + + def test_create_backup_schedules(self): + req = webob.Request.blank('/v1.0/servers/1/backup_schedules') + req.method = 'POST' + res = req.get_response(nova.api.API()) + self.assertEqual(res.status, '404 Not Found') + + def test_delete_backup_schedules(self): + req = webob.Request.blank('/v1.0/servers/1/backup_schedules') + req.method = 'DELETE' + res = req.get_response(nova.api.API()) + self.assertEqual(res.status, '404 Not Found') + + def test_get_server_backup_schedules(self): + req = webob.Request.blank('/v1.0/servers/1/backup_schedules') + res = req.get_response(nova.api.API()) + self.assertEqual(res.status, '404 Not Found') + + def test_get_all_server_details(self): + req = webob.Request.blank('/v1.0/servers/detail') + res = req.get_response(nova.api.API()) + res_dict = json.loads(res.body) + + i = 0 + for s in res_dict['servers']: + self.assertEqual(s['id'], i) + self.assertEqual(s['name'], 'server%d'%i) + self.assertEqual(s['imageId'], 10) + i += 1 def test_server_reboot(self): - pass + body = dict(server=dict( + name='server_test', imageId=2, flavorId=2, metadata={}, + personality = {} + )) + req = webob.Request.blank('/v1.0/servers/1/action') + req.method = 'POST' + req.content_type= 'application/json' + req.body = json.dumps(body) + res = req.get_response(nova.api.API()) def test_server_rebuild(self): - pass + body = dict(server=dict( + name='server_test', imageId=2, flavorId=2, metadata={}, + personality = {} + )) + req = webob.Request.blank('/v1.0/servers/1/action') + req.method = 'POST' + req.content_type= 'application/json' + req.body = json.dumps(body) + res = req.get_response(nova.api.API()) def test_server_resize(self): - pass + body = dict(server=dict( + name='server_test', imageId=2, flavorId=2, metadata={}, + personality = {} + )) + req = webob.Request.blank('/v1.0/servers/1/action') + req.method = 'POST' + req.content_type= 'application/json' + req.body = json.dumps(body) + res = req.get_response(nova.api.API()) def test_delete_server_instance(self): - pass + req = webob.Request.blank('/v1.0/servers/1') + req.method = 'DELETE' + + self.server_delete_called = False + def instance_destroy_mock(context, id): + self.server_delete_called = True + + self.stubs.Set(nova.db.api, 'instance_destroy', + instance_destroy_mock) + + res = req.get_response(nova.api.API()) + self.assertEqual(res.status, '202 Accepted') + self.assertEqual(self.server_delete_called, True) + +if __name__ == "__main__": + unittest.main() diff --git a/nova/tests/api/rackspace/sharedipgroups.py b/nova/tests/api/rackspace/sharedipgroups.py index b4b281db7..1906b54f5 100644 --- a/nova/tests/api/rackspace/sharedipgroups.py +++ b/nova/tests/api/rackspace/sharedipgroups.py @@ -15,6 +15,7 @@ # License for the specific language governing permissions and limitations # under the License. +import stubout import unittest from nova.api.rackspace import sharedipgroups diff --git a/nova/tests/api/rackspace/test_helper.py b/nova/tests/api/rackspace/test_helper.py new file mode 100644 index 000000000..2cf154f63 --- /dev/null +++ b/nova/tests/api/rackspace/test_helper.py @@ -0,0 +1,134 @@ +import datetime +import json + +import webob +import webob.dec + +from nova import auth +from nova import utils +from nova import flags +import nova.api.rackspace.auth +import nova.api.rackspace._id_translator +from nova.image import service +from nova.wsgi import Router + +FLAGS = flags.FLAGS + +class Context(object): + pass + +class FakeRouter(Router): + def __init__(self): + pass + + @webob.dec.wsgify + def __call__(self, req): + res = webob.Response() + res.status = '200' + res.headers['X-Test-Success'] = 'True' + return res + +def fake_auth_init(self): + self.db = FakeAuthDatabase() + self.context = Context() + self.auth = FakeAuthManager() + self.host = 'foo' + +@webob.dec.wsgify +def fake_wsgi(self, req): + req.environ['nova.context'] = dict(user=dict(id=1)) + if req.body: + req.environ['inst_dict'] = json.loads(req.body) + return self.application + +def stub_out_key_pair_funcs(stubs): + def key_pair(context, user_id): + return [dict(name='key', public_key='public_key')] + stubs.Set(nova.db.api, 'key_pair_get_all_by_user', + key_pair) + +def stub_out_image_service(stubs): + def fake_image_show(meh, id): + return dict(kernelId=1, ramdiskId=1) + + stubs.Set(nova.image.service.LocalImageService, 'show', fake_image_show) + +def stub_out_id_translator(stubs): + class FakeTranslator(object): + def __init__(self, id_type, service_name): + pass + + def to_rs_id(self, id): + return id + + def from_rs_id(self, id): + return id + + stubs.Set(nova.api.rackspace._id_translator, + 'RackspaceAPIIdTranslator', FakeTranslator) + +def stub_out_auth(stubs): + def fake_auth_init(self, app): + self.application = app + + stubs.Set(nova.api.rackspace.AuthMiddleware, + '__init__', fake_auth_init) + stubs.Set(nova.api.rackspace.AuthMiddleware, + '__call__', fake_wsgi) + +def stub_out_rate_limiting(stubs): + def fake_rate_init(self, app): + super(nova.api.rackspace.RateLimitingMiddleware, self).__init__(app) + self.application = app + + stubs.Set(nova.api.rackspace.RateLimitingMiddleware, + '__init__', fake_rate_init) + + stubs.Set(nova.api.rackspace.RateLimitingMiddleware, + '__call__', fake_wsgi) + +def stub_for_testing(stubs): + def get_my_ip(): + return '127.0.0.1' + stubs.Set(nova.utils, 'get_my_ip', get_my_ip) + FLAGS.FAKE_subdomain = 'rs' + +class FakeAuthDatabase(object): + data = {} + + @staticmethod + def auth_get_token(context, token_hash): + return FakeAuthDatabase.data.get(token_hash, None) + + @staticmethod + def auth_create_token(context, token): + token['created_at'] = datetime.datetime.now() + FakeAuthDatabase.data[token['token_hash']] = token + + @staticmethod + def auth_destroy_token(context, token): + if FakeAuthDatabase.data.has_key(token['token_hash']): + del FakeAuthDatabase.data['token_hash'] + +class FakeAuthManager(object): + auth_data = {} + + def add_user(self, key, user): + FakeAuthManager.auth_data[key] = user + + def get_user(self, uid): + for k, v in FakeAuthManager.auth_data.iteritems(): + if v['uid'] == uid: + return v + return None + + def get_user_from_access_key(self, key): + return FakeAuthManager.auth_data.get(key, None) + +class FakeRateLimiter(object): + def __init__(self, application): + self.application = application + + @webob.dec.wsgify + def __call__(self, req): + return self.application diff --git a/nova/tests/api/rackspace/testfaults.py b/nova/tests/api/rackspace/testfaults.py new file mode 100644 index 000000000..b2931bc98 --- /dev/null +++ b/nova/tests/api/rackspace/testfaults.py @@ -0,0 +1,40 @@ +import unittest +import webob +import webob.dec +import webob.exc + +from nova.api.rackspace import faults + +class TestFaults(unittest.TestCase): + + def test_fault_parts(self): + req = webob.Request.blank('/.xml') + f = faults.Fault(webob.exc.HTTPBadRequest(explanation='scram')) + resp = req.get_response(f) + + first_two_words = resp.body.strip().split()[:2] + self.assertEqual(first_two_words, ['<badRequest', 'code="400">']) + body_without_spaces = ''.join(resp.body.split()) + self.assertTrue('<message>scram</message>' in body_without_spaces) + + def test_retry_header(self): + req = webob.Request.blank('/.xml') + exc = webob.exc.HTTPRequestEntityTooLarge(explanation='sorry', + headers={'Retry-After': 4}) + f = faults.Fault(exc) + resp = req.get_response(f) + first_two_words = resp.body.strip().split()[:2] + self.assertEqual(first_two_words, ['<overLimit', 'code="413">']) + body_sans_spaces = ''.join(resp.body.split()) + self.assertTrue('<message>sorry</message>' in body_sans_spaces) + self.assertTrue('<retryAfter>4</retryAfter>' in body_sans_spaces) + self.assertEqual(resp.headers['Retry-After'], 4) + + def test_raise(self): + @webob.dec.wsgify + def raiser(req): + raise faults.Fault(webob.exc.HTTPNotFound(explanation='whut?')) + req = webob.Request.blank('/.xml') + resp = req.get_response(raiser) + self.assertEqual(resp.status_int, 404) + self.assertTrue('whut?' in resp.body) diff --git a/nova/tests/api/test_helper.py b/nova/tests/api/test_helper.py index 8151a4af6..d0a2cc027 100644 --- a/nova/tests/api/test_helper.py +++ b/nova/tests/api/test_helper.py @@ -1,4 +1,5 @@ import webob.dec +from nova import wsgi class APIStub(object): """Class to verify request and mark it was called.""" diff --git a/nova/tests/api/wsgi_test.py b/nova/tests/api/wsgi_test.py index 786dc1bce..9425b01d0 100644 --- a/nova/tests/api/wsgi_test.py +++ b/nova/tests/api/wsgi_test.py @@ -91,6 +91,57 @@ class Test(unittest.TestCase): result = webob.Request.blank('/test/123').get_response(Router()) self.assertNotEqual(result.body, "123") - def test_serializer(self): - # TODO(eday): Placeholder for serializer testing. - pass + +class SerializerTest(unittest.TestCase): + + def match(self, url, accept, expect): + input_dict = dict(servers=dict(a=(2,3))) + expected_xml = '<servers><a>(2,3)</a></servers>' + expected_json = '{"servers":{"a":[2,3]}}' + req = webob.Request.blank(url, headers=dict(Accept=accept)) + result = wsgi.Serializer(req.environ).to_content_type(input_dict) + result = result.replace('\n', '').replace(' ', '') + if expect == 'xml': + self.assertEqual(result, expected_xml) + elif expect == 'json': + self.assertEqual(result, expected_json) + else: + raise "Bad expect value" + + def test_basic(self): + self.match('/servers/4.json', None, expect='json') + self.match('/servers/4', 'application/json', expect='json') + self.match('/servers/4', 'application/xml', expect='xml') + self.match('/servers/4.xml', None, expect='xml') + + def test_defaults_to_json(self): + self.match('/servers/4', None, expect='json') + self.match('/servers/4', 'text/html', expect='json') + + def test_suffix_takes_precedence_over_accept_header(self): + self.match('/servers/4.xml', 'application/json', expect='xml') + self.match('/servers/4.xml.', 'application/json', expect='json') + + def test_deserialize(self): + xml = """ + <a a1="1" a2="2"> + <bs><b>1</b><b>2</b><b>3</b><b><c c1="1"/></b></bs> + <d><e>1</e></d> + <f>1</f> + </a> + """.strip() + as_dict = dict(a={ + 'a1': '1', + 'a2': '2', + 'bs': ['1', '2', '3', {'c': dict(c1='1')}], + 'd': {'e': '1'}, + 'f': '1'}) + metadata = {'application/xml': dict(plurals={'bs': 'b', 'ts': 't'})} + serializer = wsgi.Serializer({}, metadata) + self.assertEqual(serializer.deserialize(xml), as_dict) + + def test_deserialize_empty_xml(self): + xml = """<a></a>""" + as_dict = {"a": {}} + serializer = wsgi.Serializer({}) + self.assertEqual(serializer.deserialize(xml), as_dict) diff --git a/nova/tests/api_unittest.py b/nova/tests/api_unittest.py index 462d1b295..c040cdad3 100644 --- a/nova/tests/api_unittest.py +++ b/nova/tests/api_unittest.py @@ -23,60 +23,17 @@ from boto.ec2 import regioninfo import httplib import random import StringIO -from tornado import httpserver -from twisted.internet import defer +import webob from nova import flags from nova import test +from nova import api +from nova.api.ec2 import cloud from nova.auth import manager -from nova.endpoint import api -from nova.endpoint import cloud FLAGS = flags.FLAGS - - -# NOTE(termie): These are a bunch of helper methods and classes to short -# circuit boto calls and feed them into our tornado handlers, -# it's pretty damn circuitous so apologies if you have to fix -# a bug in it -# NOTE(jaypipes) The pylint disables here are for R0913 (too many args) which -# isn't controllable since boto's HTTPRequest needs that many -# args, and for the version-differentiated import of tornado's -# httputil. -# NOTE(jaypipes): The disable-msg=E1101 and E1103 below is because pylint is -# unable to introspect the deferred's return value properly - -def boto_to_tornado(method, path, headers, data, # pylint: disable-msg=R0913 - host, connection=None): - """ translate boto requests into tornado requests - - connection should be a FakeTornadoHttpConnection instance - """ - try: - headers = httpserver.HTTPHeaders() - except AttributeError: - from tornado import httputil # pylint: disable-msg=E0611 - headers = httputil.HTTPHeaders() - for k, v in headers.iteritems(): - headers[k] = v - - req = httpserver.HTTPRequest(method=method, - uri=path, - headers=headers, - body=data, - host=host, - remote_ip='127.0.0.1', - connection=connection) - return req - - -def raw_to_httpresponse(response_string): - """translate a raw tornado http response into an httplib.HTTPResponse""" - sock = FakeHttplibSocket(response_string) - resp = httplib.HTTPResponse(sock) - resp.begin() - return resp +FLAGS.FAKE_subdomain = 'ec2' class FakeHttplibSocket(object): @@ -89,85 +46,35 @@ class FakeHttplibSocket(object): return self._buffer -class FakeTornadoStream(object): - """a fake stream to satisfy tornado's assumptions, trivial""" - def set_close_callback(self, _func): - """Dummy callback for stream""" - pass - - -class FakeTornadoConnection(object): - """A fake connection object for tornado to pass to its handlers - - web requests are expected to write to this as they get data and call - finish when they are done with the request, we buffer the writes and - kick off a callback when it is done so that we can feed the result back - into boto. - """ - def __init__(self, deferred): - self._deferred = deferred - self._buffer = StringIO.StringIO() - - def write(self, chunk): - """Writes a chunk of data to the internal buffer""" - self._buffer.write(chunk) - - def finish(self): - """Finalizes the connection and returns the buffered data via the - deferred callback. - """ - data = self._buffer.getvalue() - self._deferred.callback(data) - - xheaders = None - - @property - def stream(self): # pylint: disable-msg=R0201 - """Required property for interfacing with tornado""" - return FakeTornadoStream() - - class FakeHttplibConnection(object): """A fake httplib.HTTPConnection for boto to use requests made via this connection actually get translated and routed into - our tornado app, we then wait for the response and turn it back into + our WSGI app, we then wait for the response and turn it back into the httplib.HTTPResponse that boto expects. """ def __init__(self, app, host, is_secure=False): self.app = app self.host = host - self.deferred = defer.Deferred() def request(self, method, path, data, headers): - """Creates a connection to a fake tornado and sets - up a deferred request with the supplied data and - headers""" - conn = FakeTornadoConnection(self.deferred) - request = boto_to_tornado(connection=conn, - method=method, - path=path, - headers=headers, - data=data, - host=self.host) - self.app(request) - self.deferred.addCallback(raw_to_httpresponse) + req = webob.Request.blank(path) + req.method = method + req.body = data + req.headers = headers + req.headers['Accept'] = 'text/html' + req.host = self.host + # Call the WSGI app, get the HTTP response + resp = str(req.get_response(self.app)) + # For some reason, the response doesn't have "HTTP/1.0 " prepended; I + # guess that's a function the web server usually provides. + resp = "HTTP/1.0 %s" % resp + sock = FakeHttplibSocket(resp) + self.http_response = httplib.HTTPResponse(sock) + self.http_response.begin() def getresponse(self): - """A bit of deferred magic for catching the response - from the previously deferred request""" - @defer.inlineCallbacks - def _waiter(): - """Callback that simply yields the deferred's - return value.""" - result = yield self.deferred - defer.returnValue(result) - d = _waiter() - # NOTE(termie): defer.returnValue above should ensure that - # this deferred has already been called by the time - # we get here, we are going to cheat and return - # the result of the callback - return d.result # pylint: disable-msg=E1101 + return self.http_response def close(self): """Required for compatibility with boto/tornado""" @@ -180,17 +87,16 @@ class ApiEc2TestCase(test.BaseTestCase): super(ApiEc2TestCase, self).setUp() self.manager = manager.AuthManager() - self.cloud = cloud.CloudController() self.host = '127.0.0.1' - self.app = api.APIServerApplication({'Cloud': self.cloud}) + self.app = api.API() self.ec2 = boto.connect_ec2( aws_access_key_id='fake', aws_secret_access_key='fake', is_secure=False, region=regioninfo.RegionInfo(None, 'test', self.host), - port=FLAGS.cc_port, + port=8773, path='/services/Cloud') self.mox.StubOutWithMock(self.ec2, 'new_http_connection') @@ -198,7 +104,7 @@ class ApiEc2TestCase(test.BaseTestCase): def expect_http(self, host=None, is_secure=False): """Returns a new EC2 connection""" http = FakeHttplibConnection( - self.app, '%s:%d' % (self.host, FLAGS.cc_port), False) + self.app, '%s:8773' % (self.host), False) # pylint: disable-msg=E1103 self.ec2.new_http_connection(host, is_secure).AndReturn(http) return http @@ -224,7 +130,8 @@ class ApiEc2TestCase(test.BaseTestCase): for x in range(random.randint(4, 8))) user = self.manager.create_user('fake', 'fake', 'fake') project = self.manager.create_project('fake', 'fake', 'fake') - self.manager.generate_key_pair(user.id, keyname) + # NOTE(vish): create depends on pool, so call helper directly + cloud._gen_key(None, user.id, keyname) rv = self.ec2.get_all_key_pairs() results = [k for k in rv if k.name == keyname] diff --git a/nova/tests/auth_unittest.py b/nova/tests/auth_unittest.py index b54e68274..99f7ab599 100644 --- a/nova/tests/auth_unittest.py +++ b/nova/tests/auth_unittest.py @@ -17,8 +17,6 @@ # under the License. import logging -from M2Crypto import BIO -from M2Crypto import RSA from M2Crypto import X509 import unittest @@ -26,29 +24,76 @@ from nova import crypto from nova import flags from nova import test from nova.auth import manager -from nova.endpoint import cloud +from nova.api.ec2 import cloud FLAGS = flags.FLAGS - -class AuthTestCase(test.BaseTestCase): +class user_generator(object): + def __init__(self, manager, **user_state): + if 'name' not in user_state: + user_state['name'] = 'test1' + self.manager = manager + self.user = manager.create_user(**user_state) + + def __enter__(self): + return self.user + + def __exit__(self, value, type, trace): + self.manager.delete_user(self.user) + +class project_generator(object): + def __init__(self, manager, **project_state): + if 'name' not in project_state: + project_state['name'] = 'testproj' + if 'manager_user' not in project_state: + project_state['manager_user'] = 'test1' + self.manager = manager + self.project = manager.create_project(**project_state) + + def __enter__(self): + return self.project + + def __exit__(self, value, type, trace): + self.manager.delete_project(self.project) + +class user_and_project_generator(object): + def __init__(self, manager, user_state={}, project_state={}): + self.manager = manager + if 'name' not in user_state: + user_state['name'] = 'test1' + if 'name' not in project_state: + project_state['name'] = 'testproj' + if 'manager_user' not in project_state: + project_state['manager_user'] = 'test1' + self.user = manager.create_user(**user_state) + self.project = manager.create_project(**project_state) + + def __enter__(self): + return (self.user, self.project) + + def __exit__(self, value, type, trace): + self.manager.delete_user(self.user) + self.manager.delete_project(self.project) + +class AuthManagerTestCase(object): def setUp(self): - super(AuthTestCase, self).setUp() + FLAGS.auth_driver = self.auth_driver + super(AuthManagerTestCase, self).setUp() self.flags(connection_type='fake') self.manager = manager.AuthManager() - def test_001_can_create_users(self): - self.manager.create_user('test1', 'access', 'secret') - self.manager.create_user('test2') - - def test_002_can_get_user(self): - user = self.manager.get_user('test1') + def test_create_and_find_user(self): + with user_generator(self.manager): + self.assert_(self.manager.get_user('test1')) - def test_003_can_retreive_properties(self): - user = self.manager.get_user('test1') - self.assertEqual('test1', user.id) - self.assertEqual('access', user.access) - self.assertEqual('secret', user.secret) + def test_create_and_find_with_properties(self): + with user_generator(self.manager, name="herbert", secret="classified", + access="private-party"): + u = self.manager.get_user('herbert') + self.assertEqual('herbert', u.id) + self.assertEqual('herbert', u.name) + self.assertEqual('classified', u.secret) + self.assertEqual('private-party', u.access) def test_004_signature_is_valid(self): #self.assertTrue(self.manager.authenticate( **boto.generate_url ... ? ? ? )) @@ -65,156 +110,222 @@ class AuthTestCase(test.BaseTestCase): 'export S3_URL="http://127.0.0.1:3333/"\n' + 'export EC2_USER_ID="test1"\n') - def test_006_test_key_storage(self): - user = self.manager.get_user('test1') - user.create_key_pair('public', 'key', 'fingerprint') - key = user.get_key_pair('public') - self.assertEqual('key', key.public_key) - self.assertEqual('fingerprint', key.fingerprint) - - def test_007_test_key_generation(self): - user = self.manager.get_user('test1') - private_key, fingerprint = user.generate_key_pair('public2') - key = RSA.load_key_string(private_key, callback=lambda: None) - bio = BIO.MemoryBuffer() - public_key = user.get_key_pair('public2').public_key - key.save_pub_key_bio(bio) - converted = crypto.ssl_pub_to_ssh_pub(bio.read()) - # assert key fields are equal - self.assertEqual(public_key.split(" ")[1].strip(), - converted.split(" ")[1].strip()) - - def test_008_can_list_key_pairs(self): - keys = self.manager.get_user('test1').get_key_pairs() - self.assertTrue(filter(lambda k: k.name == 'public', keys)) - self.assertTrue(filter(lambda k: k.name == 'public2', keys)) - - def test_009_can_delete_key_pair(self): - self.manager.get_user('test1').delete_key_pair('public') - keys = self.manager.get_user('test1').get_key_pairs() - self.assertFalse(filter(lambda k: k.name == 'public', keys)) - - def test_010_can_list_users(self): - users = self.manager.get_users() - logging.warn(users) - self.assertTrue(filter(lambda u: u.id == 'test1', users)) - - def test_101_can_add_user_role(self): - self.assertFalse(self.manager.has_role('test1', 'itsec')) - self.manager.add_role('test1', 'itsec') - self.assertTrue(self.manager.has_role('test1', 'itsec')) - - def test_199_can_remove_user_role(self): - self.assertTrue(self.manager.has_role('test1', 'itsec')) - self.manager.remove_role('test1', 'itsec') - self.assertFalse(self.manager.has_role('test1', 'itsec')) - - def test_201_can_create_project(self): - project = self.manager.create_project('testproj', 'test1', 'A test project', ['test1']) - self.assertTrue(filter(lambda p: p.name == 'testproj', self.manager.get_projects())) - self.assertEqual(project.name, 'testproj') - self.assertEqual(project.description, 'A test project') - self.assertEqual(project.project_manager_id, 'test1') - self.assertTrue(project.has_member('test1')) - - def test_202_user1_is_project_member(self): - self.assertTrue(self.manager.get_user('test1').is_project_member('testproj')) - - def test_203_user2_is_not_project_member(self): - self.assertFalse(self.manager.get_user('test2').is_project_member('testproj')) - - def test_204_user1_is_project_manager(self): - self.assertTrue(self.manager.get_user('test1').is_project_manager('testproj')) - - def test_205_user2_is_not_project_manager(self): - self.assertFalse(self.manager.get_user('test2').is_project_manager('testproj')) - - def test_206_can_add_user_to_project(self): - self.manager.add_to_project('test2', 'testproj') - self.assertTrue(self.manager.get_project('testproj').has_member('test2')) - - def test_207_can_remove_user_from_project(self): - self.manager.remove_from_project('test2', 'testproj') - self.assertFalse(self.manager.get_project('testproj').has_member('test2')) - - def test_208_can_remove_add_user_with_role(self): - self.manager.add_to_project('test2', 'testproj') - self.manager.add_role('test2', 'developer', 'testproj') - self.manager.remove_from_project('test2', 'testproj') - self.assertFalse(self.manager.has_role('test2', 'developer', 'testproj')) - self.manager.add_to_project('test2', 'testproj') - self.manager.remove_from_project('test2', 'testproj') - - def test_209_can_generate_x509(self): - # MUST HAVE RUN CLOUD SETUP BY NOW - self.cloud = cloud.CloudController() - self.cloud.setup() - _key, cert_str = self.manager._generate_x509_cert('test1', 'testproj') - logging.debug(cert_str) - - # Need to verify that it's signed by the right intermediate CA - full_chain = crypto.fetch_ca(project_id='testproj', chain=True) - int_cert = crypto.fetch_ca(project_id='testproj', chain=False) - cloud_cert = crypto.fetch_ca() - logging.debug("CA chain:\n\n =====\n%s\n\n=====" % full_chain) - signed_cert = X509.load_cert_string(cert_str) - chain_cert = X509.load_cert_string(full_chain) - int_cert = X509.load_cert_string(int_cert) - cloud_cert = X509.load_cert_string(cloud_cert) - self.assertTrue(signed_cert.verify(chain_cert.get_pubkey())) - self.assertTrue(signed_cert.verify(int_cert.get_pubkey())) - - if not FLAGS.use_intermediate_ca: - self.assertTrue(signed_cert.verify(cloud_cert.get_pubkey())) - else: - self.assertFalse(signed_cert.verify(cloud_cert.get_pubkey())) - - def test_210_can_add_project_role(self): - project = self.manager.get_project('testproj') - self.assertFalse(project.has_role('test1', 'sysadmin')) - self.manager.add_role('test1', 'sysadmin') - self.assertFalse(project.has_role('test1', 'sysadmin')) - project.add_role('test1', 'sysadmin') - self.assertTrue(project.has_role('test1', 'sysadmin')) - - def test_211_can_list_project_roles(self): - project = self.manager.get_project('testproj') - user = self.manager.get_user('test1') - self.manager.add_role(user, 'netadmin', project) - roles = self.manager.get_user_roles(user) - self.assertTrue('sysadmin' in roles) - self.assertFalse('netadmin' in roles) - project_roles = self.manager.get_user_roles(user, project) - self.assertTrue('sysadmin' in project_roles) - self.assertTrue('netadmin' in project_roles) - # has role should be false because global role is missing - self.assertFalse(self.manager.has_role(user, 'netadmin', project)) - - - def test_212_can_remove_project_role(self): - project = self.manager.get_project('testproj') - self.assertTrue(project.has_role('test1', 'sysadmin')) - project.remove_role('test1', 'sysadmin') - self.assertFalse(project.has_role('test1', 'sysadmin')) - self.manager.remove_role('test1', 'sysadmin') - self.assertFalse(project.has_role('test1', 'sysadmin')) - - def test_214_can_retrieve_project_by_user(self): - project = self.manager.create_project('testproj2', 'test2', 'Another test project', ['test2']) - self.assert_(len(self.manager.get_projects()) > 1) - self.assertEqual(len(self.manager.get_projects('test2')), 1) - - def test_299_can_delete_project(self): - self.manager.delete_project('testproj') - self.assertFalse(filter(lambda p: p.name == 'testproj', self.manager.get_projects())) - self.manager.delete_project('testproj2') - - def test_999_can_delete_users(self): + def test_can_list_users(self): + with user_generator(self.manager): + with user_generator(self.manager, name="test2"): + users = self.manager.get_users() + self.assert_(filter(lambda u: u.id == 'test1', users)) + self.assert_(filter(lambda u: u.id == 'test2', users)) + self.assert_(not filter(lambda u: u.id == 'test3', users)) + + def test_can_add_and_remove_user_role(self): + with user_generator(self.manager): + self.assertFalse(self.manager.has_role('test1', 'itsec')) + self.manager.add_role('test1', 'itsec') + self.assertTrue(self.manager.has_role('test1', 'itsec')) + self.manager.remove_role('test1', 'itsec') + self.assertFalse(self.manager.has_role('test1', 'itsec')) + + def test_can_create_and_get_project(self): + with user_and_project_generator(self.manager) as (u,p): + self.assert_(self.manager.get_user('test1')) + self.assert_(self.manager.get_user('test1')) + self.assert_(self.manager.get_project('testproj')) + + def test_can_list_projects(self): + with user_and_project_generator(self.manager): + with project_generator(self.manager, name="testproj2"): + projects = self.manager.get_projects() + self.assert_(filter(lambda p: p.name == 'testproj', projects)) + self.assert_(filter(lambda p: p.name == 'testproj2', projects)) + self.assert_(not filter(lambda p: p.name == 'testproj3', + projects)) + + def test_can_create_and_get_project_with_attributes(self): + with user_generator(self.manager): + with project_generator(self.manager, description='A test project'): + project = self.manager.get_project('testproj') + self.assertEqual('A test project', project.description) + + def test_can_create_project_with_manager(self): + with user_and_project_generator(self.manager) as (user, project): + self.assertEqual('test1', project.project_manager_id) + self.assertTrue(self.manager.is_project_manager(user, project)) + + def test_create_project_assigns_manager_to_members(self): + with user_and_project_generator(self.manager) as (user, project): + self.assertTrue(self.manager.is_project_member(user, project)) + + def test_no_extra_project_members(self): + with user_generator(self.manager, name='test2') as baduser: + with user_and_project_generator(self.manager) as (user, project): + self.assertFalse(self.manager.is_project_member(baduser, + project)) + + def test_no_extra_project_managers(self): + with user_generator(self.manager, name='test2') as baduser: + with user_and_project_generator(self.manager) as (user, project): + self.assertFalse(self.manager.is_project_manager(baduser, + project)) + + def test_can_add_user_to_project(self): + with user_generator(self.manager, name='test2') as user: + with user_and_project_generator(self.manager) as (_user, project): + self.manager.add_to_project(user, project) + project = self.manager.get_project('testproj') + self.assertTrue(self.manager.is_project_member(user, project)) + + def test_can_remove_user_from_project(self): + with user_generator(self.manager, name='test2') as user: + with user_and_project_generator(self.manager) as (_user, project): + self.manager.add_to_project(user, project) + project = self.manager.get_project('testproj') + self.assertTrue(self.manager.is_project_member(user, project)) + self.manager.remove_from_project(user, project) + project = self.manager.get_project('testproj') + self.assertFalse(self.manager.is_project_member(user, project)) + + def test_can_add_remove_user_with_role(self): + with user_generator(self.manager, name='test2') as user: + with user_and_project_generator(self.manager) as (_user, project): + # NOTE(todd): after modifying users you must reload project + self.manager.add_to_project(user, project) + project = self.manager.get_project('testproj') + self.manager.add_role(user, 'developer', project) + self.assertTrue(self.manager.is_project_member(user, project)) + self.manager.remove_from_project(user, project) + project = self.manager.get_project('testproj') + self.assertFalse(self.manager.has_role(user, 'developer', + project)) + self.assertFalse(self.manager.is_project_member(user, project)) + + def test_can_generate_x509(self): + # NOTE(todd): this doesn't assert against the auth manager + # so it probably belongs in crypto_unittest + # but I'm leaving it where I found it. + with user_and_project_generator(self.manager) as (user, project): + # NOTE(todd): Should mention why we must setup controller first + # (somebody please clue me in) + cloud_controller = cloud.CloudController() + cloud_controller.setup() + _key, cert_str = self.manager._generate_x509_cert('test1', + 'testproj') + logging.debug(cert_str) + + # Need to verify that it's signed by the right intermediate CA + full_chain = crypto.fetch_ca(project_id='testproj', chain=True) + int_cert = crypto.fetch_ca(project_id='testproj', chain=False) + cloud_cert = crypto.fetch_ca() + logging.debug("CA chain:\n\n =====\n%s\n\n=====" % full_chain) + signed_cert = X509.load_cert_string(cert_str) + chain_cert = X509.load_cert_string(full_chain) + int_cert = X509.load_cert_string(int_cert) + cloud_cert = X509.load_cert_string(cloud_cert) + self.assertTrue(signed_cert.verify(chain_cert.get_pubkey())) + self.assertTrue(signed_cert.verify(int_cert.get_pubkey())) + if not FLAGS.use_intermediate_ca: + self.assertTrue(signed_cert.verify(cloud_cert.get_pubkey())) + else: + self.assertFalse(signed_cert.verify(cloud_cert.get_pubkey())) + + def test_adding_role_to_project_is_ignored_unless_added_to_user(self): + with user_and_project_generator(self.manager) as (user, project): + self.assertFalse(self.manager.has_role(user, 'sysadmin', project)) + self.manager.add_role(user, 'sysadmin', project) + # NOTE(todd): it will still show up in get_user_roles(u, project) + self.assertFalse(self.manager.has_role(user, 'sysadmin', project)) + self.manager.add_role(user, 'sysadmin') + self.assertTrue(self.manager.has_role(user, 'sysadmin', project)) + + def test_add_user_role_doesnt_infect_project_roles(self): + with user_and_project_generator(self.manager) as (user, project): + self.assertFalse(self.manager.has_role(user, 'sysadmin', project)) + self.manager.add_role(user, 'sysadmin') + self.assertFalse(self.manager.has_role(user, 'sysadmin', project)) + + def test_can_list_user_roles(self): + with user_and_project_generator(self.manager) as (user, project): + self.manager.add_role(user, 'sysadmin') + roles = self.manager.get_user_roles(user) + self.assertTrue('sysadmin' in roles) + self.assertFalse('netadmin' in roles) + + def test_can_list_project_roles(self): + with user_and_project_generator(self.manager) as (user, project): + self.manager.add_role(user, 'sysadmin') + self.manager.add_role(user, 'sysadmin', project) + self.manager.add_role(user, 'netadmin', project) + project_roles = self.manager.get_user_roles(user, project) + self.assertTrue('sysadmin' in project_roles) + self.assertTrue('netadmin' in project_roles) + # has role should be false user-level role is missing + self.assertFalse(self.manager.has_role(user, 'netadmin', project)) + + def test_can_remove_user_roles(self): + with user_and_project_generator(self.manager) as (user, project): + self.manager.add_role(user, 'sysadmin') + self.assertTrue(self.manager.has_role(user, 'sysadmin')) + self.manager.remove_role(user, 'sysadmin') + self.assertFalse(self.manager.has_role(user, 'sysadmin')) + + def test_removing_user_role_hides_it_from_project(self): + with user_and_project_generator(self.manager) as (user, project): + self.manager.add_role(user, 'sysadmin') + self.manager.add_role(user, 'sysadmin', project) + self.assertTrue(self.manager.has_role(user, 'sysadmin', project)) + self.manager.remove_role(user, 'sysadmin') + self.assertFalse(self.manager.has_role(user, 'sysadmin', project)) + + def test_can_remove_project_role_but_keep_user_role(self): + with user_and_project_generator(self.manager) as (user, project): + self.manager.add_role(user, 'sysadmin') + self.manager.add_role(user, 'sysadmin', project) + self.assertTrue(self.manager.has_role(user, 'sysadmin')) + self.manager.remove_role(user, 'sysadmin', project) + self.assertFalse(self.manager.has_role(user, 'sysadmin', project)) + self.assertTrue(self.manager.has_role(user, 'sysadmin')) + + def test_can_retrieve_project_by_user(self): + with user_and_project_generator(self.manager) as (user, project): + self.assertEqual(1, len(self.manager.get_projects('test1'))) + + def test_can_modify_project(self): + with user_and_project_generator(self.manager): + with user_generator(self.manager, name='test2'): + self.manager.modify_project('testproj', 'test2', 'new desc') + project = self.manager.get_project('testproj') + self.assertEqual('test2', project.project_manager_id) + self.assertEqual('new desc', project.description) + + def test_can_delete_project(self): + with user_generator(self.manager): + self.manager.create_project('testproj', 'test1') + self.assert_(self.manager.get_project('testproj')) + self.manager.delete_project('testproj') + projectlist = self.manager.get_projects() + self.assert_(not filter(lambda p: p.name == 'testproj', + projectlist)) + + def test_can_delete_user(self): + self.manager.create_user('test1') + self.assert_(self.manager.get_user('test1')) self.manager.delete_user('test1') - users = self.manager.get_users() - self.assertFalse(filter(lambda u: u.id == 'test1', users)) - self.manager.delete_user('test2') - self.assertEqual(self.manager.get_user('test2'), None) + userlist = self.manager.get_users() + self.assert_(not filter(lambda u: u.id == 'test1', userlist)) + + def test_can_modify_users(self): + with user_generator(self.manager): + self.manager.modify_user('test1', 'access', 'secret', True) + user = self.manager.get_user('test1') + self.assertEqual('access', user.access) + self.assertEqual('secret', user.secret) + self.assertTrue(user.is_admin()) + +class AuthManagerLdapTestCase(AuthManagerTestCase, test.TrialTestCase): + auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver' + +class AuthManagerDbTestCase(AuthManagerTestCase, test.TrialTestCase): + auth_driver = 'nova.auth.dbdriver.DbDriver' if __name__ == "__main__": diff --git a/nova/tests/cloud_unittest.py b/nova/tests/cloud_unittest.py index c36d5a34f..ae7dea1db 100644 --- a/nova/tests/cloud_unittest.py +++ b/nova/tests/cloud_unittest.py @@ -16,31 +16,45 @@ # License for the specific language governing permissions and limitations # under the License. +import json import logging +from M2Crypto import BIO +from M2Crypto import RSA +import os import StringIO +import tempfile import time -from tornado import ioloop + from twisted.internet import defer import unittest from xml.etree import ElementTree +from nova import crypto +from nova import db from nova import flags from nova import rpc from nova import test from nova import utils from nova.auth import manager from nova.compute import power_state -from nova.endpoint import api -from nova.endpoint import cloud +from nova.api.ec2 import context +from nova.api.ec2 import cloud +from nova.objectstore import image FLAGS = flags.FLAGS -class CloudTestCase(test.BaseTestCase): +# Temp dirs for working with image attributes through the cloud controller +# (stole this from objectstore_unittest.py) +OSS_TEMPDIR = tempfile.mkdtemp(prefix='test_oss-') +IMAGES_PATH = os.path.join(OSS_TEMPDIR, 'images') +os.makedirs(IMAGES_PATH) + +class CloudTestCase(test.TrialTestCase): def setUp(self): super(CloudTestCase, self).setUp() - self.flags(connection_type='fake') + self.flags(connection_type='fake', images_path=IMAGES_PATH) self.conn = rpc.Connection.instance() logging.getLogger().setLevel(logging.DEBUG) @@ -51,20 +65,24 @@ class CloudTestCase(test.BaseTestCase): # set up a service self.compute = utils.import_class(FLAGS.compute_manager) self.compute_consumer = rpc.AdapterConsumer(connection=self.conn, - topic=FLAGS.compute_topic, - proxy=self.compute) - self.injected.append(self.compute_consumer.attach_to_tornado(self.ioloop)) + topic=FLAGS.compute_topic, + proxy=self.compute) + self.compute_consumer.attach_to_twisted() - try: - manager.AuthManager().create_user('admin', 'admin', 'admin') - except: pass - admin = manager.AuthManager().get_user('admin') - project = manager.AuthManager().create_project('proj', 'admin', 'proj') - self.context = api.APIRequestContext(handler=None,project=project,user=admin) + self.manager = manager.AuthManager() + self.user = self.manager.create_user('admin', 'admin', 'admin', True) + self.project = self.manager.create_project('proj', 'admin', 'proj') + self.context = context.APIRequestContext(user=self.user, + project=self.project) def tearDown(self): - manager.AuthManager().delete_project('proj') - manager.AuthManager().delete_user('admin') + self.manager.delete_project(self.project) + self.manager.delete_user(self.user) + super(CloudTestCase, self).tearDown() + + def _create_key(self, name): + # NOTE(vish): create depends on pool, so just call helper directly + return cloud._gen_key(self.context, self.context.user.id, name) def test_console_output(self): if FLAGS.connection_type == 'fake': @@ -77,6 +95,33 @@ class CloudTestCase(test.BaseTestCase): self.assert_(output) rv = yield self.compute.terminate_instance(instance_id) + + def test_key_generation(self): + result = self._create_key('test') + private_key = result['private_key'] + key = RSA.load_key_string(private_key, callback=lambda: None) + bio = BIO.MemoryBuffer() + public_key = db.key_pair_get(self.context, + self.context.user.id, + 'test')['public_key'] + key.save_pub_key_bio(bio) + converted = crypto.ssl_pub_to_ssh_pub(bio.read()) + # assert key fields are equal + self.assertEqual(public_key.split(" ")[1].strip(), + converted.split(" ")[1].strip()) + + def test_describe_key_pairs(self): + self._create_key('test1') + self._create_key('test2') + result = self.cloud.describe_key_pairs(self.context) + keys = result["keypairsSet"] + self.assertTrue(filter(lambda k: k['keyName'] == 'test1', keys)) + self.assertTrue(filter(lambda k: k['keyName'] == 'test2', keys)) + + def test_delete_key_pair(self): + self._create_key('test') + self.cloud.delete_key_pair(self.context, 'test') + def test_run_instances(self): if FLAGS.connection_type == 'fake': logging.debug("Can't test instances without a real virtual env.") @@ -156,3 +201,67 @@ class CloudTestCase(test.BaseTestCase): #for i in xrange(4): # data = self.cloud.get_metadata(instance(i)['private_dns_name']) # self.assert_(data['meta-data']['ami-id'] == 'ami-%s' % i) + + @staticmethod + def _fake_set_image_description(ctxt, image_id, description): + from nova.objectstore import handler + class req: + pass + request = req() + request.context = ctxt + request.args = {'image_id': [image_id], + 'description': [description]} + + resource = handler.ImagesResource() + resource.render_POST(request) + + def test_user_editable_image_endpoint(self): + pathdir = os.path.join(FLAGS.images_path, 'ami-testing') + os.mkdir(pathdir) + info = {'isPublic': False} + with open(os.path.join(pathdir, 'info.json'), 'w') as f: + json.dump(info, f) + img = image.Image('ami-testing') + # self.cloud.set_image_description(self.context, 'ami-testing', + # 'Foo Img') + # NOTE(vish): Above won't work unless we start objectstore or create + # a fake version of api/ec2/images.py conn that can + # call methods directly instead of going through boto. + # for now, just cheat and call the method directly + self._fake_set_image_description(self.context, 'ami-testing', + 'Foo Img') + self.assertEqual('Foo Img', img.metadata['description']) + self._fake_set_image_description(self.context, 'ami-testing', '') + self.assertEqual('', img.metadata['description']) + + def test_update_of_instance_display_fields(self): + inst = db.instance_create({}, {}) + self.cloud.update_instance(self.context, inst['ec2_id'], + display_name='c00l 1m4g3') + inst = db.instance_get({}, inst['id']) + self.assertEqual('c00l 1m4g3', inst['display_name']) + db.instance_destroy({}, inst['id']) + + def test_update_of_instance_wont_update_private_fields(self): + inst = db.instance_create({}, {}) + self.cloud.update_instance(self.context, inst['id'], + mac_address='DE:AD:BE:EF') + inst = db.instance_get({}, inst['id']) + self.assertEqual(None, inst['mac_address']) + db.instance_destroy({}, inst['id']) + + def test_update_of_volume_display_fields(self): + vol = db.volume_create({}, {}) + self.cloud.update_volume(self.context, vol['id'], + display_name='c00l v0lum3') + vol = db.volume_get({}, vol['id']) + self.assertEqual('c00l v0lum3', vol['display_name']) + db.volume_destroy({}, vol['id']) + + def test_update_of_volume_wont_update_private_fields(self): + vol = db.volume_create({}, {}) + self.cloud.update_volume(self.context, vol['id'], + mountpoint='/not/here') + vol = db.volume_get({}, vol['id']) + self.assertEqual(None, vol['mountpoint']) + db.volume_destroy({}, vol['id']) diff --git a/nova/tests/compute_unittest.py b/nova/tests/compute_unittest.py index f5c0f1c09..1e2bb113b 100644 --- a/nova/tests/compute_unittest.py +++ b/nova/tests/compute_unittest.py @@ -30,7 +30,7 @@ from nova import flags from nova import test from nova import utils from nova.auth import manager - +from nova.api import context FLAGS = flags.FLAGS @@ -96,7 +96,9 @@ class ComputeTestCase(test.TrialTestCase): self.assertEqual(instance_ref['deleted_at'], None) terminate = datetime.datetime.utcnow() yield self.compute.terminate_instance(self.context, instance_id) - instance_ref = db.instance_get({'deleted': True}, instance_id) + self.context = context.get_admin_context(user=self.user, + read_deleted=True) + instance_ref = db.instance_get(self.context, instance_id) self.assert_(instance_ref['launched_at'] < terminate) self.assert_(instance_ref['deleted_at'] > terminate) diff --git a/nova/tests/fake_flags.py b/nova/tests/fake_flags.py index 8f4754650..4bbef8832 100644 --- a/nova/tests/fake_flags.py +++ b/nova/tests/fake_flags.py @@ -24,7 +24,7 @@ flags.DECLARE('volume_driver', 'nova.volume.manager') FLAGS.volume_driver = 'nova.volume.driver.FakeAOEDriver' FLAGS.connection_type = 'fake' FLAGS.fake_rabbit = True -FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver' +FLAGS.auth_driver = 'nova.auth.dbdriver.DbDriver' flags.DECLARE('network_size', 'nova.network.manager') flags.DECLARE('num_networks', 'nova.network.manager') flags.DECLARE('fake_network', 'nova.network.manager') diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index dc5277f02..59b0a36e4 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -28,7 +28,7 @@ from nova import flags from nova import test from nova import utils from nova.auth import manager -from nova.endpoint import api +from nova.api.ec2 import context FLAGS = flags.FLAGS @@ -49,19 +49,19 @@ class NetworkTestCase(test.TrialTestCase): self.user = self.manager.create_user('netuser', 'netuser', 'netuser') self.projects = [] self.network = utils.import_object(FLAGS.network_manager) - self.context = api.APIRequestContext(None, project=None, user=self.user) + self.context = context.APIRequestContext(project=None, user=self.user) for i in range(5): name = 'project%s' % i self.projects.append(self.manager.create_project(name, 'netuser', name)) # create the necessary network data for the project - self.network.set_network_host(self.context, self.projects[i].id) - instance_ref = db.instance_create(None, - {'mac_address': utils.generate_mac()}) + user_context = context.get_admin_context(user=self.user) + + self.network.set_network_host(user_context, self.projects[i].id) + instance_ref = self._create_instance(0) self.instance_id = instance_ref['id'] - instance_ref = db.instance_create(None, - {'mac_address': utils.generate_mac()}) + instance_ref = self._create_instance(1) self.instance2_id = instance_ref['id'] def tearDown(self): # pylint: disable-msg=C0103 @@ -74,6 +74,15 @@ class NetworkTestCase(test.TrialTestCase): self.manager.delete_project(project) self.manager.delete_user(self.user) + def _create_instance(self, project_num, mac=None): + if not mac: + mac = utils.generate_mac() + project = self.projects[project_num] + self.context.project = project + return db.instance_create(self.context, + {'project_id': project.id, + 'mac_address': mac}) + def _create_address(self, project_num, instance_id=None): """Create an address in given project num""" if instance_id is None: @@ -81,9 +90,15 @@ class NetworkTestCase(test.TrialTestCase): self.context.project = self.projects[project_num] return self.network.allocate_fixed_ip(self.context, instance_id) + def _deallocate_address(self, project_num, address): + self.context.project = self.projects[project_num] + self.network.deallocate_fixed_ip(self.context, address) + + def test_public_network_association(self): """Makes sure that we can allocaate a public ip""" # TODO(vish): better way of adding floating ips + self.context.project = self.projects[0] pubnet = IPy.IP(flags.FLAGS.public_range) address = str(pubnet[0]) try: @@ -109,7 +124,7 @@ class NetworkTestCase(test.TrialTestCase): address = self._create_address(0) self.assertTrue(is_allocated_in_project(address, self.projects[0].id)) lease_ip(address) - self.network.deallocate_fixed_ip(self.context, address) + self._deallocate_address(0, address) # Doesn't go away until it's dhcp released self.assertTrue(is_allocated_in_project(address, self.projects[0].id)) @@ -130,14 +145,14 @@ class NetworkTestCase(test.TrialTestCase): lease_ip(address) lease_ip(address2) - self.network.deallocate_fixed_ip(self.context, address) + self._deallocate_address(0, address) release_ip(address) self.assertFalse(is_allocated_in_project(address, self.projects[0].id)) # First address release shouldn't affect the second self.assertTrue(is_allocated_in_project(address2, self.projects[1].id)) - self.network.deallocate_fixed_ip(self.context, address2) + self._deallocate_address(1, address2) release_ip(address2) self.assertFalse(is_allocated_in_project(address2, self.projects[1].id)) @@ -148,24 +163,19 @@ class NetworkTestCase(test.TrialTestCase): lease_ip(first) instance_ids = [] for i in range(1, 5): - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address = self._create_address(i, instance_ref['id']) - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address2 = self._create_address(i, instance_ref['id']) - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address3 = self._create_address(i, instance_ref['id']) lease_ip(address) lease_ip(address2) lease_ip(address3) + self.context.project = self.projects[i] self.assertFalse(is_allocated_in_project(address, self.projects[0].id)) self.assertFalse(is_allocated_in_project(address2, @@ -181,7 +191,7 @@ class NetworkTestCase(test.TrialTestCase): for instance_id in instance_ids: db.instance_destroy(None, instance_id) release_ip(first) - self.network.deallocate_fixed_ip(self.context, first) + self._deallocate_address(0, first) def test_vpn_ip_and_port_looks_valid(self): """Ensure the vpn ip and port are reasonable""" @@ -242,9 +252,7 @@ class NetworkTestCase(test.TrialTestCase): addresses = [] instance_ids = [] for i in range(num_available_ips): - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(0) instance_ids.append(instance_ref['id']) address = self._create_address(0, instance_ref['id']) addresses.append(address) diff --git a/nova/tests/objectstore_unittest.py b/nova/tests/objectstore_unittest.py index dece4b5d5..5a599ff3a 100644 --- a/nova/tests/objectstore_unittest.py +++ b/nova/tests/objectstore_unittest.py @@ -53,7 +53,7 @@ os.makedirs(os.path.join(OSS_TEMPDIR, 'images')) os.makedirs(os.path.join(OSS_TEMPDIR, 'buckets')) -class ObjectStoreTestCase(test.BaseTestCase): +class ObjectStoreTestCase(test.TrialTestCase): """Test objectstore API directly.""" def setUp(self): # pylint: disable-msg=C0103 @@ -164,6 +164,12 @@ class ObjectStoreTestCase(test.BaseTestCase): self.context.project = self.auth_manager.get_project('proj2') self.assertFalse(my_img.is_authorized(self.context)) + # change user-editable fields + my_img.update_user_editable_fields({'display_name': 'my cool image'}) + self.assertEqual('my cool image', my_img.metadata['displayName']) + my_img.update_user_editable_fields({'display_name': ''}) + self.assert_(not my_img.metadata['displayName']) + class TestHTTPChannel(http.HTTPChannel): """Dummy site required for twisted.web""" diff --git a/nova/tests/quota_unittest.py b/nova/tests/quota_unittest.py index cab9f663d..370ccd506 100644 --- a/nova/tests/quota_unittest.py +++ b/nova/tests/quota_unittest.py @@ -25,8 +25,8 @@ from nova import quota from nova import test from nova import utils from nova.auth import manager -from nova.endpoint import cloud -from nova.endpoint import api +from nova.api.ec2 import cloud +from nova.api.ec2 import context FLAGS = flags.FLAGS @@ -48,9 +48,8 @@ class QuotaTestCase(test.TrialTestCase): self.user = self.manager.create_user('admin', 'admin', 'admin', True) self.project = self.manager.create_project('admin', 'admin', 'admin') self.network = utils.import_object(FLAGS.network_manager) - self.context = api.APIRequestContext(handler=None, - project=self.project, - user=self.user) + self.context = context.APIRequestContext(project=self.project, + user=self.user) def tearDown(self): # pylint: disable-msg=C0103 manager.AuthManager().delete_project(self.project) @@ -95,11 +94,11 @@ class QuotaTestCase(test.TrialTestCase): for i in range(FLAGS.quota_instances): instance_id = self._create_instance() instance_ids.append(instance_id) - self.assertFailure(self.cloud.run_instances(self.context, - min_count=1, - max_count=1, - instance_type='m1.small'), - cloud.QuotaError) + self.assertRaises(cloud.QuotaError, self.cloud.run_instances, + self.context, + min_count=1, + max_count=1, + instance_type='m1.small') for instance_id in instance_ids: db.instance_destroy(self.context, instance_id) @@ -107,11 +106,11 @@ class QuotaTestCase(test.TrialTestCase): instance_ids = [] instance_id = self._create_instance(cores=4) instance_ids.append(instance_id) - self.assertFailure(self.cloud.run_instances(self.context, - min_count=1, - max_count=1, - instance_type='m1.small'), - cloud.QuotaError) + self.assertRaises(cloud.QuotaError, self.cloud.run_instances, + self.context, + min_count=1, + max_count=1, + instance_type='m1.small') for instance_id in instance_ids: db.instance_destroy(self.context, instance_id) @@ -120,10 +119,9 @@ class QuotaTestCase(test.TrialTestCase): for i in range(FLAGS.quota_volumes): volume_id = self._create_volume() volume_ids.append(volume_id) - self.assertRaises(cloud.QuotaError, - self.cloud.create_volume, - self.context, - size=10) + self.assertRaises(cloud.QuotaError, self.cloud.create_volume, + self.context, + size=10) for volume_id in volume_ids: db.volume_destroy(self.context, volume_id) @@ -151,5 +149,4 @@ class QuotaTestCase(test.TrialTestCase): # make an rpc.call, the test just finishes with OK. It # appears to be something in the magic inline callbacks # that is breaking. - self.assertFailure(self.cloud.allocate_address(self.context), - cloud.QuotaError) + self.assertRaises(cloud.QuotaError, self.cloud.allocate_address, self.context) diff --git a/nova/tests/rpc_unittest.py b/nova/tests/rpc_unittest.py index e12a28fbc..9652841f2 100644 --- a/nova/tests/rpc_unittest.py +++ b/nova/tests/rpc_unittest.py @@ -30,7 +30,7 @@ from nova import test FLAGS = flags.FLAGS -class RpcTestCase(test.BaseTestCase): +class RpcTestCase(test.TrialTestCase): """Test cases for rpc""" def setUp(self): # pylint: disable-msg=C0103 super(RpcTestCase, self).setUp() @@ -39,14 +39,13 @@ class RpcTestCase(test.BaseTestCase): self.consumer = rpc.AdapterConsumer(connection=self.conn, topic='test', proxy=self.receiver) - - self.injected.append(self.consumer.attach_to_tornado(self.ioloop)) + self.consumer.attach_to_twisted() def test_call_succeed(self): """Get a value through rpc call""" value = 42 - result = yield rpc.call('test', {"method": "echo", - "args": {"value": value}}) + result = yield rpc.call_twisted('test', {"method": "echo", + "args": {"value": value}}) self.assertEqual(value, result) def test_call_exception(self): @@ -57,12 +56,12 @@ class RpcTestCase(test.BaseTestCase): to an int in the test. """ value = 42 - self.assertFailure(rpc.call('test', {"method": "fail", - "args": {"value": value}}), + self.assertFailure(rpc.call_twisted('test', {"method": "fail", + "args": {"value": value}}), rpc.RemoteError) try: - yield rpc.call('test', {"method": "fail", - "args": {"value": value}}) + yield rpc.call_twisted('test', {"method": "fail", + "args": {"value": value}}) self.fail("should have thrown rpc.RemoteError") except rpc.RemoteError as exc: self.assertEqual(int(exc.value), value) diff --git a/nova/tests/scheduler_unittest.py b/nova/tests/scheduler_unittest.py index fde30f81e..53a8be144 100644 --- a/nova/tests/scheduler_unittest.py +++ b/nova/tests/scheduler_unittest.py @@ -117,10 +117,12 @@ class SimpleDriverTestCase(test.TrialTestCase): 'nova-compute', 'compute', FLAGS.compute_manager) + compute1.startService() compute2 = service.Service('host2', 'nova-compute', 'compute', FLAGS.compute_manager) + compute2.startService() hosts = self.scheduler.driver.hosts_up(self.context, 'compute') self.assertEqual(len(hosts), 2) compute1.kill() @@ -132,10 +134,12 @@ class SimpleDriverTestCase(test.TrialTestCase): 'nova-compute', 'compute', FLAGS.compute_manager) + compute1.startService() compute2 = service.Service('host2', 'nova-compute', 'compute', FLAGS.compute_manager) + compute2.startService() instance_id1 = self._create_instance() compute1.run_instance(self.context, instance_id1) instance_id2 = self._create_instance() @@ -153,10 +157,12 @@ class SimpleDriverTestCase(test.TrialTestCase): 'nova-compute', 'compute', FLAGS.compute_manager) + compute1.startService() compute2 = service.Service('host2', 'nova-compute', 'compute', FLAGS.compute_manager) + compute2.startService() instance_ids1 = [] instance_ids2 = [] for index in xrange(FLAGS.max_cores): @@ -184,10 +190,12 @@ class SimpleDriverTestCase(test.TrialTestCase): 'nova-volume', 'volume', FLAGS.volume_manager) + volume1.startService() volume2 = service.Service('host2', 'nova-volume', 'volume', FLAGS.volume_manager) + volume2.startService() volume_id1 = self._create_volume() volume1.create_volume(self.context, volume_id1) volume_id2 = self._create_volume() @@ -205,10 +213,12 @@ class SimpleDriverTestCase(test.TrialTestCase): 'nova-volume', 'volume', FLAGS.volume_manager) + volume1.startService() volume2 = service.Service('host2', 'nova-volume', 'volume', FLAGS.volume_manager) + volume2.startService() volume_ids1 = [] volume_ids2 = [] for index in xrange(FLAGS.max_gigabytes): diff --git a/nova/tests/service_unittest.py b/nova/tests/service_unittest.py index 01da0eb8a..6afeec377 100644 --- a/nova/tests/service_unittest.py +++ b/nova/tests/service_unittest.py @@ -22,6 +22,8 @@ Unit Tests for remote procedure calls using queue import mox +from twisted.application.app import startApplication + from nova import exception from nova import flags from nova import rpc @@ -65,15 +67,20 @@ class ServiceTestCase(test.BaseTestCase): proxy=mox.IsA(service.Service)).AndReturn( rpc.AdapterConsumer) + rpc.AdapterConsumer.attach_to_twisted() + rpc.AdapterConsumer.attach_to_twisted() + # Stub out looping call a bit needlessly since we don't have an easy # way to cancel it (yet) when the tests finishes service.task.LoopingCall(mox.IgnoreArg()).AndReturn( service.task.LoopingCall) service.task.LoopingCall.start(interval=mox.IgnoreArg(), now=mox.IgnoreArg()) + service.task.LoopingCall(mox.IgnoreArg()).AndReturn( + service.task.LoopingCall) + service.task.LoopingCall.start(interval=mox.IgnoreArg(), + now=mox.IgnoreArg()) - rpc.AdapterConsumer.attach_to_twisted() - rpc.AdapterConsumer.attach_to_twisted() service_create = {'host': host, 'binary': binary, 'topic': topic, @@ -91,6 +98,7 @@ class ServiceTestCase(test.BaseTestCase): self.mox.ReplayAll() app = service.Service.create(host=host, binary=binary) + startApplication(app, False) self.assert_(app) # We're testing sort of weird behavior in how report_state decides diff --git a/nova/utils.py b/nova/utils.py index 8939043e6..d18dd9843 100644 --- a/nova/utils.py +++ b/nova/utils.py @@ -39,17 +39,6 @@ from nova.exception import ProcessExecutionError FLAGS = flags.FLAGS TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" -class ProcessExecutionError(IOError): - def __init__( self, stdout=None, stderr=None, exit_code=None, cmd=None, - description=None): - if description is None: - description = "Unexpected error while running command." - if exit_code is None: - exit_code = '-' - message = "%s\nCommand: %s\nExit code: %s\nStdout: %r\nStderr: %r" % ( - description, cmd, exit_code, stdout, stderr) - IOError.__init__(self, message) - def import_class(import_str): """Returns a class from a string including module and class""" mod_str, _sep, class_str = import_str.rpartition('.') diff --git a/nova/virt/xenapi.py b/nova/virt/xenapi.py index 5fdd2b9fc..04e830b64 100644 --- a/nova/virt/xenapi.py +++ b/nova/virt/xenapi.py @@ -42,10 +42,12 @@ from twisted.internet import defer from twisted.internet import reactor from twisted.internet import task +from nova import db from nova import flags from nova import process from nova import utils from nova.auth.manager import AuthManager +from nova.compute import instance_types from nova.compute import power_state from nova.virt import images @@ -103,8 +105,8 @@ class XenAPIConnection(object): self._conn.login_with_password(user, pw) def list_instances(self): - result = [self._conn.xenapi.VM.get_name_label(vm) \ - for vm in self._conn.xenapi.VM.get_all()] + return [self._conn.xenapi.VM.get_name_label(vm) \ + for vm in self._conn.xenapi.VM.get_all()] @defer.inlineCallbacks def spawn(self, instance): @@ -113,32 +115,24 @@ class XenAPIConnection(object): raise Exception('Attempted to create non-unique name %s' % instance.name) - if 'bridge_name' in instance.datamodel: - network_ref = \ - yield self._find_network_with_bridge( - instance.datamodel['bridge_name']) - else: - network_ref = None - - if 'mac_address' in instance.datamodel: - mac_address = instance.datamodel['mac_address'] - else: - mac_address = '' + network = db.project_get_network(None, instance.project_id) + network_ref = \ + yield self._find_network_with_bridge(network.bridge) - user = AuthManager().get_user(instance.datamodel['user_id']) - project = AuthManager().get_project(instance.datamodel['project_id']) + user = AuthManager().get_user(instance.user_id) + project = AuthManager().get_project(instance.project_id) vdi_uuid = yield self._fetch_image( - instance.datamodel['image_id'], user, project, True) + instance.image_id, user, project, True) kernel = yield self._fetch_image( - instance.datamodel['kernel_id'], user, project, False) + instance.kernel_id, user, project, False) ramdisk = yield self._fetch_image( - instance.datamodel['ramdisk_id'], user, project, False) + instance.ramdisk_id, user, project, False) vdi_ref = yield self._call_xenapi('VDI.get_by_uuid', vdi_uuid) vm_ref = yield self._create_vm(instance, kernel, ramdisk) yield self._create_vbd(vm_ref, vdi_ref, 0, True) if network_ref: - yield self._create_vif(vm_ref, network_ref, mac_address) + yield self._create_vif(vm_ref, network_ref, instance.mac_address) logging.debug('Starting VM %s...', vm_ref) yield self._call_xenapi('VM.start', vm_ref, False, False) logging.info('Spawning VM %s created %s.', instance.name, vm_ref) @@ -148,8 +142,9 @@ class XenAPIConnection(object): """Create a VM record. Returns a Deferred that gives the new VM reference.""" - mem = str(long(instance.datamodel['memory_kb']) * 1024) - vcpus = str(instance.datamodel['vcpus']) + instance_type = instance_types.INSTANCE_TYPES[instance.instance_type] + mem = str(long(instance_type['memory_mb']) * 1024 * 1024) + vcpus = str(instance_type['vcpus']) rec = { 'name_label': instance.name, 'name_description': '', diff --git a/nova/volume/manager.py b/nova/volume/manager.py index 034763512..8508f27b2 100644 --- a/nova/volume/manager.py +++ b/nova/volume/manager.py @@ -77,7 +77,7 @@ class AOEManager(manager.Manager): size = volume_ref['size'] logging.debug("volume %s: creating lv of size %sG", volume_id, size) - yield self.driver.create_volume(volume_ref['str_id'], size) + yield self.driver.create_volume(volume_ref['ec2_id'], size) logging.debug("volume %s: allocating shelf & blade", volume_id) self._ensure_blades(context) @@ -87,7 +87,7 @@ class AOEManager(manager.Manager): logging.debug("volume %s: exporting shelf %s & blade %s", volume_id, shelf_id, blade_id) - yield self.driver.create_export(volume_ref['str_id'], + yield self.driver.create_export(volume_ref['ec2_id'], shelf_id, blade_id) @@ -111,10 +111,10 @@ class AOEManager(manager.Manager): raise exception.Error("Volume is not local to this node") shelf_id, blade_id = self.db.volume_get_shelf_and_blade(context, volume_id) - yield self.driver.remove_export(volume_ref['str_id'], + yield self.driver.remove_export(volume_ref['ec2_id'], shelf_id, blade_id) - yield self.driver.delete_volume(volume_ref['str_id']) + yield self.driver.delete_volume(volume_ref['ec2_id']) self.db.volume_destroy(context, volume_id) defer.returnValue(True) @@ -125,7 +125,7 @@ class AOEManager(manager.Manager): Returns path to device. """ volume_ref = self.db.volume_get(context, volume_id) - yield self.driver.discover_volume(volume_ref['str_id']) + yield self.driver.discover_volume(volume_ref['ec2_id']) shelf_id, blade_id = self.db.volume_get_shelf_and_blade(context, volume_id) defer.returnValue("/dev/etherd/e%s.%s" % (shelf_id, blade_id)) diff --git a/nova/wsgi.py b/nova/wsgi.py index 8a4e2a9f4..b91d91121 100644 --- a/nova/wsgi.py +++ b/nova/wsgi.py @@ -21,14 +21,17 @@ Utility methods for working with WSGI servers """ +import json import logging import sys +from xml.dom import minidom import eventlet import eventlet.wsgi eventlet.patcher.monkey_patch(all=False, socket=True) import routes import routes.middleware +import webob import webob.dec import webob.exc @@ -227,10 +230,19 @@ class Controller(object): serializer = Serializer(request.environ, _metadata) return serializer.to_content_type(data) + def _deserialize(self, data, request): + """ + Deserialize the request body to the response type requested in request. + Uses self._serialization_metadata if it exists, which is a dict mapping + MIME types to information needed to serialize to that type. + """ + _metadata = getattr(type(self), "_serialization_metadata", {}) + serializer = Serializer(request.environ, _metadata) + return serializer.deserialize(data) class Serializer(object): """ - Serializes a dictionary to a Content Type specified by a WSGI environment. + Serializes and deserializes dictionaries to certain MIME types. """ def __init__(self, environ, metadata=None): @@ -239,31 +251,77 @@ class Serializer(object): 'metadata' is an optional dict mapping MIME types to information needed to serialize a dictionary to that type. """ - self.environ = environ self.metadata = metadata or {} - self._methods = { - 'application/json': self._to_json, - 'application/xml': self._to_xml} + req = webob.Request(environ) + suffix = req.path_info.split('.')[-1].lower() + if suffix == 'json': + self.handler = self._to_json + elif suffix == 'xml': + self.handler = self._to_xml + elif 'application/json' in req.accept: + self.handler = self._to_json + elif 'application/xml' in req.accept: + self.handler = self._to_xml + else: + self.handler = self._to_json # default def to_content_type(self, data): """ - Serialize a dictionary into a string. The format of the string - will be decided based on the Content Type requested in self.environ: - by Accept: header, or by URL suffix. + Serialize a dictionary into a string. + + The format of the string will be decided based on the Content Type + requested in self.environ: by Accept: header, or by URL suffix. """ - mimetype = 'application/xml' - # TODO(gundlach): determine mimetype from request - return self._methods.get(mimetype, repr)(data) + return self.handler(data) + + def deserialize(self, datastring): + """ + Deserialize a string to a dictionary. + + The string must be in the format of a supported MIME type. + """ + datastring = datastring.strip() + try: + is_xml = (datastring[0] == '<') + if not is_xml: + return json.loads(datastring) + return self._from_xml(datastring) + except: + return None + + def _from_xml(self, datastring): + xmldata = self.metadata.get('application/xml', {}) + plurals = set(xmldata.get('plurals', {})) + node = minidom.parseString(datastring).childNodes[0] + return {node.nodeName: self._from_xml_node(node, plurals)} + + def _from_xml_node(self, node, listnames): + """ + Convert a minidom node to a simple Python type. + + listnames is a collection of names of XML nodes whose subnodes should + be considered list items. + """ + if len(node.childNodes) == 1 and node.childNodes[0].nodeType == 3: + return node.childNodes[0].nodeValue + elif node.nodeName in listnames: + return [self._from_xml_node(n, listnames) for n in node.childNodes] + else: + result = dict() + for attr in node.attributes.keys(): + result[attr] = node.attributes[attr].nodeValue + for child in node.childNodes: + if child.nodeType != node.TEXT_NODE: + result[child.nodeName] = self._from_xml_node(child, listnames) + return result def _to_json(self, data): - import json return json.dumps(data) def _to_xml(self, data): metadata = self.metadata.get('application/xml', {}) # We expect data to contain a single key which is the XML root. root_key = data.keys()[0] - from xml.dom import minidom doc = minidom.Document() node = self._to_xml_node(doc, metadata, root_key, data[root_key]) return node.toprettyxml(indent=' ') @@ -1,7 +1,8 @@ [Messages Control] # W0511: TODOs in code comments are fine. # W0142: *args and **kwargs are fine. -disable-msg=W0511,W0142 +# W0622: Redefining id is fine. +disable-msg=W0511,W0142,W0622 [Basic] # Variable names can be 1 to 31 characters long, with lowercase and underscores @@ -54,5 +54,5 @@ setup(name='nova', 'bin/nova-manage', 'bin/nova-network', 'bin/nova-objectstore', - 'bin/nova-api-new', + 'bin/nova-scheduler', 'bin/nova-volume']) diff --git a/tools/install_venv.py b/tools/install_venv.py index 5d2369a96..32c372352 100644 --- a/tools/install_venv.py +++ b/tools/install_venv.py @@ -88,6 +88,10 @@ def create_virtualenv(venv=VENV): def install_dependencies(venv=VENV): print 'Installing dependencies with pip (this can take a while)...' + # Install greenlet by hand - just listing it in the requires file does not + # get it in stalled in the right order + run_command(['tools/with_venv.sh', 'pip', 'install', '-E', venv, 'greenlet'], + redirect_output=False) run_command(['tools/with_venv.sh', 'pip', 'install', '-E', venv, '-r', PIP_REQUIRES], redirect_output=False) run_command(['tools/with_venv.sh', 'pip', 'install', '-E', venv, TWISTED_NOVA], diff --git a/tools/pip-requires b/tools/pip-requires index dd69708ce..1e2707be7 100644 --- a/tools/pip-requires +++ b/tools/pip-requires @@ -7,15 +7,16 @@ amqplib==0.6.1 anyjson==0.2.4 boto==2.0b1 carrot==0.10.5 -eventlet==0.9.10 +eventlet==0.9.12 lockfile==0.8 python-daemon==1.5.5 python-gflags==1.3 redis==2.0.0 routes==1.12.3 tornado==1.0 -webob==0.9.8 +WebOb==0.9.8 wsgiref==0.1.2 zope.interface==3.6.1 mox==0.5.0 -f http://pymox.googlecode.com/files/mox-0.5.0.tar.gz +greenlet==0.3.1 diff --git a/tools/setup_iptables.sh b/tools/setup_iptables.sh new file mode 100755 index 000000000..673353eb4 --- /dev/null +++ b/tools/setup_iptables.sh @@ -0,0 +1,158 @@ +#!/usr/bin/env bash +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +# NOTE(vish): This script sets up some reasonable defaults for iptables and +# creates nova-specific chains. If you use this script you should +# run nova-network and nova-compute with --use_nova_chains=True + +# NOTE(vish): If you run nova-api on a different port, make sure to change +# the port here +API_PORT=${API_PORT:-"8773"} +if [ -n "$1" ]; then + CMD=$1 +else + CMD="all" +fi + +if [ -n "$2" ]; then + IP=$2 +else + # NOTE(vish): This will just get the first ip in the list, so if you + # have more than one eth device set up, this will fail, and + # you should explicitly pass in the ip of the instance + IP=`ifconfig | grep -m 1 'inet addr:'| cut -d: -f2 | awk '{print $1}'` +fi + +if [ -n "$3" ]; then + PRIVATE_RANGE=$3 +else + PRIVATE_RANGE="10.0.0.0/12" +fi + + +if [ -n "$4" ]; then + # NOTE(vish): Management IP is the ip over which to allow ssh traffic. It + # will also allow traffic to nova-api + MGMT_IP=$4 +else + MGMT_IP="$IP" +fi +if [ "$CMD" == "clear" ]; then + iptables -P INPUT ACCEPT + iptables -P FORWARD ACCEPT + iptables -P OUTPUT ACCEPT + iptables -F + iptables -t nat -F + iptables -F nova_input + iptables -F nova_output + iptables -F nova_forward + iptables -t nat -F nova_input + iptables -t nat -F nova_output + iptables -t nat -F nova_forward + iptables -t nat -X + iptables -X +fi + +if [ "$CMD" == "base" ] || [ "$CMD" == "all" ]; then + iptables -P INPUT DROP + iptables -A INPUT -m state --state INVALID -j DROP + iptables -A INPUT -m state --state RELATED,ESTABLISHED -j ACCEPT + iptables -A INPUT -m tcp -p tcp -d $MGMT_IP --dport 22 -j ACCEPT + iptables -A INPUT -m udp -p udp --dport 123 -j ACCEPT + iptables -N nova_input + iptables -A INPUT -j nova_input + iptables -A INPUT -p icmp -j ACCEPT + iptables -A INPUT -p tcp -j REJECT --reject-with tcp-reset + iptables -A INPUT -j REJECT --reject-with icmp-port-unreachable + + iptables -P FORWARD DROP + iptables -A FORWARD -m state --state INVALID -j DROP + iptables -A FORWARD -m state --state RELATED,ESTABLISHED -j ACCEPT + iptables -A FORWARD -p tcp -m tcp --tcp-flags SYN,RST SYN -j TCPMSS --clamp-mss-to-pmtu + iptables -N nova_forward + iptables -A FORWARD -j nova_forward + + # NOTE(vish): DROP on output is too restrictive for now. We need to add + # in a bunch of more specific output rules to use it. + # iptables -P OUTPUT DROP + iptables -A OUTPUT -m state --state INVALID -j DROP + iptables -A OUTPUT -m state --state RELATED,ESTABLISHED -j ACCEPT + iptables -N nova_output + iptables -A OUTPUT -j nova_output + + iptables -t nat -N nova_prerouting + iptables -t nat -A PREROUTING -j nova_prerouting + + iptables -t nat -N nova_postrouting + iptables -t nat -A POSTROUTING -j nova_postrouting + + iptables -t nat -N nova_output + iptables -t nat -A OUTPUT -j nova_output +fi + +if [ "$CMD" == "ganglia" ] || [ "$CMD" == "all" ]; then + iptables -A nova_input -m tcp -p tcp -d $IP --dport 8649 -j ACCEPT + iptables -A nova_input -m udp -p udp -d $IP --dport 8649 -j ACCEPT +fi + +if [ "$CMD" == "web" ] || [ "$CMD" == "all" ]; then + # NOTE(vish): This opens up ports for web access, allowing web-based + # dashboards to work. + iptables -A nova_input -m tcp -p tcp -d $IP --dport 80 -j ACCEPT + iptables -A nova_input -m tcp -p tcp -d $IP --dport 443 -j ACCEPT +fi + +if [ "$CMD" == "objectstore" ] || [ "$CMD" == "all" ]; then + iptables -A nova_input -m tcp -p tcp -d $IP --dport 3333 -j ACCEPT +fi + +if [ "$CMD" == "api" ] || [ "$CMD" == "all" ]; then + iptables -A nova_input -m tcp -p tcp -d $IP --dport $API_PORT -j ACCEPT + if [ "$IP" != "$MGMT_IP" ]; then + iptables -A nova_input -m tcp -p tcp -d $MGMT_IP --dport $API_PORT -j ACCEPT + fi +fi + +if [ "$CMD" == "redis" ] || [ "$CMD" == "all" ]; then + iptables -A nova_input -m tcp -p tcp -d $IP --dport 6379 -j ACCEPT +fi + +if [ "$CMD" == "mysql" ] || [ "$CMD" == "all" ]; then + iptables -A nova_input -m tcp -p tcp -d $IP --dport 3306 -j ACCEPT +fi + +if [ "$CMD" == "rabbitmq" ] || [ "$CMD" == "all" ]; then + iptables -A nova_input -m tcp -p tcp -d $IP --dport 4369 -j ACCEPT + iptables -A nova_input -m tcp -p tcp -d $IP --dport 5672 -j ACCEPT + iptables -A nova_input -m tcp -p tcp -d $IP --dport 53284 -j ACCEPT +fi + +if [ "$CMD" == "dnsmasq" ] || [ "$CMD" == "all" ]; then + # NOTE(vish): this could theoretically be setup per network + # for each host, but it seems like overkill + iptables -A nova_input -m tcp -p tcp -s $PRIVATE_RANGE --dport 53 -j ACCEPT + iptables -A nova_input -m udp -p udp -s $PRIVATE_RANGE --dport 53 -j ACCEPT + iptables -A nova_input -m udp -p udp --dport 67 -j ACCEPT +fi + +if [ "$CMD" == "ldap" ] || [ "$CMD" == "all" ]; then + iptables -A nova_input -m tcp -p tcp -d $IP --dport 389 -j ACCEPT +fi + + |
