From f21d8510bb3f55b2b76aab251b0427dbfa69c5d9 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 7 Sep 2010 14:34:27 +0200 Subject: Add a clean-traffic filterref to the libvirt templates to prevent spoofing and snooping attacks from the guests. --- nova/virt/libvirt.qemu.xml.template | 3 +++ nova/virt/libvirt.uml.xml.template | 3 +++ 2 files changed, 6 insertions(+) diff --git a/nova/virt/libvirt.qemu.xml.template b/nova/virt/libvirt.qemu.xml.template index 307f9d03a..3de1e5009 100644 --- a/nova/virt/libvirt.qemu.xml.template +++ b/nova/virt/libvirt.qemu.xml.template @@ -20,6 +20,9 @@ + + + diff --git a/nova/virt/libvirt.uml.xml.template b/nova/virt/libvirt.uml.xml.template index 6f4290f98..e64b172d8 100644 --- a/nova/virt/libvirt.uml.xml.template +++ b/nova/virt/libvirt.uml.xml.template @@ -14,6 +14,9 @@ + + + -- cgit From 62dad8422532af4257769bbb0e68120b3393739a Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 7 Sep 2010 14:52:38 +0200 Subject: Add stubbed out handler for AuthorizeSecurityGroupIngress EC2 API call. --- nova/endpoint/cloud.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index 8e2beb1e3..89402896c 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -240,6 +240,10 @@ class CloudController(object): # Stubbed for now to unblock other things. return groups + @rbac.allow('netadmin') + def authorize_security_group_ingress(self, context, group_name, **kwargs): + return True + @rbac.allow('netadmin') def create_security_group(self, context, group_name, **kwargs): return True -- cgit From bd07d6b3b3e9ed3ef3e65e99b628c8b1aaf2f82c Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Thu, 9 Sep 2010 12:35:46 +0200 Subject: Alright, first hole poked all the way through. We can now create security groups and read them back. --- nova/auth/manager.py | 6 +++++ nova/db/api.py | 22 ++++++++++++++++++ nova/db/sqlalchemy/api.py | 38 +++++++++++++++++++++++++++++++ nova/db/sqlalchemy/models.py | 54 +++++++++++++++++++++++++++++++++++++++++++- nova/endpoint/cloud.py | 14 ++++++++---- nova/tests/api_unittest.py | 34 +++++++++++++++++++++++++--- 6 files changed, 160 insertions(+), 8 deletions(-) diff --git a/nova/auth/manager.py b/nova/auth/manager.py index d5fbec7c5..6aa5721c8 100644 --- a/nova/auth/manager.py +++ b/nova/auth/manager.py @@ -640,11 +640,17 @@ class AuthManager(object): with self.driver() as drv: user_dict = drv.create_user(name, access, secret, admin) if user_dict: + db.security_group_create(context={}, + values={ 'name' : 'default', + 'description' : 'default', + 'user_id' : name }) return User(**user_dict) def delete_user(self, user): """Deletes a user""" with self.driver() as drv: + for security_group in db.security_group_get_by_user(context = {}, user_id=user.id): + db.security_group_destroy({}, security_group.id) drv.delete_user(User.safe_id(user)) def generate_key_pair(self, user, key_name): diff --git a/nova/db/api.py b/nova/db/api.py index b49707392..b67e3afe0 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -442,3 +442,25 @@ def volume_update(context, volume_id, values): """ return IMPL.volume_update(context, volume_id, values) + +#################### + + +def security_group_create(context, values): + """Create a new security group""" + return IMPL.security_group_create(context, values) + + +def security_group_get_by_instance(context, instance_id): + """Get security groups to which the instance is assigned""" + return IMPL.security_group_get_by_instance(context, instance_id) + + +def security_group_get_by_user(context, user_id): + """Get security groups owned by the given user""" + return IMPL.security_group_get_by_user(context, user_id) + + +def security_group_destroy(context, security_group_id): + """Deletes a security group""" + return IMPL.security_group_destroy(context, security_group_id) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 5172b87b3..d790d3fac 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -581,3 +581,41 @@ def volume_update(context, volume_id, values): for (key, value) in values.iteritems(): volume_ref[key] = value volume_ref.save() + + +################### + + +def security_group_create(_context, values): + security_group_ref = models.SecurityGroup() + for (key, value) in values.iteritems(): + security_group_ref[key] = value + security_group_ref.save() + return security_group_ref + + +def security_group_get_by_instance(_context, instance_id): + with managed_session() as session: + return session.query(models.Instance) \ + .get(instance_id) \ + .security_groups \ + .all() + + +def security_group_get_by_user(_context, user_id): + with managed_session() as session: + return session.query(models.SecurityGroup) \ + .filter_by(user_id=user_id) \ + .filter_by(deleted=False) \ + .all() + +def security_group_destroy(_context, security_group_id): + with managed_session() as session: + security_group = session.query(models.SecurityGroup) \ + .get(security_group_id) + security_group.delete(session=session) + +def security_group_get_all(_context): + return models.SecurityGroup.all() + + diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 310d4640e..28c25bfbc 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -26,7 +26,7 @@ import datetime # TODO(vish): clean up these imports from sqlalchemy.orm import relationship, backref, validates, exc from sqlalchemy.sql import func -from sqlalchemy import Column, Integer, String +from sqlalchemy import Column, Integer, String, Table from sqlalchemy import ForeignKey, DateTime, Boolean, Text from sqlalchemy.ext.declarative import declarative_base @@ -292,6 +292,58 @@ class ExportDevice(BASE, NovaBase): uselist=False)) +security_group_instance_association = Table('security_group_instance_association', + BASE.metadata, + Column('security_group_id', Integer, + ForeignKey('security_group.id')), + Column('instance_id', Integer, + ForeignKey('instances.id'))) + +class SecurityGroup(BASE, NovaBase): + """Represents a security group""" + __tablename__ = 'security_group' + id = Column(Integer, primary_key=True) + + name = Column(String(255)) + description = Column(String(255)) + + user_id = Column(String(255)) + project_id = Column(String(255)) + + instances = relationship(Instance, + secondary=security_group_instance_association, + backref='security_groups') + + @property + def user(self): + return auth.manager.AuthManager().get_user(self.user_id) + + @property + def project(self): + return auth.manager.AuthManager().get_project(self.project_id) + + +class SecurityGroupIngressRule(BASE, NovaBase): + """Represents a rule in a security group""" + __tablename__ = 'security_group_rules' + id = Column(Integer, primary_key=True) + + parent_security_group = Column(Integer, ForeignKey('security_group.id')) + protocol = Column(String(5)) # "tcp", "udp", or "icmp" + fromport = Column(Integer) + toport = Column(Integer) + + # Note: This is not the parent SecurityGroup's owner. It's the owner of + # the SecurityGroup we're granting access. + user_id = Column(String(255)) + group_id = Column(Integer, ForeignKey('security_group.id')) + + @property + def user(self): + return auth.manager.AuthManager().get_user(self.user_id) + + cidr = Column(String(255)) + class Network(BASE, NovaBase): """Represents a network""" __tablename__ = 'networks' diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index 44997be59..7df8bd081 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -212,10 +212,12 @@ class CloudController(object): return True @rbac.allow('all') - def describe_security_groups(self, context, group_names, **kwargs): - groups = {'securityGroupSet': []} + def describe_security_groups(self, context, **kwargs): + groups = {'securityGroupSet': + [{ 'groupDescription': group.description, + 'groupName' : group.name, + 'ownerId': context.user.id } for group in db.security_group_get_by_user(context, context.user.id) ] } - # Stubbed for now to unblock other things. return groups @rbac.allow('netadmin') @@ -223,7 +225,11 @@ class CloudController(object): return True @rbac.allow('netadmin') - def create_security_group(self, context, group_name, **kwargs): + def create_security_group(self, context, group_name, group_description): + db.security_group_create(context, + values = { 'user_id' : context.user.id, + 'name': group_name, + 'description': group_description }) return True @rbac.allow('netadmin') diff --git a/nova/tests/api_unittest.py b/nova/tests/api_unittest.py index 462d1b295..87d99607d 100644 --- a/nova/tests/api_unittest.py +++ b/nova/tests/api_unittest.py @@ -185,6 +185,9 @@ class ApiEc2TestCase(test.BaseTestCase): self.host = '127.0.0.1' self.app = api.APIServerApplication({'Cloud': self.cloud}) + + def expect_http(self, host=None, is_secure=False): + """Returns a new EC2 connection""" self.ec2 = boto.connect_ec2( aws_access_key_id='fake', aws_secret_access_key='fake', @@ -194,9 +197,6 @@ class ApiEc2TestCase(test.BaseTestCase): path='/services/Cloud') self.mox.StubOutWithMock(self.ec2, 'new_http_connection') - - def expect_http(self, host=None, is_secure=False): - """Returns a new EC2 connection""" http = FakeHttplibConnection( self.app, '%s:%d' % (self.host, FLAGS.cc_port), False) # pylint: disable-msg=E1103 @@ -231,3 +231,31 @@ class ApiEc2TestCase(test.BaseTestCase): self.assertEquals(len(results), 1) self.manager.delete_project(project) self.manager.delete_user(user) + + def test_get_all_security_groups(self): + """Test that operations on security groups stick""" + self.expect_http() + self.mox.ReplayAll() + security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ + for x in range(random.randint(4, 8))) + user = self.manager.create_user('fake', 'fake', 'fake', admin=True) + project = self.manager.create_project('fake', 'fake', 'fake') + + rv = self.ec2.get_all_security_groups() + self.assertEquals(len(rv), 1) + self.assertEquals(rv[0].name, 'default') + + self.expect_http() + self.mox.ReplayAll() + + self.ec2.create_security_group(security_group_name, 'test group') + + self.expect_http() + self.mox.ReplayAll() + + rv = self.ec2.get_all_security_groups() + self.assertEquals(len(rv), 2) + self.assertTrue(security_group_name in [group.name for group in rv]) + + self.manager.delete_project(project) + self.manager.delete_user(user) -- cgit From bd7ac72b9774a181e51dde5dff09ed4c47b556a7 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Thu, 9 Sep 2010 15:13:04 +0200 Subject: AuthorizeSecurityGroupIngress now works. --- nova/db/api.py | 13 +++++++ nova/db/sqlalchemy/api.py | 19 ++++++++++ nova/db/sqlalchemy/models.py | 9 +++-- nova/endpoint/cloud.py | 50 +++++++++++++++++++++++--- nova/tests/api_unittest.py | 83 +++++++++++++++++++++++++++++++++++++++++--- 5 files changed, 161 insertions(+), 13 deletions(-) diff --git a/nova/db/api.py b/nova/db/api.py index b67e3afe0..af574d6de 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -461,6 +461,19 @@ def security_group_get_by_user(context, user_id): return IMPL.security_group_get_by_user(context, user_id) +def security_group_get_by_user_and_name(context, user_id, name): + """Get user's named security group""" + return IMPL.security_group_get_by_user_and_name(context, user_id, name) + + def security_group_destroy(context, security_group_id): """Deletes a security group""" return IMPL.security_group_destroy(context, security_group_id) + + +#################### + + +def security_group_rule_create(context, values): + """Create a new security group""" + return IMPL.security_group_rule_create(context, values) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index d790d3fac..c8d852f9d 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -609,6 +609,14 @@ def security_group_get_by_user(_context, user_id): .filter_by(deleted=False) \ .all() +def security_group_get_by_user_and_name(_context, user_id, name): + with managed_session() as session: + return session.query(models.SecurityGroup) \ + .filter_by(user_id=user_id) \ + .filter_by(name=name) \ + .filter_by(deleted=False) \ + .one() + def security_group_destroy(_context, security_group_id): with managed_session() as session: security_group = session.query(models.SecurityGroup) \ @@ -619,3 +627,14 @@ def security_group_get_all(_context): return models.SecurityGroup.all() + + +################### + + +def security_group_rule_create(_context, values): + security_group_rule_ref = models.SecurityGroupIngressRule() + for (key, value) in values.iteritems(): + security_group_rule_ref[key] = value + security_group_rule_ref.save() + return security_group_rule_ref diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 28c25bfbc..330262a88 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -330,12 +330,11 @@ class SecurityGroupIngressRule(BASE, NovaBase): parent_security_group = Column(Integer, ForeignKey('security_group.id')) protocol = Column(String(5)) # "tcp", "udp", or "icmp" - fromport = Column(Integer) - toport = Column(Integer) + from_port = Column(Integer) + to_port = Column(Integer) - # Note: This is not the parent SecurityGroup's owner. It's the owner of - # the SecurityGroup we're granting access. - user_id = Column(String(255)) + # Note: This is not the parent SecurityGroup. It's SecurityGroup we're + # granting access for. group_id = Column(Integer, ForeignKey('security_group.id')) @property diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index 7df8bd081..0a929b865 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -214,14 +214,54 @@ class CloudController(object): @rbac.allow('all') def describe_security_groups(self, context, **kwargs): groups = {'securityGroupSet': - [{ 'groupDescription': group.description, - 'groupName' : group.name, - 'ownerId': context.user.id } for group in db.security_group_get_by_user(context, context.user.id) ] } + [{ 'groupDescription': group.description, + 'groupName' : group.name, + 'ownerId': context.user.id } for group in \ + db.security_group_get_by_user(context, + context.user.id) ] } return groups @rbac.allow('netadmin') - def authorize_security_group_ingress(self, context, group_name, **kwargs): + def authorize_security_group_ingress(self, context, group_name, + to_port=None, from_port=None, + ip_protocol=None, cidr_ip=None, + user_id=None, + source_security_group_name=None, + source_security_group_owner_id=None): + security_group = db.security_group_get_by_user_and_name(context, + context.user.id, + group_name) + values = { 'parent_security_group' : security_group.id } + + # Aw, crap. + if source_security_group_name: + if source_security_group_owner_id: + other_user_id = source_security_group_owner_id + else: + other_user_id = context.user.id + + foreign_security_group = \ + db.security_group_get_by_user_and_name(context, + other_user_id, + source_security_group_name) + values['group_id'] = foreign_security_group.id + elif cidr_ip: + values['cidr'] = cidr_ip + else: + return { 'return': False } + + if ip_protocol and from_port and to_port: + values['protocol'] = ip_protocol + values['from_port'] = from_port + values['to_port'] = to_port + else: + # If cidr based filtering, protocol and ports are mandatory + if 'cidr' in values: + print values + return None + + security_group_rule = db.security_group_rule_create(context, values) return True @rbac.allow('netadmin') @@ -234,6 +274,8 @@ class CloudController(object): @rbac.allow('netadmin') def delete_security_group(self, context, group_name, **kwargs): + security_group = db.security_group_get_by_user_and_name(context, context.user.id, group_name) + security_group.delete() return True @rbac.allow('projectmanager', 'sysadmin') diff --git a/nova/tests/api_unittest.py b/nova/tests/api_unittest.py index 87d99607d..6cd59541f 100644 --- a/nova/tests/api_unittest.py +++ b/nova/tests/api_unittest.py @@ -233,20 +233,29 @@ class ApiEc2TestCase(test.BaseTestCase): self.manager.delete_user(user) def test_get_all_security_groups(self): - """Test that operations on security groups stick""" + """Test that we can retrieve security groups""" self.expect_http() self.mox.ReplayAll() - security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ - for x in range(random.randint(4, 8))) user = self.manager.create_user('fake', 'fake', 'fake', admin=True) project = self.manager.create_project('fake', 'fake', 'fake') rv = self.ec2.get_all_security_groups() + self.assertEquals(len(rv), 1) - self.assertEquals(rv[0].name, 'default') + self.assertEquals(rv[0].name, 'default') + + self.manager.delete_project(project) + self.manager.delete_user(user) + def test_create_delete_security_group(self): + """Test that we can create a security group""" self.expect_http() self.mox.ReplayAll() + user = self.manager.create_user('fake', 'fake', 'fake', admin=True) + project = self.manager.create_project('fake', 'fake', 'fake') + + security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ + for x in range(random.randint(4, 8))) self.ec2.create_security_group(security_group_name, 'test group') @@ -257,5 +266,71 @@ class ApiEc2TestCase(test.BaseTestCase): self.assertEquals(len(rv), 2) self.assertTrue(security_group_name in [group.name for group in rv]) + self.expect_http() + self.mox.ReplayAll() + + self.ec2.delete_security_group(security_group_name) + self.manager.delete_project(project) self.manager.delete_user(user) + + def test_authorize_security_group_cidr(self): + """Test that we can add rules to a security group""" + self.expect_http() + self.mox.ReplayAll() + user = self.manager.create_user('fake', 'fake', 'fake', admin=True) + project = self.manager.create_project('fake', 'fake', 'fake') + + security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ + for x in range(random.randint(4, 8))) + + group = self.ec2.create_security_group(security_group_name, 'test group') + + self.expect_http() + self.mox.ReplayAll() + group.connection = self.ec2 + + group.authorize('tcp', 80, 80, '0.0.0.0/0') + + self.expect_http() + self.mox.ReplayAll() + + self.ec2.delete_security_group(security_group_name) + + self.manager.delete_project(project) + self.manager.delete_user(user) + + return + + def test_authorize_security_group_foreign_group(self): + """Test that we can grant another security group access to a security group""" + self.expect_http() + self.mox.ReplayAll() + user = self.manager.create_user('fake', 'fake', 'fake', admin=True) + project = self.manager.create_project('fake', 'fake', 'fake') + + security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ + for x in range(random.randint(4, 8))) + + group = self.ec2.create_security_group(security_group_name, 'test group') + + self.expect_http() + self.mox.ReplayAll() + + other_group = self.ec2.create_security_group('appserver', 'The application tier') + + self.expect_http() + self.mox.ReplayAll() + group.connection = self.ec2 + + group.authorize(src_group=other_group) + + self.expect_http() + self.mox.ReplayAll() + + self.ec2.delete_security_group(security_group_name) + + self.manager.delete_project(project) + self.manager.delete_user(user) + + return -- cgit From 59a959299d7883c48626d8d5630974d718194960 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Thu, 9 Sep 2010 17:35:02 +0200 Subject: Authorize and Revoke access now works. --- nova/db/api.py | 9 ++++++++ nova/db/sqlalchemy/api.py | 8 +++++++ nova/db/sqlalchemy/models.py | 7 ++++-- nova/endpoint/cloud.py | 51 +++++++++++++++++++++++++++++++++++++++++--- nova/tests/api_unittest.py | 26 ++++++++++++++++++---- 5 files changed, 92 insertions(+), 9 deletions(-) diff --git a/nova/db/api.py b/nova/db/api.py index af574d6de..63ead04e0 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -477,3 +477,12 @@ def security_group_destroy(context, security_group_id): def security_group_rule_create(context, values): """Create a new security group""" return IMPL.security_group_rule_create(context, values) + + +def security_group_rule_get_by_security_group(context, security_group_id): + """Get all rules for a a given security group""" + return IMPL.security_group_rule_get_by_security_group(context, security_group_id) + +def security_group_rule_destroy(context, security_group_rule_id): + """Deletes a security group rule""" + return IMPL.security_group_rule_destroy(context, security_group_rule_id) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index c8d852f9d..2db876154 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -25,6 +25,7 @@ from nova import flags from nova.db.sqlalchemy import models from nova.db.sqlalchemy.session import managed_session from sqlalchemy import or_ +from sqlalchemy.orm import eagerload FLAGS = flags.FLAGS @@ -615,6 +616,7 @@ def security_group_get_by_user_and_name(_context, user_id, name): .filter_by(user_id=user_id) \ .filter_by(name=name) \ .filter_by(deleted=False) \ + .options(eagerload('rules')) \ .one() def security_group_destroy(_context, security_group_id): @@ -638,3 +640,9 @@ def security_group_rule_create(_context, values): security_group_rule_ref[key] = value security_group_rule_ref.save() return security_group_rule_ref + +def security_group_rule_destroy(_context, security_group_rule_id): + with managed_session() as session: + security_group_rule = session.query(models.SecurityGroupIngressRule) \ + .get(security_group_rule_id) + security_group_rule.delete(session=session) diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 330262a88..d177688d8 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -328,14 +328,17 @@ class SecurityGroupIngressRule(BASE, NovaBase): __tablename__ = 'security_group_rules' id = Column(Integer, primary_key=True) - parent_security_group = Column(Integer, ForeignKey('security_group.id')) + parent_group_id = Column(Integer, ForeignKey('security_group.id')) + parent_group = relationship("SecurityGroup", backref="rules", foreign_keys=parent_group_id) +# primaryjoin=SecurityGroup().id==parent_group_id) + protocol = Column(String(5)) # "tcp", "udp", or "icmp" from_port = Column(Integer) to_port = Column(Integer) # Note: This is not the parent SecurityGroup. It's SecurityGroup we're # granting access for. - group_id = Column(Integer, ForeignKey('security_group.id')) +# group_id = Column(Integer, ForeignKey('security_group.id')) @property def user(self): diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index 0a929b865..6e32a945b 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -222,6 +222,52 @@ class CloudController(object): return groups + @rbac.allow('netadmin') + def revoke_security_group_ingress(self, context, group_name, + to_port=None, from_port=None, + ip_protocol=None, cidr_ip=None, + user_id=None, + source_security_group_name=None, + source_security_group_owner_id=None): + security_group = db.security_group_get_by_user_and_name(context, + context.user.id, + group_name) + + criteria = {} + + if source_security_group_name: + if source_security_group_owner_id: + other_user_id = source_security_group_owner_id + else: + other_user_id = context.user.id + + foreign_security_group = \ + db.security_group_get_by_user_and_name(context, + other_user_id, + source_security_group_name) + criteria['group_id'] = foreign_security_group.id + elif cidr_ip: + criteria['cidr'] = cidr_ip + else: + return { 'return': False } + + if ip_protocol and from_port and to_port: + criteria['protocol'] = ip_protocol + criteria['from_port'] = from_port + criteria['to_port'] = to_port + else: + # If cidr based filtering, protocol and ports are mandatory + if 'cidr' in criteria: + return { 'return': False } + + for rule in security_group.rules: + for (k,v) in criteria.iteritems(): + if getattr(rule, k, False) != v: + break + # If we make it here, we have a match + db.security_group_rule_destroy(context, rule.id) + return True + @rbac.allow('netadmin') def authorize_security_group_ingress(self, context, group_name, to_port=None, from_port=None, @@ -232,13 +278,12 @@ class CloudController(object): security_group = db.security_group_get_by_user_and_name(context, context.user.id, group_name) - values = { 'parent_security_group' : security_group.id } + values = { 'parent_group_id' : security_group.id } - # Aw, crap. if source_security_group_name: if source_security_group_owner_id: other_user_id = source_security_group_owner_id - else: + else: other_user_id = context.user.id foreign_security_group = \ diff --git a/nova/tests/api_unittest.py b/nova/tests/api_unittest.py index 6cd59541f..f25e377d0 100644 --- a/nova/tests/api_unittest.py +++ b/nova/tests/api_unittest.py @@ -274,8 +274,11 @@ class ApiEc2TestCase(test.BaseTestCase): self.manager.delete_project(project) self.manager.delete_user(user) - def test_authorize_security_group_cidr(self): - """Test that we can add rules to a security group""" + def test_authorize_revoke_security_group_cidr(self): + """ + Test that we can add and remove CIDR based rules + to a security group + """ self.expect_http() self.mox.ReplayAll() user = self.manager.create_user('fake', 'fake', 'fake', admin=True) @@ -292,6 +295,12 @@ class ApiEc2TestCase(test.BaseTestCase): group.authorize('tcp', 80, 80, '0.0.0.0/0') + self.expect_http() + self.mox.ReplayAll() + group.connection = self.ec2 + + group.revoke('tcp', 80, 80, '0.0.0.0/0') + self.expect_http() self.mox.ReplayAll() @@ -302,8 +311,11 @@ class ApiEc2TestCase(test.BaseTestCase): return - def test_authorize_security_group_foreign_group(self): - """Test that we can grant another security group access to a security group""" + def test_authorize_revoke_security_group_foreign_group(self): + """ + Test that we can grant and revoke another security group access + to a security group + """ self.expect_http() self.mox.ReplayAll() user = self.manager.create_user('fake', 'fake', 'fake', admin=True) @@ -325,6 +337,12 @@ class ApiEc2TestCase(test.BaseTestCase): group.authorize(src_group=other_group) + self.expect_http() + self.mox.ReplayAll() + group.connection = self.ec2 + + group.revoke(src_group=other_group) + self.expect_http() self.mox.ReplayAll() -- cgit From ecbbfa343edf0ca0e82b35dc655fa23701bbdf22 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Fri, 10 Sep 2010 11:47:06 +0200 Subject: Create and delete security groups works. Adding and revoking rules works. DescribeSecurityGroups returns the groups and rules. So, the API seems to be done. Yay. --- nova/db/api.py | 5 ++++ nova/db/sqlalchemy/api.py | 7 ++++++ nova/db/sqlalchemy/models.py | 6 ++--- nova/endpoint/api.py | 1 + nova/endpoint/cloud.py | 43 ++++++++++++++++++++++++++------ nova/tests/api_unittest.py | 58 ++++++++++++++++++++++++++++++++++++++++---- 6 files changed, 104 insertions(+), 16 deletions(-) diff --git a/nova/db/api.py b/nova/db/api.py index 63ead04e0..c7a6da183 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -451,6 +451,11 @@ def security_group_create(context, values): return IMPL.security_group_create(context, values) +def security_group_get_by_id(context, security_group_id): + """Get security group by its internal id""" + return IMPL.security_group_get_by_id(context, security_group_id) + + def security_group_get_by_instance(context, instance_id): """Get security groups to which the instance is assigned""" return IMPL.security_group_get_by_instance(context, instance_id) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 2db876154..4027e901c 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -595,6 +595,12 @@ def security_group_create(_context, values): return security_group_ref +def security_group_get_by_id(_context, security_group_id): + with managed_session() as session: + return session.query(models.SecurityGroup) \ + .get(security_group_id) + + def security_group_get_by_instance(_context, instance_id): with managed_session() as session: return session.query(models.Instance) \ @@ -608,6 +614,7 @@ def security_group_get_by_user(_context, user_id): return session.query(models.SecurityGroup) \ .filter_by(user_id=user_id) \ .filter_by(deleted=False) \ + .options(eagerload('rules')) \ .all() def security_group_get_by_user_and_name(_context, user_id, name): diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index d177688d8..27c8e4d4c 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -329,8 +329,8 @@ class SecurityGroupIngressRule(BASE, NovaBase): id = Column(Integer, primary_key=True) parent_group_id = Column(Integer, ForeignKey('security_group.id')) - parent_group = relationship("SecurityGroup", backref="rules", foreign_keys=parent_group_id) -# primaryjoin=SecurityGroup().id==parent_group_id) + parent_group = relationship("SecurityGroup", backref="rules", foreign_keys=parent_group_id, + primaryjoin=parent_group_id==SecurityGroup.id) protocol = Column(String(5)) # "tcp", "udp", or "icmp" from_port = Column(Integer) @@ -338,7 +338,7 @@ class SecurityGroupIngressRule(BASE, NovaBase): # Note: This is not the parent SecurityGroup. It's SecurityGroup we're # granting access for. -# group_id = Column(Integer, ForeignKey('security_group.id')) + group_id = Column(Integer, ForeignKey('security_group.id')) @property def user(self): diff --git a/nova/endpoint/api.py b/nova/endpoint/api.py index 40be00bb7..1f37aeb02 100755 --- a/nova/endpoint/api.py +++ b/nova/endpoint/api.py @@ -135,6 +135,7 @@ class APIRequest(object): response = xml.toxml() xml.unlink() +# print response _log.debug(response) return response diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index 6e32a945b..e6eca9850 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -213,14 +213,41 @@ class CloudController(object): @rbac.allow('all') def describe_security_groups(self, context, **kwargs): - groups = {'securityGroupSet': - [{ 'groupDescription': group.description, - 'groupName' : group.name, - 'ownerId': context.user.id } for group in \ - db.security_group_get_by_user(context, - context.user.id) ] } - - return groups + groups = [] + for group in db.security_group_get_by_user(context, context.user.id): + group_dict = {} + group_dict['groupDescription'] = group.description + group_dict['groupName'] = group.name + group_dict['ownerId'] = context.user.id + group_dict['ipPermissions'] = [] + for rule in group.rules: + rule_dict = {} + rule_dict['ipProtocol'] = rule.protocol + rule_dict['fromPort'] = rule.from_port + rule_dict['toPort'] = rule.to_port + rule_dict['groups'] = [] + rule_dict['ipRanges'] = [] + if rule.group_id: + foreign_group = db.security_group_get_by_id({}, rule.group_id) + rule_dict['groups'] += [ { 'groupName': foreign_group.name, + 'userId': foreign_group.user_id } ] + else: + rule_dict['ipRanges'] += [ { 'cidrIp': rule.cidr } ] + group_dict['ipPermissions'] += [ rule_dict ] + groups += [ group_dict ] + + return {'securityGroupInfo': groups } +# +# [{ 'groupDescription': group.description, +# 'groupName' : group.name, +# 'ownerId': context.user.id, +# 'ipPermissions' : [ +# { 'ipProtocol' : rule.protocol, +# 'fromPort' : rule.from_port, +# 'toPort' : rule.to_port, +# 'ipRanges' : [ { 'cidrIp' : rule.cidr } ] } for rule in group.rules ] } for group in \ +# +# return groups @rbac.allow('netadmin') def revoke_security_group_ingress(self, context, group_name, diff --git a/nova/tests/api_unittest.py b/nova/tests/api_unittest.py index f25e377d0..7e914e6f5 100644 --- a/nova/tests/api_unittest.py +++ b/nova/tests/api_unittest.py @@ -293,19 +293,43 @@ class ApiEc2TestCase(test.BaseTestCase): self.mox.ReplayAll() group.connection = self.ec2 - group.authorize('tcp', 80, 80, '0.0.0.0/0') + group.authorize('tcp', 80, 81, '0.0.0.0/0') + + self.expect_http() + self.mox.ReplayAll() + + rv = self.ec2.get_all_security_groups() + # I don't bother checkng that we actually find it here, + # because the create/delete unit test further up should + # be good enough for that. + for group in rv: + if group.name == security_group_name: + self.assertEquals(len(group.rules), 1) + self.assertEquals(int(group.rules[0].from_port), 80) + self.assertEquals(int(group.rules[0].to_port), 81) + self.assertEquals(len(group.rules[0].grants), 1) + self.assertEquals(str(group.rules[0].grants[0]), '0.0.0.0/0') self.expect_http() self.mox.ReplayAll() group.connection = self.ec2 - group.revoke('tcp', 80, 80, '0.0.0.0/0') + group.revoke('tcp', 80, 81, '0.0.0.0/0') self.expect_http() self.mox.ReplayAll() self.ec2.delete_security_group(security_group_name) + self.expect_http() + self.mox.ReplayAll() + group.connection = self.ec2 + + rv = self.ec2.get_all_security_groups() + + self.assertEqual(len(rv), 1) + self.assertEqual(rv[0].name, 'default') + self.manager.delete_project(project) self.manager.delete_user(user) @@ -323,13 +347,16 @@ class ApiEc2TestCase(test.BaseTestCase): security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ for x in range(random.randint(4, 8))) + other_security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ + for x in range(random.randint(4, 8))) group = self.ec2.create_security_group(security_group_name, 'test group') self.expect_http() self.mox.ReplayAll() - other_group = self.ec2.create_security_group('appserver', 'The application tier') + other_group = self.ec2.create_security_group(other_security_group_name, + 'some other group') self.expect_http() self.mox.ReplayAll() @@ -339,9 +366,30 @@ class ApiEc2TestCase(test.BaseTestCase): self.expect_http() self.mox.ReplayAll() - group.connection = self.ec2 - group.revoke(src_group=other_group) + rv = self.ec2.get_all_security_groups() + # I don't bother checkng that we actually find it here, + # because the create/delete unit test further up should + # be good enough for that. + for group in rv: + if group.name == security_group_name: + self.assertEquals(len(group.rules), 1) + self.assertEquals(len(group.rules[0].grants), 1) + self.assertEquals(str(group.rules[0].grants[0]), + '%s-%s' % (other_security_group_name, 'fake')) + + + self.expect_http() + self.mox.ReplayAll() + + rv = self.ec2.get_all_security_groups() + + for group in rv: + if group.name == security_group_name: + self.expect_http() + self.mox.ReplayAll() + group.connection = self.ec2 + group.revoke(src_group=other_group) self.expect_http() self.mox.ReplayAll() -- cgit From c3dd0aa79d982d8f34172e6023d4b632ea23f2b9 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Fri, 10 Sep 2010 14:56:36 +0200 Subject: First pass of nwfilter based security group implementation. It is not where it is supposed to be and it does not actually do anything yet. --- nova/auth/manager.py | 2 +- nova/db/sqlalchemy/api.py | 1 + nova/endpoint/cloud.py | 1 - nova/tests/virt_unittest.py | 50 ++++++++++++++++++++++++++++++++--- nova/virt/libvirt_conn.py | 63 +++++++++++++++++++++++++++++++++++++++++++++ run_tests.py | 1 + 6 files changed, 113 insertions(+), 5 deletions(-) diff --git a/nova/auth/manager.py b/nova/auth/manager.py index 6aa5721c8..281e2d8f0 100644 --- a/nova/auth/manager.py +++ b/nova/auth/manager.py @@ -649,7 +649,7 @@ class AuthManager(object): def delete_user(self, user): """Deletes a user""" with self.driver() as drv: - for security_group in db.security_group_get_by_user(context = {}, user_id=user.id): + for security_group in db.security_group_get_by_user(context = {}, user_id=User.safe_id(user)): db.security_group_destroy({}, security_group.id) drv.delete_user(User.safe_id(user)) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 4027e901c..622e76cd7 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -598,6 +598,7 @@ def security_group_create(_context, values): def security_group_get_by_id(_context, security_group_id): with managed_session() as session: return session.query(models.SecurityGroup) \ + .options(eagerload('rules')) \ .get(security_group_id) diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index e6eca9850..5e5ed6c5e 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -299,7 +299,6 @@ class CloudController(object): def authorize_security_group_ingress(self, context, group_name, to_port=None, from_port=None, ip_protocol=None, cidr_ip=None, - user_id=None, source_security_group_name=None, source_security_group_owner_id=None): security_group = db.security_group_get_by_user_and_name(context, diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py index 2aab16809..b8dcec12b 100644 --- a/nova/tests/virt_unittest.py +++ b/nova/tests/virt_unittest.py @@ -14,23 +14,30 @@ # License for the specific language governing permissions and limitations # under the License. +from xml.dom.minidom import parseString + from nova import flags from nova import test +from nova.endpoint import cloud from nova.virt import libvirt_conn FLAGS = flags.FLAGS class LibvirtConnTestCase(test.TrialTestCase): - def test_get_uri_and_template(self): + def bitrot_test_get_uri_and_template(self): class MockDataModel(object): + def __getitem__(self, name): + return self.datamodel[name] + def __init__(self): self.datamodel = { 'name' : 'i-cafebabe', 'memory_kb' : '1024000', 'basepath' : '/some/path', 'bridge_name' : 'br100', 'mac_address' : '02:12:34:46:56:67', - 'vcpus' : 2 } + 'vcpus' : 2, + 'project_id' : None } type_uri_map = { 'qemu' : ('qemu:///system', [lambda s: '' in s, @@ -53,7 +60,7 @@ class LibvirtConnTestCase(test.TrialTestCase): self.assertEquals(uri, expected_uri) for i, check in enumerate(checks): - xml = conn.toXml(MockDataModel()) + xml = conn.to_xml(MockDataModel()) self.assertTrue(check(xml), '%s failed check %d' % (xml, i)) # Deliberately not just assigning this string to FLAGS.libvirt_uri and @@ -67,3 +74,40 @@ class LibvirtConnTestCase(test.TrialTestCase): uri, template = conn.get_uri_and_template() self.assertEquals(uri, testuri) + +class NWFilterTestCase(test.TrialTestCase): + def test_stuff(self): + cloud_controller = cloud.CloudController() + class FakeContext(object): + pass + + context = FakeContext() + context.user = FakeContext() + context.user.id = 'fake' + context.user.is_superuser = lambda:True + cloud_controller.create_security_group(context, 'testgroup', 'test group description') + cloud_controller.authorize_security_group_ingress(context, 'testgroup', from_port='80', + to_port='81', ip_protocol='tcp', + cidr_ip='0.0.0.0/0') + + fw = libvirt_conn.NWFilterFirewall() + xml = fw.security_group_to_nwfilter_xml(1) + + dom = parseString(xml) + self.assertEqual(dom.firstChild.tagName, 'filter') + + rules = dom.getElementsByTagName('rule') + self.assertEqual(len(rules), 1) + + # It's supposed to allow inbound traffic. + self.assertEqual(rules[0].getAttribute('action'), 'allow') + self.assertEqual(rules[0].getAttribute('direction'), 'in') + + # Must be lower priority than the base filter (which blocks everything) + self.assertTrue(int(rules[0].getAttribute('priority')) < 1000) + + ip_conditions = rules[0].getElementsByTagName('ip') + self.assertEqual(len(ip_conditions), 1) + self.assertEqual(ip_conditions[0].getAttribute('protocol'), 'tcp') + self.assertEqual(ip_conditions[0].getAttribute('dstportstart'), '80') + self.assertEqual(ip_conditions[0].getAttribute('dstportend'), '81') diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index e26030158..7bf2a68b1 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -426,3 +426,66 @@ class LibvirtConnection(object): """ domain = self._conn.lookupByName(instance_name) return domain.interfaceStats(interface) + + +class NWFilterFirewall(object): + """ + This class implements a network filtering mechanism versatile + enough for EC2 style Security Group filtering by leveraging + libvirt's nwfilter. + + First, all instances get a filter ("nova-base-filter") applied. + This filter drops all incoming ipv4 and ipv6 connections. + Outgoing connections are never blocked. + + Second, every security group maps to a nwfilter filter(*). + NWFilters can be updated at runtime and changes are applied + immediately, so changes to security groups can be applied at + runtime (as mandated by the spec). + + Security group rules are named "nova-secgroup-" where + is the internal id of the security group. They're applied only on + hosts that have instances in the security group in question. + + Updates to security groups are done by updating the data model + (in response to API calls) followed by a request sent to all + the nodes with instances in the security group to refresh the + security group. + + Outstanding questions: + + The name is unique, so would there be any good reason to sync + the uuid across the nodes (by assigning it from the datamodel)? + + + (*) This sentence brought to you by the redundancy department of + redundancy. + """ + + def __init__(self): + pass + + def nova_base_filter(self): + return ''' + 26717364-50cf-42d1-8185-29bf893ab110 + + + + + + +''' + + def security_group_to_nwfilter_xml(self, security_group_id): + security_group = db.security_group_get_by_id({}, security_group_id) + rule_xml = "" + for rule in security_group.rules: + rule_xml += "" + if rule.cidr: + rule_xml += ("") % \ + (rule.cidr, rule.protocol, + rule.from_port, rule.to_port) + rule_xml += "" + xml = '''%s''' % (security_group_id, rule_xml,) + return xml diff --git a/run_tests.py b/run_tests.py index d5dc5f934..75ab561a1 100644 --- a/run_tests.py +++ b/run_tests.py @@ -62,6 +62,7 @@ from nova.tests.rpc_unittest import * from nova.tests.service_unittest import * from nova.tests.validator_unittest import * from nova.tests.volume_unittest import * +from nova.tests.virt_unittest import * FLAGS = flags.FLAGS -- cgit From fffa02ac32055650b2bfffff090ec7d52c86291a Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Fri, 10 Sep 2010 15:32:56 +0200 Subject: Adjust a few things to make the unit tests happy again. --- nova/endpoint/cloud.py | 2 +- nova/tests/virt_unittest.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index 3334f09af..930274aed 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -348,7 +348,7 @@ class CloudController(object): @rbac.allow('netadmin') def delete_security_group(self, context, group_name, **kwargs): security_group = db.security_group_get_by_user_and_name(context, context.user.id, group_name) - security_group.delete() + db.security_group_destroy(context, security_group.id) return True @rbac.allow('projectmanager', 'sysadmin') diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py index b8dcec12b..1f573c463 100644 --- a/nova/tests/virt_unittest.py +++ b/nova/tests/virt_unittest.py @@ -16,6 +16,7 @@ from xml.dom.minidom import parseString +from nova import db from nova import flags from nova import test from nova.endpoint import cloud @@ -91,7 +92,10 @@ class NWFilterTestCase(test.TrialTestCase): cidr_ip='0.0.0.0/0') fw = libvirt_conn.NWFilterFirewall() - xml = fw.security_group_to_nwfilter_xml(1) + + security_group = db.security_group_get_by_user_and_name({}, 'fake', 'testgroup') + + xml = fw.security_group_to_nwfilter_xml(security_group.id) dom = parseString(xml) self.assertEqual(dom.firstChild.tagName, 'filter') -- cgit From e53676bb32b70ff01ca27c310e558b651590be3d Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Fri, 10 Sep 2010 15:26:13 -0700 Subject: Refactored to security group api to support projects --- nova/auth/manager.py | 2 -- nova/db/api.py | 34 ++++++++++-------- nova/db/sqlalchemy/api.py | 76 ++++++++++++++++++++++++---------------- nova/db/sqlalchemy/models.py | 22 ++++++------ nova/endpoint/cloud.py | 83 +++++++++++++++++++++++++++----------------- nova/tests/api_unittest.py | 1 + nova/tests/virt_unittest.py | 4 ++- nova/virt/libvirt_conn.py | 2 +- 8 files changed, 133 insertions(+), 91 deletions(-) diff --git a/nova/auth/manager.py b/nova/auth/manager.py index 281e2d8f0..34aa73bf6 100644 --- a/nova/auth/manager.py +++ b/nova/auth/manager.py @@ -649,8 +649,6 @@ class AuthManager(object): def delete_user(self, user): """Deletes a user""" with self.driver() as drv: - for security_group in db.security_group_get_by_user(context = {}, user_id=User.safe_id(user)): - db.security_group_destroy({}, security_group.id) drv.delete_user(User.safe_id(user)) def generate_key_pair(self, user, key_name): diff --git a/nova/db/api.py b/nova/db/api.py index 2bcf0bd2b..cdbd15486 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -442,33 +442,39 @@ def volume_update(context, volume_id, values): """ return IMPL.volume_update(context, volume_id, values) + #################### -def security_group_create(context, values): - """Create a new security group""" - return IMPL.security_group_create(context, values) +def security_group_get_all(context): + """Get all security groups""" + return IMPL.security_group_get_all(context) -def security_group_get_by_id(context, security_group_id): +def security_group_get(context, security_group_id): """Get security group by its internal id""" - return IMPL.security_group_get_by_id(context, security_group_id) + return IMPL.security_group_get(context, security_group_id) -def security_group_get_by_instance(context, instance_id): - """Get security groups to which the instance is assigned""" - return IMPL.security_group_get_by_instance(context, instance_id) +def security_group_get_by_name(context, project_id, group_name): + """Returns a security group with the specified name from a project""" + return IMPL.securitygroup_get_by_name(context, project_id, group_name) -def security_group_get_by_user(context, user_id): - """Get security groups owned by the given user""" - return IMPL.security_group_get_by_user(context, user_id) +def security_group_get_by_project(context, project_id): + """Get all security groups belonging to a project""" + return IMPL.securitygroup_get_by_project(context, project_id) -def security_group_get_by_user_and_name(context, user_id, name): - """Get user's named security group""" - return IMPL.security_group_get_by_user_and_name(context, user_id, name) +def security_group_get_by_instance(context, instance_id): + """Get security groups to which the instance is assigned""" + return IMPL.security_group_get_by_instance(context, instance_id) + +def security_group_create(context, values): + """Create a new security group""" + return IMPL.security_group_create(context, values) + def security_group_destroy(context, security_group_id): """Deletes a security group""" diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 1c95efd83..61d733940 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -616,20 +616,45 @@ def volume_update(_context, volume_id, values): ################### -def security_group_create(_context, values): - security_group_ref = models.SecurityGroup() - for (key, value) in values.iteritems(): - security_group_ref[key] = value - security_group_ref.save() - return security_group_ref +def security_group_get_all(_context): + session = get_session() + return session.query(models.SecurityGroup + ).options(eagerload('rules') + ).filter_by(deleted=False + ).all() -def security_group_get_by_id(_context, security_group_id): +def security_group_get(_context, security_group_id): session = get_session() with session.begin(): return session.query(models.SecurityGroup + ).options(eagerload('rules') + ).get(security_group_id) + + +def securitygroup_get_by_name(context, project_id, group_name): + session = get_session() + group_ref = session.query(models.SecurityGroup ).options(eagerload('rules') - ).get(security_group_id) + ).filter_by(project_id=project_id + ).filter_by(name=group_name + ).filter_by(deleted=False + ).first() + if not group_ref: + raise exception.NotFound( + 'No security group named %s for project: %s' \ + % (group_name, project_id)) + + return group_ref + + +def securitygroup_get_by_project(_context, project_id): + session = get_session() + return session.query(models.SecurityGroup + ).options(eagerload('rules') + ).filter_by(project_id=project_id + ).filter_by(deleted=False + ).all() def security_group_get_by_instance(_context, instance_id): @@ -638,34 +663,27 @@ def security_group_get_by_instance(_context, instance_id): return session.query(models.Instance ).get(instance_id ).security_groups \ - .all() + .filter_by(deleted=False + ).all() -def security_group_get_by_user(_context, user_id): - session = get_session() - with session.begin(): - return session.query(models.SecurityGroup - ).filter_by(user_id=user_id - ).filter_by(deleted=False - ).options(eagerload('rules') - ).all() +def security_group_create(_context, values): + security_group_ref = models.SecurityGroup() + for (key, value) in values.iteritems(): + security_group_ref[key] = value + security_group_ref.save() + return security_group_ref -def security_group_get_by_user_and_name(_context, user_id, name): - session = get_session() - with session.begin(): - return session.query(models.SecurityGroup - ).filter_by(user_id=user_id - ).filter_by(name=name - ).filter_by(deleted=False - ).options(eagerload('rules') - ).one() def security_group_destroy(_context, security_group_id): session = get_session() with session.begin(): - security_group = session.query(models.SecurityGroup - ).get(security_group_id) - security_group.delete(session=session) + # TODO(vish): do we have to use sql here? + session.execute('update security_group set deleted=1 where id=:id', + {'id': security_group_id}) + session.execute('update security_group_rule set deleted=1 ' + 'where group_id=:id', + {'id': security_group_id}) ################### diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index f27520aa8..3c4b9ddd7 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -306,26 +306,23 @@ class SecurityGroup(BASE, NovaBase): class SecurityGroupIngressRule(BASE, NovaBase): """Represents a rule in a security group""" - __tablename__ = 'security_group_rules' + __tablename__ = 'security_group_rule' id = Column(Integer, primary_key=True) - parent_group_id = Column(Integer, ForeignKey('security_group.id')) - parent_group = relationship("SecurityGroup", backref="rules", foreign_keys=parent_group_id, - primaryjoin=parent_group_id==SecurityGroup.id) + group_id = Column(Integer, ForeignKey('security_group.id')) + group = relationship("SecurityGroup", backref="rules", + foreign_keys=group_id, + primaryjoin=group_id==SecurityGroup.id) protocol = Column(String(5)) # "tcp", "udp", or "icmp" from_port = Column(Integer) to_port = Column(Integer) + cidr = Column(String(255)) # Note: This is not the parent SecurityGroup. It's SecurityGroup we're # granting access for. - group_id = Column(Integer, ForeignKey('security_group.id')) - - @property - def user(self): - return auth.manager.AuthManager().get_user(self.user_id) + source_group_id = Column(Integer, ForeignKey('security_group.id')) - cidr = Column(String(255)) class Network(BASE, NovaBase): """Represents a network""" @@ -430,8 +427,9 @@ class FloatingIp(BASE, NovaBase): def register_models(): """Register Models and create metadata""" from sqlalchemy import create_engine - models = (Service, Instance, Volume, ExportDevice, - FixedIp, FloatingIp, Network, NetworkIndex) # , Image, Host) + models = (Service, Instance, Volume, ExportDevice, FixedIp, FloatingIp, + Network, NetworkIndex, SecurityGroup, SecurityGroupIngressRule) + # , Image, Host engine = create_engine(FLAGS.sql_connection, echo=False) for model in models: model.metadata.create_all(engine) diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index 930274aed..4cb09bedb 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -216,7 +216,8 @@ class CloudController(object): @rbac.allow('all') def describe_security_groups(self, context, **kwargs): groups = [] - for group in db.security_group_get_by_user(context, context.user.id): + for group in db.security_group_get_by_project(context, + context.project.id): group_dict = {} group_dict['groupDescription'] = group.description group_dict['groupName'] = group.name @@ -229,10 +230,11 @@ class CloudController(object): rule_dict['toPort'] = rule.to_port rule_dict['groups'] = [] rule_dict['ipRanges'] = [] + import pdb; pdb.set_trace() if rule.group_id: - foreign_group = db.security_group_get_by_id({}, rule.group_id) - rule_dict['groups'] += [ { 'groupName': foreign_group.name, - 'userId': foreign_group.user_id } ] + source_group = db.security_group_get(context, rule.group_id) + rule_dict['groups'] += [ { 'groupName': source_group.name, + 'userId': source_group.user_id } ] else: rule_dict['ipRanges'] += [ { 'cidrIp': rule.cidr } ] group_dict['ipPermissions'] += [ rule_dict ] @@ -258,23 +260,22 @@ class CloudController(object): user_id=None, source_security_group_name=None, source_security_group_owner_id=None): - security_group = db.security_group_get_by_user_and_name(context, - context.user.id, - group_name) + security_group = db.security_group_get_by_name(context, + context.project.id, + group_name) criteria = {} if source_security_group_name: - if source_security_group_owner_id: - other_user_id = source_security_group_owner_id - else: - other_user_id = context.user.id - - foreign_security_group = \ - db.security_group_get_by_user_and_name(context, - other_user_id, - source_security_group_name) - criteria['group_id'] = foreign_security_group.id + source_project_id = self._get_source_project_id(context, + source_security_group_owner_id) + + source_security_group = \ + db.security_group_get_by_name(context, + source_project_id, + source_security_group_name) + + criteria['group_id'] = source_security_group.id elif cidr_ip: criteria['cidr'] = cidr_ip else: @@ -303,22 +304,20 @@ class CloudController(object): ip_protocol=None, cidr_ip=None, source_security_group_name=None, source_security_group_owner_id=None): - security_group = db.security_group_get_by_user_and_name(context, - context.user.id, - group_name) - values = { 'parent_group_id' : security_group.id } + security_group = db.security_group_get_by_name(context, + context.project.id, + group_name) + values = { 'group_id' : security_group.id } if source_security_group_name: - if source_security_group_owner_id: - other_user_id = source_security_group_owner_id - else: - other_user_id = context.user.id - - foreign_security_group = \ - db.security_group_get_by_user_and_name(context, - other_user_id, - source_security_group_name) - values['group_id'] = foreign_security_group.id + source_project_id = self._get_source_project_id(context, + source_security_group_owner_id) + + source_security_group = \ + db.security_group_get_by_name(context, + source_project_id, + source_security_group_name) + values['source_group_id'] = source_security_group.id elif cidr_ip: values['cidr'] = cidr_ip else: @@ -336,18 +335,38 @@ class CloudController(object): security_group_rule = db.security_group_rule_create(context, values) return True + + def _get_source_project_id(self, context, source_security_group_owner_id): + if source_security_group_owner_id: + # Parse user:project for source group. + source_parts = source_security_group_owner_id.split(':') + + # If no project name specified, assume it's same as user name. + # Since we're looking up by project name, the user name is not + # used here. It's only read for EC2 API compatibility. + if len(source_parts) == 2: + source_project_id = parts[1] + else: + source_project_id = parts[0] + else: + source_project_id = context.project.id + + return source_project_id @rbac.allow('netadmin') def create_security_group(self, context, group_name, group_description): db.security_group_create(context, values = { 'user_id' : context.user.id, + 'project_id': context.project.id, 'name': group_name, 'description': group_description }) return True @rbac.allow('netadmin') def delete_security_group(self, context, group_name, **kwargs): - security_group = db.security_group_get_by_user_and_name(context, context.user.id, group_name) + security_group = db.security_group_get_by_name(context, + context.project.id, + group_name) db.security_group_destroy(context, security_group.id) return True diff --git a/nova/tests/api_unittest.py b/nova/tests/api_unittest.py index 7e914e6f5..55b7cb4d8 100644 --- a/nova/tests/api_unittest.py +++ b/nova/tests/api_unittest.py @@ -304,6 +304,7 @@ class ApiEc2TestCase(test.BaseTestCase): # be good enough for that. for group in rv: if group.name == security_group_name: + import pdb; pdb.set_trace() self.assertEquals(len(group.rules), 1) self.assertEquals(int(group.rules[0].from_port), 80) self.assertEquals(int(group.rules[0].to_port), 81) diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py index 1f573c463..dceced3a9 100644 --- a/nova/tests/virt_unittest.py +++ b/nova/tests/virt_unittest.py @@ -86,6 +86,8 @@ class NWFilterTestCase(test.TrialTestCase): context.user = FakeContext() context.user.id = 'fake' context.user.is_superuser = lambda:True + context.project = FakeContext() + context.project.id = 'fake' cloud_controller.create_security_group(context, 'testgroup', 'test group description') cloud_controller.authorize_security_group_ingress(context, 'testgroup', from_port='80', to_port='81', ip_protocol='tcp', @@ -93,7 +95,7 @@ class NWFilterTestCase(test.TrialTestCase): fw = libvirt_conn.NWFilterFirewall() - security_group = db.security_group_get_by_user_and_name({}, 'fake', 'testgroup') + security_group = db.security_group_get_by_name({}, 'fake', 'testgroup') xml = fw.security_group_to_nwfilter_xml(security_group.id) diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index 6f708bb80..09c94577c 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -492,7 +492,7 @@ class NWFilterFirewall(object): ''' def security_group_to_nwfilter_xml(self, security_group_id): - security_group = db.security_group_get_by_id({}, security_group_id) + security_group = db.security_group_get({}, security_group_id) rule_xml = "" for rule in security_group.rules: rule_xml += "" -- cgit From 60b6b06d15ed620cf990db10277c4126b686de80 Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Fri, 10 Sep 2010 19:19:08 -0700 Subject: Finished security group / project refactor --- nova/auth/manager.py | 20 ++++++++++++++++---- nova/db/sqlalchemy/api.py | 2 +- nova/db/sqlalchemy/models.py | 12 ++++++------ nova/endpoint/cloud.py | 5 ++--- nova/tests/api_unittest.py | 2 +- 5 files changed, 26 insertions(+), 15 deletions(-) diff --git a/nova/auth/manager.py b/nova/auth/manager.py index 34aa73bf6..48d314ae6 100644 --- a/nova/auth/manager.py +++ b/nova/auth/manager.py @@ -531,6 +531,12 @@ class AuthManager(object): except: drv.delete_project(project.id) raise + + db.security_group_create(context={}, + values={ 'name': 'default', + 'description': 'default', + 'user_id': manager_user, + 'project_id': project.id }) return project def add_to_project(self, user, project): @@ -586,6 +592,16 @@ class AuthManager(object): except: logging.exception('Could not destroy network for %s', project) + try: + project_id = Project.safe_id(project) + groups = db.security_group_get_by_project(context={}, + project_id=project_id) + for group in groups: + db.security_group_destroy({}, group.id) + except: + logging.exception('Could not destroy security groups for %s', + project) + with self.driver() as drv: drv.delete_project(Project.safe_id(project)) @@ -640,10 +656,6 @@ class AuthManager(object): with self.driver() as drv: user_dict = drv.create_user(name, access, secret, admin) if user_dict: - db.security_group_create(context={}, - values={ 'name' : 'default', - 'description' : 'default', - 'user_id' : name }) return User(**user_dict) def delete_user(self, user): diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 61d733940..f3d4b68c4 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -681,7 +681,7 @@ def security_group_destroy(_context, security_group_id): # TODO(vish): do we have to use sql here? session.execute('update security_group set deleted=1 where id=:id', {'id': security_group_id}) - session.execute('update security_group_rule set deleted=1 ' + session.execute('update security_group_rules set deleted=1 ' 'where group_id=:id', {'id': security_group_id}) diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 3c4b9ddd7..e79a0415b 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -306,13 +306,13 @@ class SecurityGroup(BASE, NovaBase): class SecurityGroupIngressRule(BASE, NovaBase): """Represents a rule in a security group""" - __tablename__ = 'security_group_rule' + __tablename__ = 'security_group_rules' id = Column(Integer, primary_key=True) - group_id = Column(Integer, ForeignKey('security_group.id')) - group = relationship("SecurityGroup", backref="rules", - foreign_keys=group_id, - primaryjoin=group_id==SecurityGroup.id) + parent_group_id = Column(Integer, ForeignKey('security_group.id')) + parent_group = relationship("SecurityGroup", backref="rules", + foreign_keys=parent_group_id, + primaryjoin=parent_group_id==SecurityGroup.id) protocol = Column(String(5)) # "tcp", "udp", or "icmp" from_port = Column(Integer) @@ -321,7 +321,7 @@ class SecurityGroupIngressRule(BASE, NovaBase): # Note: This is not the parent SecurityGroup. It's SecurityGroup we're # granting access for. - source_group_id = Column(Integer, ForeignKey('security_group.id')) + group_id = Column(Integer, ForeignKey('security_group.id')) class Network(BASE, NovaBase): diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index 4cb09bedb..a26f90753 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -230,7 +230,6 @@ class CloudController(object): rule_dict['toPort'] = rule.to_port rule_dict['groups'] = [] rule_dict['ipRanges'] = [] - import pdb; pdb.set_trace() if rule.group_id: source_group = db.security_group_get(context, rule.group_id) rule_dict['groups'] += [ { 'groupName': source_group.name, @@ -307,7 +306,7 @@ class CloudController(object): security_group = db.security_group_get_by_name(context, context.project.id, group_name) - values = { 'group_id' : security_group.id } + values = { 'parent_group_id' : security_group.id } if source_security_group_name: source_project_id = self._get_source_project_id(context, @@ -317,7 +316,7 @@ class CloudController(object): db.security_group_get_by_name(context, source_project_id, source_security_group_name) - values['source_group_id'] = source_security_group.id + values['group_id'] = source_security_group.id elif cidr_ip: values['cidr'] = cidr_ip else: diff --git a/nova/tests/api_unittest.py b/nova/tests/api_unittest.py index 55b7cb4d8..70669206c 100644 --- a/nova/tests/api_unittest.py +++ b/nova/tests/api_unittest.py @@ -304,7 +304,6 @@ class ApiEc2TestCase(test.BaseTestCase): # be good enough for that. for group in rv: if group.name == security_group_name: - import pdb; pdb.set_trace() self.assertEquals(len(group.rules), 1) self.assertEquals(int(group.rules[0].from_port), 80) self.assertEquals(int(group.rules[0].to_port), 81) @@ -369,6 +368,7 @@ class ApiEc2TestCase(test.BaseTestCase): self.mox.ReplayAll() rv = self.ec2.get_all_security_groups() + # I don't bother checkng that we actually find it here, # because the create/delete unit test further up should # be good enough for that. -- cgit From edccf3f6cf95a4869d7900032a5a6c8eaa65cd18 Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Sat, 11 Sep 2010 02:35:25 +0000 Subject: Fixed manager_user reference in create_project --- nova/auth/manager.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/nova/auth/manager.py b/nova/auth/manager.py index 48d314ae6..5529515a6 100644 --- a/nova/auth/manager.py +++ b/nova/auth/manager.py @@ -531,12 +531,12 @@ class AuthManager(object): except: drv.delete_project(project.id) raise - - db.security_group_create(context={}, - values={ 'name': 'default', - 'description': 'default', - 'user_id': manager_user, - 'project_id': project.id }) + + values = {'name': 'default', + 'description': 'default', + 'user_id': User.safe_id(manager_user), + 'project_id': project.id} + db.security_group_create({}, values) return project def add_to_project(self, user, project): -- cgit From f24f20948cf7e6cc0e14c2b1fc41a61d8d2fa34c Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Sat, 11 Sep 2010 11:19:22 -0700 Subject: Security Group API layer cleanup --- nova/db/api.py | 5 +++ nova/db/sqlalchemy/api.py | 11 +++++++ nova/endpoint/cloud.py | 84 ++++++++++++++++++++++++----------------------- 3 files changed, 59 insertions(+), 41 deletions(-) diff --git a/nova/db/api.py b/nova/db/api.py index cdbd15486..cf39438c2 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -471,6 +471,11 @@ def security_group_get_by_instance(context, instance_id): return IMPL.security_group_get_by_instance(context, instance_id) +def securitygroup_exists(context, project_id, group_name): + """Indicates if a group name exists in a project""" + return IMPL.securitygroup_exists(context, project_id, group_name) + + def security_group_create(context, values): """Create a new security group""" return IMPL.security_group_create(context, values) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index f3d4b68c4..513b47bc9 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -667,8 +667,19 @@ def security_group_get_by_instance(_context, instance_id): ).all() +def securitygroup_exists(_context, project_id, group_name): + try: + group = securitygroup_get_by_name(_context, project_id, group_name) + return group != None + except exception.NotFound: + return False + + def security_group_create(_context, values): security_group_ref = models.SecurityGroup() + # FIXME(devcamcar): Unless I do this, rules fails with lazy load exception + # once save() is called. This will get cleaned up in next orm pass. + security_group_ref.rules for (key, value) in values.iteritems(): security_group_ref[key] = value security_group_ref.save() diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index a26f90753..7408e02e9 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -214,43 +214,40 @@ class CloudController(object): return True @rbac.allow('all') - def describe_security_groups(self, context, **kwargs): - groups = [] - for group in db.security_group_get_by_project(context, - context.project.id): - group_dict = {} - group_dict['groupDescription'] = group.description - group_dict['groupName'] = group.name - group_dict['ownerId'] = context.user.id - group_dict['ipPermissions'] = [] - for rule in group.rules: - rule_dict = {} - rule_dict['ipProtocol'] = rule.protocol - rule_dict['fromPort'] = rule.from_port - rule_dict['toPort'] = rule.to_port - rule_dict['groups'] = [] - rule_dict['ipRanges'] = [] - if rule.group_id: - source_group = db.security_group_get(context, rule.group_id) - rule_dict['groups'] += [ { 'groupName': source_group.name, - 'userId': source_group.user_id } ] - else: - rule_dict['ipRanges'] += [ { 'cidrIp': rule.cidr } ] - group_dict['ipPermissions'] += [ rule_dict ] - groups += [ group_dict ] + def describe_security_groups(self, context, group_name=None, **kwargs): + if context.user.is_admin(): + groups = db.security_group_get_all(context) + else: + groups = db.security_group_get_by_project(context, + context.project.id) + groups = [self._format_security_group(context, g) for g in groups] + if not group_name is None: + groups = [g for g in groups if g.name in group_name] return {'securityGroupInfo': groups } -# -# [{ 'groupDescription': group.description, -# 'groupName' : group.name, -# 'ownerId': context.user.id, -# 'ipPermissions' : [ -# { 'ipProtocol' : rule.protocol, -# 'fromPort' : rule.from_port, -# 'toPort' : rule.to_port, -# 'ipRanges' : [ { 'cidrIp' : rule.cidr } ] } for rule in group.rules ] } for group in \ -# -# return groups + + def _format_security_group(self, context, group): + g = {} + g['groupDescription'] = group.description + g['groupName'] = group.name + g['ownerId'] = context.user.id + g['ipPermissions'] = [] + for rule in group.rules: + r = {} + r['ipProtocol'] = rule.protocol + r['fromPort'] = rule.from_port + r['toPort'] = rule.to_port + r['groups'] = [] + r['ipRanges'] = [] + if rule.group_id: + source_group = db.security_group_get(context, rule.group_id) + r['groups'] += [{'groupName': source_group.name, + 'userId': source_group.user_id}] + else: + r['ipRanges'] += [{'cidrIp': rule.cidr}] + g['ipPermissions'] += [r] + return g + @rbac.allow('netadmin') def revoke_security_group_ingress(self, context, group_name, @@ -354,12 +351,17 @@ class CloudController(object): @rbac.allow('netadmin') def create_security_group(self, context, group_name, group_description): - db.security_group_create(context, - values = { 'user_id' : context.user.id, - 'project_id': context.project.id, - 'name': group_name, - 'description': group_description }) - return True + if db.securitygroup_exists(context, context.project.id, group_name): + raise exception.ApiError('group %s already exists' % group_name) + + group = {'user_id' : context.user.id, + 'project_id': context.project.id, + 'name': group_name, + 'description': group_description} + group_ref = db.security_group_create(context, group) + + return {'securityGroupSet': [self._format_security_group(context, + group_ref)]} @rbac.allow('netadmin') def delete_security_group(self, context, group_name, **kwargs): -- cgit From 2a782110bc51f147bdb35264445badac3b3e8e65 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 13 Sep 2010 11:45:28 +0200 Subject: Filters all get defined when running an instance. --- nova/db/api.py | 5 ++ nova/db/sqlalchemy/api.py | 16 +++++- nova/db/sqlalchemy/models.py | 1 - nova/tests/virt_unittest.py | 101 +++++++++++++++++++++++++++++++----- nova/virt/libvirt.qemu.xml.template | 2 +- nova/virt/libvirt.uml.xml.template | 2 +- nova/virt/libvirt_conn.py | 74 ++++++++++++++++++++++++-- 7 files changed, 179 insertions(+), 22 deletions(-) diff --git a/nova/db/api.py b/nova/db/api.py index cf39438c2..1d10b1987 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -244,6 +244,11 @@ def instance_update(context, instance_id, values): return IMPL.instance_update(context, instance_id, values) +def instance_add_security_group(context, instance_id, security_group_id): + """Associate the given security group with the given instance""" + return IMPL.instance_add_security_group(context, instance_id, security_group_id) + + #################### diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 513b47bc9..11779e30c 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -238,7 +238,10 @@ def instance_destroy(_context, instance_id): def instance_get(_context, instance_id): - return models.Instance.find(instance_id) + session = get_session() + return session.query(models.Instance + ).options(eagerload('security_groups') + ).get(instance_id) def instance_get_all(_context): @@ -317,6 +320,17 @@ def instance_update(_context, instance_id, values): instance_ref.save(session=session) +def instance_add_security_group(context, instance_id, security_group_id): + """Associate the given security group with the given instance""" + session = get_session() + with session.begin(): + instance_ref = models.Instance.find(instance_id, session=session) + security_group_ref = models.SecurityGroup.find(security_group_id, + session=session) + instance_ref.security_groups += [security_group_ref] + instance_ref.save(session=session) + + ################### diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index e79a0415b..424906c1f 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -215,7 +215,6 @@ class Instance(BASE, NovaBase): launch_index = Column(Integer) key_name = Column(String(255)) key_data = Column(Text) - security_group = Column(String(255)) state = Column(Integer) state_description = Column(String(255)) diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py index dceced3a9..a61849a72 100644 --- a/nova/tests/virt_unittest.py +++ b/nova/tests/virt_unittest.py @@ -77,27 +77,39 @@ class LibvirtConnTestCase(test.TrialTestCase): class NWFilterTestCase(test.TrialTestCase): - def test_stuff(self): - cloud_controller = cloud.CloudController() - class FakeContext(object): + def setUp(self): + super(NWFilterTestCase, self).setUp() + + class Mock(object): pass - context = FakeContext() - context.user = FakeContext() - context.user.id = 'fake' - context.user.is_superuser = lambda:True - context.project = FakeContext() - context.project.id = 'fake' - cloud_controller.create_security_group(context, 'testgroup', 'test group description') - cloud_controller.authorize_security_group_ingress(context, 'testgroup', from_port='80', - to_port='81', ip_protocol='tcp', + self.context = Mock() + self.context.user = Mock() + self.context.user.id = 'fake' + self.context.user.is_superuser = lambda:True + self.context.project = Mock() + self.context.project.id = 'fake' + + self.fake_libvirt_connection = Mock() + + self.fw = libvirt_conn.NWFilterFirewall(self.fake_libvirt_connection) + + def test_cidr_rule_nwfilter_xml(self): + cloud_controller = cloud.CloudController() + cloud_controller.create_security_group(self.context, + 'testgroup', + 'test group description') + cloud_controller.authorize_security_group_ingress(self.context, + 'testgroup', + from_port='80', + to_port='81', + ip_protocol='tcp', cidr_ip='0.0.0.0/0') - fw = libvirt_conn.NWFilterFirewall() security_group = db.security_group_get_by_name({}, 'fake', 'testgroup') - xml = fw.security_group_to_nwfilter_xml(security_group.id) + xml = self.fw.security_group_to_nwfilter_xml(security_group.id) dom = parseString(xml) self.assertEqual(dom.firstChild.tagName, 'filter') @@ -117,3 +129,64 @@ class NWFilterTestCase(test.TrialTestCase): self.assertEqual(ip_conditions[0].getAttribute('protocol'), 'tcp') self.assertEqual(ip_conditions[0].getAttribute('dstportstart'), '80') self.assertEqual(ip_conditions[0].getAttribute('dstportend'), '81') + + + self.teardown_security_group() + + def teardown_security_group(self): + cloud_controller = cloud.CloudController() + cloud_controller.delete_security_group(self.context, 'testgroup') + + + def setup_and_return_security_group(self): + cloud_controller = cloud.CloudController() + cloud_controller.create_security_group(self.context, + 'testgroup', + 'test group description') + cloud_controller.authorize_security_group_ingress(self.context, + 'testgroup', + from_port='80', + to_port='81', + ip_protocol='tcp', + cidr_ip='0.0.0.0/0') + + return db.security_group_get_by_name({}, 'fake', 'testgroup') + + def test_creates_base_rule_first(self): + self.defined_filters = [] + self.fake_libvirt_connection.listNWFilter = lambda:self.defined_filters + self.base_filter_defined = False + self.i = 0 + def _filterDefineXMLMock(xml): + dom = parseString(xml) + name = dom.firstChild.getAttribute('name') + if self.i == 0: + self.assertEqual(dom.firstChild.getAttribute('name'), + 'nova-base-filter') + elif self.i == 1: + self.assertTrue(name.startswith('nova-secgroup-'), + 'unexpected name: %s' % name) + elif self.i == 2: + self.assertTrue(name.startswith('nova-instance-'), + 'unexpected name: %s' % name) + + self.defined_filters.append(name) + self.i += 1 + return True + + def _ensure_all_called(_): + self.assertEqual(self.i, 3) + + self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock + + inst_id = db.instance_create({}, { 'user_id' : 'fake', 'project_id' : 'fake' }) + security_group = self.setup_and_return_security_group() + + db.instance_add_security_group({}, inst_id, security_group.id) + instance = db.instance_get({}, inst_id) + + d = self.fw.setup_nwfilters_for_instance(instance) + d.addCallback(_ensure_all_called) + d.addCallback(lambda _:self.teardown_security_group()) + + return d diff --git a/nova/virt/libvirt.qemu.xml.template b/nova/virt/libvirt.qemu.xml.template index 5d3755b65..cbf501f9c 100644 --- a/nova/virt/libvirt.qemu.xml.template +++ b/nova/virt/libvirt.qemu.xml.template @@ -20,7 +20,7 @@ - + diff --git a/nova/virt/libvirt.uml.xml.template b/nova/virt/libvirt.uml.xml.template index 1000da5ab..2030b87d2 100644 --- a/nova/virt/libvirt.uml.xml.template +++ b/nova/virt/libvirt.uml.xml.template @@ -14,7 +14,7 @@ - + diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index 09c94577c..89ede1d1a 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -27,6 +27,7 @@ import shutil from twisted.internet import defer from twisted.internet import task +from twisted.internet import threads from nova import db from nova import exception @@ -216,6 +217,7 @@ class LibvirtConnection(object): instance['id'], power_state.NOSTATE, 'launching') + yield NWFilterFirewall(self._conn).setup_nwfilters_for_instance(instance) yield self._create_image(instance, xml) yield self._conn.createXML(xml, 0) # TODO(termie): this should actually register @@ -442,7 +444,6 @@ class LibvirtConnection(object): domain = self._conn.lookupByName(instance_name) return domain.interfaceStats(interface) - class NWFilterFirewall(object): """ This class implements a network filtering mechanism versatile @@ -467,6 +468,14 @@ class NWFilterFirewall(object): the nodes with instances in the security group to refresh the security group. + Each instance has its own NWFilter, which references the above + mentioned security group NWFilters. This was done because + interfaces can only reference one filter while filters can + reference multiple other filters. This has the added benefit of + actually being able to add and remove security groups from an + instance at run time. This functionality is not exposed anywhere, + though. + Outstanding questions: The name is unique, so would there be any good reason to sync @@ -477,12 +486,14 @@ class NWFilterFirewall(object): redundancy. """ - def __init__(self): - pass + def __init__(self, get_connection): + self._conn = get_connection + def nova_base_filter(self): return ''' 26717364-50cf-42d1-8185-29bf893ab110 + @@ -491,6 +502,60 @@ class NWFilterFirewall(object): ''' + + def setup_nwfilters_for_instance(self, instance): + """ + Creates an NWFilter for the given instance. In the process, + it makes sure the filters for the security groups as well as + the base filter are all in place. + """ + + d = self.ensure_base_filter() + + nwfilter_xml = ("\n" + + " \n" + ) % instance['name'] + + for security_group in instance.security_groups: + d.addCallback(lambda _:self.ensure_security_group_filter(security_group.id)) + + nwfilter_xml += (" \n" + ) % security_group.id + nwfilter_xml += "" + + d.addCallback(lambda _: threads.deferToThread( + self._conn.nwfilterDefineXML, + nwfilter_xml)) + return d + + + def _nwfilter_name_for_security_group(self, security_group_id): + return 'nova-secgroup-%d' % (security_group_id,) + + + def ensure_filter(self, name, xml_generator): + def _already_exists_check(filterlist, filter): + return filter in filterlist + def _define_if_not_exists(exists, xml_generator): + if not exists: + xml = xml_generator() + return threads.deferToThread(self._conn.nwfilterDefineXML, xml) + d = threads.deferToThread(self._conn.listNWFilter) + d.addCallback(_already_exists_check, name) + d.addCallback(_define_if_not_exists, xml_generator) + return d + + + def ensure_base_filter(self): + return self.ensure_filter('nova-base-filter', self.nova_base_filter) + + + def ensure_security_group_filter(self, security_group_id): + return self.ensure_filter( + self._nwfilter_name_for_security_group(security_group_id), + lambda:self.security_group_to_nwfilter_xml(security_group_id)) + + def security_group_to_nwfilter_xml(self, security_group_id): security_group = db.security_group_get({}, security_group_id) rule_xml = "" @@ -498,7 +563,8 @@ class NWFilterFirewall(object): rule_xml += "" if rule.cidr: rule_xml += ("") % \ + "dstportstart='%s' dstportend='%s' />" + + "priority='900'\n") % \ (rule.cidr, rule.protocol, rule.from_port, rule.to_port) rule_xml += "" -- cgit From 077fc783c4f94de427da98818d262aeb09a31044 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 13 Sep 2010 12:04:06 +0200 Subject: (Untested) Make changes to security group rules propagate to the relevant compute nodes. --- nova/compute/manager.py | 5 +++++ nova/endpoint/cloud.py | 20 +++++++++++++++++--- nova/virt/libvirt_conn.py | 37 ++++++++++++++++++++++++------------- 3 files changed, 46 insertions(+), 16 deletions(-) diff --git a/nova/compute/manager.py b/nova/compute/manager.py index 5f7a94106..a00fd9baa 100644 --- a/nova/compute/manager.py +++ b/nova/compute/manager.py @@ -61,6 +61,11 @@ class ComputeManager(manager.Manager): state = self.driver.get_info(instance_ref.name)['state'] self.db.instance_set_state(context, instance_id, state) + @defer.inlineCallbacks + @exception.wrap_exception + def refresh_security_group(self, context, security_group_id, **_kwargs): + self.driver.refresh_security_group(security_group_id) + @defer.inlineCallbacks @exception.wrap_exception def run_instance(self, context, instance_id, **_kwargs): diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index 7408e02e9..1403a62f6 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -93,6 +93,14 @@ class CloudController(object): result[instance['key_name']] = [line] return result + def _refresh_security_group(self, security_group): + nodes = set([instance.host for instance in security_group.instances]) + for node in nodes: + rpc.call('%s.%s' % (FLAGS.compute_topic, node), + { "method": "refresh_security_group", + "args": { "context": None, + "security_group_id": security_group.id}}) + def get_metadata(self, address): instance_ref = db.fixed_ip_get_instance(None, address) if instance_ref is None: @@ -265,12 +273,12 @@ class CloudController(object): if source_security_group_name: source_project_id = self._get_source_project_id(context, source_security_group_owner_id) - + source_security_group = \ db.security_group_get_by_name(context, source_project_id, source_security_group_name) - + criteria['group_id'] = source_security_group.id elif cidr_ip: criteria['cidr'] = cidr_ip @@ -292,6 +300,9 @@ class CloudController(object): break # If we make it here, we have a match db.security_group_rule_destroy(context, rule.id) + + self._refresh_security_group(security_group) + return True @rbac.allow('netadmin') @@ -330,8 +341,11 @@ class CloudController(object): return None security_group_rule = db.security_group_rule_create(context, values) + + self._refresh_security_group(security_group) + return True - + def _get_source_project_id(self, context, source_security_group_owner_id): if source_security_group_owner_id: # Parse user:project for source group. diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index 89ede1d1a..a343267dc 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -444,6 +444,12 @@ class LibvirtConnection(object): domain = self._conn.lookupByName(instance_name) return domain.interfaceStats(interface) + + def refresh_security_group(self, security_group_id): + fw = self.NWFilterFirewall(self._conn) + fw.ensure_security_group_filter(security_group_id, override=True) + + class NWFilterFirewall(object): """ This class implements a network filtering mechanism versatile @@ -533,27 +539,32 @@ class NWFilterFirewall(object): return 'nova-secgroup-%d' % (security_group_id,) - def ensure_filter(self, name, xml_generator): - def _already_exists_check(filterlist, filter): - return filter in filterlist - def _define_if_not_exists(exists, xml_generator): - if not exists: - xml = xml_generator() - return threads.deferToThread(self._conn.nwfilterDefineXML, xml) - d = threads.deferToThread(self._conn.listNWFilter) - d.addCallback(_already_exists_check, name) + def define_filter(self, name, xml_generator, override=False): + if not override: + def _already_exists_check(filterlist, filter): + return filter in filterlist + def _define_if_not_exists(exists, xml_generator): + if not exists: + xml = xml_generator() + return threads.deferToThread(self._conn.nwfilterDefineXML, xml) + d = threads.deferToThread(self._conn.listNWFilter) + d.addCallback(_already_exists_check, name) + else: + # Pretend we looked it up and it wasn't defined + d = defer.succeed(False) d.addCallback(_define_if_not_exists, xml_generator) return d def ensure_base_filter(self): - return self.ensure_filter('nova-base-filter', self.nova_base_filter) + return self.define_filter('nova-base-filter', self.nova_base_filter) - def ensure_security_group_filter(self, security_group_id): - return self.ensure_filter( + def ensure_security_group_filter(self, security_group_id, override=False): + return self.define_filter( self._nwfilter_name_for_security_group(security_group_id), - lambda:self.security_group_to_nwfilter_xml(security_group_id)) + lambda:self.security_group_to_nwfilter_xml(security_group_id), + override=override) def security_group_to_nwfilter_xml(self, security_group_id): -- cgit From b15bde79b71e474d96674c8eae4108ac9c063731 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 13 Sep 2010 14:18:08 +0200 Subject: Fix call to listNWFilters --- nova/tests/virt_unittest.py | 2 +- nova/virt/libvirt_conn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py index a61849a72..8cafa778e 100644 --- a/nova/tests/virt_unittest.py +++ b/nova/tests/virt_unittest.py @@ -154,7 +154,7 @@ class NWFilterTestCase(test.TrialTestCase): def test_creates_base_rule_first(self): self.defined_filters = [] - self.fake_libvirt_connection.listNWFilter = lambda:self.defined_filters + self.fake_libvirt_connection.listNWFilters = lambda:self.defined_filters self.base_filter_defined = False self.i = 0 def _filterDefineXMLMock(xml): diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index a343267dc..2e1dfcefc 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -547,7 +547,7 @@ class NWFilterFirewall(object): if not exists: xml = xml_generator() return threads.deferToThread(self._conn.nwfilterDefineXML, xml) - d = threads.deferToThread(self._conn.listNWFilter) + d = threads.deferToThread(self._conn.listNWFilters) d.addCallback(_already_exists_check, name) else: # Pretend we looked it up and it wasn't defined -- cgit From 9c4b6612e65d548542b1bf37373200e4e6abc98d Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 13 Sep 2010 14:20:32 +0200 Subject: Correctly pass ip_address to templates. --- nova/virt/libvirt.qemu.xml.template | 4 ++-- nova/virt/libvirt.uml.xml.template | 4 ++-- nova/virt/libvirt_conn.py | 4 +++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/nova/virt/libvirt.qemu.xml.template b/nova/virt/libvirt.qemu.xml.template index cbf501f9c..d02aa9114 100644 --- a/nova/virt/libvirt.qemu.xml.template +++ b/nova/virt/libvirt.qemu.xml.template @@ -20,8 +20,8 @@ - - + + diff --git a/nova/virt/libvirt.uml.xml.template b/nova/virt/libvirt.uml.xml.template index 2030b87d2..bf3f2f86a 100644 --- a/nova/virt/libvirt.uml.xml.template +++ b/nova/virt/libvirt.uml.xml.template @@ -14,8 +14,8 @@ - - + + diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index 2e1dfcefc..00a80989f 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -322,6 +322,7 @@ class LibvirtConnection(object): network = db.project_get_network(None, instance['project_id']) # FIXME(vish): stick this in db instance_type = instance_types.INSTANCE_TYPES[instance['instance_type']] + ip_address = db.instance_get_fixed_address({}, instance['id']) xml_info = {'type': FLAGS.libvirt_type, 'name': instance['name'], 'basepath': os.path.join(FLAGS.instances_path, @@ -329,7 +330,8 @@ class LibvirtConnection(object): 'memory_kb': instance_type['memory_mb'] * 1024, 'vcpus': instance_type['vcpus'], 'bridge_name': network['bridge'], - 'mac_address': instance['mac_address']} + 'mac_address': instance['mac_address'], + 'ip_address': ip_address } libvirt_xml = self.libvirt_xml % xml_info logging.debug('instance %s: finished toXML method', instance['name']) -- cgit From 3fbbc09cbe2594e816803796e22ef39bcf02b029 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 14 Sep 2010 13:01:57 +0200 Subject: Multiple security group support. --- nova/endpoint/cloud.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index 1403a62f6..715470f30 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -279,7 +279,7 @@ class CloudController(object): source_project_id, source_security_group_name) - criteria['group_id'] = source_security_group.id + criteria['group_id'] = source_security_group elif cidr_ip: criteria['cidr'] = cidr_ip else: @@ -682,8 +682,16 @@ class CloudController(object): kwargs['key_name']) key_data = key_pair.public_key - # TODO: Get the real security group of launch in here - security_group = "default" + security_group_arg = kwargs.get('security_group', ["default"]) + if not type(security_group_arg) is list: + security_group_arg = [security_group_arg] + + security_groups = [] + for security_group_name in security_group_arg: + group = db.security_group_get_by_project(context, + context.project.id, + security_group_name) + security_groups.append(group) reservation_id = utils.generate_uid('r') base_options = {} @@ -697,7 +705,7 @@ class CloudController(object): base_options['project_id'] = context.project.id base_options['user_data'] = kwargs.get('user_data', '') base_options['instance_type'] = kwargs.get('instance_type', 'm1.small') - base_options['security_group'] = security_group + base_options['security_groups'] = security_groups for num in range(int(kwargs['max_count'])): inst_id = db.instance_create(context, base_options) -- cgit From 757088eb394552b0aaee61673b0af5094f01c356 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 14 Sep 2010 13:22:17 +0200 Subject: Add a bunch of TODO's to the API implementation. --- nova/endpoint/cloud.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index 715470f30..5dd1bd340 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -305,6 +305,18 @@ class CloudController(object): return True + # TODO(soren): Lots and lots of input validation. We're accepting + # strings here (such as ipProtocol), which is put into + # filter rules verbatim. + # TODO(soren): Dupe detection. Adding the same rule twice actually + # adds the same rule twice to the rule set, which is + # pointless. + # TODO(soren): This has only been tested with Boto as the client. + # Unfortunately, it seems Boto is using an old API + # for these operations, so support for newer API versions + # is sketchy. + # TODO(soren): De-duplicate the turning method arguments into dict stuff. + # revoke_security_group_ingress uses the exact same logic. @rbac.allow('netadmin') def authorize_security_group_ingress(self, context, group_name, to_port=None, from_port=None, @@ -350,7 +362,7 @@ class CloudController(object): if source_security_group_owner_id: # Parse user:project for source group. source_parts = source_security_group_owner_id.split(':') - + # If no project name specified, assume it's same as user name. # Since we're looking up by project name, the user name is not # used here. It's only read for EC2 API compatibility. @@ -360,14 +372,14 @@ class CloudController(object): source_project_id = parts[0] else: source_project_id = context.project.id - + return source_project_id @rbac.allow('netadmin') def create_security_group(self, context, group_name, group_description): if db.securitygroup_exists(context, context.project.id, group_name): raise exception.ApiError('group %s already exists' % group_name) - + group = {'user_id' : context.user.id, 'project_id': context.project.id, 'name': group_name, -- cgit From 01a041dd732ae9c56533f6eac25f08c34917d733 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 14 Sep 2010 15:17:52 +0200 Subject: Fix up rule generation. It turns out nwfilter gets very, very wonky indeed if you mix rules and rules. Setting a TCP rule adds an early rule to ebtables that ends up overriding the rules which are last in that table. --- nova/db/sqlalchemy/session.py | 11 ++++++----- nova/endpoint/cloud.py | 3 +-- nova/tests/virt_unittest.py | 6 +++--- nova/virt/libvirt_conn.py | 41 +++++++++++++++++++++++++++-------------- 4 files changed, 37 insertions(+), 24 deletions(-) diff --git a/nova/db/sqlalchemy/session.py b/nova/db/sqlalchemy/session.py index 69a205378..fffbd3443 100644 --- a/nova/db/sqlalchemy/session.py +++ b/nova/db/sqlalchemy/session.py @@ -20,7 +20,7 @@ Session Handling for SQLAlchemy backend """ from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker, scoped_session from nova import flags @@ -36,7 +36,8 @@ def get_session(autocommit=True, expire_on_commit=False): if not _MAKER: if not _ENGINE: _ENGINE = create_engine(FLAGS.sql_connection, echo=False) - _MAKER = sessionmaker(bind=_ENGINE, - autocommit=autocommit, - expire_on_commit=expire_on_commit) - return _MAKER() + _MAKER = scoped_session(sessionmaker(bind=_ENGINE, + autocommit=autocommit, + expire_on_commit=expire_on_commit)) + session = _MAKER() + return session diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index 5dd1bd340..fc83a9d1c 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -326,7 +326,7 @@ class CloudController(object): security_group = db.security_group_get_by_name(context, context.project.id, group_name) - values = { 'parent_group_id' : security_group.id } + values = { 'parent_group' : security_group } if source_security_group_name: source_project_id = self._get_source_project_id(context, @@ -349,7 +349,6 @@ class CloudController(object): else: # If cidr based filtering, protocol and ports are mandatory if 'cidr' in values: - print values return None security_group_rule = db.security_group_rule_create(context, values) diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py index 8cafa778e..d5a6d11f8 100644 --- a/nova/tests/virt_unittest.py +++ b/nova/tests/virt_unittest.py @@ -118,15 +118,15 @@ class NWFilterTestCase(test.TrialTestCase): self.assertEqual(len(rules), 1) # It's supposed to allow inbound traffic. - self.assertEqual(rules[0].getAttribute('action'), 'allow') + self.assertEqual(rules[0].getAttribute('action'), 'accept') self.assertEqual(rules[0].getAttribute('direction'), 'in') # Must be lower priority than the base filter (which blocks everything) self.assertTrue(int(rules[0].getAttribute('priority')) < 1000) - ip_conditions = rules[0].getElementsByTagName('ip') + ip_conditions = rules[0].getElementsByTagName('tcp') self.assertEqual(len(ip_conditions), 1) - self.assertEqual(ip_conditions[0].getAttribute('protocol'), 'tcp') + self.assertEqual(ip_conditions[0].getAttribute('srcipaddr'), '0.0.0.0/0') self.assertEqual(ip_conditions[0].getAttribute('dstportstart'), '80') self.assertEqual(ip_conditions[0].getAttribute('dstportend'), '81') diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index 00a80989f..aaa2c69b6 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -290,7 +290,6 @@ class LibvirtConnection(object): address = db.instance_get_fixed_address(None, inst['id']) with open(FLAGS.injected_network_template) as f: net = f.read() % {'address': address, - 'network': network_ref['network'], 'netmask': network_ref['netmask'], 'gateway': network_ref['gateway'], 'broadcast': network_ref['broadcast'], @@ -448,7 +447,7 @@ class LibvirtConnection(object): def refresh_security_group(self, security_group_id): - fw = self.NWFilterFirewall(self._conn) + fw = NWFilterFirewall(self._conn) fw.ensure_security_group_filter(security_group_id, override=True) @@ -541,19 +540,26 @@ class NWFilterFirewall(object): return 'nova-secgroup-%d' % (security_group_id,) + # TODO(soren): Should override be the default (and should it even + # be optional? We save a bit of processing time in + # libvirt by only defining this conditionally, but + # we still have to go and ask libvirt if the group + # is already defined, and there's the off chance of + # of inconsitencies having snuck in which would get + # fixed by just redefining the filter. def define_filter(self, name, xml_generator, override=False): if not override: def _already_exists_check(filterlist, filter): return filter in filterlist - def _define_if_not_exists(exists, xml_generator): - if not exists: - xml = xml_generator() - return threads.deferToThread(self._conn.nwfilterDefineXML, xml) d = threads.deferToThread(self._conn.listNWFilters) d.addCallback(_already_exists_check, name) else: # Pretend we looked it up and it wasn't defined d = defer.succeed(False) + def _define_if_not_exists(exists, xml_generator): + if not exists: + xml = xml_generator() + return threads.deferToThread(self._conn.nwfilterDefineXML, xml) d.addCallback(_define_if_not_exists, xml_generator) return d @@ -573,13 +579,20 @@ class NWFilterFirewall(object): security_group = db.security_group_get({}, security_group_id) rule_xml = "" for rule in security_group.rules: - rule_xml += "" + rule_xml += "" if rule.cidr: - rule_xml += ("" + - "priority='900'\n") % \ - (rule.cidr, rule.protocol, - rule.from_port, rule.to_port) - rule_xml += "" - xml = '''%s''' % (security_group_id, rule_xml,) + rule_xml += "<%s srcipaddr='%s' " % (rule.protocol, rule.cidr) + if rule.protocol in ['tcp', 'udp']: + rule_xml += "dstportstart='%s' dstportend='%s' " % \ + (rule.from_port, rule.to_port) + elif rule.protocol == 'icmp': + logging.info('rule.protocol: %r, rule.from_port: %r, rule.to_port: %r' % (rule.protocol, rule.from_port, rule.to_port)) + if rule.from_port != -1: + rule_xml += "type='%s' " % rule.from_port + if rule.to_port != -1: + rule_xml += "code='%s' " % rule.to_port + + rule_xml += '/>\n' + rule_xml += "\n" + xml = '''%s''' % (security_group_id, rule_xml,) return xml -- cgit From 65113c4aa92fa5e803bbe1ab56f7facf57753962 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 14 Sep 2010 15:20:08 +0200 Subject: Make refresh_security_groups play well with inlineCallbacks. --- nova/compute/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nova/compute/manager.py b/nova/compute/manager.py index a00fd9baa..1f3a181ff 100644 --- a/nova/compute/manager.py +++ b/nova/compute/manager.py @@ -64,7 +64,7 @@ class ComputeManager(manager.Manager): @defer.inlineCallbacks @exception.wrap_exception def refresh_security_group(self, context, security_group_id, **_kwargs): - self.driver.refresh_security_group(security_group_id) + yield self.driver.refresh_security_group(security_group_id) @defer.inlineCallbacks @exception.wrap_exception -- cgit From 85dbf6162d7b22991389db397f9aa1871464737f Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 14 Sep 2010 15:22:56 +0200 Subject: Cast process input to a str. It must not be unicode, but stuff that comes out of the database might very well be unicode, so using such a value in a template makes the whole thing unicode. --- nova/process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nova/process.py b/nova/process.py index 74725c157..bda8147d5 100644 --- a/nova/process.py +++ b/nova/process.py @@ -113,7 +113,7 @@ class BackRelayWithInput(protocol.ProcessProtocol): if self.started_deferred: self.started_deferred.callback(self) if self.process_input: - self.transport.write(self.process_input) + self.transport.write(str(self.process_input)) self.transport.closeStdin() def get_process_output(executable, args=None, env=None, path=None, -- cgit From b6932a9553e45c122af8a71f6300ac62381efb94 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 14 Sep 2010 15:23:29 +0200 Subject: Network model has network_str attribute. --- nova/network/manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nova/network/manager.py b/nova/network/manager.py index 83de5d023..bca3217f0 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -193,7 +193,6 @@ class FlatManager(NetworkManager): # in the datastore? net = {} net['injected'] = True - net['network_str'] = FLAGS.flat_network_network net['netmask'] = FLAGS.flat_network_netmask net['bridge'] = FLAGS.flat_network_bridge net['gateway'] = FLAGS.flat_network_gateway -- cgit From 587b21cc00919cc29e2f815fc9de3e3ad6e6fa30 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 14 Sep 2010 15:23:58 +0200 Subject: Leave out the network setting from the interfaces template. It does not get passed anymore. --- nova/virt/interfaces.template | 1 - 1 file changed, 1 deletion(-) diff --git a/nova/virt/interfaces.template b/nova/virt/interfaces.template index 11df301f6..87b92b84a 100644 --- a/nova/virt/interfaces.template +++ b/nova/virt/interfaces.template @@ -10,7 +10,6 @@ auto eth0 iface eth0 inet static address %(address)s netmask %(netmask)s - network %(network)s broadcast %(broadcast)s gateway %(gateway)s dns-nameservers %(dns)s -- cgit From faebe1ecd4aec4e2050a12f191266beadc468134 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Wed, 15 Sep 2010 12:01:08 +0200 Subject: Clean up use of objects coming out of the ORM. --- nova/auth/manager.py | 12 ++++++------ nova/endpoint/api.py | 1 - nova/endpoint/cloud.py | 18 +++++++++--------- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/nova/auth/manager.py b/nova/auth/manager.py index 5529515a6..c4f964b80 100644 --- a/nova/auth/manager.py +++ b/nova/auth/manager.py @@ -531,11 +531,11 @@ class AuthManager(object): except: drv.delete_project(project.id) raise - - values = {'name': 'default', - 'description': 'default', - 'user_id': User.safe_id(manager_user), - 'project_id': project.id} + + values = { 'name' : 'default', + 'description' : 'default', + 'user_id' : User.safe_id(manager_user), + 'project_id' : project['id'] } db.security_group_create({}, values) return project @@ -597,7 +597,7 @@ class AuthManager(object): groups = db.security_group_get_by_project(context={}, project_id=project_id) for group in groups: - db.security_group_destroy({}, group.id) + db.security_group_destroy({}, group['id']) except: logging.exception('Could not destroy security groups for %s', project) diff --git a/nova/endpoint/api.py b/nova/endpoint/api.py index 1f37aeb02..40be00bb7 100755 --- a/nova/endpoint/api.py +++ b/nova/endpoint/api.py @@ -135,7 +135,6 @@ class APIRequest(object): response = xml.toxml() xml.unlink() -# print response _log.debug(response) return response diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index fc83a9d1c..0289de285 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -93,7 +93,7 @@ class CloudController(object): result[instance['key_name']] = [line] return result - def _refresh_security_group(self, security_group): + def _trigger_refresh_security_group(self, security_group): nodes = set([instance.host for instance in security_group.instances]) for node in nodes: rpc.call('%s.%s' % (FLAGS.compute_topic, node), @@ -227,7 +227,7 @@ class CloudController(object): groups = db.security_group_get_all(context) else: groups = db.security_group_get_by_project(context, - context.project.id) + context.project['id']) groups = [self._format_security_group(context, g) for g in groups] if not group_name is None: groups = [g for g in groups if g.name in group_name] @@ -265,7 +265,7 @@ class CloudController(object): source_security_group_name=None, source_security_group_owner_id=None): security_group = db.security_group_get_by_name(context, - context.project.id, + context.project['id'], group_name) criteria = {} @@ -301,12 +301,12 @@ class CloudController(object): # If we make it here, we have a match db.security_group_rule_destroy(context, rule.id) - self._refresh_security_group(security_group) + self._trigger_refresh_security_group(security_group) return True # TODO(soren): Lots and lots of input validation. We're accepting - # strings here (such as ipProtocol), which is put into + # strings here (such as ipProtocol), which are put into # filter rules verbatim. # TODO(soren): Dupe detection. Adding the same rule twice actually # adds the same rule twice to the rule set, which is @@ -324,7 +324,7 @@ class CloudController(object): source_security_group_name=None, source_security_group_owner_id=None): security_group = db.security_group_get_by_name(context, - context.project.id, + context.project['id'], group_name) values = { 'parent_group' : security_group } @@ -366,11 +366,11 @@ class CloudController(object): # Since we're looking up by project name, the user name is not # used here. It's only read for EC2 API compatibility. if len(source_parts) == 2: - source_project_id = parts[1] + source_project_id = source_parts[1] else: - source_project_id = parts[0] + source_project_id = source_parts[0] else: - source_project_id = context.project.id + source_project_id = context.project['id'] return source_project_id -- cgit From 62871e83ba9b7bd8b17a7c457d8af7feb18853ea Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Wed, 15 Sep 2010 12:05:37 +0200 Subject: More ORM object cleanup. --- nova/endpoint/cloud.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index 0289de285..32732e9d5 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -700,9 +700,9 @@ class CloudController(object): security_groups = [] for security_group_name in security_group_arg: group = db.security_group_get_by_project(context, - context.project.id, + context.project['id'], security_group_name) - security_groups.append(group) + security_groups.append(group['id']) reservation_id = utils.generate_uid('r') base_options = {} @@ -716,11 +716,14 @@ class CloudController(object): base_options['project_id'] = context.project.id base_options['user_data'] = kwargs.get('user_data', '') base_options['instance_type'] = kwargs.get('instance_type', 'm1.small') - base_options['security_groups'] = security_groups for num in range(int(kwargs['max_count'])): inst_id = db.instance_create(context, base_options) + for security_group_id in security_groups: + db.instance_add_security_group(context, inst_id, + security_group_id) + inst = {} inst['mac_address'] = utils.generate_mac() inst['launch_index'] = num -- cgit From 0cb25fddcad2626ce617f5c2472cea1c02f1d961 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Wed, 15 Sep 2010 13:56:17 +0200 Subject: Roll back my slightly over-zealous clean up work. --- nova/auth/manager.py | 4 ++-- nova/endpoint/cloud.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/nova/auth/manager.py b/nova/auth/manager.py index c4f964b80..323c48dd0 100644 --- a/nova/auth/manager.py +++ b/nova/auth/manager.py @@ -535,7 +535,7 @@ class AuthManager(object): values = { 'name' : 'default', 'description' : 'default', 'user_id' : User.safe_id(manager_user), - 'project_id' : project['id'] } + 'project_id' : project.id } db.security_group_create({}, values) return project @@ -601,7 +601,7 @@ class AuthManager(object): except: logging.exception('Could not destroy security groups for %s', project) - + with self.driver() as drv: drv.delete_project(Project.safe_id(project)) diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index 32732e9d5..ab3f5b2d9 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -227,7 +227,7 @@ class CloudController(object): groups = db.security_group_get_all(context) else: groups = db.security_group_get_by_project(context, - context.project['id']) + context.project.id) groups = [self._format_security_group(context, g) for g in groups] if not group_name is None: groups = [g for g in groups if g.name in group_name] @@ -265,7 +265,7 @@ class CloudController(object): source_security_group_name=None, source_security_group_owner_id=None): security_group = db.security_group_get_by_name(context, - context.project['id'], + context.project.id, group_name) criteria = {} @@ -324,7 +324,7 @@ class CloudController(object): source_security_group_name=None, source_security_group_owner_id=None): security_group = db.security_group_get_by_name(context, - context.project['id'], + context.project.id, group_name) values = { 'parent_group' : security_group } @@ -370,7 +370,7 @@ class CloudController(object): else: source_project_id = source_parts[0] else: - source_project_id = context.project['id'] + source_project_id = context.project.id return source_project_id @@ -700,7 +700,7 @@ class CloudController(object): security_groups = [] for security_group_name in security_group_arg: group = db.security_group_get_by_project(context, - context.project['id'], + context.project.id, security_group_name) security_groups.append(group['id']) -- cgit From 9196b74080d5effd8dcfacce9de7d2dd37fcba1b Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Wed, 15 Sep 2010 14:04:07 +0200 Subject: Clean up use of ORM to remove the need for scoped_session. --- nova/db/api.py | 6 +++--- nova/db/sqlalchemy/api.py | 9 +++++---- nova/db/sqlalchemy/session.py | 4 ++-- nova/endpoint/cloud.py | 4 ++-- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/nova/db/api.py b/nova/db/api.py index 1d10b1987..fa937dab2 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -463,12 +463,12 @@ def security_group_get(context, security_group_id): def security_group_get_by_name(context, project_id, group_name): """Returns a security group with the specified name from a project""" - return IMPL.securitygroup_get_by_name(context, project_id, group_name) + return IMPL.security_group_get_by_name(context, project_id, group_name) def security_group_get_by_project(context, project_id): """Get all security groups belonging to a project""" - return IMPL.securitygroup_get_by_project(context, project_id) + return IMPL.security_group_get_by_project(context, project_id) def security_group_get_by_instance(context, instance_id): @@ -478,7 +478,7 @@ def security_group_get_by_instance(context, instance_id): def securitygroup_exists(context, project_id, group_name): """Indicates if a group name exists in a project""" - return IMPL.securitygroup_exists(context, project_id, group_name) + return IMPL.security_group_exists(context, project_id, group_name) def security_group_create(context, values): diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 11779e30c..038bb7f23 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -646,10 +646,11 @@ def security_group_get(_context, security_group_id): ).get(security_group_id) -def securitygroup_get_by_name(context, project_id, group_name): +def security_group_get_by_name(context, project_id, group_name): session = get_session() group_ref = session.query(models.SecurityGroup ).options(eagerload('rules') + ).options(eagerload('instances') ).filter_by(project_id=project_id ).filter_by(name=group_name ).filter_by(deleted=False @@ -662,7 +663,7 @@ def securitygroup_get_by_name(context, project_id, group_name): return group_ref -def securitygroup_get_by_project(_context, project_id): +def security_group_get_by_project(_context, project_id): session = get_session() return session.query(models.SecurityGroup ).options(eagerload('rules') @@ -681,9 +682,9 @@ def security_group_get_by_instance(_context, instance_id): ).all() -def securitygroup_exists(_context, project_id, group_name): +def security_group_exists(_context, project_id, group_name): try: - group = securitygroup_get_by_name(_context, project_id, group_name) + group = security_group_get_by_name(_context, project_id, group_name) return group != None except exception.NotFound: return False diff --git a/nova/db/sqlalchemy/session.py b/nova/db/sqlalchemy/session.py index fffbd3443..826754f6a 100644 --- a/nova/db/sqlalchemy/session.py +++ b/nova/db/sqlalchemy/session.py @@ -20,7 +20,7 @@ Session Handling for SQLAlchemy backend """ from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, scoped_session +from sqlalchemy.orm import sessionmaker from nova import flags @@ -36,7 +36,7 @@ def get_session(autocommit=True, expire_on_commit=False): if not _MAKER: if not _ENGINE: _ENGINE = create_engine(FLAGS.sql_connection, echo=False) - _MAKER = scoped_session(sessionmaker(bind=_ENGINE, + _MAKER = (sessionmaker(bind=_ENGINE, autocommit=autocommit, expire_on_commit=expire_on_commit)) session = _MAKER() diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index ab3f5b2d9..d2606e3a7 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -326,7 +326,7 @@ class CloudController(object): security_group = db.security_group_get_by_name(context, context.project.id, group_name) - values = { 'parent_group' : security_group } + values = { 'parent_group_id' : security_group.id } if source_security_group_name: source_project_id = self._get_source_project_id(context, @@ -353,7 +353,7 @@ class CloudController(object): security_group_rule = db.security_group_rule_create(context, values) - self._refresh_security_group(security_group) + self._trigger_refresh_security_group(security_group) return True -- cgit From 28336ed41e0d44d7600588a6014f6253e4b87a42 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Wed, 15 Sep 2010 14:27:34 +0200 Subject: Address a couple of the TODO's: We now have half-decent input validation for AuthorizeSecurityGroupIngress and RevokeDitto. --- nova/endpoint/cloud.py | 95 +++++++++++++++++++++++--------------------------- nova/exception.py | 3 ++ 2 files changed, 46 insertions(+), 52 deletions(-) diff --git a/nova/endpoint/cloud.py b/nova/endpoint/cloud.py index d2606e3a7..d1ccf24ff 100644 --- a/nova/endpoint/cloud.py +++ b/nova/endpoint/cloud.py @@ -27,6 +27,8 @@ import logging import os import time +import IPy + from twisted.internet import defer from nova import db @@ -43,6 +45,7 @@ from nova.endpoint import images FLAGS = flags.FLAGS flags.DECLARE('storage_availability_zone', 'nova.volume.manager') +InvalidInputException = exception.InvalidInputException def _gen_key(user_id, key_name): """ Tuck this into AuthManager """ @@ -257,18 +260,14 @@ class CloudController(object): return g - @rbac.allow('netadmin') - def revoke_security_group_ingress(self, context, group_name, - to_port=None, from_port=None, - ip_protocol=None, cidr_ip=None, - user_id=None, - source_security_group_name=None, - source_security_group_owner_id=None): - security_group = db.security_group_get_by_name(context, - context.project.id, - group_name) + def _authorize_revoke_rule_args_to_dict(self, context, + to_port=None, from_port=None, + ip_protocol=None, cidr_ip=None, + user_id=None, + source_security_group_name=None, + source_security_group_owner_id=None): - criteria = {} + values = {} if source_security_group_name: source_project_id = self._get_source_project_id(context, @@ -278,21 +277,43 @@ class CloudController(object): db.security_group_get_by_name(context, source_project_id, source_security_group_name) - - criteria['group_id'] = source_security_group + values['group_id'] = source_security_group.id elif cidr_ip: - criteria['cidr'] = cidr_ip + # If this fails, it throws an exception. This is what we want. + IPy.IP(cidr_ip) + values['cidr'] = cidr_ip else: return { 'return': False } if ip_protocol and from_port and to_port: - criteria['protocol'] = ip_protocol - criteria['from_port'] = from_port - criteria['to_port'] = to_port + from_port = int(from_port) + to_port = int(to_port) + ip_protocol = str(ip_protocol) + + if ip_protocol.upper() not in ['TCP','UDP','ICMP']: + raise InvalidInputException('%s is not a valid ipProtocol' % + (ip_protocol,)) + if ((min(from_port, to_port) < -1) or + (max(from_port, to_port) > 65535)): + raise InvalidInputException('Invalid port range') + + values['protocol'] = ip_protocol + values['from_port'] = from_port + values['to_port'] = to_port else: # If cidr based filtering, protocol and ports are mandatory - if 'cidr' in criteria: - return { 'return': False } + if 'cidr' in values: + return None + + return values + + @rbac.allow('netadmin') + def revoke_security_group_ingress(self, context, group_name, **kwargs): + security_group = db.security_group_get_by_name(context, + context.project.id, + group_name) + + criteria = self._authorize_revoke_rule_args_to_dict(context, **kwargs) for rule in security_group.rules: for (k,v) in criteria.iteritems(): @@ -305,9 +326,6 @@ class CloudController(object): return True - # TODO(soren): Lots and lots of input validation. We're accepting - # strings here (such as ipProtocol), which are put into - # filter rules verbatim. # TODO(soren): Dupe detection. Adding the same rule twice actually # adds the same rule twice to the rule set, which is # pointless. @@ -315,41 +333,14 @@ class CloudController(object): # Unfortunately, it seems Boto is using an old API # for these operations, so support for newer API versions # is sketchy. - # TODO(soren): De-duplicate the turning method arguments into dict stuff. - # revoke_security_group_ingress uses the exact same logic. @rbac.allow('netadmin') - def authorize_security_group_ingress(self, context, group_name, - to_port=None, from_port=None, - ip_protocol=None, cidr_ip=None, - source_security_group_name=None, - source_security_group_owner_id=None): + def authorize_security_group_ingress(self, context, group_name, **kwargs): security_group = db.security_group_get_by_name(context, context.project.id, group_name) - values = { 'parent_group_id' : security_group.id } - if source_security_group_name: - source_project_id = self._get_source_project_id(context, - source_security_group_owner_id) - - source_security_group = \ - db.security_group_get_by_name(context, - source_project_id, - source_security_group_name) - values['group_id'] = source_security_group.id - elif cidr_ip: - values['cidr'] = cidr_ip - else: - return { 'return': False } - - if ip_protocol and from_port and to_port: - values['protocol'] = ip_protocol - values['from_port'] = from_port - values['to_port'] = to_port - else: - # If cidr based filtering, protocol and ports are mandatory - if 'cidr' in values: - return None + values = self._authorize_revoke_rule_args_to_dict(context, **kwargs) + values['parent_group_id'] = security_group.id security_group_rule = db.security_group_rule_create(context, values) diff --git a/nova/exception.py b/nova/exception.py index 29bcb17f8..43e5c36c6 100644 --- a/nova/exception.py +++ b/nova/exception.py @@ -57,6 +57,9 @@ class NotEmpty(Error): class Invalid(Error): pass +class InvalidInputException(Error): + pass + def wrap_exception(f): def _wrap(*args, **kw): -- cgit From 169ac33d89e0721c3e5229f2c58b799b64f1b51d Mon Sep 17 00:00:00 2001 From: Vishvananda Ishaya Date: Tue, 21 Sep 2010 16:32:07 -0700 Subject: typo in instance_get --- nova/db/sqlalchemy/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 5ceb4c814..90c7938df 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -380,7 +380,7 @@ def instance_destroy(_context, instance_id): def instance_get(context, instance_id): session = get_session() - result = session.query(models.FixedIp + result = session.query(models.Instance ).options(eagerload('security_groups') ).filter_by(instance_id=instance_id ).filter_by(deleted=_deleted(context) -- cgit From e78273f72640eb9cbd1797d8d66dc41dcb96bee0 Mon Sep 17 00:00:00 2001 From: Vishvananda Ishaya Date: Tue, 21 Sep 2010 16:43:32 -0700 Subject: typo in instance_get --- nova/db/sqlalchemy/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 90c7938df..420fd4a4a 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -382,7 +382,7 @@ def instance_get(context, instance_id): session = get_session() result = session.query(models.Instance ).options(eagerload('security_groups') - ).filter_by(instance_id=instance_id + ).filter_by(id=instance_id ).filter_by(deleted=_deleted(context) ).first() if not result: -- cgit From fed57c47da49a0457fce8fec3b59c9142e62785e Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Thu, 23 Sep 2010 13:59:33 +0200 Subject: Address Vishy's comments. --- nova/api/ec2/cloud.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 046aee14a..0f0aa327c 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -725,9 +725,9 @@ class CloudController(object): security_groups = [] for security_group_name in security_group_arg: - group = db.security_group_get_by_project(context, - context.project.id, - security_group_name) + group = db.security_group_get_by_name(context, + context.project.id, + security_group_name) security_groups.append(group['id']) reservation_id = utils.generate_uid('r') @@ -744,6 +744,7 @@ class CloudController(object): base_options['user_data'] = kwargs.get('user_data', '') type_data = INSTANCE_TYPES[instance_type] + base_options['instance_type'] = instance_type base_options['memory_mb'] = type_data['memory_mb'] base_options['vcpus'] = type_data['vcpus'] base_options['local_gb'] = type_data['local_gb'] -- cgit From ab2bed9ed60c5333a0f9ba3e679df9893781b72f Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 27 Sep 2010 10:39:52 +0200 Subject: Apply IP configuration to bridge regardless of whether it existed before. The fixes a race condition on hosts running both compute and network where, if compute got there first, it would set up the bridge, but not do IP configuration (because that's meant to happen on the network host), and when network came around, it would see the interface already there and not configure it further. --- nova/network/linux_net.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/nova/network/linux_net.py b/nova/network/linux_net.py index 41aeb5da7..9d5bd8495 100644 --- a/nova/network/linux_net.py +++ b/nova/network/linux_net.py @@ -118,15 +118,16 @@ def ensure_bridge(bridge, interface, net_attrs=None): # _execute("sudo brctl setageing %s 10" % bridge) _execute("sudo brctl stp %s off" % bridge) _execute("sudo brctl addif %s %s" % (bridge, interface)) - if net_attrs: - _execute("sudo ifconfig %s %s broadcast %s netmask %s up" % \ - (bridge, - net_attrs['gateway'], - net_attrs['broadcast'], - net_attrs['netmask'])) - _confirm_rule("FORWARD --in-interface %s -j ACCEPT" % bridge) - else: - _execute("sudo ifconfig %s up" % bridge) + + if net_attrs: + _execute("sudo ifconfig %s %s broadcast %s netmask %s up" % \ + (bridge, + net_attrs['gateway'], + net_attrs['broadcast'], + net_attrs['netmask'])) + _confirm_rule("FORWARD --in-interface %s -j ACCEPT" % bridge) + else: + _execute("sudo ifconfig %s up" % bridge) def get_dhcp_hosts(context, network_id): -- cgit From b4dbc4efa576af61ddc26d1c277237ad4bcdfcfa Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 27 Sep 2010 12:07:55 +0200 Subject: Add db api methods for retrieving the networks for which a host is the designated network host. --- nova/db/api.py | 12 ++++++++++++ nova/db/sqlalchemy/api.py | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/nova/db/api.py b/nova/db/api.py index c1cb1953a..4657408db 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -554,3 +554,15 @@ def volume_update(context, volume_id, values): """ return IMPL.volume_update(context, volume_id, values) + + +################### + + +def host_get_networks(context, host): + """Return all networks for which the given host is the designated + network host + """ + return IMPL.host_get_networks(context, host) + + diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 2b0dd6ea6..6e6b0e3fc 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -848,3 +848,15 @@ def volume_update(_context, volume_id, values): for (key, value) in values.iteritems(): volume_ref[key] = value volume_ref.save(session=session) + + +################### + + +def host_get_networks(context, host): + session = get_session() + with session.begin(): + return session.query(models.Network + ).filter_by(deleted=False + ).filter_by(host=host + ).all() -- cgit From e70948dbec0b21664739b2b7cdb1cc3da92bd01b Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 27 Sep 2010 12:08:40 +0200 Subject: Set up network at manager instantiation time to ensure we're ready to handle the networks we're already supposed to handle. --- nova/network/manager.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nova/network/manager.py b/nova/network/manager.py index 191c1d364..c17823f1e 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -80,6 +80,10 @@ class NetworkManager(manager.Manager): network_driver = FLAGS.network_driver self.driver = utils.import_object(network_driver) super(NetworkManager, self).__init__(*args, **kwargs) + # Set up networking for the projects for which we're already + # the designated network host. + for network in self.db.host_get_networks(None, host=kwargs['host']): + self._on_set_network_host(None, network['id']) def set_network_host(self, context, project_id): """Safely sets the host of the projects network""" -- cgit From 47cccfc21dfd4c1acf74b6d84ced8abba8c40e76 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 27 Sep 2010 12:14:20 +0200 Subject: Ensure dnsmasq can read updates to dnsmasq conffile. --- nova/network/linux_net.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/nova/network/linux_net.py b/nova/network/linux_net.py index 9d5bd8495..7d708968c 100644 --- a/nova/network/linux_net.py +++ b/nova/network/linux_net.py @@ -150,9 +150,14 @@ def update_dhcp(context, network_id): signal causing it to reload, otherwise spawn a new instance """ network_ref = db.network_get(context, network_id) - with open(_dhcp_file(network_ref['vlan'], 'conf'), 'w') as f: + + conffile = _dhcp_file(network_ref['vlan'], 'conf') + with open(conffile, 'w') as f: f.write(get_dhcp_hosts(context, network_id)) + # Make sure dnsmasq can actually read it (it setuid()s to "nobody") + os.chmod(conffile, 0644) + pid = _dnsmasq_pid_for(network_ref['vlan']) # if dnsmasq is already running, then tell it to reload -- cgit From 928df580e5973bc1fd3871a0aa31886302bb9268 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 27 Sep 2010 13:03:29 +0200 Subject: Add a flag the specifies where to find nova-dhcpbridge. --- nova/network/linux_net.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/nova/network/linux_net.py b/nova/network/linux_net.py index 7d708968c..bfa73dca0 100644 --- a/nova/network/linux_net.py +++ b/nova/network/linux_net.py @@ -28,6 +28,11 @@ from nova import flags from nova import utils +def _bin_file(script): + """Return the absolute path to scipt in the bin directory""" + return os.path.abspath(os.path.join(__file__, "../../../bin", script)) + + FLAGS = flags.FLAGS flags.DEFINE_string('dhcpbridge_flagfile', '/etc/nova/nova-dhcpbridge.conf', @@ -39,6 +44,8 @@ flags.DEFINE_string('public_interface', 'vlan1', 'Interface for public IP addresses') flags.DEFINE_string('bridge_dev', 'eth0', 'network device for bridges') +flags.DEFINE_string('dhcpbridge', _bin_file('nova-dhcpbridge'), + 'location of nova-dhcpbridge') DEFAULT_PORTS = [("tcp", 80), ("tcp", 22), ("udp", 1194), ("tcp", 443)] @@ -222,7 +229,7 @@ def _dnsmasq_cmd(net): ' --except-interface=lo', ' --dhcp-range=%s,static,120s' % net['dhcp_start'], ' --dhcp-hostsfile=%s' % _dhcp_file(net['vlan'], 'conf'), - ' --dhcp-script=%s' % _bin_file('nova-dhcpbridge'), + ' --dhcp-script=%s' % FLAGS.dhcpbridge, ' --leasefile-ro'] return ''.join(cmd) @@ -244,11 +251,6 @@ def _dhcp_file(vlan, kind): return os.path.abspath("%s/nova-%s.%s" % (FLAGS.networks_path, vlan, kind)) -def _bin_file(script): - """Return the absolute path to scipt in the bin directory""" - return os.path.abspath(os.path.join(__file__, "../../../bin", script)) - - def _dnsmasq_pid_for(vlan): """Returns he pid for prior dnsmasq instance for a vlan -- cgit From 9dbdca83a8233110e94356415629ab9589b580d5 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 27 Sep 2010 13:13:29 +0200 Subject: Allow DHCP requests through, pass the IP of the gateway as the dhcp server. --- nova/virt/libvirt.qemu.xml.template | 1 + nova/virt/libvirt.uml.xml.template | 1 + nova/virt/libvirt_conn.py | 6 +++++- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/nova/virt/libvirt.qemu.xml.template b/nova/virt/libvirt.qemu.xml.template index d02aa9114..2538b1ade 100644 --- a/nova/virt/libvirt.qemu.xml.template +++ b/nova/virt/libvirt.qemu.xml.template @@ -22,6 +22,7 @@ + diff --git a/nova/virt/libvirt.uml.xml.template b/nova/virt/libvirt.uml.xml.template index bf3f2f86a..bb8b47911 100644 --- a/nova/virt/libvirt.uml.xml.template +++ b/nova/virt/libvirt.uml.xml.template @@ -16,6 +16,7 @@ + diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index 4c4c7980b..93f6977d4 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -319,6 +319,8 @@ class LibvirtConnection(object): # FIXME(vish): stick this in db instance_type = instance_types.INSTANCE_TYPES[instance['instance_type']] ip_address = db.instance_get_fixed_address({}, instance['id']) + # Assume that the gateway also acts as the dhcp server. + dhcp_server = network['gateway'] xml_info = {'type': FLAGS.libvirt_type, 'name': instance['name'], 'basepath': os.path.join(FLAGS.instances_path, @@ -327,7 +329,8 @@ class LibvirtConnection(object): 'vcpus': instance_type['vcpus'], 'bridge_name': network['bridge'], 'mac_address': instance['mac_address'], - 'ip_address': ip_address } + 'ip_address': ip_address, + 'dhcp_server': dhcp_server } libvirt_xml = self.libvirt_xml % xml_info logging.debug('instance %s: finished toXML method', instance['name']) @@ -498,6 +501,7 @@ class NWFilterFirewall(object): return ''' 26717364-50cf-42d1-8185-29bf893ab110 + -- cgit From 04fa25e63bf37222d2b1cf88837f1c85cf944f54 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 27 Sep 2010 13:23:39 +0200 Subject: Only call _on_set_network_host on nova-network hosts. --- nova/network/manager.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/nova/network/manager.py b/nova/network/manager.py index c17823f1e..2530f04b7 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -80,10 +80,13 @@ class NetworkManager(manager.Manager): network_driver = FLAGS.network_driver self.driver = utils.import_object(network_driver) super(NetworkManager, self).__init__(*args, **kwargs) - # Set up networking for the projects for which we're already - # the designated network host. - for network in self.db.host_get_networks(None, host=kwargs['host']): - self._on_set_network_host(None, network['id']) + # Host only gets passed if being instantiated as part of the network + # service. + if 'host' in kwargs: + # Set up networking for the projects for which we're already + # the designated network host. + for network in self.db.host_get_networks(None, host=kwargs['host']): + self._on_set_network_host(None, network['id']) def set_network_host(self, context, project_id): """Safely sets the host of the projects network""" -- cgit From e6ada2403cb83070c270a96c7e371513d21e27f4 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 27 Sep 2010 15:13:11 +0200 Subject: If an instance never got scheduled for whatever reason, its host will turn up as None. Filter those out to make sure refresh works. --- nova/api/ec2/cloud.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 0f0aa327c..7330967fa 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -116,7 +116,8 @@ class CloudController(object): return result def _trigger_refresh_security_group(self, security_group): - nodes = set([instance.host for instance in security_group.instances]) + nodes = set([instance['host'] for instance in security_group.instances + if instance['host'] is not None]) for node in nodes: rpc.call('%s.%s' % (FLAGS.compute_topic, node), { "method": "refresh_security_group", -- cgit From 523f1c95ac12ed4782476c3273b337601ad8b6ae Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 27 Sep 2010 21:49:24 +0200 Subject: If neither a security group nor a cidr has been passed, assume cidr=0.0.0.0/0 --- nova/api/ec2/cloud.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 7330967fa..4cf2666a5 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -301,7 +301,7 @@ class CloudController(object): IPy.IP(cidr_ip) values['cidr'] = cidr_ip else: - return { 'return': False } + values['cidr'] = '0.0.0.0/0' if ip_protocol and from_port and to_port: from_port = int(from_port) -- cgit From ab31fa628f4d9148aae8d42bbb41d721716c18e3 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 27 Sep 2010 21:49:53 +0200 Subject: Clean up nwfilter code. Move our filters into the ipv4 chain. --- nova/virt/libvirt_conn.py | 99 ++++++++++++++++++----------------------------- 1 file changed, 38 insertions(+), 61 deletions(-) diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index 93f6977d4..558854c38 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -497,20 +497,36 @@ class NWFilterFirewall(object): self._conn = get_connection - def nova_base_filter(self): - return ''' - 26717364-50cf-42d1-8185-29bf893ab110 - - - - - - - - -''' + nova_base_filter = ''' + 26717364-50cf-42d1-8185-29bf893ab110 + + + + + + + ''' + + nova_base_ipv4_filter = ''' + + ''' + + + nova_base_ipv6_filter = ''' + + ''' + + + def _define_filter(self, xml): + if callable(xml): + xml = xml() + d = threads.deferToThread(self._conn.nwfilterDefineXML, xml) + return d + @defer.inlineCallbacks def setup_nwfilters_for_instance(self, instance): """ Creates an NWFilter for the given instance. In the process, @@ -518,63 +534,24 @@ class NWFilterFirewall(object): the base filter are all in place. """ - d = self.ensure_base_filter() + yield self._define_filter(self.nova_base_ipv4_filter) + yield self._define_filter(self.nova_base_ipv6_filter) + yield self._define_filter(self.nova_base_filter) nwfilter_xml = ("\n" + - " \n" + " \n" ) % instance['name'] for security_group in instance.security_groups: - d.addCallback(lambda _:self.ensure_security_group_filter(security_group.id)) + yield self._define_filter( + self.security_group_to_nwfilter_xml(security_group['id'])) nwfilter_xml += (" \n" ) % security_group.id nwfilter_xml += "" - d.addCallback(lambda _: threads.deferToThread( - self._conn.nwfilterDefineXML, - nwfilter_xml)) - return d - - - def _nwfilter_name_for_security_group(self, security_group_id): - return 'nova-secgroup-%d' % (security_group_id,) - - - # TODO(soren): Should override be the default (and should it even - # be optional? We save a bit of processing time in - # libvirt by only defining this conditionally, but - # we still have to go and ask libvirt if the group - # is already defined, and there's the off chance of - # of inconsitencies having snuck in which would get - # fixed by just redefining the filter. - def define_filter(self, name, xml_generator, override=False): - if not override: - def _already_exists_check(filterlist, filter): - return filter in filterlist - d = threads.deferToThread(self._conn.listNWFilters) - d.addCallback(_already_exists_check, name) - else: - # Pretend we looked it up and it wasn't defined - d = defer.succeed(False) - def _define_if_not_exists(exists, xml_generator): - if not exists: - xml = xml_generator() - return threads.deferToThread(self._conn.nwfilterDefineXML, xml) - d.addCallback(_define_if_not_exists, xml_generator) - return d - - - def ensure_base_filter(self): - return self.define_filter('nova-base-filter', self.nova_base_filter) - - - def ensure_security_group_filter(self, security_group_id, override=False): - return self.define_filter( - self._nwfilter_name_for_security_group(security_group_id), - lambda:self.security_group_to_nwfilter_xml(security_group_id), - override=override) - + yield self._define_filter(nwfilter_xml) + return def security_group_to_nwfilter_xml(self, security_group_id): security_group = db.security_group_get({}, security_group_id) @@ -593,7 +570,7 @@ class NWFilterFirewall(object): if rule.to_port != -1: rule_xml += "code='%s' " % rule.to_port - rule_xml += '/>\n' + rule_xml += '/>\n' rule_xml += "\n" - xml = '''%s''' % (security_group_id, rule_xml,) + xml = '''%s''' % (security_group_id, rule_xml,) return xml -- cgit From e705b666679ecccfc3e91c8029f2c646849509ee Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 27 Sep 2010 21:57:13 +0200 Subject: Recreate ensure_security_group_filter. Needed for refresh. --- nova/virt/libvirt_conn.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index 558854c38..a7370e036 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -448,7 +448,7 @@ class LibvirtConnection(object): def refresh_security_group(self, security_group_id): fw = NWFilterFirewall(self._conn) - fw.ensure_security_group_filter(security_group_id, override=True) + fw.ensure_security_group_filter(security_group_id) class NWFilterFirewall(object): @@ -543,16 +543,20 @@ class NWFilterFirewall(object): ) % instance['name'] for security_group in instance.security_groups: - yield self._define_filter( - self.security_group_to_nwfilter_xml(security_group['id'])) + yield self.ensure_security_group_filter(security_group['id']) nwfilter_xml += (" \n" - ) % security_group.id + ) % security_group['id'] nwfilter_xml += "" yield self._define_filter(nwfilter_xml) return + def ensure_security_group_filter(self, security_group_id): + return self._define_filter( + self.security_group_to_nwfilter_xml(security_group_id)) + + def security_group_to_nwfilter_xml(self, security_group_id): security_group = db.security_group_get({}, security_group_id) rule_xml = "" -- cgit From 9140cd991e5507f65ff1d6a608bd8fd4c9956dbf Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 27 Sep 2010 22:00:17 +0200 Subject: Set priority of security group rules to 300 to make sure they override the defaults. --- nova/virt/libvirt_conn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index a7370e036..d90853084 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -561,7 +561,7 @@ class NWFilterFirewall(object): security_group = db.security_group_get({}, security_group_id) rule_xml = "" for rule in security_group.rules: - rule_xml += "" + rule_xml += "" if rule.cidr: rule_xml += "<%s srcipaddr='%s' " % (rule.protocol, rule.cidr) if rule.protocol in ['tcp', 'udp']: -- cgit From 574aa4bb03c6e79c204d73a8f2a146460cbdb848 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 28 Sep 2010 00:21:36 +0200 Subject: This is getting ridiculous. --- nova/virt/libvirt_conn.py | 50 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 10 deletions(-) diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index d90853084..854fa6761 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -503,20 +503,49 @@ class NWFilterFirewall(object): + ''' - nova_base_ipv4_filter = ''' - - ''' - - - nova_base_ipv6_filter = ''' - - ''' + nova_dhcp_filter = ''' + 891e4787-e5c0-d59b-cbd6-41bc3c6b36fc + + + + + + + ''' + + def nova_base_ipv4_filter(self): + retval = "" + for protocol in ['tcp', 'udp', 'icmp']: + for direction,action in [('out','accept'), + ('in','drop')]: + retval += """ + <%s /> + """ % (action, direction, protocol) + retval += '' + return retval + + + def nova_base_ipv6_filter(self): + retval = "" + for protocol in ['tcp', 'udp', 'icmp']: + for direction,action in [('out','accept'), + ('in','drop')]: + retval += """ + <%s-ipv6 /> + """ % (action, direction, protocol) + retval += '' + return retval def _define_filter(self, xml): @@ -536,6 +565,7 @@ class NWFilterFirewall(object): yield self._define_filter(self.nova_base_ipv4_filter) yield self._define_filter(self.nova_base_ipv6_filter) + yield self._define_filter(self.nova_dhcp_filter) yield self._define_filter(self.nova_base_filter) nwfilter_xml = ("\n" + -- cgit From 886534ba4d0281afc0d169546a8d55d3a5c8ece9 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 28 Sep 2010 09:07:48 +0200 Subject: Make the incoming blocking rules take precedence over the output accept rules. --- nova/virt/libvirt_conn.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index 854fa6761..40a921743 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -527,11 +527,11 @@ class NWFilterFirewall(object): def nova_base_ipv4_filter(self): retval = "" for protocol in ['tcp', 'udp', 'icmp']: - for direction,action in [('out','accept'), - ('in','drop')]: - retval += """ + for direction,action,priority in [('out','accept', 400), + ('in','drop', 399)]: + retval += """ <%s /> - """ % (action, direction, protocol) + """ % (action, direction, protocol, priority) retval += '' return retval @@ -539,11 +539,12 @@ class NWFilterFirewall(object): def nova_base_ipv6_filter(self): retval = "" for protocol in ['tcp', 'udp', 'icmp']: - for direction,action in [('out','accept'), - ('in','drop')]: - retval += """ + for direction,action,priority in [('out','accept',400), + ('in','drop',399)]: + retval += """ <%s-ipv6 /> - """ % (action, direction, protocol) + """ % (action, direction, + protocol, priority) retval += '' return retval -- cgit From 0dcf2e7e593cce4be1654fb4923ec4bb4524198f Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 28 Sep 2010 09:47:25 +0200 Subject: Make sure arguments to string format are in the correct order. --- nova/virt/libvirt_conn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index 40a921743..c86f3ffb7 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -531,7 +531,8 @@ class NWFilterFirewall(object): ('in','drop', 399)]: retval += """ <%s /> - """ % (action, direction, protocol, priority) + """ % (action, direction, + priority, protocol) retval += '' return retval @@ -544,7 +545,7 @@ class NWFilterFirewall(object): retval += """ <%s-ipv6 /> """ % (action, direction, - protocol, priority) + priority, protocol) retval += '' return retval -- cgit From f09fa50fd31ded3f2f31e020b54f2d3d2b380a35 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 28 Sep 2010 10:26:29 +0200 Subject: Improve unit tests for network filtering. It now tracks recursive filter dependencies, so even if we change the filter layering, it still correctly checks for the presence of the arp, mac, and ip spoofing filters. --- nova/tests/virt_unittest.py | 45 ++++++++++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py index 985236edf..f9ff0f71f 100644 --- a/nova/tests/virt_unittest.py +++ b/nova/tests/virt_unittest.py @@ -153,37 +153,48 @@ class NWFilterTestCase(test.TrialTestCase): return db.security_group_get_by_name({}, 'fake', 'testgroup') def test_creates_base_rule_first(self): - self.defined_filters = [] - self.fake_libvirt_connection.listNWFilters = lambda:self.defined_filters - self.base_filter_defined = False - self.i = 0 + # These come pre-defined by libvirt + self.defined_filters = ['no-mac-spoofing', + 'no-ip-spoofing', + 'no-arp-spoofing', + 'allow-dhcp-server'] + + self.recursive_depends = {} + for f in self.defined_filters: + self.recursive_depends[f] = [] + def _filterDefineXMLMock(xml): dom = parseString(xml) name = dom.firstChild.getAttribute('name') - if self.i == 0: - self.assertEqual(dom.firstChild.getAttribute('name'), - 'nova-base-filter') - elif self.i == 1: - self.assertTrue(name.startswith('nova-secgroup-'), - 'unexpected name: %s' % name) - elif self.i == 2: - self.assertTrue(name.startswith('nova-instance-'), - 'unexpected name: %s' % name) + self.recursive_depends[name] = [] + for f in dom.getElementsByTagName('filterref'): + ref = f.getAttribute('filter') + self.assertTrue(ref in self.defined_filters, + ('%s referenced filter that does ' + + 'not yet exist: %s') % (name, ref)) + dependencies = [ref] + self.recursive_depends[ref] + self.recursive_depends[name] += dependencies self.defined_filters.append(name) - self.i += 1 return True def _ensure_all_called(_): - self.assertEqual(self.i, 3) + instance_filter = 'nova-instance-i-1' + secgroup_filter = 'nova-secgroup-%s' % self.security_group['id'] + for required in [secgroup_filter, 'allow-dhcp-server', + 'no-arp-spoofing', 'no-ip-spoofing', + 'no-mac-spoofing']: + self.assertTrue(required in self.recursive_depends[instance_filter], + "Instance's filter does not include %s" % required) self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock inst_id = db.instance_create({}, {'user_id': 'fake', 'project_id': 'fake'})['id'] - security_group = self.setup_and_return_security_group() - db.instance_add_security_group({}, inst_id, security_group.id) + self.security_group = self.setup_and_return_security_group() + + db.instance_add_security_group({}, inst_id, self.security_group.id) instance = db.instance_get({}, inst_id) d = self.fw.setup_nwfilters_for_instance(instance) -- cgit From afc782e0e80a71ac8d1eb2f1d70e67375ba62aca Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 28 Sep 2010 10:59:55 +0200 Subject: Make sure we also start dnsmasq on startup if we're managing networks. --- nova/network/manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nova/network/manager.py b/nova/network/manager.py index 2530f04b7..20d4fe0f7 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -358,6 +358,7 @@ class VlanManager(NetworkManager): self.driver.ensure_vlan_bridge(network_ref['vlan'], network_ref['bridge'], network_ref) + self.driver.update_dhcp(context, network_id) @property def _bottom_reserved_ips(self): -- cgit From 687a90d6a7ad947c4a5851b1766a19209bb5e46f Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 28 Sep 2010 11:09:40 +0200 Subject: Call out to 'sudo kill' instead of using os.kill. dnsmasq runs as root or nobody, nova may or may not be running as root, so os.kill won't work. --- nova/network/linux_net.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nova/network/linux_net.py b/nova/network/linux_net.py index bfa73dca0..50d2831c3 100644 --- a/nova/network/linux_net.py +++ b/nova/network/linux_net.py @@ -172,7 +172,7 @@ def update_dhcp(context, network_id): # TODO(ja): use "/proc/%d/cmdline" % (pid) to determine if pid refers # correct dnsmasq process try: - os.kill(pid, signal.SIGHUP) + _execute('sudo kill -HUP %d' % pid) return except Exception as exc: # pylint: disable-msg=W0703 logging.debug("Hupping dnsmasq threw %s", exc) @@ -240,7 +240,7 @@ def _stop_dnsmasq(network): if pid: try: - os.kill(pid, signal.SIGTERM) + _execute('sudo kill -TERM %d' % pid) except Exception as exc: # pylint: disable-msg=W0703 logging.debug("Killing dnsmasq threw %s", exc) -- cgit From 7c8c2f57c752cd8681eef073349f9bdcaa95c868 Mon Sep 17 00:00:00 2001 From: "jaypipes@gmail.com" <> Date: Tue, 28 Sep 2010 14:48:03 -0400 Subject: Adds --force option to run_tests.sh to clear virtualenv. Useful when dependencies change --- run_tests.sh | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/run_tests.sh b/run_tests.sh index 6ea40d95e..ec727d094 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -6,6 +6,7 @@ function usage { echo "" echo " -V, --virtual-env Always use virtualenv. Install automatically if not present" echo " -N, --no-virtual-env Don't use virtualenv. Run tests in local environment" + echo " -f, --force Force a clean re-build of the virtual environment. Useful when dependencies have been added." echo " -h, --help Print this usage message" echo "" echo "Note: with no options specified, the script will try to run the tests in a virtual environment," @@ -14,20 +15,12 @@ function usage { exit } -function process_options { - array=$1 - elements=${#array[@]} - for (( x=0;x<$elements;x++)); do - process_option ${array[${x}]} - done -} - function process_option { - option=$1 - case $option in + case "$1" in -h|--help) usage;; -V|--virtual-env) let always_venv=1; let never_venv=0;; -N|--no-virtual-env) let always_venv=0; let never_venv=1;; + -f|--force) let force=1;; esac } @@ -35,9 +28,11 @@ venv=.nova-venv with_venv=tools/with_venv.sh always_venv=0 never_venv=0 -options=("$@") +force=0 -process_options $options +for arg in "$@"; do + process_option $arg +done if [ $never_venv -eq 1 ]; then # Just run the test suites in current environment @@ -45,6 +40,12 @@ if [ $never_venv -eq 1 ]; then exit fi +# Remove the virtual environment if --force used +if [ $force -eq 1 ]; then + echo "Cleaning virtualenv..." + rm -rf ${venv} +fi + if [ -e ${venv} ]; then ${with_venv} python run_tests.py $@ else -- cgit From 84fbad82d65b837d43f138e7a5acd24f182499e2 Mon Sep 17 00:00:00 2001 From: Vishvananda Ishaya Date: Tue, 28 Sep 2010 12:09:17 -0700 Subject: move default group creation to api --- nova/api/ec2/cloud.py | 17 +++++++++++++++++ nova/auth/manager.py | 14 -------------- nova/db/api.py | 5 +++++ nova/db/sqlalchemy/api.py | 6 ++++++ nova/test.py | 2 ++ 5 files changed, 30 insertions(+), 14 deletions(-) diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 4cf2666a5..d54562ec6 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -244,6 +244,7 @@ class CloudController(object): return True def describe_security_groups(self, context, group_name=None, **kwargs): + self._ensure_default_security_group(context) if context.user.is_admin(): groups = db.security_group_get_all(context) else: @@ -326,6 +327,7 @@ class CloudController(object): return values def revoke_security_group_ingress(self, context, group_name, **kwargs): + self._ensure_default_security_group(context) security_group = db.security_group_get_by_name(context, context.project.id, group_name) @@ -351,6 +353,7 @@ class CloudController(object): # for these operations, so support for newer API versions # is sketchy. def authorize_security_group_ingress(self, context, group_name, **kwargs): + self._ensure_default_security_group(context) security_group = db.security_group_get_by_name(context, context.project.id, group_name) @@ -383,6 +386,7 @@ class CloudController(object): def create_security_group(self, context, group_name, group_description): + self._ensure_default_security_group(context) if db.securitygroup_exists(context, context.project.id, group_name): raise exception.ApiError('group %s already exists' % group_name) @@ -673,6 +677,18 @@ class CloudController(object): "project_id": context.project.id}}) return db.queue_get_for(context, FLAGS.network_topic, host) + def _ensure_default_security_group(self, context): + try: + db.security_group_get_by_name(context, + context.project.id, + 'default') + except exception.NotFound: + values = { 'name' : 'default', + 'description' : 'default', + 'user_id' : context.user.id, + 'project_id' : context.project.id } + group = db.security_group_create({}, values) + def run_instances(self, context, **kwargs): instance_type = kwargs.get('instance_type', 'm1.small') if instance_type not in INSTANCE_TYPES: @@ -725,6 +741,7 @@ class CloudController(object): security_group_arg = [security_group_arg] security_groups = [] + self._ensure_default_security_group(context) for security_group_name in security_group_arg: group = db.security_group_get_by_name(context, context.project.id, diff --git a/nova/auth/manager.py b/nova/auth/manager.py index 7075070cf..bea4c7933 100644 --- a/nova/auth/manager.py +++ b/nova/auth/manager.py @@ -491,11 +491,6 @@ class AuthManager(object): drv.delete_project(project.id) raise - values = { 'name' : 'default', - 'description' : 'default', - 'user_id' : User.safe_id(manager_user), - 'project_id' : project.id } - db.security_group_create({}, values) return project def modify_project(self, project, manager_user=None, description=None): @@ -571,15 +566,6 @@ class AuthManager(object): except: logging.exception('Could not destroy network for %s', project) - try: - project_id = Project.safe_id(project) - groups = db.security_group_get_by_project(context={}, - project_id=project_id) - for group in groups: - db.security_group_destroy({}, group['id']) - except: - logging.exception('Could not destroy security groups for %s', - project) with self.driver() as drv: drv.delete_project(Project.safe_id(project)) diff --git a/nova/db/api.py b/nova/db/api.py index 602c3cf09..5e033b59d 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -604,6 +604,11 @@ def security_group_destroy(context, security_group_id): return IMPL.security_group_destroy(context, security_group_id) +def security_group_destroy_all(context): + """Deletes a security group""" + return IMPL.security_group_destroy_all(context) + + #################### diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index d2847506e..07ea5d145 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -947,6 +947,12 @@ def security_group_destroy(_context, security_group_id): 'where group_id=:id', {'id': security_group_id}) +def security_group_destroy_all(_context): + session = get_session() + with session.begin(): + # TODO(vish): do we have to use sql here? + session.execute('update security_group set deleted=1') + session.execute('update security_group_rules set deleted=1') ################### diff --git a/nova/test.py b/nova/test.py index c392c8a84..5ed0c73d3 100644 --- a/nova/test.py +++ b/nova/test.py @@ -31,6 +31,7 @@ from tornado import ioloop from twisted.internet import defer from twisted.trial import unittest +from nova import db from nova import fakerabbit from nova import flags @@ -74,6 +75,7 @@ class TrialTestCase(unittest.TestCase): if FLAGS.fake_rabbit: fakerabbit.reset_all() + db.security_group_destroy_all(None) def flags(self, **kw): """Override flag variables for a test""" -- cgit From c53af2fc9d9803cebc7f4078b8f772476a09df81 Mon Sep 17 00:00:00 2001 From: Vishvananda Ishaya Date: Tue, 28 Sep 2010 18:47:47 -0700 Subject: fix security group revoke --- nova/api/ec2/cloud.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 4cf2666a5..6eea95f84 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -295,7 +295,7 @@ class CloudController(object): db.security_group_get_by_name(context, source_project_id, source_security_group_name) - values['group_id'] = source_security_group.id + values['group_id'] = source_security_group['id'] elif cidr_ip: # If this fails, it throws an exception. This is what we want. IPy.IP(cidr_ip) @@ -331,17 +331,19 @@ class CloudController(object): group_name) criteria = self._authorize_revoke_rule_args_to_dict(context, **kwargs) + if criteria == None: + raise exception.ApiError("No rule for the specified parameters.") for rule in security_group.rules: + match = True for (k,v) in criteria.iteritems(): if getattr(rule, k, False) != v: - break - # If we make it here, we have a match - db.security_group_rule_destroy(context, rule.id) + match = False + if match: + db.security_group_rule_destroy(context, rule['id']) + self._trigger_refresh_security_group(security_group) - self._trigger_refresh_security_group(security_group) - - return True + raise exception.ApiError("No rule for the specified parameters.") # TODO(soren): Dupe detection. Adding the same rule twice actually # adds the same rule twice to the rule set, which is -- cgit From d9855ba51f53a27f5475b3c0b7f669b378ccc006 Mon Sep 17 00:00:00 2001 From: Vishvananda Ishaya Date: Tue, 28 Sep 2010 20:53:24 -0700 Subject: fix eagerload to be joins that filter by deleted == False --- nova/db/sqlalchemy/api.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index d2847506e..8b754c78e 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -25,7 +25,7 @@ from nova import flags from nova.db.sqlalchemy import models from nova.db.sqlalchemy.session import get_session from sqlalchemy import or_ -from sqlalchemy.orm import eagerload, joinedload_all +from sqlalchemy.orm import contains_eager, eagerload, joinedload_all from sqlalchemy.sql import func FLAGS = flags.FLAGS @@ -711,7 +711,7 @@ def auth_create_token(_context, token): tk[k] = v tk.save() return tk - + ################### @@ -868,7 +868,9 @@ def volume_update(_context, volume_id, values): def security_group_get_all(_context): session = get_session() return session.query(models.SecurityGroup - ).options(eagerload('rules') + ).join(models.SecurityGroupIngressRule + ).options(contains_eager(models.SecurityGroup.rules) + ).filter(models.SecurityGroupIngressRule.deleted == False ).filter_by(deleted=False ).all() @@ -876,7 +878,11 @@ def security_group_get_all(_context): def security_group_get(_context, security_group_id): session = get_session() result = session.query(models.SecurityGroup - ).options(eagerload('rules') + ).join(models.SecurityGroupIngressRule + ).options(contains_eager(models.SecurityGroup.rules) + ).filter(models.SecurityGroupIngressRule.deleted == False + ).filter_by(deleted=False + ).filter_by(id=security_group_id ).get(security_group_id) if not result: raise exception.NotFound("No secuity group with id %s" % @@ -887,8 +893,11 @@ def security_group_get(_context, security_group_id): def security_group_get_by_name(context, project_id, group_name): session = get_session() group_ref = session.query(models.SecurityGroup - ).options(eagerload('rules') - ).options(eagerload('instances') + ).join(models.SecurityGroupIngressRule + ).join(models.Instances + ).options(contains_eager(models.SecurityGroup.rules) + ).options(contains_eager(models.SecurityGroup.instances) + ).filter(models.SecurityGroupIngressRule.deleted == False ).filter_by(project_id=project_id ).filter_by(name=group_name ).filter_by(deleted=False @@ -903,7 +912,9 @@ def security_group_get_by_name(context, project_id, group_name): def security_group_get_by_project(_context, project_id): session = get_session() return session.query(models.SecurityGroup - ).options(eagerload('rules') + ).join(models.SecurityGroupIngressRule + ).options(contains_eager(models.SecurityGroup.rules) + ).filter(models.SecurityGroupIngressRule.deleted == False ).filter_by(project_id=project_id ).filter_by(deleted=False ).all() -- cgit From 3124cf70c6ab2bcab570f0ffbcbe31672a9556f8 Mon Sep 17 00:00:00 2001 From: Vishvananda Ishaya Date: Tue, 28 Sep 2010 21:03:45 -0700 Subject: fix join and misnamed method --- nova/api/ec2/cloud.py | 2 +- nova/db/api.py | 2 +- nova/db/sqlalchemy/api.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 6eea95f84..a1a3960f6 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -385,7 +385,7 @@ class CloudController(object): def create_security_group(self, context, group_name, group_description): - if db.securitygroup_exists(context, context.project.id, group_name): + if db.security_group_exists(context, context.project.id, group_name): raise exception.ApiError('group %s already exists' % group_name) group = {'user_id' : context.user.id, diff --git a/nova/db/api.py b/nova/db/api.py index 602c3cf09..1e2738b99 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -589,7 +589,7 @@ def security_group_get_by_instance(context, instance_id): return IMPL.security_group_get_by_instance(context, instance_id) -def securitygroup_exists(context, project_id, group_name): +def security_group_exists(context, project_id, group_name): """Indicates if a group name exists in a project""" return IMPL.security_group_exists(context, project_id, group_name) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 8b754c78e..fcdf945eb 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -868,7 +868,7 @@ def volume_update(_context, volume_id, values): def security_group_get_all(_context): session = get_session() return session.query(models.SecurityGroup - ).join(models.SecurityGroupIngressRule + ).join(models.SecurityGroup.rules ).options(contains_eager(models.SecurityGroup.rules) ).filter(models.SecurityGroupIngressRule.deleted == False ).filter_by(deleted=False @@ -878,7 +878,7 @@ def security_group_get_all(_context): def security_group_get(_context, security_group_id): session = get_session() result = session.query(models.SecurityGroup - ).join(models.SecurityGroupIngressRule + ).join(models.SecurityGroup.rules ).options(contains_eager(models.SecurityGroup.rules) ).filter(models.SecurityGroupIngressRule.deleted == False ).filter_by(deleted=False @@ -893,7 +893,7 @@ def security_group_get(_context, security_group_id): def security_group_get_by_name(context, project_id, group_name): session = get_session() group_ref = session.query(models.SecurityGroup - ).join(models.SecurityGroupIngressRule + ).join(models.SecurityGroup.rules ).join(models.Instances ).options(contains_eager(models.SecurityGroup.rules) ).options(contains_eager(models.SecurityGroup.instances) @@ -912,7 +912,7 @@ def security_group_get_by_name(context, project_id, group_name): def security_group_get_by_project(_context, project_id): session = get_session() return session.query(models.SecurityGroup - ).join(models.SecurityGroupIngressRule + ).join(models.SecurityGroup.rules ).options(contains_eager(models.SecurityGroup.rules) ).filter(models.SecurityGroupIngressRule.deleted == False ).filter_by(project_id=project_id -- cgit From b952e1ef61a6ed73e34c6dd0318cd4d52faf47dc Mon Sep 17 00:00:00 2001 From: Vishvananda Ishaya Date: Tue, 28 Sep 2010 21:07:26 -0700 Subject: patch for test --- nova/tests/virt_unittest.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py index f9ff0f71f..5e9505374 100644 --- a/nova/tests/virt_unittest.py +++ b/nova/tests/virt_unittest.py @@ -178,8 +178,14 @@ class NWFilterTestCase(test.TrialTestCase): self.defined_filters.append(name) return True + self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock + + instance_ref = db.instance_create({}, {'user_id': 'fake', + 'project_id': 'fake'}) + inst_id = instance_ref['id'] + def _ensure_all_called(_): - instance_filter = 'nova-instance-i-1' + instance_filter = 'nova-instance-%s' % instance_ref['str_id'] secgroup_filter = 'nova-secgroup-%s' % self.security_group['id'] for required in [secgroup_filter, 'allow-dhcp-server', 'no-arp-spoofing', 'no-ip-spoofing', @@ -187,11 +193,6 @@ class NWFilterTestCase(test.TrialTestCase): self.assertTrue(required in self.recursive_depends[instance_filter], "Instance's filter does not include %s" % required) - self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock - - inst_id = db.instance_create({}, {'user_id': 'fake', - 'project_id': 'fake'})['id'] - self.security_group = self.setup_and_return_security_group() db.instance_add_security_group({}, inst_id, self.security_group.id) -- cgit From 970114e1729c35ebcc05930659bb5dfaf5b59d3d Mon Sep 17 00:00:00 2001 From: Vishvananda Ishaya Date: Wed, 29 Sep 2010 00:30:35 -0700 Subject: fix loading to ignore deleted items --- nova/api/ec2/cloud.py | 2 +- nova/db/sqlalchemy/api.py | 65 ++++++++++++++++++++++++++------------------ nova/db/sqlalchemy/models.py | 21 ++++++++------ 3 files changed, 53 insertions(+), 35 deletions(-) diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 4c27440dc..d85b8512a 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -342,7 +342,7 @@ class CloudController(object): if match: db.security_group_rule_destroy(context, rule['id']) self._trigger_refresh_security_group(security_group) - + return True raise exception.ApiError("No rule for the specified parameters.") # TODO(soren): Dupe detection. Adding the same rule twice actually diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index dad544cdb..fee50ec9c 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -29,7 +29,7 @@ from nova.db.sqlalchemy import models from nova.db.sqlalchemy.session import get_session from sqlalchemy import or_ from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import contains_eager, eagerload, joinedload_all +from sqlalchemy.orm import contains_eager, joinedload_all from sqlalchemy.sql import exists, func FLAGS = flags.FLAGS @@ -410,8 +410,17 @@ def instance_destroy(_context, instance_id): def instance_get(context, instance_id): - return models.Instance().find(instance_id, deleted=_deleted(context), - options=eagerload('security_groups')) + session = get_session() + instance_ref = session.query(models.Instance + ).options(joinedload_all('fixed_ip.floating_ips') + ).options(joinedload_all('security_groups') + ).filter_by(id=instance_id + ).filter_by(deleted=_deleted(context) + ).first() + if not instance_ref: + raise exception.NotFound('Instance %s not found' % (instance_id)) + + return instance_ref def instance_get_all(context): @@ -942,25 +951,29 @@ def volume_update(_context, volume_id, values): ################### +INSTANCES_OR = or_(models.Instance.deleted == False, + models.Instance.deleted == None) + + +RULES_OR = or_(models.SecurityGroupIngressRule.deleted == False, + models.SecurityGroupIngressRule.deleted == None) + + def security_group_get_all(_context): session = get_session() return session.query(models.SecurityGroup - ).join(models.SecurityGroup.rules - ).options(contains_eager(models.SecurityGroup.rules) - ).filter(models.SecurityGroupIngressRule.deleted == False ).filter_by(deleted=False + ).options(joinedload_all('rules') ).all() def security_group_get(_context, security_group_id): session = get_session() result = session.query(models.SecurityGroup - ).join(models.SecurityGroup.rules - ).options(contains_eager(models.SecurityGroup.rules) - ).filter(models.SecurityGroupIngressRule.deleted == False ).filter_by(deleted=False ).filter_by(id=security_group_id - ).get(security_group_id) + ).options(joinedload_all('rules') + ).first() if not result: raise exception.NotFound("No secuity group with id %s" % security_group_id) @@ -969,41 +982,41 @@ def security_group_get(_context, security_group_id): def security_group_get_by_name(context, project_id, group_name): session = get_session() - group_ref = session.query(models.SecurityGroup - ).join(models.SecurityGroup.rules - ).join(models.Instances - ).options(contains_eager(models.SecurityGroup.rules) - ).options(contains_eager(models.SecurityGroup.instances) - ).filter(models.SecurityGroupIngressRule.deleted == False + result = session.query(models.SecurityGroup ).filter_by(project_id=project_id ).filter_by(name=group_name ).filter_by(deleted=False + ).options(joinedload_all('rules') + ).options(joinedload_all('instances') ).first() - if not group_ref: + if not result: raise exception.NotFound( 'No security group named %s for project: %s' \ % (group_name, project_id)) - return group_ref + return result def security_group_get_by_project(_context, project_id): session = get_session() return session.query(models.SecurityGroup - ).join(models.SecurityGroup.rules - ).options(contains_eager(models.SecurityGroup.rules) - ).filter(models.SecurityGroupIngressRule.deleted == False ).filter_by(project_id=project_id ).filter_by(deleted=False + ).options(joinedload_all('rules') + ).outerjoin(models.SecurityGroup.rules + ).options(contains_eager(models.SecurityGroup.rules) + ).filter(RULES_OR ).all() def security_group_get_by_instance(_context, instance_id): session = get_session() - with session.begin(): - return session.query(models.Instance - ).join(models.Instance.security_groups - ).filter_by(deleted=False - ).all() + return session.query(models.SecurityGroup + ).filter_by(deleted=False + ).options(joinedload_all('rules') + ).join(models.SecurityGroup.instances + ).filter_by(id=instance_id + ).filter_by(deleted=False + ).all() def security_group_exists(_context, project_id, group_name): diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index d4caf0b52..b89616ddb 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -340,12 +340,12 @@ class ExportDevice(BASE, NovaBase): uselist=False)) -security_group_instance_association = Table('security_group_instance_association', - BASE.metadata, - Column('security_group_id', Integer, - ForeignKey('security_group.id')), - Column('instance_id', Integer, - ForeignKey('instances.id'))) +class SecurityGroupInstanceAssociation(BASE, NovaBase): + __tablename__ = 'security_group_instance_association' + id = Column(Integer, primary_key=True) + security_group_id = Column(Integer, ForeignKey('security_group.id')) + instance_id = Column(Integer, ForeignKey('instances.id')) + class SecurityGroup(BASE, NovaBase): """Represents a security group""" @@ -358,7 +358,11 @@ class SecurityGroup(BASE, NovaBase): project_id = Column(String(255)) instances = relationship(Instance, - secondary=security_group_instance_association, + secondary="security_group_instance_association", + secondaryjoin="and_(SecurityGroup.id == SecurityGroupInstanceAssociation.security_group_id," + "Instance.id == SecurityGroupInstanceAssociation.instance_id," + "SecurityGroup.deleted == False," + "Instance.deleted == False)", backref='security_groups') @property @@ -378,7 +382,8 @@ class SecurityGroupIngressRule(BASE, NovaBase): parent_group_id = Column(Integer, ForeignKey('security_group.id')) parent_group = relationship("SecurityGroup", backref="rules", foreign_keys=parent_group_id, - primaryjoin=parent_group_id==SecurityGroup.id) + primaryjoin="and_(SecurityGroupIngressRule.parent_group_id == SecurityGroup.id," + "SecurityGroupIngressRule.deleted == False)") protocol = Column(String(5)) # "tcp", "udp", or "icmp" from_port = Column(Integer) -- cgit From c0abb5cd45314e072096e173830b2e3d379bf3e7 Mon Sep 17 00:00:00 2001 From: Vishvananda Ishaya Date: Wed, 29 Sep 2010 00:42:18 -0700 Subject: removed a few extra items --- nova/db/sqlalchemy/api.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index e823bf15e..7ef92cad5 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -29,7 +29,7 @@ from nova.db.sqlalchemy import models from nova.db.sqlalchemy.session import get_session from sqlalchemy import or_ from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import contains_eager, joinedload_all +from sqlalchemy.orm import joinedload_all from sqlalchemy.sql import exists, func FLAGS = flags.FLAGS @@ -951,14 +951,6 @@ def volume_update(_context, volume_id, values): ################### -INSTANCES_OR = or_(models.Instance.deleted == False, - models.Instance.deleted == None) - - -RULES_OR = or_(models.SecurityGroupIngressRule.deleted == False, - models.SecurityGroupIngressRule.deleted == None) - - def security_group_get_all(_context): session = get_session() return session.query(models.SecurityGroup @@ -1002,9 +994,6 @@ def security_group_get_by_project(_context, project_id): ).filter_by(project_id=project_id ).filter_by(deleted=False ).options(joinedload_all('rules') - ).outerjoin(models.SecurityGroup.rules - ).options(contains_eager(models.SecurityGroup.rules) - ).filter(RULES_OR ).all() -- cgit From bfb01ef2e2960803feffb2a3998810b0966e1e79 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Wed, 29 Sep 2010 09:46:37 +0200 Subject: Apply patch from Vish to fix a hardcoded id in the unit tests. --- nova/tests/virt_unittest.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py index f9ff0f71f..5e9505374 100644 --- a/nova/tests/virt_unittest.py +++ b/nova/tests/virt_unittest.py @@ -178,8 +178,14 @@ class NWFilterTestCase(test.TrialTestCase): self.defined_filters.append(name) return True + self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock + + instance_ref = db.instance_create({}, {'user_id': 'fake', + 'project_id': 'fake'}) + inst_id = instance_ref['id'] + def _ensure_all_called(_): - instance_filter = 'nova-instance-i-1' + instance_filter = 'nova-instance-%s' % instance_ref['str_id'] secgroup_filter = 'nova-secgroup-%s' % self.security_group['id'] for required in [secgroup_filter, 'allow-dhcp-server', 'no-arp-spoofing', 'no-ip-spoofing', @@ -187,11 +193,6 @@ class NWFilterTestCase(test.TrialTestCase): self.assertTrue(required in self.recursive_depends[instance_filter], "Instance's filter does not include %s" % required) - self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock - - inst_id = db.instance_create({}, {'user_id': 'fake', - 'project_id': 'fake'})['id'] - self.security_group = self.setup_and_return_security_group() db.instance_add_security_group({}, inst_id, self.security_group.id) -- cgit From fe139bbdee60aadd720cb7a83d0846f2824c078f Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 00:49:04 -0700 Subject: Began wiring up context authorization --- nova/db/sqlalchemy/api.py | 50 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 9c3caf9af..b5847d299 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -19,6 +19,7 @@ Implementation of SQLAlchemy backend """ +import logging import sys from nova import db @@ -48,6 +49,24 @@ def _deleted(context): return context.get('deleted', False) +def is_admin_context(context): + if not context: + logging.warning('Use of empty request context is deprecated') + return True + if not context.user: + return True + return context.user.is_admin() + + +def is_user_context(context): + if not context: + logging.warning('Use of empty request context is deprecated') + return False + if not context.user or not context.project: + return False + return True + + ################### @@ -869,14 +888,41 @@ def volume_detached(_context, volume_id): def volume_get(context, volume_id): - return models.Volume.find(volume_id, deleted=_deleted(context)) + session = get_session() + + if is_admin_context(context): + volume_ref = session.query(models.Volume + ).filter_by(id=volume_id + ).filter_by(deleted=_deleted(context) + ).first() + if not volume_ref: + raise exception.NotFound('No volume for id %s' % volume_id) + + if is_user_context(context): + volume_ref = session.query(models.Volume + ).filter_by(project_id=project_id + ).filter_by(id=volume_id + ).filter_by(deleted=False + ).first() + if not volume_ref: + raise exception.NotFound('No volume for id %s' % volume_id) + + raise exception.NotAuthorized() def volume_get_all(context): - return models.Volume.all(deleted=_deleted(context)) + if is_admin_context(context): + return models.Volume.all(deleted=_deleted(context)) + raise exception.NotAuthorized() def volume_get_all_by_project(context, project_id): + if is_user_context(context): + if context.project.id != project_id: + raise exception.NotAuthorized() + elif not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.Volume ).filter_by(project_id=project_id -- cgit From 793516d14630a82bb3592f626b753736e63955ec Mon Sep 17 00:00:00 2001 From: Vishvananda Ishaya Date: Wed, 29 Sep 2010 01:33:30 -0700 Subject: autocreate the models and use security_groups --- nova/db/sqlalchemy/api.py | 4 ++-- nova/db/sqlalchemy/models.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 7ef92cad5..200fb3b3c 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -1031,7 +1031,7 @@ def security_group_destroy(_context, security_group_id): session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? - session.execute('update security_group set deleted=1 where id=:id', + session.execute('update security_groups set deleted=1 where id=:id', {'id': security_group_id}) session.execute('update security_group_rules set deleted=1 ' 'where group_id=:id', @@ -1041,7 +1041,7 @@ def security_group_destroy_all(_context): session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? - session.execute('update security_group set deleted=1') + session.execute('update security_groups set deleted=1') session.execute('update security_group_rules set deleted=1') ################### diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index b89616ddb..c2dbf2345 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -25,7 +25,7 @@ import datetime # TODO(vish): clean up these imports from sqlalchemy.orm import relationship, backref, exc, object_mapper -from sqlalchemy import Column, Integer, String, Table +from sqlalchemy import Column, Integer, String from sqlalchemy import ForeignKey, DateTime, Boolean, Text from sqlalchemy.ext.declarative import declarative_base @@ -343,13 +343,13 @@ class ExportDevice(BASE, NovaBase): class SecurityGroupInstanceAssociation(BASE, NovaBase): __tablename__ = 'security_group_instance_association' id = Column(Integer, primary_key=True) - security_group_id = Column(Integer, ForeignKey('security_group.id')) + security_group_id = Column(Integer, ForeignKey('security_groups.id')) instance_id = Column(Integer, ForeignKey('instances.id')) class SecurityGroup(BASE, NovaBase): """Represents a security group""" - __tablename__ = 'security_group' + __tablename__ = 'security_groups' id = Column(Integer, primary_key=True) name = Column(String(255)) @@ -379,7 +379,7 @@ class SecurityGroupIngressRule(BASE, NovaBase): __tablename__ = 'security_group_rules' id = Column(Integer, primary_key=True) - parent_group_id = Column(Integer, ForeignKey('security_group.id')) + parent_group_id = Column(Integer, ForeignKey('security_groups.id')) parent_group = relationship("SecurityGroup", backref="rules", foreign_keys=parent_group_id, primaryjoin="and_(SecurityGroupIngressRule.parent_group_id == SecurityGroup.id," @@ -392,7 +392,7 @@ class SecurityGroupIngressRule(BASE, NovaBase): # Note: This is not the parent SecurityGroup. It's SecurityGroup we're # granting access for. - group_id = Column(Integer, ForeignKey('security_group.id')) + group_id = Column(Integer, ForeignKey('security_groups.id')) class KeyPair(BASE, NovaBase): @@ -546,7 +546,7 @@ def register_models(): from sqlalchemy import create_engine models = (Service, Instance, Volume, ExportDevice, FixedIp, FloatingIp, Network, NetworkIndex, SecurityGroup, SecurityGroupIngressRule, - AuthToken) # , Image, Host + SecurityGroupInstanceAssociation, AuthToken) # , Image, Host engine = create_engine(FLAGS.sql_connection, echo=False) for model in models: model.metadata.create_all(engine) -- cgit From 5fa5a0b0b9e13f8f44b257eac0385730c959b92f Mon Sep 17 00:00:00 2001 From: Vishvananda Ishaya Date: Wed, 29 Sep 2010 01:58:19 -0700 Subject: fix the primary and secondary join --- nova/db/sqlalchemy/api.py | 4 ++-- nova/db/sqlalchemy/models.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 200fb3b3c..447d20b25 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -412,10 +412,10 @@ def instance_destroy(_context, instance_id): def instance_get(context, instance_id): session = get_session() instance_ref = session.query(models.Instance - ).options(joinedload_all('fixed_ip.floating_ips') - ).options(joinedload_all('security_groups') ).filter_by(id=instance_id ).filter_by(deleted=_deleted(context) + ).options(joinedload_all('security_groups') + ).options(joinedload_all('fixed_ip.floating_ips') ).first() if not instance_ref: raise exception.NotFound('Instance %s not found' % (instance_id)) diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index c2dbf2345..67142ad78 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -359,9 +359,9 @@ class SecurityGroup(BASE, NovaBase): instances = relationship(Instance, secondary="security_group_instance_association", - secondaryjoin="and_(SecurityGroup.id == SecurityGroupInstanceAssociation.security_group_id," - "Instance.id == SecurityGroupInstanceAssociation.instance_id," - "SecurityGroup.deleted == False," + primaryjoin="and_(SecurityGroup.id == SecurityGroupInstanceAssociation.security_group_id," + "SecurityGroup.deleted == False)", + secondaryjoin="and_(SecurityGroupInstanceAssociation.instance_id == Instance.id," "Instance.deleted == False)", backref='security_groups') -- cgit From a86507b3224eb051fea97f65bd5653758fa91668 Mon Sep 17 00:00:00 2001 From: Vishvananda Ishaya Date: Wed, 29 Sep 2010 06:17:39 -0700 Subject: fix ordering of rules to actually allow out and drop in --- nova/virt/libvirt_conn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index c86f3ffb7..9d889cf29 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -527,8 +527,8 @@ class NWFilterFirewall(object): def nova_base_ipv4_filter(self): retval = "" for protocol in ['tcp', 'udp', 'icmp']: - for direction,action,priority in [('out','accept', 400), - ('in','drop', 399)]: + for direction,action,priority in [('out','accept', 399), + ('inout','drop', 400)]: retval += """ <%s /> """ % (action, direction, @@ -540,8 +540,8 @@ class NWFilterFirewall(object): def nova_base_ipv6_filter(self): retval = "" for protocol in ['tcp', 'udp', 'icmp']: - for direction,action,priority in [('out','accept',400), - ('in','drop',399)]: + for direction,action,priority in [('out','accept',399), + ('inout','drop',400)]: retval += """ <%s-ipv6 /> """ % (action, direction, -- cgit From e258998923b7e8fa92656aa409f875b640df930c Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 13:26:14 -0700 Subject: Progress on volumes Fixed foreign keys to respect deleted flag --- nova/db/sqlalchemy/api.py | 130 ++++++++++++++++++++++++++++++------------- nova/db/sqlalchemy/models.py | 35 +++++++++--- 2 files changed, 118 insertions(+), 47 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index b5847d299..28b937233 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -30,7 +30,7 @@ from nova.db.sqlalchemy import models from nova.db.sqlalchemy.session import get_session from sqlalchemy import or_ from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import joinedload_all +from sqlalchemy.orm import joinedload, joinedload_all from sqlalchemy.sql import exists, func FLAGS = flags.FLAGS @@ -811,6 +811,7 @@ def quota_destroy(_context, project_id): def volume_allocate_shelf_and_blade(_context, volume_id): + # TODO(devcamcar): Make admin only session = get_session() with session.begin(): export_device = session.query(models.ExportDevice @@ -839,7 +840,7 @@ def volume_attached(_context, volume_id, instance_id, mountpoint): volume_ref.save(session=session) -def volume_create(_context, values): +def volume_create(context, values): volume_ref = models.Volume() for (key, value) in values.iteritems(): volume_ref[key] = value @@ -848,7 +849,7 @@ def volume_create(_context, values): with session.begin(): while volume_ref.ec2_id == None: ec2_id = utils.generate_uid(volume_ref.__prefix__) - if not volume_ec2_id_exists(_context, ec2_id, session=session): + if not volume_ec2_id_exists(context, ec2_id, session=session): volume_ref.ec2_id = ec2_id volume_ref.save(session=session) return volume_ref @@ -876,10 +877,10 @@ def volume_destroy(_context, volume_id): {'id': volume_id}) -def volume_detached(_context, volume_id): +def volume_detached(context, volume_id): session = get_session() with session.begin(): - volume_ref = models.Volume.find(volume_id, session=session) + volume_ref = volume_get(context, volume_id, session=session) volume_ref['status'] = 'available' volume_ref['mountpoint'] = None volume_ref['attach_status'] = 'detached' @@ -887,27 +888,29 @@ def volume_detached(_context, volume_id): volume_ref.save(session=session) -def volume_get(context, volume_id): - session = get_session() +def volume_get(context, volume_id, session=None): + if not session: + session = get_session() + result = None if is_admin_context(context): - volume_ref = session.query(models.Volume - ).filter_by(id=volume_id - ).filter_by(deleted=_deleted(context) - ).first() - if not volume_ref: - raise exception.NotFound('No volume for id %s' % volume_id) + result = session.query(models.Volume + ).filter_by(id=volume_id + ).filter_by(deleted=_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Volume + ).filter_by(project_id=context.project.project_id + ).filter_by(id=volume_id + ).filter_by(deleted=False + ).first() + else: + raise exception.NotAuthorized() - if is_user_context(context): - volume_ref = session.query(models.Volume - ).filter_by(project_id=project_id - ).filter_by(id=volume_id - ).filter_by(deleted=False - ).first() - if not volume_ref: - raise exception.NotFound('No volume for id %s' % volume_id) + if not result: + raise exception.NotFound('No volume for id %s' % volume_id) - raise exception.NotAuthorized() + return result def volume_get_all(context): @@ -916,6 +919,7 @@ def volume_get_all(context): raise exception.NotAuthorized() + def volume_get_all_by_project(context, project_id): if is_user_context(context): if context.project.id != project_id: @@ -932,42 +936,92 @@ def volume_get_all_by_project(context, project_id): def volume_get_by_ec2_id(context, ec2_id): session = get_session() - volume_ref = session.query(models.Volume + result = None + + if is_admin_context(context): + result = session.query(models.Volume ).filter_by(ec2_id=ec2_id ).filter_by(deleted=_deleted(context) ).first() - if not volume_ref: - raise exception.NotFound('Volume %s not found' % (ec2_id)) + elif is_user_context(context): + result = session.query(models.Volume + ).filter_by(project_id=context.project.id + ).filter_by(ec2_id=ec2_id + ).filter_by(deleted=False + ).first() + else: + raise exception.NotAuthorized() - return volume_ref + if not result: + raise exception.NotFound('Volume %s not found' % ec2_id) + + return result def volume_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() - return session.query(exists().where(models.Volume.id==ec2_id)).one()[0] + + if is_admin_context(context) or is_user_context(context): + return session.query(exists( + ).where(models.Volume.id==ec2_id) + ).one()[0] + else: + raise exception.NotAuthorized() -def volume_get_instance(_context, volume_id): +def volume_get_instance(context, volume_id): session = get_session() - with session.begin(): - return models.Volume.find(volume_id, session=session).instance + result = None + + if is_admin_context(context): + result = session.query(models.Volume + ).filter_by(id=volume_id + ).filter_by(deleted=_deleted(context) + ).options(joinedload('instance') + ).first() + elif is_user_context(context): + result = session.query(models.Volume + ).filter_by(project_id=context.project.id + ).filter_by(deleted=False + ).options(joinedload('instance') + ).first() + else: + raise exception.NotAuthorized() + + if not result: + raise exception.NotFound('Volume %s not found' % ec2_id) + + return result.instance -def volume_get_shelf_and_blade(_context, volume_id): +def volume_get_shelf_and_blade(context, volume_id): session = get_session() - export_device = session.query(models.ExportDevice - ).filter_by(volume_id=volume_id - ).first() - if not export_device: + result = None + + if is_admin_context(context): + result = session.query(models.ExportDevice + ).filter_by(volume_id=volume_id + ).first() + elif is_user_context(context): + result = session.query(models.ExportDevice + ).join(models.Volume + ).filter(models.Volume.project_id==context.project.id + ).filter_by(volume_id=volume_id + ).first() + else: + raise exception.NotAuthorized() + + if not result: raise exception.NotFound() - return (export_device.shelf_id, export_device.blade_id) + return (result.shelf_id, result.blade_id) -def volume_update(_context, volume_id, values): + +def volume_update(context, volume_id, values): session = get_session() with session.begin(): - volume_ref = models.Volume.find(volume_id, session=session) + volume_ref = volume_get(context, volume_id, session=session) for (key, value) in values.iteritems(): volume_ref[key] = value volume_ref.save(session=session) diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 01e58b05e..1b9edf475 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -282,7 +282,11 @@ class Volume(BASE, NovaBase): size = Column(Integer) availability_zone = Column(String(255)) # TODO(vish): foreign key? instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True) - instance = relationship(Instance, backref=backref('volumes')) + instance = relationship(Instance, + backref=backref('volumes'), + foreign_keys=instance_id, + primaryjoin='and_(Volume.instance_id==Instance.id,' + 'Volume.deleted==False)') mountpoint = Column(String(255)) attach_time = Column(String(255)) # TODO(vish): datetime status = Column(String(255)) # TODO(vish): enum? @@ -333,8 +337,11 @@ class ExportDevice(BASE, NovaBase): shelf_id = Column(Integer) blade_id = Column(Integer) volume_id = Column(Integer, ForeignKey('volumes.id'), nullable=True) - volume = relationship(Volume, backref=backref('export_device', - uselist=False)) + volume = relationship(Volume, + backref=backref('export_device', uselist=False), + foreign_keys=volume_id, + primaryjoin='and_(ExportDevice.volume_id==Volume.id,' + 'ExportDevice.deleted==False)') class KeyPair(BASE, NovaBase): @@ -407,8 +414,12 @@ class NetworkIndex(BASE, NovaBase): id = Column(Integer, primary_key=True) index = Column(Integer, unique=True) network_id = Column(Integer, ForeignKey('networks.id'), nullable=True) - network = relationship(Network, backref=backref('network_index', - uselist=False)) + network = relationship(Network, + backref=backref('network_index', uselist=False), + foreign_keys=network_id, + primaryjoin='and_(NetworkIndex.network_id==Network.id,' + 'NetworkIndex.deleted==False)') + class AuthToken(BASE, NovaBase): """Represents an authorization token for all API transactions. Fields @@ -432,8 +443,11 @@ class FixedIp(BASE, NovaBase): network_id = Column(Integer, ForeignKey('networks.id'), nullable=True) network = relationship(Network, backref=backref('fixed_ips')) instance_id = Column(Integer, ForeignKey('instances.id'), nullable=True) - instance = relationship(Instance, backref=backref('fixed_ip', - uselist=False)) + instance = relationship(Instance, + backref=backref('fixed_ip', uselist=False), + foreign_keys=instance_id, + primaryjoin='and_(FixedIp.instance_id==Instance.id,' + 'FixedIp.deleted==False)') allocated = Column(Boolean, default=False) leased = Column(Boolean, default=False) reserved = Column(Boolean, default=False) @@ -462,8 +476,11 @@ class FloatingIp(BASE, NovaBase): id = Column(Integer, primary_key=True) address = Column(String(255)) fixed_ip_id = Column(Integer, ForeignKey('fixed_ips.id'), nullable=True) - fixed_ip = relationship(FixedIp, backref=backref('floating_ips')) - + fixed_ip = relationship(FixedIp, + backref=backref('floating_ips'), + foreign_keys=fixed_ip_id, + primaryjoin='and_(FloatingIp.fixed_ip_id==FixedIp.id,' + 'FloatingIp.deleted==False)') project_id = Column(String(255)) host = Column(String(255)) # , ForeignKey('hosts.id')) -- cgit From f4cf49ec3761bdd38dd1a6cb064875b90e65ad4e Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 14:27:31 -0700 Subject: Wired up context auth for services --- nova/db/sqlalchemy/api.py | 111 ++++++++++++++++++++++++++++++++++--------- nova/db/sqlalchemy/models.py | 15 ------ 2 files changed, 89 insertions(+), 37 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 28b937233..01a5af38b 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -71,16 +71,37 @@ def is_user_context(context): def service_destroy(context, service_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - service_ref = models.Service.find(service_id, session=session) + service_ref = service_get(context, service_id, session=session) service_ref.delete(session=session) -def service_get(_context, service_id): - return models.Service.find(service_id) + +def service_get(context, service_id, session=None): + if not is_admin_context(context): + raise exception.NotAuthorized() + + if not session: + session = get_session() + + result = session.query(models.Service + ).filter_by(id=service_id + ).filter_by(deleted=_deleted(context) + ).first() + + if not result: + raise exception.NotFound('No service for id %s' % service_id) + + return result def service_get_all_by_topic(context, topic): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.Service ).filter_by(deleted=False @@ -89,7 +110,10 @@ def service_get_all_by_topic(context, topic): ).all() -def _service_get_all_topic_subquery(_context, session, topic, subq, label): +def _service_get_all_topic_subquery(context, session, topic, subq, label): + if not is_admin_context(context): + raise exception.NotAuthorized() + sort_value = getattr(subq.c, label) return session.query(models.Service, func.coalesce(sort_value, 0) ).filter_by(topic=topic @@ -101,6 +125,9 @@ def _service_get_all_topic_subquery(_context, session, topic, subq, label): def service_get_all_compute_sorted(context): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): # NOTE(vish): The intended query is below @@ -125,6 +152,9 @@ def service_get_all_compute_sorted(context): def service_get_all_network_sorted(context): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): topic = 'network' @@ -142,6 +172,9 @@ def service_get_all_network_sorted(context): def service_get_all_volume_sorted(context): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): topic = 'volume' @@ -158,11 +191,27 @@ def service_get_all_volume_sorted(context): label) -def service_get_by_args(_context, host, binary): - return models.Service.find_by_args(host, binary) +def service_get_by_args(context, host, binary): + if not is_admin_context(context): + raise exception.NotAuthorized() + + session = get_session() + result = session.query(models.Service + ).filter_by(host=host + ).filter_by(binary=binary + ).filter_by(deleted=_deleted(context) + ).first() + + if not result: + raise exception.NotFound('No service for %s, %s' % (host, binary)) + + return result + +def service_create(context, values): + if not is_admin_context(context): + return exception.NotAuthorized() -def service_create(_context, values): service_ref = models.Service() for (key, value) in values.iteritems(): service_ref[key] = value @@ -170,10 +219,13 @@ def service_create(_context, values): return service_ref -def service_update(_context, service_id, values): +def service_update(context, service_id, values): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - service_ref = models.Service.find(service_id, session=session) + service_ref = session_get(context, service_id, session=session) for (key, value) in values.iteritems(): service_ref[key] = value service_ref.save(session=session) @@ -428,8 +480,8 @@ def instance_destroy(_context, instance_id): instance_ref.delete(session=session) -def instance_get(context, instance_id): - return models.Instance.find(instance_id, deleted=_deleted(context)) +def instance_get(context, instance_id, session=None): + return models.Instance.find(instance_id, session=session, deleted=_deleted(context)) def instance_get_all(context): @@ -810,8 +862,10 @@ def quota_destroy(_context, project_id): ################### -def volume_allocate_shelf_and_blade(_context, volume_id): - # TODO(devcamcar): Make admin only +def volume_allocate_shelf_and_blade(context, volume_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): export_device = session.query(models.ExportDevice @@ -828,15 +882,17 @@ def volume_allocate_shelf_and_blade(_context, volume_id): return (export_device.shelf_id, export_device.blade_id) -def volume_attached(_context, volume_id, instance_id, mountpoint): +def volume_attached(context, volume_id, instance_id, mountpoint): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - volume_ref = models.Volume.find(volume_id, session=session) + volume_ref = volume_get(context, volume_id, session=session) volume_ref['status'] = 'in-use' volume_ref['mountpoint'] = mountpoint volume_ref['attach_status'] = 'attached' - volume_ref.instance = models.Instance.find(instance_id, - session=session) + volume_ref.instance = instance_get(context, instance_id, session=session) volume_ref.save(session=session) @@ -855,7 +911,10 @@ def volume_create(context, values): return volume_ref -def volume_data_get_for_project(_context, project_id): +def volume_data_get_for_project(context, project_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() result = session.query(func.count(models.Volume.id), func.sum(models.Volume.size) @@ -866,7 +925,10 @@ def volume_data_get_for_project(_context, project_id): return (result[0] or 0, result[1] or 0) -def volume_destroy(_context, volume_id): +def volume_destroy(context, volume_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -878,6 +940,9 @@ def volume_destroy(_context, volume_id): def volume_detached(context, volume_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): volume_ref = volume_get(context, volume_id, session=session) @@ -914,10 +979,12 @@ def volume_get(context, volume_id, session=None): def volume_get_all(context): - if is_admin_context(context): - return models.Volume.all(deleted=_deleted(context)) + if not is_admin_context(context): + raise exception.NotAuthorized() - raise exception.NotAuthorized() + return session.query(models.Volume + ).filter_by(deleted=_deleted(context) + ).all() def volume_get_all_by_project(context, project_id): diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 1b9edf475..b9bb8e4f2 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -176,21 +176,6 @@ class Service(BASE, NovaBase): report_count = Column(Integer, nullable=False, default=0) disabled = Column(Boolean, default=False) - @classmethod - def find_by_args(cls, host, binary, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(host=host - ).filter_by(binary=binary - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for %s, %s" % (host, - binary)) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - class Instance(BASE, NovaBase): """Represents a guest vm""" -- cgit From 734df1fbad8195e7cd7072d0d0aeb5b94841f121 Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 19:09:00 -0700 Subject: Made network tests pass again --- nova/db/api.py | 1 - nova/db/sqlalchemy/api.py | 233 +++++++++++++++++++++++++++++------------ nova/db/sqlalchemy/models.py | 26 ----- nova/network/manager.py | 3 +- nova/tests/network_unittest.py | 1 + 5 files changed, 170 insertions(+), 94 deletions(-) diff --git a/nova/db/api.py b/nova/db/api.py index b68a0fe8f..4cfdd788c 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -175,7 +175,6 @@ def floating_ip_get_by_address(context, address): return IMPL.floating_ip_get_by_address(context, address) -def floating_ip_get_instance(context, address): """Get an instance for a floating ip by address.""" return IMPL.floating_ip_get_instance(context, address) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 01a5af38b..d129df2be 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -234,7 +234,13 @@ def service_update(context, service_id, values): ################### -def floating_ip_allocate_address(_context, host, project_id): +def floating_ip_allocate_address(context, host, project_id): + if is_user_context(context): + if context.project.id != project_id: + raise exception.NotAuthorized() + elif not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): floating_ip_ref = session.query(models.FloatingIp @@ -253,7 +259,10 @@ def floating_ip_allocate_address(_context, host, project_id): return floating_ip_ref['address'] -def floating_ip_create(_context, values): +def floating_ip_create(context, values): + if not is_user_context(context) and not is_admin_context(context): + raise exception.NotAuthorized() + floating_ip_ref = models.FloatingIp() for (key, value) in values.iteritems(): floating_ip_ref[key] = value @@ -261,7 +270,13 @@ def floating_ip_create(_context, values): return floating_ip_ref['address'] -def floating_ip_count_by_project(_context, project_id): +def floating_ip_count_by_project(context, project_id): + if is_user_context(context): + if context.project.id != project_id: + raise exception.NotAuthorized() + elif not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.FloatingIp ).filter_by(project_id=project_id @@ -269,39 +284,63 @@ def floating_ip_count_by_project(_context, project_id): ).count() -def floating_ip_fixed_ip_associate(_context, floating_address, fixed_address): +#@require_context +def floating_ip_fixed_ip_associate(context, floating_address, fixed_address): + if not is_user_context(context) and not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(floating_address, - session=session) - fixed_ip_ref = models.FixedIp.find_by_str(fixed_address, - session=session) + # TODO(devcamcar): How to ensure floating_id belongs to user? + floating_ip_ref = floating_ip_get_by_address(context, + floating_address, + session=session) + fixed_ip_ref = fixed_ip_get_by_address(context, + fixed_address, + session=session) floating_ip_ref.fixed_ip = fixed_ip_ref floating_ip_ref.save(session=session) -def floating_ip_deallocate(_context, address): +#@require_context +def floating_ip_deallocate(context, address): + if not is_user_context(context) and not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) + # TODO(devcamcar): How to ensure floating id belongs to user? + floating_ip_ref = floating_ip_get_by_address(context, + address, + session=session) floating_ip_ref['project_id'] = None floating_ip_ref.save(session=session) +#@require_context +def floating_ip_destroy(context, address): + if not is_user_context(context) and not is_admin_context(context): + raise exception.NotAuthorized() -def floating_ip_destroy(_context, address): session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) + # TODO(devcamcar): Ensure address belongs to user. + floating_ip_ref = get_floating_ip_by_address(context, + address, + session=session) floating_ip_ref.delete(session=session) -def floating_ip_disassociate(_context, address): +def floating_ip_disassociate(context, address): + if not is_user_context(context) and is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) + # TODO(devcamcar): Ensure address belongs to user. + # Does get_floating_ip_by_address handle this? + floating_ip_ref = floating_ip_get_by_address(context, + address, + session=session) fixed_ip_ref = floating_ip_ref.fixed_ip if fixed_ip_ref: fixed_ip_address = fixed_ip_ref['address'] @@ -311,16 +350,22 @@ def floating_ip_disassociate(_context, address): floating_ip_ref.save(session=session) return fixed_ip_address +#@require_admin_context +def floating_ip_get_all(context): + if not is_admin_context(context): + raise exception.NotAuthorized() -def floating_ip_get_all(_context): session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') ).filter_by(deleted=False ).all() +#@require_admin_context +def floating_ip_get_all_by_host(context, host): + if not is_admin_context(context): + raise exception.NotAuthorized() -def floating_ip_get_all_by_host(_context, host): session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -328,7 +373,15 @@ def floating_ip_get_all_by_host(_context, host): ).filter_by(deleted=False ).all() -def floating_ip_get_all_by_project(_context, project_id): +#@require_context +def floating_ip_get_all_by_project(context, project_id): + # TODO(devcamcar): Change to decorate and check project_id separately. + if is_user_context(context): + if context.project.id != project_id: + raise exception.NotAuthorized() + elif not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -336,22 +389,38 @@ def floating_ip_get_all_by_project(_context, project_id): ).filter_by(deleted=False ).all() -def floating_ip_get_by_address(_context, address): - return models.FloatingIp.find_by_str(address) +#@require_context +def floating_ip_get_by_address(context, address, session=None): + # TODO(devcamcar): Ensure the address belongs to user. + if not is_user_context(context) and not is_admin_context(context): + raise exception.NotAuthorized() + + if not session: + session = get_session() + + result = session.query(models.FloatingIp + ).filter_by(address=address + ).filter_by(deleted=_deleted(context) + ).first() + if not result: + raise exception.NotFound('No fixed ip for address %s' % address) + return result -def floating_ip_get_instance(_context, address): - session = get_session() - with session.begin(): - floating_ip_ref = models.FloatingIp.find_by_str(address, - session=session) - return floating_ip_ref.fixed_ip.instance + + # floating_ip_ref = get_floating_ip_by_address(context, + # address, + # session=session) + # return floating_ip_ref.fixed_ip.instance ################### +#@require_context +def fixed_ip_associate(context, address, instance_id): + if not is_user_context(context) and not is_admin_context(context): + raise exception.NotAuthorized() -def fixed_ip_associate(_context, address, instance_id): session = get_session() with session.begin(): fixed_ip_ref = session.query(models.FixedIp @@ -364,12 +433,17 @@ def fixed_ip_associate(_context, address, instance_id): # then this has concurrency issues if not fixed_ip_ref: raise db.NoMoreAddresses() - fixed_ip_ref.instance = models.Instance.find(instance_id, - session=session) + fixed_ip_ref.instance = instance_get(context, + instance_id, + session=session) session.add(fixed_ip_ref) -def fixed_ip_associate_pool(_context, network_id, instance_id): +#@require_admin_context +def fixed_ip_associate_pool(context, network_id, instance_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): network_or_none = or_(models.FixedIp.network_id == network_id, @@ -386,14 +460,16 @@ def fixed_ip_associate_pool(_context, network_id, instance_id): if not fixed_ip_ref: raise db.NoMoreAddresses() if not fixed_ip_ref.network: - fixed_ip_ref.network = models.Network.find(network_id, - session=session) - fixed_ip_ref.instance = models.Instance.find(instance_id, - session=session) + fixed_ip_ref.network = network_get(context, + network_id, + session=session) + fixed_ip_ref.instance = instance_get(context, + instance_id, + session=session) session.add(fixed_ip_ref) return fixed_ip_ref['address'] - +#@require_context def fixed_ip_create(_context, values): fixed_ip_ref = models.FixedIp() for (key, value) in values.iteritems(): @@ -401,45 +477,56 @@ def fixed_ip_create(_context, values): fixed_ip_ref.save() return fixed_ip_ref['address'] - -def fixed_ip_disassociate(_context, address): +#@require_context +def fixed_ip_disassociate(context, address): session = get_session() with session.begin(): - fixed_ip_ref = models.FixedIp.find_by_str(address, session=session) + fixed_ip_ref = fixed_ip_get_by_address(context, + address, + session=session) fixed_ip_ref.instance = None fixed_ip_ref.save(session=session) -def fixed_ip_get_by_address(_context, address): - session = get_session() - with session.begin(): - try: - return session.query(models.FixedIp - ).options(joinedload_all('instance') - ).filter_by(address=address - ).filter_by(deleted=False - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for address %s" % address) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - - -def fixed_ip_get_instance(_context, address): - session = get_session() - with session.begin(): - return models.FixedIp.find_by_str(address, session=session).instance +#@require_context +def fixed_ip_get_by_address(context, address, session=None): + # TODO(devcamcar): Ensure floating ip belongs to user. + # Only possible if it is associated with an instance. + # May have to use system context for this always. + if not session: + session = get_session() + result = session.query(models.FixedIp + ).filter_by(address=address + ).filter_by(deleted=_deleted(context) + ).options(joinedload('network') + ).options(joinedload('instance') + ).first() + if not result: + raise exception.NotFound('No floating ip for address %s' % address) + + return result + + +#@require_context +def fixed_ip_get_instance(context, address): + fixed_ip_ref = fixed_ip_get_by_address(context, address) + return fixed_ip_ref.instance -def fixed_ip_get_network(_context, address): - session = get_session() - with session.begin(): - return models.FixedIp.find_by_str(address, session=session).network +#@require_admin_context +def fixed_ip_get_network(context, address): + fixed_ip_ref = fixed_ip_get_by_address(context, address) + return fixed_ip_ref.network -def fixed_ip_update(_context, address, values): + +#@require_context +def fixed_ip_update(context, address, values): session = get_session() with session.begin(): - fixed_ip_ref = models.FixedIp.find_by_str(address, session=session) + fixed_ip_ref = fixed_ip_get_by_address(context, + address, + session=session) for (key, value) in values.iteritems(): fixed_ip_ref[key] = value fixed_ip_ref.save(session=session) @@ -462,7 +549,9 @@ def instance_create(_context, values): instance_ref.save(session=session) return instance_ref + def instance_data_get_for_project(_context, project_id): + # TODO(devmcar): Admin only session = get_session() result = session.query(func.count(models.Instance.id), func.sum(models.Instance.vcpus) @@ -474,6 +563,7 @@ def instance_data_get_for_project(_context, project_id): def instance_destroy(_context, instance_id): + # TODO(devcamcar): Support user context session = get_session() with session.begin(): instance_ref = models.Instance.find(instance_id, session=session) @@ -481,17 +571,21 @@ def instance_destroy(_context, instance_id): def instance_get(context, instance_id, session=None): + # TODO(devcamcar): Support user context return models.Instance.find(instance_id, session=session, deleted=_deleted(context)) def instance_get_all(context): + # TODO(devcamcar): Admin only session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).filter_by(deleted=_deleted(context) ).all() + def instance_get_all_by_user(context, user_id): + # TODO(devcamcar): Admin only session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -499,7 +593,9 @@ def instance_get_all_by_user(context, user_id): ).filter_by(user_id=user_id ).all() + def instance_get_all_by_project(context, project_id): + # TODO(devcamcar): Support user context session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -509,6 +605,7 @@ def instance_get_all_by_project(context, project_id): def instance_get_all_by_reservation(_context, reservation_id): + # TODO(devcamcar): Support user context session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -518,6 +615,7 @@ def instance_get_all_by_reservation(_context, reservation_id): def instance_get_by_ec2_id(context, ec2_id): + # TODO(devcamcar): Support user context session = get_session() instance_ref = session.query(models.Instance ).filter_by(ec2_id=ec2_id @@ -536,6 +634,7 @@ def instance_ec2_id_exists(context, ec2_id, session=None): def instance_get_fixed_address(_context, instance_id): + # TODO(devcamcar): Support user context session = get_session() with session.begin(): instance_ref = models.Instance.find(instance_id, session=session) @@ -545,6 +644,7 @@ def instance_get_fixed_address(_context, instance_id): def instance_get_floating_address(_context, instance_id): + # TODO(devcamcar): Support user context session = get_session() with session.begin(): instance_ref = models.Instance.find(instance_id, session=session) @@ -557,6 +657,7 @@ def instance_get_floating_address(_context, instance_id): def instance_is_vpn(context, instance_id): + # TODO(devcamcar): Admin only # TODO(vish): Move this into image code somewhere instance_ref = instance_get(context, instance_id) return instance_ref['image_id'] == FLAGS.vpn_image_id @@ -683,8 +784,8 @@ def network_destroy(_context, network_id): {'id': network_id}) -def network_get(_context, network_id): - return models.Network.find(network_id) +def network_get(_context, network_id, session=None): + return models.Network.find(network_id, session=session) # NOTE(vish): pylint complains because of the long method name, but diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index b9bb8e4f2..7a085c4df 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -441,19 +441,6 @@ class FixedIp(BASE, NovaBase): def str_id(self): return self.address - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(address=str_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for address %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - class FloatingIp(BASE, NovaBase): """Represents a floating ip that dynamically forwards to a fixed ip""" @@ -469,19 +456,6 @@ class FloatingIp(BASE, NovaBase): project_id = Column(String(255)) host = Column(String(255)) # , ForeignKey('hosts.id')) - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(address=str_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for address %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - def register_models(): """Register Models and create metadata""" diff --git a/nova/network/manager.py b/nova/network/manager.py index a7126ea4f..d125d28d8 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -232,7 +232,8 @@ class VlanManager(NetworkManager): address = network_ref['vpn_private_address'] self.db.fixed_ip_associate(context, address, instance_id) else: - address = self.db.fixed_ip_associate_pool(context, + # TODO(devcamcar) Pass system context here. + address = self.db.fixed_ip_associate_pool(None, network_ref['id'], instance_id) self.db.fixed_ip_update(context, address, {'allocated': True}) diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index da65b50a2..110e8430c 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -84,6 +84,7 @@ class NetworkTestCase(test.TrialTestCase): def test_public_network_association(self): """Makes sure that we can allocaate a public ip""" # TODO(vish): better way of adding floating ips + self.context.project = self.projects[0] pubnet = IPy.IP(flags.FLAGS.public_range) address = str(pubnet[0]) try: -- cgit From d32d95e08d67084ea04ccd1565ce6faffb1766ce Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 20:29:55 -0700 Subject: Finished instance context auth --- nova/db/sqlalchemy/api.py | 185 ++++++++++++++++++++++++++++++----------- nova/tests/compute_unittest.py | 2 + nova/tests/network_unittest.py | 4 +- 3 files changed, 141 insertions(+), 50 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index d129df2be..9ab53b89b 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -41,9 +41,10 @@ FLAGS = flags.FLAGS # pylint: disable-msg=C0111 def _deleted(context): """Calculates whether to include deleted objects based on context. - - Currently just looks for a flag called deleted in the context dict. + Currently just looks for a flag called deleted in the context dict. """ + if is_user_context(context): + return False if not hasattr(context, 'get'): return False return context.get('deleted', False) @@ -69,7 +70,7 @@ def is_user_context(context): ################### - +#@require_admin_context def service_destroy(context, service_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -80,6 +81,7 @@ def service_destroy(context, service_id): service_ref.delete(session=session) +#@require_admin_context def service_get(context, service_id, session=None): if not is_admin_context(context): raise exception.NotAuthorized() @@ -98,6 +100,7 @@ def service_get(context, service_id, session=None): return result +#@require_admin_context def service_get_all_by_topic(context, topic): if not is_admin_context(context): raise exception.NotAuthorized() @@ -110,6 +113,7 @@ def service_get_all_by_topic(context, topic): ).all() +#@require_admin_context def _service_get_all_topic_subquery(context, session, topic, subq, label): if not is_admin_context(context): raise exception.NotAuthorized() @@ -124,6 +128,7 @@ def _service_get_all_topic_subquery(context, session, topic, subq, label): ).all() +#@require_admin_context def service_get_all_compute_sorted(context): if not is_admin_context(context): raise exception.NotAuthorized() @@ -151,6 +156,7 @@ def service_get_all_compute_sorted(context): label) +#@require_admin_context def service_get_all_network_sorted(context): if not is_admin_context(context): raise exception.NotAuthorized() @@ -171,6 +177,7 @@ def service_get_all_network_sorted(context): label) +#@require_admin_context def service_get_all_volume_sorted(context): if not is_admin_context(context): raise exception.NotAuthorized() @@ -191,6 +198,7 @@ def service_get_all_volume_sorted(context): label) +#@require_admin_context def service_get_by_args(context, host, binary): if not is_admin_context(context): raise exception.NotAuthorized() @@ -208,6 +216,7 @@ def service_get_by_args(context, host, binary): return result +#@require_admin_context def service_create(context, values): if not is_admin_context(context): return exception.NotAuthorized() @@ -219,6 +228,7 @@ def service_create(context, values): return service_ref +#@require_admin_context def service_update(context, service_id, values): if not is_admin_context(context): raise exception.NotAuthorized() @@ -234,12 +244,11 @@ def service_update(context, service_id, values): ################### +#@require_context def floating_ip_allocate_address(context, host, project_id): if is_user_context(context): if context.project.id != project_id: raise exception.NotAuthorized() - elif not is_admin_context(context): - raise exception.NotAuthorized() session = get_session() with session.begin(): @@ -259,6 +268,7 @@ def floating_ip_allocate_address(context, host, project_id): return floating_ip_ref['address'] +#@require_context def floating_ip_create(context, values): if not is_user_context(context) and not is_admin_context(context): raise exception.NotAuthorized() @@ -270,12 +280,11 @@ def floating_ip_create(context, values): return floating_ip_ref['address'] +#@require_context def floating_ip_count_by_project(context, project_id): if is_user_context(context): if context.project.id != project_id: raise exception.NotAuthorized() - elif not is_admin_context(context): - raise exception.NotAuthorized() session = get_session() return session.query(models.FloatingIp @@ -316,6 +325,7 @@ def floating_ip_deallocate(context, address): floating_ip_ref['project_id'] = None floating_ip_ref.save(session=session) + #@require_context def floating_ip_destroy(context, address): if not is_user_context(context) and not is_admin_context(context): @@ -330,6 +340,7 @@ def floating_ip_destroy(context, address): floating_ip_ref.delete(session=session) +#@require_context def floating_ip_disassociate(context, address): if not is_user_context(context) and is_admin_context(context): raise exception.NotAuthorized() @@ -350,6 +361,7 @@ def floating_ip_disassociate(context, address): floating_ip_ref.save(session=session) return fixed_ip_address + #@require_admin_context def floating_ip_get_all(context): if not is_admin_context(context): @@ -361,6 +373,7 @@ def floating_ip_get_all(context): ).filter_by(deleted=False ).all() + #@require_admin_context def floating_ip_get_all_by_host(context, host): if not is_admin_context(context): @@ -373,6 +386,7 @@ def floating_ip_get_all_by_host(context, host): ).filter_by(deleted=False ).all() + #@require_context def floating_ip_get_all_by_project(context, project_id): # TODO(devcamcar): Change to decorate and check project_id separately. @@ -389,6 +403,7 @@ def floating_ip_get_all_by_project(context, project_id): ).filter_by(deleted=False ).all() + #@require_context def floating_ip_get_by_address(context, address, session=None): # TODO(devcamcar): Ensure the address belongs to user. @@ -408,14 +423,9 @@ def floating_ip_get_by_address(context, address, session=None): return result - # floating_ip_ref = get_floating_ip_by_address(context, - # address, - # session=session) - # return floating_ip_ref.fixed_ip.instance - - ################### + #@require_context def fixed_ip_associate(context, address, instance_id): if not is_user_context(context) and not is_admin_context(context): @@ -469,6 +479,7 @@ def fixed_ip_associate_pool(context, network_id, instance_id): session.add(fixed_ip_ref) return fixed_ip_ref['address'] + #@require_context def fixed_ip_create(_context, values): fixed_ip_ref = models.FixedIp() @@ -477,6 +488,7 @@ def fixed_ip_create(_context, values): fixed_ip_ref.save() return fixed_ip_ref['address'] + #@require_context def fixed_ip_disassociate(context, address): session = get_session() @@ -535,7 +547,8 @@ def fixed_ip_update(context, address, values): ################### -def instance_create(_context, values): +#@require_context +def instance_create(context, values): instance_ref = models.Instance() for (key, value) in values.iteritems(): instance_ref[key] = value @@ -544,14 +557,14 @@ def instance_create(_context, values): with session.begin(): while instance_ref.ec2_id == None: ec2_id = utils.generate_uid(instance_ref.__prefix__) - if not instance_ec2_id_exists(_context, ec2_id, session=session): + if not instance_ec2_id_exists(context, ec2_id, session=session): instance_ref.ec2_id = ec2_id instance_ref.save(session=session) return instance_ref -def instance_data_get_for_project(_context, project_id): - # TODO(devmcar): Admin only +#@require_admin_context +def instance_data_get_for_project(context, project_id): session = get_session() result = session.query(func.count(models.Instance.id), func.sum(models.Instance.vcpus) @@ -562,21 +575,42 @@ def instance_data_get_for_project(_context, project_id): return (result[0] or 0, result[1] or 0) -def instance_destroy(_context, instance_id): - # TODO(devcamcar): Support user context +#@require_context +def instance_destroy(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) instance_ref.delete(session=session) +#@require_context def instance_get(context, instance_id, session=None): - # TODO(devcamcar): Support user context - return models.Instance.find(instance_id, session=session, deleted=_deleted(context)) + if not session: + session = get_session() + result = None + + if is_admin_context(context): + result = session.query(models.Instance + ).filter_by(id=instance_id + ).filter_by(deleted=_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Instance + ).filter_by(project_id=context.project.id + ).filter_by(id=instance_id + ).filter_by(deleted=False + ).first() + if not result: + raise exception.NotFound('No instance for id %s' % instance_id) + + return result +#@require_admin_context def instance_get_all(context): - # TODO(devcamcar): Admin only + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -584,8 +618,11 @@ def instance_get_all(context): ).all() +#@require_admin_context def instance_get_all_by_user(context, user_id): - # TODO(devcamcar): Admin only + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -594,8 +631,12 @@ def instance_get_all_by_user(context, user_id): ).all() +#@require_context def instance_get_all_by_project(context, project_id): - # TODO(devcamcar): Support user context + if is_user_context(context): + if context.project.id != project_id: + raise exception.NotAuthorized() + session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -604,50 +645,68 @@ def instance_get_all_by_project(context, project_id): ).all() -def instance_get_all_by_reservation(_context, reservation_id): - # TODO(devcamcar): Support user context +#@require_context +def instance_get_all_by_reservation(context, reservation_id): session = get_session() - return session.query(models.Instance - ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(reservation_id=reservation_id - ).filter_by(deleted=False - ).all() + + if is_admin_context(context): + return session.query(models.Instance + ).options(joinedload_all('fixed_ip.floating_ips') + ).filter_by(reservation_id=reservation_id + ).filter_by(deleted=_deleted(context) + ).all() + elif is_user_context(context): + return session.query(models.Instance + ).options(joinedload_all('fixed_ip.floating_ips') + ).filter_by(project_id=context.project.id + ).filter_by(reservation_id=reservation_id + ).filter_by(deleted=False + ).all() +#@require_context def instance_get_by_ec2_id(context, ec2_id): - # TODO(devcamcar): Support user context session = get_session() - instance_ref = session.query(models.Instance + + if is_admin_context(context): + result = session.query(models.Instance ).filter_by(ec2_id=ec2_id ).filter_by(deleted=_deleted(context) ).first() - if not instance_ref: + elif is_user_context(context): + result = session.query(models.Instance + ).filter_by(project_id=context.project.id + ).filter_by(ec2_id=ec2_id + ).filter_by(deleted=False + ).first() + if not result: raise exception.NotFound('Instance %s not found' % (ec2_id)) - return instance_ref + return result +#@require_context def instance_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() return session.query(exists().where(models.Instance.id==ec2_id)).one()[0] -def instance_get_fixed_address(_context, instance_id): - # TODO(devcamcar): Support user context +#@require_context +def instance_get_fixed_address(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) if not instance_ref.fixed_ip: return None return instance_ref.fixed_ip['address'] -def instance_get_floating_address(_context, instance_id): - # TODO(devcamcar): Support user context +#@require_context +def instance_get_floating_address(context, instance_id): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) if not instance_ref.fixed_ip: return None if not instance_ref.fixed_ip.floating_ips: @@ -656,14 +715,20 @@ def instance_get_floating_address(_context, instance_id): return instance_ref.fixed_ip.floating_ips[0]['address'] +#@require_admin_context def instance_is_vpn(context, instance_id): - # TODO(devcamcar): Admin only + if not is_admin_context(context): + raise exception.NotAuthorized() # TODO(vish): Move this into image code somewhere instance_ref = instance_get(context, instance_id) return instance_ref['image_id'] == FLAGS.vpn_image_id +#@require_admin_context def instance_set_state(context, instance_id, state, description=None): + if not is_admin_context(context): + raise exception.NotAuthorized() + # TODO(devcamcar): Move this out of models and into driver from nova.compute import power_state if not description: @@ -674,10 +739,11 @@ def instance_set_state(context, instance_id, state, description=None): 'state_description': description}) -def instance_update(_context, instance_id, values): +#@require_context +def instance_update(context, instance_id, values): session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) + instance_ref = instance_get(context, instance_id, session=session) for (key, value) in values.iteritems(): instance_ref[key] = value instance_ref.save(session=session) @@ -686,6 +752,7 @@ def instance_update(_context, instance_id, values): ################### +#@require_context def key_pair_create(_context, values): key_pair_ref = models.KeyPair() for (key, value) in values.iteritems(): @@ -694,7 +761,8 @@ def key_pair_create(_context, values): return key_pair_ref -def key_pair_destroy(_context, user_id, name): +#@require_context +def key_pair_destroy(context, user_id, name): session = get_session() with session.begin(): key_pair_ref = models.KeyPair.find_by_args(user_id, @@ -784,8 +852,27 @@ def network_destroy(_context, network_id): {'id': network_id}) -def network_get(_context, network_id, session=None): - return models.Network.find(network_id, session=session) +#@require_context +def network_get(context, network_id, session=None): + if not session: + session = get_session() + result = None + + if is_admin_context(context): + result = session.query(models.Network + ).filter_by(id=network_id + ).filter_by(deleted=_deleted(context) + ).first() + elif is_user_context(context): + result = session.query(models.Network + ).filter_by(project_id=context.project.id + ).filter_by(id=network_id + ).filter_by(deleted=False + ).first() + if not result: + raise exception.NotFound('No network for id %s' % network_id) + + return result # NOTE(vish): pylint complains because of the long method name, but @@ -1066,7 +1153,7 @@ def volume_get(context, volume_id, session=None): ).first() elif is_user_context(context): result = session.query(models.Volume - ).filter_by(project_id=context.project.project_id + ).filter_by(project_id=context.project.id ).filter_by(id=volume_id ).filter_by(deleted=False ).first() diff --git a/nova/tests/compute_unittest.py b/nova/tests/compute_unittest.py index f5c0f1c09..e705c2552 100644 --- a/nova/tests/compute_unittest.py +++ b/nova/tests/compute_unittest.py @@ -96,6 +96,8 @@ class ComputeTestCase(test.TrialTestCase): self.assertEqual(instance_ref['deleted_at'], None) terminate = datetime.datetime.utcnow() yield self.compute.terminate_instance(self.context, instance_id) + # TODO(devcamcar): Pass deleted in using system context. + # context.read_deleted ? instance_ref = db.instance_get({'deleted': True}, instance_id) self.assert_(instance_ref['launched_at'] < terminate) self.assert_(instance_ref['deleted_at'] > terminate) diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index 110e8430c..ca6a4bbc2 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -56,7 +56,9 @@ class NetworkTestCase(test.TrialTestCase): 'netuser', name)) # create the necessary network data for the project - self.network.set_network_host(self.context, self.projects[i].id) + user_context = context.APIRequestContext(project=self.projects[i], + user=self.user) + self.network.set_network_host(user_context, self.projects[i].id) instance_ref = db.instance_create(None, {'mac_address': utils.generate_mac()}) self.instance_id = instance_ref['id'] -- cgit From ea5dcda819f2656589df177331f693f945d98f4a Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 20:35:24 -0700 Subject: Finished instance context auth --- nova/db/sqlalchemy/api.py | 32 +++++++++++++++++++++++++++++--- nova/tests/network_unittest.py | 1 + 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 9ab53b89b..2d553d98d 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -794,11 +794,21 @@ def key_pair_get_all_by_user(_context, user_id): ################### -def network_count(_context): - return models.Network.count() +#@require_admin_context +def network_count(context): + if not is_admin_context(context): + raise exception.NotAuthorized() + return session.query(models.Network + ).filter_by(deleted=deleted + ).count() + +#@require_admin_context def network_count_allocated_ips(_context, network_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -807,7 +817,11 @@ def network_count_allocated_ips(_context, network_id): ).count() +#@require_admin_context def network_count_available_ips(_context, network_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -817,7 +831,11 @@ def network_count_available_ips(_context, network_id): ).count() +#@require_admin_context def network_count_reserved_ips(_context, network_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -826,7 +844,11 @@ def network_count_reserved_ips(_context, network_id): ).count() +#@require_admin_context def network_create(_context, values): + if not is_admin_context(context): + raise exception.NotAuthorized() + network_ref = models.Network() for (key, value) in values.iteritems(): network_ref[key] = value @@ -834,7 +856,11 @@ def network_create(_context, values): return network_ref -def network_destroy(_context, network_id): +#@require_admin_context +def network_destroy(context, network_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index ca6a4bbc2..e01d7cff9 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -49,6 +49,7 @@ class NetworkTestCase(test.TrialTestCase): self.user = self.manager.create_user('netuser', 'netuser', 'netuser') self.projects = [] self.network = utils.import_object(FLAGS.network_manager) + # TODO(devcamcar): Passing project=None is Bad(tm). self.context = context.APIRequestContext(project=None, user=self.user) for i in range(5): name = 'project%s' % i -- cgit From e716990fd58521f8c0166330ec9bc62c7cd91b7e Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Wed, 29 Sep 2010 20:54:15 -0700 Subject: Finished context auth for network --- nova/db/sqlalchemy/api.py | 103 ++++++++++++++++++++++++++++++++-------------- nova/network/manager.py | 3 +- 2 files changed, 73 insertions(+), 33 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 2d553d98d..23589b7d8 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -799,13 +799,14 @@ def network_count(context): if not is_admin_context(context): raise exception.NotAuthorized() + session = get_session() return session.query(models.Network - ).filter_by(deleted=deleted + ).filter_by(deleted=_deleted(context) ).count() #@require_admin_context -def network_count_allocated_ips(_context, network_id): +def network_count_allocated_ips(context, network_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -818,7 +819,7 @@ def network_count_allocated_ips(_context, network_id): #@require_admin_context -def network_count_available_ips(_context, network_id): +def network_count_available_ips(context, network_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -832,7 +833,7 @@ def network_count_available_ips(_context, network_id): #@require_admin_context -def network_count_reserved_ips(_context, network_id): +def network_count_reserved_ips(context, network_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -845,7 +846,7 @@ def network_count_reserved_ips(_context, network_id): #@require_admin_context -def network_create(_context, values): +def network_create(context, values): if not is_admin_context(context): raise exception.NotAuthorized() @@ -904,7 +905,11 @@ def network_get(context, network_id, session=None): # NOTE(vish): pylint complains because of the long method name, but # it fits with the names of the rest of the methods # pylint: disable-msg=C0103 -def network_get_associated_fixed_ips(_context, network_id): +#@require_admin_context +def network_get_associated_fixed_ips(context, network_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() return session.query(models.FixedIp ).options(joinedload_all('instance') @@ -914,18 +919,28 @@ def network_get_associated_fixed_ips(_context, network_id): ).all() -def network_get_by_bridge(_context, bridge): +#@require_admin_context +def network_get_by_bridge(context, bridge): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() - rv = session.query(models.Network + result = session.query(models.Network ).filter_by(bridge=bridge ).filter_by(deleted=False ).first() - if not rv: + + if not result: raise exception.NotFound('No network for bridge %s' % bridge) - return rv + + return result -def network_get_index(_context, network_id): +#@require_admin_context +def network_get_index(context, network_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): network_index = session.query(models.NetworkIndex @@ -933,19 +948,34 @@ def network_get_index(_context, network_id): ).filter_by(deleted=False ).with_lockmode('update' ).first() + if not network_index: raise db.NoMoreNetworks() - network_index['network'] = models.Network.find(network_id, - session=session) + + network_index['network'] = network_get(context, + network_id, + session=session) session.add(network_index) + return network_index['index'] -def network_index_count(_context): - return models.NetworkIndex.count() +#@require_admin_context +def network_index_count(context): + if not is_admin_context(context): + raise exception.NotAuthorized() + + session = get_session() + return session.query(models.NetworkIndex + ).filter_by(deleted=_deleted(context) + ).count() + +#@require_admin_context +def network_index_create_safe(context, values): + if not is_admin_context(context): + raise exception.NotAuthorized() -def network_index_create_safe(_context, values): network_index_ref = models.NetworkIndex() for (key, value) in values.iteritems(): network_index_ref[key] = value @@ -955,29 +985,35 @@ def network_index_create_safe(_context, values): pass -def network_set_host(_context, network_id, host_id): +#@require_admin_context +def network_set_host(context, network_id, host_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - network = session.query(models.Network - ).filter_by(id=network_id - ).filter_by(deleted=False - ).with_lockmode('update' - ).first() - if not network: - raise exception.NotFound("Couldn't find network with %s" % - network_id) + network_ref = session.query(models.Network + ).filter_by(id=network_id + ).filter_by(deleted=False + ).with_lockmode('update' + ).first() + if not network_ref: + raise exception.NotFound('No network for id %s' % network_id) + # NOTE(vish): if with_lockmode isn't supported, as in sqlite, # then this has concurrency issues - if not network['host']: - network['host'] = host_id - session.add(network) - return network['host'] + if not network_ref['host']: + network_ref['host'] = host_id + session.add(network_ref) + + return network_ref['host'] -def network_update(_context, network_id, values): +#@require_context +def network_update(context, network_id, values): session = get_session() with session.begin(): - network_ref = models.Network.find(network_id, session=session) + network_ref = network_get(context, network_id, session=session) for (key, value) in values.iteritems(): network_ref[key] = value network_ref.save(session=session) @@ -985,7 +1021,10 @@ def network_update(_context, network_id, values): ################### - +# YOU ARE HERE. +# random idea for system user: +# ctx = context.system_user(on_behalf_of=user, read_deleted=False) +# TODO(devcamcar): Rename to network_get_all_by_project def project_get_network(_context, project_id): session = get_session() rv = session.query(models.Network diff --git a/nova/network/manager.py b/nova/network/manager.py index d125d28d8..ecf2fa2c2 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -88,7 +88,8 @@ class NetworkManager(manager.Manager): # TODO(vish): can we minimize db access by just getting the # id here instead of the ref? network_id = network_ref['id'] - host = self.db.network_set_host(context, + # TODO(devcamcar): Replace with system context + host = self.db.network_set_host(None, network_id, self.host) self._on_set_network_host(context, network_id) -- cgit From 98cac90592658773791eb15b19ed60adf0a57d96 Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Thu, 30 Sep 2010 00:36:10 -0700 Subject: Completed quota context auth --- nova/db/sqlalchemy/api.py | 103 +++++++++++++++++++++++++++++++------------ nova/db/sqlalchemy/models.py | 12 ----- 2 files changed, 75 insertions(+), 40 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 23589b7d8..b225a6a88 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -1021,19 +1021,22 @@ def network_update(context, network_id, values): ################### -# YOU ARE HERE. -# random idea for system user: -# ctx = context.system_user(on_behalf_of=user, read_deleted=False) -# TODO(devcamcar): Rename to network_get_all_by_project -def project_get_network(_context, project_id): + +#@require_context +def project_get_network(context, project_id): + if not is_admin_context(context) and not is_user_context(context): + raise error.NotAuthorized() + session = get_session() - rv = session.query(models.Network + result= session.query(models.Network ).filter_by(project_id=project_id ).filter_by(deleted=False ).first() - if not rv: + + if not result: raise exception.NotFound('No network for project: %s' % project_id) - return rv + + return result ################### @@ -1043,14 +1046,26 @@ def queue_get_for(_context, topic, physical_node_id): # FIXME(ja): this should be servername? return "%s.%s" % (topic, physical_node_id) + ################### -def export_device_count(_context): - return models.ExportDevice.count() +#@require_admin_context +def export_device_count(context): + if not is_admin_context(context): + raise exception.notauthorized() + + session = get_session() + return session.query(models.ExportDevice + ).filter_by(deleted=_deleted(context) + ).count() + +#@require_admin_context +def export_device_create(context, values): + if not is_admin_context(context): + raise exception.notauthorized() -def export_device_create(_context, values): export_device_ref = models.ExportDevice() for (key, value) in values.iteritems(): export_device_ref[key] = value @@ -1084,7 +1099,29 @@ def auth_create_token(_context, token): ################### +#@require_admin_context +def quota_get(context, project_id, session=None): + if not is_admin_context(context): + raise exception.NotAuthorized() + + if not session: + session = get_session() + + result = session.query(models.Quota + ).filter_by(project_id=project_id + ).filter_by(deleted=_deleted(context) + ).first() + if not result: + raise exception.NotFound('No quota for project_id %s' % project_id) + + return result + + +#@require_admin_context def quota_create(_context, values): + if not is_admin_context(context): + raise exception.NotAuthorized() + quota_ref = models.Quota() for (key, value) in values.iteritems(): quota_ref[key] = value @@ -1092,29 +1129,34 @@ def quota_create(_context, values): return quota_ref -def quota_get(_context, project_id): - return models.Quota.find_by_str(project_id) - +#@require_admin_context +def quota_update(context, project_id, values): + if not is_admin_context(context): + raise exception.NotAuthorized() -def quota_update(_context, project_id, values): session = get_session() with session.begin(): - quota_ref = models.Quota.find_by_str(project_id, session=session) + quota_ref = quota_get(context, project_id, session=session) for (key, value) in values.iteritems(): quota_ref[key] = value quota_ref.save(session=session) -def quota_destroy(_context, project_id): +#@require_admin_context +def quota_destroy(context, project_id): + if not is_admin_context(context): + raise exception.NotAuthorized() + session = get_session() with session.begin(): - quota_ref = models.Quota.find_by_str(project_id, session=session) + quota_ref = quota_get(context, project_id, session=session) quota_ref.delete(session=session) ################### +#@require_admin_context def volume_allocate_shelf_and_blade(context, volume_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -1135,6 +1177,7 @@ def volume_allocate_shelf_and_blade(context, volume_id): return (export_device.shelf_id, export_device.blade_id) +#@require_admin_context def volume_attached(context, volume_id, instance_id, mountpoint): if not is_admin_context(context): raise exception.NotAuthorized() @@ -1149,6 +1192,7 @@ def volume_attached(context, volume_id, instance_id, mountpoint): volume_ref.save(session=session) +#@require_context def volume_create(context, values): volume_ref = models.Volume() for (key, value) in values.iteritems(): @@ -1164,6 +1208,7 @@ def volume_create(context, values): return volume_ref +#@require_admin_context def volume_data_get_for_project(context, project_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -1178,6 +1223,7 @@ def volume_data_get_for_project(context, project_id): return (result[0] or 0, result[1] or 0) +#@require_admin_context def volume_destroy(context, volume_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -1192,6 +1238,7 @@ def volume_destroy(context, volume_id): {'id': volume_id}) +#@require_admin_context def volume_detached(context, volume_id): if not is_admin_context(context): raise exception.NotAuthorized() @@ -1206,6 +1253,7 @@ def volume_detached(context, volume_id): volume_ref.save(session=session) +#@require_context def volume_get(context, volume_id, session=None): if not session: session = get_session() @@ -1222,15 +1270,13 @@ def volume_get(context, volume_id, session=None): ).filter_by(id=volume_id ).filter_by(deleted=False ).first() - else: - raise exception.NotAuthorized() - if not result: raise exception.NotFound('No volume for id %s' % volume_id) return result +#@require_admin_context def volume_get_all(context): if not is_admin_context(context): raise exception.NotAuthorized() @@ -1239,7 +1285,7 @@ def volume_get_all(context): ).filter_by(deleted=_deleted(context) ).all() - +#@require_context def volume_get_all_by_project(context, project_id): if is_user_context(context): if context.project.id != project_id: @@ -1254,6 +1300,7 @@ def volume_get_all_by_project(context, project_id): ).all() +#@require_context def volume_get_by_ec2_id(context, ec2_id): session = get_session() result = None @@ -1278,6 +1325,7 @@ def volume_get_by_ec2_id(context, ec2_id): return result +#@require_context def volume_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() @@ -1286,10 +1334,9 @@ def volume_ec2_id_exists(context, ec2_id, session=None): return session.query(exists( ).where(models.Volume.id==ec2_id) ).one()[0] - else: - raise exception.NotAuthorized() +#@require_context def volume_get_instance(context, volume_id): session = get_session() result = None @@ -1315,6 +1362,7 @@ def volume_get_instance(context, volume_id): return result.instance +#@require_context def volume_get_shelf_and_blade(context, volume_id): session = get_session() result = None @@ -1329,15 +1377,14 @@ def volume_get_shelf_and_blade(context, volume_id): ).filter(models.Volume.project_id==context.project.id ).filter_by(volume_id=volume_id ).first() - else: - raise exception.NotAuthorized() - if not result: - raise exception.NotFound() + raise exception.NotFound('No export device found for volume %s' % + volume_id) return (result.shelf_id, result.blade_id) +#@require_context def volume_update(context, volume_id, values): session = get_session() with session.begin(): diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 7a085c4df..76444127f 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -302,18 +302,6 @@ class Quota(BASE, NovaBase): def str_id(self): return self.project_id - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(project_id=str_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for project_id %s" % str_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] class ExportDevice(BASE, NovaBase): """Represates a shelf and blade that a volume can be exported on""" -- cgit From 30541d48b17ab4626791d969388871b3a1b7758f Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Thu, 30 Sep 2010 01:07:05 -0700 Subject: Wired up context auth for keypairs --- nova/db/sqlalchemy/api.py | 46 +++++++++++++++++++++++++++++++++++--------- nova/db/sqlalchemy/models.py | 20 ------------------- 2 files changed, 37 insertions(+), 29 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index b225a6a88..302322979 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -753,7 +753,7 @@ def instance_update(context, instance_id, values): #@require_context -def key_pair_create(_context, values): +def key_pair_create(context, values): key_pair_ref = models.KeyPair() for (key, value) in values.iteritems(): key_pair_ref[key] = value @@ -763,15 +763,22 @@ def key_pair_create(_context, values): #@require_context def key_pair_destroy(context, user_id, name): + if is_user_context(context): + if context.user.id != user_id: + raise exception.NotAuthorized() + session = get_session() with session.begin(): - key_pair_ref = models.KeyPair.find_by_args(user_id, - name, - session=session) + key_pair_ref = key_pair_get(context, user_id, name, session=session) key_pair_ref.delete(session=session) -def key_pair_destroy_all_by_user(_context, user_id): +#@require_context +def key_pair_destroy_all_by_user(context, user_id): + if is_user_context(context): + if context.user.id != user_id: + raise exception.NotAuthorized() + session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -779,11 +786,32 @@ def key_pair_destroy_all_by_user(_context, user_id): {'id': user_id}) -def key_pair_get(_context, user_id, name): - return models.KeyPair.find_by_args(user_id, name) +#@require_context +def key_pair_get(context, user_id, name, session=None): + if is_user_context(context): + if context.user.id != user_id: + raise exception.NotAuthorized() + + if not session: + session = get_session() + + result = session.query(models.KeyPair + ).filter_by(user_id=user_id + ).filter_by(name=name + ).filter_by(deleted=_deleted(context) + ).first() + if not result: + raise exception.NotFound('no keypair for user %s, name %s' % + (user_id, name)) + return result + +#@require_context +def key_pair_get_all_by_user(context, user_id): + if is_user_context(context): + if context.user.id != user_id: + raise exception.NotAuthorized() -def key_pair_get_all_by_user(_context, user_id): session = get_session() return session.query(models.KeyPair ).filter_by(user_id=user_id @@ -1118,7 +1146,7 @@ def quota_get(context, project_id, session=None): #@require_admin_context -def quota_create(_context, values): +def quota_create(context, values): if not is_admin_context(context): raise exception.NotAuthorized() diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 76444127f..1f5bdf9f5 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -332,26 +332,6 @@ class KeyPair(BASE, NovaBase): def str_id(self): return '%s.%s' % (self.user_id, self.name) - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - user_id, _sep, name = str_id.partition('.') - return cls.find_by_str(user_id, name, session, deleted) - - @classmethod - def find_by_args(cls, user_id, name, session=None, deleted=False): - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(user_id=user_id - ).filter_by(name=name - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for user %s, name %s" % - (user_id, name)) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - class Network(BASE, NovaBase): """Represents a network""" -- cgit From 336523b36ceb8f5302acd443b7f1171b67575f73 Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Thu, 30 Sep 2010 01:11:16 -0700 Subject: Removed deprecated bits from NovaBase --- nova/db/sqlalchemy/models.py | 38 -------------------------------------- 1 file changed, 38 deletions(-) diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 1f5bdf9f5..a29090c60 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -50,44 +50,6 @@ class NovaBase(object): deleted_at = Column(DateTime) deleted = Column(Boolean, default=False) - @classmethod - def all(cls, session=None, deleted=False): - """Get all objects of this type""" - if not session: - session = get_session() - return session.query(cls - ).filter_by(deleted=deleted - ).all() - - @classmethod - def count(cls, session=None, deleted=False): - """Count objects of this type""" - if not session: - session = get_session() - return session.query(cls - ).filter_by(deleted=deleted - ).count() - - @classmethod - def find(cls, obj_id, session=None, deleted=False): - """Find object by id""" - if not session: - session = get_session() - try: - return session.query(cls - ).filter_by(id=obj_id - ).filter_by(deleted=deleted - ).one() - except exc.NoResultFound: - new_exc = exception.NotFound("No model for id %s" % obj_id) - raise new_exc.__class__, new_exc, sys.exc_info()[2] - - @classmethod - def find_by_str(cls, str_id, session=None, deleted=False): - """Find object by str_id""" - int_id = int(str_id.rpartition('-')[2]) - return cls.find(int_id, session=session, deleted=deleted) - @property def str_id(self): """Get string id of object (generally prefix + '-' + id)""" -- cgit From 8bd81f3ec811e19f6e7faf7a4fe271f85fbc7fc7 Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Thu, 30 Sep 2010 02:02:14 -0700 Subject: Simplified authorization with decorators" " --- nova/db/sqlalchemy/api.py | 408 ++++++++++++++++------------------------------ 1 file changed, 142 insertions(+), 266 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 302322979..0e7d2e664 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -51,6 +51,7 @@ def _deleted(context): def is_admin_context(context): + """Indicates if the request context is an administrator.""" if not context: logging.warning('Use of empty request context is deprecated') return True @@ -60,6 +61,7 @@ def is_admin_context(context): def is_user_context(context): + """Indicates if the request context is a normal user.""" if not context: logging.warning('Use of empty request context is deprecated') return False @@ -68,24 +70,62 @@ def is_user_context(context): return True +def authorize_project_context(context, project_id): + """Ensures that the request context has permission to access the + given project. + """ + if is_user_context(context): + if not context.project: + raise exception.NotAuthorized() + elif context.project.id != project_id: + raise exception.NotAuthorized() + + +def authorize_user_context(context, user_id): + """Ensures that the request context has permission to access the + given user. + """ + if is_user_context(context): + if not context.user: + raise exception.NotAuthorized() + elif context.user.id != user_id: + raise exception.NotAuthorized() + + +def require_admin_context(f): + """Decorator used to indicate that the method requires an + administrator context. + """ + def wrapper(*args, **kwargs): + if not is_admin_context(args[0]): + raise exception.NotAuthorized() + return f(*args, **kwargs) + return wrapper + + +def require_context(f): + """Decorator used to indicate that the method requires either + an administrator or normal user context. + """ + def wrapper(*args, **kwargs): + if not is_admin_context(args[0]) and not is_user_context(args[0]): + raise exception.NotAuthorized() + return f(*args, **kwargs) + return wrapper + + ################### -#@require_admin_context +@require_admin_context def service_destroy(context, service_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): service_ref = service_get(context, service_id, session=session) service_ref.delete(session=session) -#@require_admin_context +@require_admin_context def service_get(context, service_id, session=None): - if not is_admin_context(context): - raise exception.NotAuthorized() - if not session: session = get_session() @@ -100,11 +140,8 @@ def service_get(context, service_id, session=None): return result -#@require_admin_context +@require_admin_context def service_get_all_by_topic(context, topic): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.Service ).filter_by(deleted=False @@ -113,11 +150,8 @@ def service_get_all_by_topic(context, topic): ).all() -#@require_admin_context +@require_admin_context def _service_get_all_topic_subquery(context, session, topic, subq, label): - if not is_admin_context(context): - raise exception.NotAuthorized() - sort_value = getattr(subq.c, label) return session.query(models.Service, func.coalesce(sort_value, 0) ).filter_by(topic=topic @@ -128,11 +162,8 @@ def _service_get_all_topic_subquery(context, session, topic, subq, label): ).all() -#@require_admin_context +@require_admin_context def service_get_all_compute_sorted(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # NOTE(vish): The intended query is below @@ -156,11 +187,8 @@ def service_get_all_compute_sorted(context): label) -#@require_admin_context +@require_admin_context def service_get_all_network_sorted(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): topic = 'network' @@ -177,11 +205,8 @@ def service_get_all_network_sorted(context): label) -#@require_admin_context +@require_admin_context def service_get_all_volume_sorted(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): topic = 'volume' @@ -198,11 +223,8 @@ def service_get_all_volume_sorted(context): label) -#@require_admin_context +@require_admin_context def service_get_by_args(context, host, binary): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() result = session.query(models.Service ).filter_by(host=host @@ -216,11 +238,8 @@ def service_get_by_args(context, host, binary): return result -#@require_admin_context +@require_admin_context def service_create(context, values): - if not is_admin_context(context): - return exception.NotAuthorized() - service_ref = models.Service() for (key, value) in values.iteritems(): service_ref[key] = value @@ -228,11 +247,8 @@ def service_create(context, values): return service_ref -#@require_admin_context +@require_admin_context def service_update(context, service_id, values): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): service_ref = session_get(context, service_id, session=session) @@ -244,11 +260,9 @@ def service_update(context, service_id, values): ################### -#@require_context +@require_context def floating_ip_allocate_address(context, host, project_id): - if is_user_context(context): - if context.project.id != project_id: - raise exception.NotAuthorized() + authorize_project_context(context, project_id) session = get_session() with session.begin(): @@ -268,11 +282,8 @@ def floating_ip_allocate_address(context, host, project_id): return floating_ip_ref['address'] -#@require_context +@require_context def floating_ip_create(context, values): - if not is_user_context(context) and not is_admin_context(context): - raise exception.NotAuthorized() - floating_ip_ref = models.FloatingIp() for (key, value) in values.iteritems(): floating_ip_ref[key] = value @@ -280,11 +291,9 @@ def floating_ip_create(context, values): return floating_ip_ref['address'] -#@require_context +@require_context def floating_ip_count_by_project(context, project_id): - if is_user_context(context): - if context.project.id != project_id: - raise exception.NotAuthorized() + authorize_project_context(context, project_id) session = get_session() return session.query(models.FloatingIp @@ -293,11 +302,8 @@ def floating_ip_count_by_project(context, project_id): ).count() -#@require_context +@require_context def floating_ip_fixed_ip_associate(context, floating_address, fixed_address): - if not is_user_context(context) and not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # TODO(devcamcar): How to ensure floating_id belongs to user? @@ -311,11 +317,8 @@ def floating_ip_fixed_ip_associate(context, floating_address, fixed_address): floating_ip_ref.save(session=session) -#@require_context +@require_context def floating_ip_deallocate(context, address): - if not is_user_context(context) and not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # TODO(devcamcar): How to ensure floating id belongs to user? @@ -326,11 +329,8 @@ def floating_ip_deallocate(context, address): floating_ip_ref.save(session=session) -#@require_context +@require_context def floating_ip_destroy(context, address): - if not is_user_context(context) and not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # TODO(devcamcar): Ensure address belongs to user. @@ -340,11 +340,8 @@ def floating_ip_destroy(context, address): floating_ip_ref.delete(session=session) -#@require_context +@require_context def floating_ip_disassociate(context, address): - if not is_user_context(context) and is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # TODO(devcamcar): Ensure address belongs to user. @@ -362,11 +359,8 @@ def floating_ip_disassociate(context, address): return fixed_ip_address -#@require_admin_context +@require_admin_context def floating_ip_get_all(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -374,11 +368,8 @@ def floating_ip_get_all(context): ).all() -#@require_admin_context +@require_admin_context def floating_ip_get_all_by_host(context, host): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -387,14 +378,9 @@ def floating_ip_get_all_by_host(context, host): ).all() -#@require_context +@require_context def floating_ip_get_all_by_project(context, project_id): - # TODO(devcamcar): Change to decorate and check project_id separately. - if is_user_context(context): - if context.project.id != project_id: - raise exception.NotAuthorized() - elif not is_admin_context(context): - raise exception.NotAuthorized() + authorize_project_context(context, project_id) session = get_session() return session.query(models.FloatingIp @@ -404,12 +390,9 @@ def floating_ip_get_all_by_project(context, project_id): ).all() -#@require_context +@require_context def floating_ip_get_by_address(context, address, session=None): # TODO(devcamcar): Ensure the address belongs to user. - if not is_user_context(context) and not is_admin_context(context): - raise exception.NotAuthorized() - if not session: session = get_session() @@ -426,11 +409,8 @@ def floating_ip_get_by_address(context, address, session=None): ################### -#@require_context +@require_context def fixed_ip_associate(context, address, instance_id): - if not is_user_context(context) and not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): fixed_ip_ref = session.query(models.FixedIp @@ -449,11 +429,8 @@ def fixed_ip_associate(context, address, instance_id): session.add(fixed_ip_ref) -#@require_admin_context +@require_admin_context def fixed_ip_associate_pool(context, network_id, instance_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): network_or_none = or_(models.FixedIp.network_id == network_id, @@ -480,7 +457,7 @@ def fixed_ip_associate_pool(context, network_id, instance_id): return fixed_ip_ref['address'] -#@require_context +@require_context def fixed_ip_create(_context, values): fixed_ip_ref = models.FixedIp() for (key, value) in values.iteritems(): @@ -489,7 +466,7 @@ def fixed_ip_create(_context, values): return fixed_ip_ref['address'] -#@require_context +@require_context def fixed_ip_disassociate(context, address): session = get_session() with session.begin(): @@ -500,7 +477,7 @@ def fixed_ip_disassociate(context, address): fixed_ip_ref.save(session=session) -#@require_context +@require_context def fixed_ip_get_by_address(context, address, session=None): # TODO(devcamcar): Ensure floating ip belongs to user. # Only possible if it is associated with an instance. @@ -520,19 +497,19 @@ def fixed_ip_get_by_address(context, address, session=None): return result -#@require_context +@require_context def fixed_ip_get_instance(context, address): fixed_ip_ref = fixed_ip_get_by_address(context, address) return fixed_ip_ref.instance -#@require_admin_context +@require_admin_context def fixed_ip_get_network(context, address): fixed_ip_ref = fixed_ip_get_by_address(context, address) return fixed_ip_ref.network -#@require_context +@require_context def fixed_ip_update(context, address, values): session = get_session() with session.begin(): @@ -547,7 +524,7 @@ def fixed_ip_update(context, address, values): ################### -#@require_context +@require_context def instance_create(context, values): instance_ref = models.Instance() for (key, value) in values.iteritems(): @@ -563,7 +540,7 @@ def instance_create(context, values): return instance_ref -#@require_admin_context +@require_admin_context def instance_data_get_for_project(context, project_id): session = get_session() result = session.query(func.count(models.Instance.id), @@ -575,7 +552,7 @@ def instance_data_get_for_project(context, project_id): return (result[0] or 0, result[1] or 0) -#@require_context +@require_context def instance_destroy(context, instance_id): session = get_session() with session.begin(): @@ -583,7 +560,7 @@ def instance_destroy(context, instance_id): instance_ref.delete(session=session) -#@require_context +@require_context def instance_get(context, instance_id, session=None): if not session: session = get_session() @@ -606,11 +583,8 @@ def instance_get(context, instance_id, session=None): return result -#@require_admin_context +@require_admin_context def instance_get_all(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -618,11 +592,8 @@ def instance_get_all(context): ).all() -#@require_admin_context +@require_admin_context def instance_get_all_by_user(context, user_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') @@ -631,11 +602,9 @@ def instance_get_all_by_user(context, user_id): ).all() -#@require_context +@require_context def instance_get_all_by_project(context, project_id): - if is_user_context(context): - if context.project.id != project_id: - raise exception.NotAuthorized() + authorize_project_context(context, project_id) session = get_session() return session.query(models.Instance @@ -645,7 +614,7 @@ def instance_get_all_by_project(context, project_id): ).all() -#@require_context +@require_context def instance_get_all_by_reservation(context, reservation_id): session = get_session() @@ -664,7 +633,7 @@ def instance_get_all_by_reservation(context, reservation_id): ).all() -#@require_context +@require_context def instance_get_by_ec2_id(context, ec2_id): session = get_session() @@ -685,14 +654,14 @@ def instance_get_by_ec2_id(context, ec2_id): return result -#@require_context +@require_context def instance_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() return session.query(exists().where(models.Instance.id==ec2_id)).one()[0] -#@require_context +@require_context def instance_get_fixed_address(context, instance_id): session = get_session() with session.begin(): @@ -702,7 +671,7 @@ def instance_get_fixed_address(context, instance_id): return instance_ref.fixed_ip['address'] -#@require_context +@require_context def instance_get_floating_address(context, instance_id): session = get_session() with session.begin(): @@ -715,20 +684,15 @@ def instance_get_floating_address(context, instance_id): return instance_ref.fixed_ip.floating_ips[0]['address'] -#@require_admin_context +@require_admin_context def instance_is_vpn(context, instance_id): - if not is_admin_context(context): - raise exception.NotAuthorized() # TODO(vish): Move this into image code somewhere instance_ref = instance_get(context, instance_id) return instance_ref['image_id'] == FLAGS.vpn_image_id -#@require_admin_context +@require_admin_context def instance_set_state(context, instance_id, state, description=None): - if not is_admin_context(context): - raise exception.NotAuthorized() - # TODO(devcamcar): Move this out of models and into driver from nova.compute import power_state if not description: @@ -739,7 +703,7 @@ def instance_set_state(context, instance_id, state, description=None): 'state_description': description}) -#@require_context +@require_context def instance_update(context, instance_id, values): session = get_session() with session.begin(): @@ -752,7 +716,7 @@ def instance_update(context, instance_id, values): ################### -#@require_context +@require_context def key_pair_create(context, values): key_pair_ref = models.KeyPair() for (key, value) in values.iteritems(): @@ -761,11 +725,9 @@ def key_pair_create(context, values): return key_pair_ref -#@require_context +@require_context def key_pair_destroy(context, user_id, name): - if is_user_context(context): - if context.user.id != user_id: - raise exception.NotAuthorized() + authorize_user_context(context, user_id) session = get_session() with session.begin(): @@ -773,11 +735,9 @@ def key_pair_destroy(context, user_id, name): key_pair_ref.delete(session=session) -#@require_context +@require_context def key_pair_destroy_all_by_user(context, user_id): - if is_user_context(context): - if context.user.id != user_id: - raise exception.NotAuthorized() + authorize_user_context(context, user_id) session = get_session() with session.begin(): @@ -786,11 +746,9 @@ def key_pair_destroy_all_by_user(context, user_id): {'id': user_id}) -#@require_context +@require_context def key_pair_get(context, user_id, name, session=None): - if is_user_context(context): - if context.user.id != user_id: - raise exception.NotAuthorized() + authorize_user_context(context, user_id) if not session: session = get_session() @@ -806,11 +764,9 @@ def key_pair_get(context, user_id, name, session=None): return result -#@require_context +@require_context def key_pair_get_all_by_user(context, user_id): - if is_user_context(context): - if context.user.id != user_id: - raise exception.NotAuthorized() + authorize_user_context(context, user_id) session = get_session() return session.query(models.KeyPair @@ -822,22 +778,16 @@ def key_pair_get_all_by_user(context, user_id): ################### -#@require_admin_context +@require_admin_context def network_count(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.Network ).filter_by(deleted=_deleted(context) ).count() -#@require_admin_context +@require_admin_context def network_count_allocated_ips(context, network_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -846,11 +796,8 @@ def network_count_allocated_ips(context, network_id): ).count() -#@require_admin_context +@require_admin_context def network_count_available_ips(context, network_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -860,11 +807,8 @@ def network_count_available_ips(context, network_id): ).count() -#@require_admin_context +@require_admin_context def network_count_reserved_ips(context, network_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.FixedIp ).filter_by(network_id=network_id @@ -873,11 +817,8 @@ def network_count_reserved_ips(context, network_id): ).count() -#@require_admin_context +@require_admin_context def network_create(context, values): - if not is_admin_context(context): - raise exception.NotAuthorized() - network_ref = models.Network() for (key, value) in values.iteritems(): network_ref[key] = value @@ -885,11 +826,8 @@ def network_create(context, values): return network_ref -#@require_admin_context +@require_admin_context def network_destroy(context, network_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -907,7 +845,7 @@ def network_destroy(context, network_id): {'id': network_id}) -#@require_context +@require_context def network_get(context, network_id, session=None): if not session: session = get_session() @@ -933,11 +871,8 @@ def network_get(context, network_id, session=None): # NOTE(vish): pylint complains because of the long method name, but # it fits with the names of the rest of the methods # pylint: disable-msg=C0103 -#@require_admin_context +@require_admin_context def network_get_associated_fixed_ips(context, network_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.FixedIp ).options(joinedload_all('instance') @@ -947,11 +882,8 @@ def network_get_associated_fixed_ips(context, network_id): ).all() -#@require_admin_context +@require_admin_context def network_get_by_bridge(context, bridge): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() result = session.query(models.Network ).filter_by(bridge=bridge @@ -964,11 +896,8 @@ def network_get_by_bridge(context, bridge): return result -#@require_admin_context +@require_admin_context def network_get_index(context, network_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): network_index = session.query(models.NetworkIndex @@ -988,22 +917,16 @@ def network_get_index(context, network_id): return network_index['index'] -#@require_admin_context +@require_admin_context def network_index_count(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() return session.query(models.NetworkIndex ).filter_by(deleted=_deleted(context) ).count() -#@require_admin_context +@require_admin_context def network_index_create_safe(context, values): - if not is_admin_context(context): - raise exception.NotAuthorized() - network_index_ref = models.NetworkIndex() for (key, value) in values.iteritems(): network_index_ref[key] = value @@ -1013,11 +936,8 @@ def network_index_create_safe(context, values): pass -#@require_admin_context +@require_admin_context def network_set_host(context, network_id, host_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): network_ref = session.query(models.Network @@ -1037,7 +957,7 @@ def network_set_host(context, network_id, host_id): return network_ref['host'] -#@require_context +@require_context def network_update(context, network_id, values): session = get_session() with session.begin(): @@ -1050,11 +970,8 @@ def network_update(context, network_id, values): ################### -#@require_context +@require_context def project_get_network(context, project_id): - if not is_admin_context(context) and not is_user_context(context): - raise error.NotAuthorized() - session = get_session() result= session.query(models.Network ).filter_by(project_id=project_id @@ -1078,22 +995,16 @@ def queue_get_for(_context, topic, physical_node_id): ################### -#@require_admin_context +@require_admin_context def export_device_count(context): - if not is_admin_context(context): - raise exception.notauthorized() - session = get_session() return session.query(models.ExportDevice ).filter_by(deleted=_deleted(context) ).count() -#@require_admin_context +@require_admin_context def export_device_create(context, values): - if not is_admin_context(context): - raise exception.notauthorized() - export_device_ref = models.ExportDevice() for (key, value) in values.iteritems(): export_device_ref[key] = value @@ -1127,11 +1038,8 @@ def auth_create_token(_context, token): ################### -#@require_admin_context +@require_admin_context def quota_get(context, project_id, session=None): - if not is_admin_context(context): - raise exception.NotAuthorized() - if not session: session = get_session() @@ -1145,11 +1053,8 @@ def quota_get(context, project_id, session=None): return result -#@require_admin_context +@require_admin_context def quota_create(context, values): - if not is_admin_context(context): - raise exception.NotAuthorized() - quota_ref = models.Quota() for (key, value) in values.iteritems(): quota_ref[key] = value @@ -1157,11 +1062,8 @@ def quota_create(context, values): return quota_ref -#@require_admin_context +@require_admin_context def quota_update(context, project_id, values): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): quota_ref = quota_get(context, project_id, session=session) @@ -1170,11 +1072,8 @@ def quota_update(context, project_id, values): quota_ref.save(session=session) -#@require_admin_context +@require_admin_context def quota_destroy(context, project_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): quota_ref = quota_get(context, project_id, session=session) @@ -1184,11 +1083,8 @@ def quota_destroy(context, project_id): ################### -#@require_admin_context +@require_admin_context def volume_allocate_shelf_and_blade(context, volume_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): export_device = session.query(models.ExportDevice @@ -1205,11 +1101,8 @@ def volume_allocate_shelf_and_blade(context, volume_id): return (export_device.shelf_id, export_device.blade_id) -#@require_admin_context +@require_admin_context def volume_attached(context, volume_id, instance_id, mountpoint): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): volume_ref = volume_get(context, volume_id, session=session) @@ -1220,7 +1113,7 @@ def volume_attached(context, volume_id, instance_id, mountpoint): volume_ref.save(session=session) -#@require_context +@require_context def volume_create(context, values): volume_ref = models.Volume() for (key, value) in values.iteritems(): @@ -1236,11 +1129,8 @@ def volume_create(context, values): return volume_ref -#@require_admin_context +@require_admin_context def volume_data_get_for_project(context, project_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() result = session.query(func.count(models.Volume.id), func.sum(models.Volume.size) @@ -1251,11 +1141,8 @@ def volume_data_get_for_project(context, project_id): return (result[0] or 0, result[1] or 0) -#@require_admin_context +@require_admin_context def volume_destroy(context, volume_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -1266,11 +1153,8 @@ def volume_destroy(context, volume_id): {'id': volume_id}) -#@require_admin_context +@require_admin_context def volume_detached(context, volume_id): - if not is_admin_context(context): - raise exception.NotAuthorized() - session = get_session() with session.begin(): volume_ref = volume_get(context, volume_id, session=session) @@ -1281,7 +1165,7 @@ def volume_detached(context, volume_id): volume_ref.save(session=session) -#@require_context +@require_context def volume_get(context, volume_id, session=None): if not session: session = get_session() @@ -1304,22 +1188,15 @@ def volume_get(context, volume_id, session=None): return result -#@require_admin_context +@require_admin_context def volume_get_all(context): - if not is_admin_context(context): - raise exception.NotAuthorized() - return session.query(models.Volume ).filter_by(deleted=_deleted(context) ).all() -#@require_context +@require_context def volume_get_all_by_project(context, project_id): - if is_user_context(context): - if context.project.id != project_id: - raise exception.NotAuthorized() - elif not is_admin_context(context): - raise exception.NotAuthorized() + authorize_project_context(context, project_id) session = get_session() return session.query(models.Volume @@ -1328,7 +1205,7 @@ def volume_get_all_by_project(context, project_id): ).all() -#@require_context +@require_context def volume_get_by_ec2_id(context, ec2_id): session = get_session() result = None @@ -1353,18 +1230,17 @@ def volume_get_by_ec2_id(context, ec2_id): return result -#@require_context +@require_context def volume_ec2_id_exists(context, ec2_id, session=None): if not session: session = get_session() - if is_admin_context(context) or is_user_context(context): - return session.query(exists( - ).where(models.Volume.id==ec2_id) - ).one()[0] + return session.query(exists( + ).where(models.Volume.id==ec2_id) + ).one()[0] -#@require_context +@require_context def volume_get_instance(context, volume_id): session = get_session() result = None @@ -1390,7 +1266,7 @@ def volume_get_instance(context, volume_id): return result.instance -#@require_context +@require_context def volume_get_shelf_and_blade(context, volume_id): session = get_session() result = None @@ -1412,7 +1288,7 @@ def volume_get_shelf_and_blade(context, volume_id): return (result.shelf_id, result.blade_id) -#@require_context +@require_context def volume_update(context, volume_id, values): session = get_session() with session.begin(): -- cgit From cf456bdb2a767644d95599aa1c8f580279959a4e Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Thu, 30 Sep 2010 02:47:05 -0700 Subject: Refactored APIRequestContext --- nova/api/context.py | 46 +++++++++++++++++++++++++++ nova/api/ec2/__init__.py | 8 ++--- nova/api/ec2/context.py | 33 -------------------- nova/db/sqlalchemy/api.py | 71 +++++++++++++++++++----------------------- nova/network/manager.py | 2 -- nova/tests/compute_unittest.py | 8 ++--- 6 files changed, 86 insertions(+), 82 deletions(-) create mode 100644 nova/api/context.py delete mode 100644 nova/api/ec2/context.py diff --git a/nova/api/context.py b/nova/api/context.py new file mode 100644 index 000000000..b66cfe468 --- /dev/null +++ b/nova/api/context.py @@ -0,0 +1,46 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +APIRequestContext +""" + +import random + + +class APIRequestContext(object): + def __init__(self, user, project): + self.user = user + self.project = project + self.request_id = ''.join( + [random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-') + for x in xrange(20)] + ) + if user: + self.is_admin = user.is_admin() + else: + self.is_admin = False + self.read_deleted = False + + +def get_admin_context(user=None, read_deleted=False): + context_ref = APIRequestContext(user=user, project=None) + context_ref.is_admin = True + context_ref.read_deleted = read_deleted + return context_ref + diff --git a/nova/api/ec2/__init__.py b/nova/api/ec2/__init__.py index 7a958f841..6b538a7f1 100644 --- a/nova/api/ec2/__init__.py +++ b/nova/api/ec2/__init__.py @@ -27,8 +27,8 @@ import webob.exc from nova import exception from nova import flags from nova import wsgi +from nova.api import context from nova.api.ec2 import apirequest -from nova.api.ec2 import context from nova.api.ec2 import admin from nova.api.ec2 import cloud from nova.auth import manager @@ -193,15 +193,15 @@ class Authorizer(wsgi.Middleware): return True if 'none' in roles: return False - return any(context.project.has_role(context.user.id, role) + return any(context.project.has_role(context.user.id, role) for role in roles) - + class Executor(wsgi.Application): """Execute an EC2 API request. - Executes 'ec2.action' upon 'ec2.controller', passing 'ec2.context' and + Executes 'ec2.action' upon 'ec2.controller', passing 'ec2.context' and 'ec2.action_args' (all variables in WSGI environ.) Returns an XML response, or a 400 upon failure. """ diff --git a/nova/api/ec2/context.py b/nova/api/ec2/context.py deleted file mode 100644 index c53ba98d9..000000000 --- a/nova/api/ec2/context.py +++ /dev/null @@ -1,33 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 United States Government as represented by the -# Administrator of the National Aeronautics and Space Administration. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -""" -APIRequestContext -""" - -import random - - -class APIRequestContext(object): - def __init__(self, user, project): - self.user = user - self.project = project - self.request_id = ''.join( - [random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-') - for x in xrange(20)] - ) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 0e7d2e664..fc5ee2235 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -21,6 +21,7 @@ Implementation of SQLAlchemy backend import logging import sys +import warnings from nova import db from nova import exception @@ -36,28 +37,13 @@ from sqlalchemy.sql import exists, func FLAGS = flags.FLAGS -# NOTE(vish): disabling docstring pylint because the docstrings are -# in the interface definition -# pylint: disable-msg=C0111 -def _deleted(context): - """Calculates whether to include deleted objects based on context. - Currently just looks for a flag called deleted in the context dict. - """ - if is_user_context(context): - return False - if not hasattr(context, 'get'): - return False - return context.get('deleted', False) - - def is_admin_context(context): """Indicates if the request context is an administrator.""" if not context: - logging.warning('Use of empty request context is deprecated') - return True - if not context.user: + warnings.warn('Use of empty request context is deprecated', + DeprecationWarning) return True - return context.user.is_admin() + return context.is_admin def is_user_context(context): @@ -92,6 +78,13 @@ def authorize_user_context(context, user_id): raise exception.NotAuthorized() +def use_deleted(context): + """Indicates if the context has access to deleted objects.""" + if not context: + return False + return context.read_deleted + + def require_admin_context(f): """Decorator used to indicate that the method requires an administrator context. @@ -131,7 +124,7 @@ def service_get(context, service_id, session=None): result = session.query(models.Service ).filter_by(id=service_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() if not result: @@ -229,7 +222,7 @@ def service_get_by_args(context, host, binary): result = session.query(models.Service ).filter_by(host=host ).filter_by(binary=binary - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() if not result: @@ -398,7 +391,7 @@ def floating_ip_get_by_address(context, address, session=None): result = session.query(models.FloatingIp ).filter_by(address=address - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() if not result: raise exception.NotFound('No fixed ip for address %s' % address) @@ -487,7 +480,7 @@ def fixed_ip_get_by_address(context, address, session=None): result = session.query(models.FixedIp ).filter_by(address=address - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).options(joinedload('network') ).options(joinedload('instance') ).first() @@ -569,7 +562,7 @@ def instance_get(context, instance_id, session=None): if is_admin_context(context): result = session.query(models.Instance ).filter_by(id=instance_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Instance @@ -588,7 +581,7 @@ def instance_get_all(context): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).all() @@ -597,7 +590,7 @@ def instance_get_all_by_user(context, user_id): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).filter_by(user_id=user_id ).all() @@ -610,7 +603,7 @@ def instance_get_all_by_project(context, project_id): return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).filter_by(project_id=project_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).all() @@ -622,7 +615,7 @@ def instance_get_all_by_reservation(context, reservation_id): return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).filter_by(reservation_id=reservation_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).all() elif is_user_context(context): return session.query(models.Instance @@ -640,7 +633,7 @@ def instance_get_by_ec2_id(context, ec2_id): if is_admin_context(context): result = session.query(models.Instance ).filter_by(ec2_id=ec2_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Instance @@ -756,7 +749,7 @@ def key_pair_get(context, user_id, name, session=None): result = session.query(models.KeyPair ).filter_by(user_id=user_id ).filter_by(name=name - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() if not result: raise exception.NotFound('no keypair for user %s, name %s' % @@ -782,7 +775,7 @@ def key_pair_get_all_by_user(context, user_id): def network_count(context): session = get_session() return session.query(models.Network - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).count() @@ -854,7 +847,7 @@ def network_get(context, network_id, session=None): if is_admin_context(context): result = session.query(models.Network ).filter_by(id=network_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Network @@ -921,7 +914,7 @@ def network_get_index(context, network_id): def network_index_count(context): session = get_session() return session.query(models.NetworkIndex - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).count() @@ -999,7 +992,7 @@ def queue_get_for(_context, topic, physical_node_id): def export_device_count(context): session = get_session() return session.query(models.ExportDevice - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).count() @@ -1045,7 +1038,7 @@ def quota_get(context, project_id, session=None): result = session.query(models.Quota ).filter_by(project_id=project_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() if not result: raise exception.NotFound('No quota for project_id %s' % project_id) @@ -1174,7 +1167,7 @@ def volume_get(context, volume_id, session=None): if is_admin_context(context): result = session.query(models.Volume ).filter_by(id=volume_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Volume @@ -1191,7 +1184,7 @@ def volume_get(context, volume_id, session=None): @require_admin_context def volume_get_all(context): return session.query(models.Volume - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).all() @require_context @@ -1201,7 +1194,7 @@ def volume_get_all_by_project(context, project_id): session = get_session() return session.query(models.Volume ).filter_by(project_id=project_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).all() @@ -1213,7 +1206,7 @@ def volume_get_by_ec2_id(context, ec2_id): if is_admin_context(context): result = session.query(models.Volume ).filter_by(ec2_id=ec2_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Volume @@ -1248,7 +1241,7 @@ def volume_get_instance(context, volume_id): if is_admin_context(context): result = session.query(models.Volume ).filter_by(id=volume_id - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=use_deleted(context) ).options(joinedload('instance') ).first() elif is_user_context(context): diff --git a/nova/network/manager.py b/nova/network/manager.py index ecf2fa2c2..265c0d742 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -88,7 +88,6 @@ class NetworkManager(manager.Manager): # TODO(vish): can we minimize db access by just getting the # id here instead of the ref? network_id = network_ref['id'] - # TODO(devcamcar): Replace with system context host = self.db.network_set_host(None, network_id, self.host) @@ -233,7 +232,6 @@ class VlanManager(NetworkManager): address = network_ref['vpn_private_address'] self.db.fixed_ip_associate(context, address, instance_id) else: - # TODO(devcamcar) Pass system context here. address = self.db.fixed_ip_associate_pool(None, network_ref['id'], instance_id) diff --git a/nova/tests/compute_unittest.py b/nova/tests/compute_unittest.py index e705c2552..1e2bb113b 100644 --- a/nova/tests/compute_unittest.py +++ b/nova/tests/compute_unittest.py @@ -30,7 +30,7 @@ from nova import flags from nova import test from nova import utils from nova.auth import manager - +from nova.api import context FLAGS = flags.FLAGS @@ -96,9 +96,9 @@ class ComputeTestCase(test.TrialTestCase): self.assertEqual(instance_ref['deleted_at'], None) terminate = datetime.datetime.utcnow() yield self.compute.terminate_instance(self.context, instance_id) - # TODO(devcamcar): Pass deleted in using system context. - # context.read_deleted ? - instance_ref = db.instance_get({'deleted': True}, instance_id) + self.context = context.get_admin_context(user=self.user, + read_deleted=True) + instance_ref = db.instance_get(self.context, instance_id) self.assert_(instance_ref['launched_at'] < terminate) self.assert_(instance_ref['deleted_at'] > terminate) -- cgit From ab948224a5c6ea976def30927ac7668dd765dbca Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Thu, 30 Sep 2010 03:13:47 -0700 Subject: Cleaned up db/api.py --- nova/db/api.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/nova/db/api.py b/nova/db/api.py index 4cfdd788c..290b460a6 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -175,10 +175,6 @@ def floating_ip_get_by_address(context, address): return IMPL.floating_ip_get_by_address(context, address) - """Get an instance for a floating ip by address.""" - return IMPL.floating_ip_get_instance(context, address) - - #################### -- cgit From b40696640b13e0974a29c23240f7faa79ad00912 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Fri, 1 Oct 2010 00:42:09 +0200 Subject: Add a DB backend for auth manager. --- nova/auth/dbdriver.py | 236 +++++++++++++++++++++++++++++++++++++++++++ nova/auth/manager.py | 2 +- nova/db/api.py | 113 +++++++++++++++++++++ nova/db/sqlalchemy/api.py | 199 ++++++++++++++++++++++++++++++++++++ nova/db/sqlalchemy/models.py | 73 ++++++++++++- nova/tests/auth_unittest.py | 9 +- nova/tests/fake_flags.py | 2 +- 7 files changed, 629 insertions(+), 5 deletions(-) create mode 100644 nova/auth/dbdriver.py diff --git a/nova/auth/dbdriver.py b/nova/auth/dbdriver.py new file mode 100644 index 000000000..09d15018b --- /dev/null +++ b/nova/auth/dbdriver.py @@ -0,0 +1,236 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Auth driver using the DB as its backend. +""" + +import logging +import sys + +from nova import exception +from nova import db + + +class DbDriver(object): + """DB Auth driver + + Defines enter and exit and therefore supports the with/as syntax. + """ + + def __init__(self): + """Imports the LDAP module""" + pass + db + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + + def get_user(self, uid): + """Retrieve user by id""" + return self._db_user_to_auth_user(db.user_get({}, uid)) + + def get_user_from_access_key(self, access): + """Retrieve user by access key""" + return self._db_user_to_auth_user(db.user_get_by_access_key({}, access)) + + def get_project(self, pid): + """Retrieve project by id""" + return self._db_project_to_auth_projectuser(db.project_get({}, pid)) + + def get_users(self): + """Retrieve list of users""" + return [self._db_user_to_auth_user(user) for user in db.user_get_all({})] + + def get_projects(self, uid=None): + """Retrieve list of projects""" + if uid: + result = db.project_get_by_user({}, uid) + else: + result = db.project_get_all({}) + return [self._db_project_to_auth_projectuser(proj) for proj in result] + + def create_user(self, name, access_key, secret_key, is_admin): + """Create a user""" + values = { 'id' : name, + 'access_key' : access_key, + 'secret_key' : secret_key, + 'is_admin' : is_admin + } + try: + user_ref = db.user_create({}, values) + return self._db_user_to_auth_user(user_ref) + except exception.Duplicate, e: + raise exception.Duplicate('User %s already exists' % name) + + def _db_user_to_auth_user(self, user_ref): + return { 'id' : user_ref['id'], + 'name' : user_ref['id'], + 'access' : user_ref['access_key'], + 'secret' : user_ref['secret_key'], + 'admin' : user_ref['is_admin'] } + + def _db_project_to_auth_projectuser(self, project_ref): + return { 'id' : project_ref['id'], + 'name' : project_ref['name'], + 'project_manager_id' : project_ref['project_manager'], + 'description' : project_ref['description'], + 'member_ids' : [member['id'] for member in project_ref['members']] } + + def create_project(self, name, manager_uid, + description=None, member_uids=None): + """Create a project""" + manager = db.user_get({}, manager_uid) + if not manager: + raise exception.NotFound("Project can't be created because " + "manager %s doesn't exist" % manager_uid) + + # description is a required attribute + if description is None: + description = name + + # First, we ensure that all the given users exist before we go + # on to create the project. This way we won't have to destroy + # the project again because a user turns out to be invalid. + members = set([manager]) + if member_uids != None: + for member_uid in member_uids: + member = db.user_get({}, member_uid) + if not member: + raise exception.NotFound("Project can't be created " + "because user %s doesn't exist" + % member_uid) + members.add(member) + + values = { 'id' : name, + 'name' : name, + 'project_manager' : manager['id'], + 'description': description } + + try: + project = db.project_create({}, values) + except exception.Duplicate: + raise exception.Duplicate("Project can't be created because " + "project %s already exists" % name) + + for member in members: + db.project_add_member({}, project['id'], member['id']) + + # This looks silly, but ensures that the members element has been + # correctly populated + project_ref = db.project_get({}, project['id']) + return self._db_project_to_auth_projectuser(project_ref) + + def modify_project(self, project_id, manager_uid=None, description=None): + """Modify an existing project""" + if not manager_uid and not description: + return + values = {} + if manager_uid: + manager = db.user_get({}, manager_uid) + if not manager: + raise exception.NotFound("Project can't be modified because " + "manager %s doesn't exist" % + manager_uid) + values['project_manager'] = manager['id'] + if description: + values['description'] = description + + db.project_update({}, project_id, values) + + def add_to_project(self, uid, project_id): + """Add user to project""" + user, project = self._validate_user_and_project(uid, project_id) + db.project_add_member({}, project['id'], user['id']) + + def remove_from_project(self, uid, project_id): + """Remove user from project""" + user, project = self._validate_user_and_project(uid, project_id) + db.project_remove_member({}, project['id'], user['id']) + + def is_in_project(self, uid, project_id): + """Check if user is in project""" + user, project = self._validate_user_and_project(uid, project_id) + return user in project.members + + def has_role(self, uid, role, project_id=None): + """Check if user has role + + If project is specified, it checks for local role, otherwise it + checks for global role + """ + + return role in self.get_user_roles(uid, project_id) + + def add_role(self, uid, role, project_id=None): + """Add role for user (or user and project)""" + if not project_id: + db.user_add_role({}, uid, role) + return + db.user_add_project_role({}, uid, project_id, role) + + def remove_role(self, uid, role, project_id=None): + """Remove role for user (or user and project)""" + if not project_id: + db.user_remove_role({}, uid, role) + return + db.user_remove_project_role({}, uid, project_id, role) + + def get_user_roles(self, uid, project_id=None): + """Retrieve list of roles for user (or user and project)""" + if project_id is None: + roles = db.user_get_roles({}, uid) + return roles + else: + roles = db.user_get_roles_for_project({}, uid, project_id) + return roles + + def delete_user(self, id): + """Delete a user""" + user = db.user_get({}, id) + db.user_delete({}, user['id']) + + def delete_project(self, project_id): + """Delete a project""" + db.project_delete({}, project_id) + + def modify_user(self, uid, access_key=None, secret_key=None, admin=None): + """Modify an existing user""" + if not access_key and not secret_key and admin is None: + return + values = {} + if access_key: + values['access_key'] = access_key + if secret_key: + values['secret_key'] = secret_key + if admin is not None: + values['is_admin'] = admin + db.user_update({}, uid, values) + + def _validate_user_and_project(self, user_id, project_id): + user = db.user_get({}, user_id) + if not user: + raise exception.NotFound('User "%s" not found' % user_id) + project = db.project_get({}, project_id) + if not project: + raise exception.NotFound('Project "%s" not found' % project_id) + return user, project + diff --git a/nova/auth/manager.py b/nova/auth/manager.py index 0bc12c80f..ce8a294df 100644 --- a/nova/auth/manager.py +++ b/nova/auth/manager.py @@ -69,7 +69,7 @@ flags.DEFINE_string('credential_cert_subject', '/C=US/ST=California/L=MountainView/O=AnsoLabs/' 'OU=NovaDev/CN=%s-%s', 'Subject for certificate for users') -flags.DEFINE_string('auth_driver', 'nova.auth.ldapdriver.FakeLdapDriver', +flags.DEFINE_string('auth_driver', 'nova.auth.dbdriver.DbDriver', 'Driver that auth manager uses') diff --git a/nova/db/api.py b/nova/db/api.py index b68a0fe8f..703936002 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -565,3 +565,116 @@ def volume_update(context, volume_id, values): """ return IMPL.volume_update(context, volume_id, values) + + +################### + + +def user_get(context, id): + """Get user by id""" + return IMPL.user_get(context, id) + + +def user_get_by_uid(context, uid): + """Get user by uid""" + return IMPL.user_get_by_uid(context, uid) + + +def user_get_by_access_key(context, access_key): + """Get user by access key""" + return IMPL.user_get_by_access_key(context, access_key) + + +def user_create(context, values): + """Create a new user""" + return IMPL.user_create(context, values) + + +def user_delete(context, id): + """Delete a user""" + return IMPL.user_delete(context, id) + + +def user_get_all(context): + """Create a new user""" + return IMPL.user_get_all(context) + + +def user_add_role(context, user_id, role): + """Add another global role for user""" + return IMPL.user_add_role(context, user_id, role) + + +def user_remove_role(context, user_id, role): + """Remove global role from user""" + return IMPL.user_remove_role(context, user_id, role) + + +def user_get_roles(context, user_id): + """Get global roles for user""" + return IMPL.user_get_roles(context, user_id) + + +def user_add_project_role(context, user_id, project_id, role): + """Add project role for user""" + return IMPL.user_add_project_role(context, user_id, project_id, role) + + +def user_remove_project_role(context, user_id, project_id, role): + """Remove project role from user""" + return IMPL.user_remove_project_role(context, user_id, project_id, role) + + +def user_get_roles_for_project(context, user_id, project_id): + """Return list of roles a user holds on project""" + return IMPL.user_get_roles_for_project(context, user_id, project_id) + + +def user_update(context, user_id, values): + """Update user""" + return IMPL.user_update(context, user_id, values) + + +def project_get(context, id): + """Get project by id""" + return IMPL.project_get(context, id) + + +#def project_get_by_uid(context, uid): +# """Get project by uid""" +# return IMPL.project_get_by_uid(context, uid) +# + +def project_create(context, values): + """Create a new project""" + return IMPL.project_create(context, values) + + +def project_add_member(context, project_id, user_id): + """Add user to project""" + return IMPL.project_add_member(context, project_id, user_id) + + +def project_get_all(context): + """Get all projects""" + return IMPL.project_get_all(context) + + +def project_get_by_user(context, user_id): + """Get all projects of which the given user is a member""" + return IMPL.project_get_by_user(context, user_id) + + +def project_remove_member(context, project_id, user_id): + """Remove the given user from the given project""" + return IMPL.project_remove_member(context, project_id, user_id) + + +def project_update(context, project_id, values): + """Update Remove the given user from the given project""" + return IMPL.project_update(context, project_id, values) + + +def project_delete(context, project_id): + """Delete project""" + return IMPL.project_delete(context, project_id) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 9c3caf9af..bd5c285d8 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -925,3 +925,202 @@ def volume_update(_context, volume_id, values): for (key, value) in values.iteritems(): volume_ref[key] = value volume_ref.save(session=session) + + +################### + + +def user_get(context, id): + return models.User.find(id, deleted=_deleted(context)) + + +def user_get_by_access_key(context, access_key): + session = get_session() + return session.query(models.User + ).filter_by(access_key=access_key + ).filter_by(deleted=_deleted(context) + ).first() + + +def user_create(_context, values): + user_ref = models.User() + for (key, value) in values.iteritems(): + user_ref[key] = value + user_ref.save() + return user_ref + + +def user_delete(context, id): + session = get_session() + with session.begin(): + session.execute('delete from user_project_association where user_id=:id', + {'id': id}) + session.execute('delete from user_role_association where user_id=:id', + {'id': id}) + session.execute('delete from user_project_role_association where user_id=:id', + {'id': id}) + user_ref = models.User.find(id, session=session) + session.delete(user_ref) + + +def user_get_all(context): + session = get_session() + return session.query(models.User + ).filter_by(deleted=_deleted(context) + ).all() + + +def project_create(_context, values): + project_ref = models.Project() + for (key, value) in values.iteritems(): + project_ref[key] = value + project_ref.save() + return project_ref + + +def project_add_member(context, project_id, user_id): + session = get_session() + with session.begin(): + project_ref = models.Project.find(project_id, session=session) + user_ref = models.User.find(user_id, session=session) + + project_ref.members += [user_ref] + project_ref.save(session=session) + + +def project_get(context, id): + session = get_session() + result = session.query(models.Project + ).filter_by(deleted=False + ).filter_by(id=id + ).options(joinedload_all('members') + ).first() + if not result: + raise exception.NotFound("No project with id %s" % id) + return result + + +def project_get_by_uid(context, uid): + session = get_session() + return session.query(models.Project + ).filter_by(uid=uid + ).filter_by(deleted=_deleted(context) + ).first() + + +def project_get_all(context): + session = get_session() + return session.query(models.Project + ).filter_by(deleted=_deleted(context) + ).options(joinedload_all('members') + ).all() + + +def project_get_by_user(context, user_id): + session = get_session() + user = session.query(models.User + ).filter_by(deleted=_deleted(context) + ).options(joinedload_all('projects') + ).first() + return user.projects + + +def project_remove_member(context, project_id, user_id): + session = get_session() + project = models.Project.find(project_id, session=session) + user = models.User.find(user_id, session=session) + if not project: + raise exception.NotFound('Project id "%s" not found' % (project_id,)) + + if not user: + raise exception.NotFound('User id "%s" not found' % (user_id,)) + + if user in project.members: + project.members.remove(user) + project.save(session=session) + + +def user_update(_context, user_id, values): + session = get_session() + with session.begin(): + user_ref = models.User.find(user_id, session=session) + for (key, value) in values.iteritems(): + user_ref[key] = value + user_ref.save(session=session) + + +def project_update(_context, project_id, values): + session = get_session() + with session.begin(): + project_ref = models.Project.find(project_id, session=session) + for (key, value) in values.iteritems(): + project_ref[key] = value + project_ref.save(session=session) + + +def project_delete(context, id): + session = get_session() + with session.begin(): + session.execute('delete from user_project_association where project_id=:id', + {'id': id}) + session.execute('delete from user_project_role_association where project_id=:id', + {'id': id}) + project_ref = models.Project.find(id, session=session) + session.delete(project_ref) + + +def user_get_roles(context, user_id): + session = get_session() + with session.begin(): + user_ref = models.User.find(user_id, session=session) + return [role.role for role in user_ref['roles']] + + +def user_get_roles_for_project(context, user_id, project_id): + session = get_session() + with session.begin(): + res = session.query(models.UserProjectRoleAssociation + ).filter_by(user_id=user_id + ).filter_by(project_id=project_id + ).all() + return [association.role for association in res] + +def user_remove_project_role(context, user_id, project_id, role): + session = get_session() + with session.begin(): + session.execute('delete from user_project_role_association where ' + \ + 'user_id=:user_id and project_id=:project_id and ' + \ + 'role=:role', { 'user_id' : user_id, + 'project_id' : project_id, + 'role' : role }) + + +def user_remove_role(context, user_id, role): + session = get_session() + with session.begin(): + res = session.query(models.UserRoleAssociation + ).filter_by(user_id=user_id + ).filter_by(role=role + ).all() + for role in res: + session.delete(role) + + +def user_add_role(context, user_id, role): + session = get_session() + with session.begin(): + user_ref = models.User.find(user_id, session=session) + models.UserRoleAssociation(user=user_ref, role=role).save(session=session) + + +def user_add_project_role(context, user_id, project_id, role): + session = get_session() + with session.begin(): + user_ref = models.User.find(user_id, session=session) + project_ref = models.Project.find(project_id, session=session) + models.UserProjectRoleAssociation(user_id=user_ref['id'], + project_id=project_ref['id'], + role=role).save(session=session) + + +################### diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 01e58b05e..b247eb416 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -27,7 +27,9 @@ import datetime from sqlalchemy.orm import relationship, backref, exc, object_mapper from sqlalchemy import Column, Integer, String from sqlalchemy import ForeignKey, DateTime, Boolean, Text +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.schema import ForeignKeyConstraint from nova.db.sqlalchemy.session import get_session @@ -98,7 +100,13 @@ class NovaBase(object): if not session: session = get_session() session.add(self) - session.flush() + try: + session.flush() + except IntegrityError, e: + if str(e).endswith('is not unique'): + raise Exception.Duplicate(str(e)) + else: + raise def delete(self, session=None): """Delete this object""" @@ -456,6 +464,67 @@ class FixedIp(BASE, NovaBase): raise new_exc.__class__, new_exc, sys.exc_info()[2] +class User(BASE, NovaBase): + """Represents a user""" + __tablename__ = 'users' + id = Column(String(255), primary_key=True) + + name = Column(String(255)) + access_key = Column(String(255)) + secret_key = Column(String(255)) + + is_admin = Column(Boolean) + + +class Project(BASE, NovaBase): + """Represents a project""" + __tablename__ = 'projects' + id = Column(String(255), primary_key=True) + name = Column(String(255)) + description = Column(String(255)) + + project_manager = Column(String(255), ForeignKey(User.id)) + + members = relationship(User, + secondary='user_project_association', + backref='projects') + + +class UserProjectRoleAssociation(BASE, NovaBase): + __tablename__ = 'user_project_role_association' + user_id = Column(String(255), primary_key=True) + user = relationship(User, + primaryjoin=user_id==User.id, + foreign_keys=[User.id], + uselist=False) + + project_id = Column(String(255), primary_key=True) + project = relationship(Project, + primaryjoin=project_id==Project.id, + foreign_keys=[Project.id], + uselist=False) + + role = Column(String(255), primary_key=True) + ForeignKeyConstraint(['user_id', + 'project_id'], + ['user_project_association.user_id', + 'user_project_association.project_id']) + + +class UserRoleAssociation(BASE, NovaBase): + __tablename__ = 'user_role_association' + user_id = Column(String(255), ForeignKey('users.id'), primary_key=True) + user = relationship(User, backref='roles') + role = Column(String(255), primary_key=True) + + +class UserProjectAssociation(BASE, NovaBase): + __tablename__ = 'user_project_association' + user_id = Column(String(255), ForeignKey(User.id), primary_key=True) + project_id = Column(String(255), ForeignKey(Project.id), primary_key=True) + + + class FloatingIp(BASE, NovaBase): """Represents a floating ip that dynamically forwards to a fixed ip""" __tablename__ = 'floating_ips' @@ -486,7 +555,7 @@ def register_models(): from sqlalchemy import create_engine models = (Service, Instance, Volume, ExportDevice, FixedIp, FloatingIp, Network, NetworkIndex, - AuthToken) # , Image, Host) + AuthToken, UserProjectAssociation, User, Project) # , Image, Host) engine = create_engine(FLAGS.sql_connection, echo=False) for model in models: model.metadata.create_all(engine) diff --git a/nova/tests/auth_unittest.py b/nova/tests/auth_unittest.py index 1955bb417..99f7ab599 100644 --- a/nova/tests/auth_unittest.py +++ b/nova/tests/auth_unittest.py @@ -75,8 +75,9 @@ class user_and_project_generator(object): self.manager.delete_user(self.user) self.manager.delete_project(self.project) -class AuthManagerTestCase(test.TrialTestCase): +class AuthManagerTestCase(object): def setUp(self): + FLAGS.auth_driver = self.auth_driver super(AuthManagerTestCase, self).setUp() self.flags(connection_type='fake') self.manager = manager.AuthManager() @@ -320,6 +321,12 @@ class AuthManagerTestCase(test.TrialTestCase): self.assertEqual('secret', user.secret) self.assertTrue(user.is_admin()) +class AuthManagerLdapTestCase(AuthManagerTestCase, test.TrialTestCase): + auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver' + +class AuthManagerDbTestCase(AuthManagerTestCase, test.TrialTestCase): + auth_driver = 'nova.auth.dbdriver.DbDriver' + if __name__ == "__main__": # TODO: Implement use_fake as an option diff --git a/nova/tests/fake_flags.py b/nova/tests/fake_flags.py index 8f4754650..4bbef8832 100644 --- a/nova/tests/fake_flags.py +++ b/nova/tests/fake_flags.py @@ -24,7 +24,7 @@ flags.DECLARE('volume_driver', 'nova.volume.manager') FLAGS.volume_driver = 'nova.volume.driver.FakeAOEDriver' FLAGS.connection_type = 'fake' FLAGS.fake_rabbit = True -FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver' +FLAGS.auth_driver = 'nova.auth.dbdriver.DbDriver' flags.DECLARE('network_size', 'nova.network.manager') flags.DECLARE('num_networks', 'nova.network.manager') flags.DECLARE('fake_network', 'nova.network.manager') -- cgit From 58ae192764b11b19f5676f9496f287a4ea2a71bd Mon Sep 17 00:00:00 2001 From: Cerberus Date: Thu, 30 Sep 2010 20:07:26 -0500 Subject: refactoring --- nova/api/cloud.py | 2 +- nova/api/rackspace/servers.py | 25 ++++++------------------- nova/tests/api/rackspace/servers.py | 2 +- 3 files changed, 8 insertions(+), 21 deletions(-) diff --git a/nova/api/cloud.py b/nova/api/cloud.py index 345677d4f..57e94a17a 100644 --- a/nova/api/cloud.py +++ b/nova/api/cloud.py @@ -34,7 +34,7 @@ def reboot(instance_id, context=None): #TODO(gundlach) not actually sure what context is used for by ec2 here -- I think we can just remove it and use None all the time. """ - instance_ref = db.instance_get_by_ec2_id(None, instance_id) + instance_ref = db.instance_get_by_internal_id(None, instance_id) host = instance_ref['host'] rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host), {"method": "reboot_instance", diff --git a/nova/api/rackspace/servers.py b/nova/api/rackspace/servers.py index 11efd8aef..39e784be2 100644 --- a/nova/api/rackspace/servers.py +++ b/nova/api/rackspace/servers.py @@ -35,9 +35,6 @@ import nova.image.service FLAGS = flags.FLAGS -flags.DEFINE_string('rs_network_manager', 'nova.network.manager.FlatManager', - 'Networking for rackspace') - def _instance_id_translator(): """ Helper method for initializing an id translator for Rackspace instance ids """ @@ -131,11 +128,8 @@ class Controller(wsgi.Controller): def show(self, req, id): """ Returns server details by server id """ - inst_id_trans = _instance_id_translator() - inst_id = inst_id_trans.from_rs_id(id) - user_id = req.environ['nova.context']['user']['id'] - inst = self.db_driver.instance_get_by_ec2_id(None, inst_id) + inst = self.db_driver.instance_get_by_instance_id(None, id) if inst: if inst.user_id == user_id: return _entity_detail(inst) @@ -143,11 +137,8 @@ class Controller(wsgi.Controller): def delete(self, req, id): """ Destroys a server """ - inst_id_trans = _instance_id_translator() - inst_id = inst_id_trans.from_rs_id(id) - user_id = req.environ['nova.context']['user']['id'] - instance = self.db_driver.instance_get_by_ec2_id(None, inst_id) + instance = self.db_driver.instance_get_by_internal_id(None, id) if instance and instance['user_id'] == user_id: self.db_driver.instance_destroy(None, id) return faults.Fault(exc.HTTPAccepted()) @@ -173,8 +164,6 @@ class Controller(wsgi.Controller): def update(self, req, id): """ Updates the server name or password """ - inst_id_trans = _instance_id_translator() - inst_id = inst_id_trans.from_rs_id(id) user_id = req.environ['nova.context']['user']['id'] inst_dict = self._deserialize(req.body, req) @@ -182,7 +171,7 @@ class Controller(wsgi.Controller): if not inst_dict: return faults.Fault(exc.HTTPUnprocessableEntity()) - instance = self.db_driver.instance_get_by_ec2_id(None, inst_id) + instance = self.db_driver.instance_get_by_internal_id(None, id) if not instance or instance.user_id != user_id: return faults.Fault(exc.HTTPNotFound()) @@ -206,8 +195,6 @@ class Controller(wsgi.Controller): ltime = time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime()) inst = {} - inst_id_trans = _instance_id_translator() - user_id = req.environ['nova.context']['user']['id'] flavor_id = env['server']['flavorId'] @@ -258,7 +245,7 @@ class Controller(wsgi.Controller): inst['local_gb'] = flavor['local_gb'] ref = self.db_driver.instance_create(None, inst) - inst['id'] = inst_id_trans.to_rs_id(ref.ec2_id) + inst['id'] = ref.internal_id # TODO(dietz): this isn't explicitly necessary, but the networking # calls depend on an object with a project_id property, and therefore @@ -270,10 +257,10 @@ class Controller(wsgi.Controller): #TODO(dietz) is this necessary? inst['launch_index'] = 0 - inst['hostname'] = ref.ec2_id + inst['hostname'] = ref.internal_id self.db_driver.instance_update(None, inst['id'], inst) - network_manager = utils.import_object(FLAGS.rs_network_manager) + network_manager = utils.import_object(FLAGS.network_manager) address = network_manager.allocate_fixed_ip(api_context, inst['id']) diff --git a/nova/tests/api/rackspace/servers.py b/nova/tests/api/rackspace/servers.py index 69ad2c1d3..ee60cfcbc 100644 --- a/nova/tests/api/rackspace/servers.py +++ b/nova/tests/api/rackspace/servers.py @@ -57,7 +57,7 @@ class ServersTest(unittest.TestCase): test_helper.stub_out_key_pair_funcs(self.stubs) test_helper.stub_out_image_service(self.stubs) self.stubs.Set(nova.db.api, 'instance_get_all', return_servers) - self.stubs.Set(nova.db.api, 'instance_get_by_ec2_id', return_server) + self.stubs.Set(nova.db.api, 'instance_get_by_internal_id', return_server) self.stubs.Set(nova.db.api, 'instance_get_all_by_user', return_servers) -- cgit From c58acf2c59420a78f6b7195e3c1ef25e84f12e20 Mon Sep 17 00:00:00 2001 From: Michael Gundlach Date: Thu, 30 Sep 2010 21:19:35 -0400 Subject: Replace database instance 'ec2_id' with 'internal_id' throughout the nova.db package. internal_id is now an integer -- we need to figure out how to make this a bigint or something. --- nova/db/api.py | 4 ++-- nova/db/sqlalchemy/api.py | 20 +++++++++++--------- nova/db/sqlalchemy/models.py | 4 ++-- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/nova/db/api.py b/nova/db/api.py index b68a0fe8f..c3c29c2fc 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -280,9 +280,9 @@ def instance_get_floating_address(context, instance_id): return IMPL.instance_get_floating_address(context, instance_id) -def instance_get_by_ec2_id(context, ec2_id): +def instance_get_by_internal_id(context, internal_id): """Get an instance by ec2 id.""" - return IMPL.instance_get_by_ec2_id(context, ec2_id) + return IMPL.instance_get_by_internal_id(context, internal_id) def instance_is_vpn(context, instance_id): diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 9c3caf9af..6dd6b545a 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -384,10 +384,11 @@ def instance_create(_context, values): session = get_session() with session.begin(): - while instance_ref.ec2_id == None: - ec2_id = utils.generate_uid(instance_ref.__prefix__) - if not instance_ec2_id_exists(_context, ec2_id, session=session): - instance_ref.ec2_id = ec2_id + while instance_ref.internal_id == None: + internal_id = utils.generate_uid(instance_ref.__prefix__) + if not instance_internal_id_exists(_context, internal_id, + session=session): + instance_ref.internal_id = internal_id instance_ref.save(session=session) return instance_ref @@ -446,22 +447,23 @@ def instance_get_all_by_reservation(_context, reservation_id): ).all() -def instance_get_by_ec2_id(context, ec2_id): +def instance_get_by_internal_id(context, internal_id): session = get_session() instance_ref = session.query(models.Instance - ).filter_by(ec2_id=ec2_id + ).filter_by(internal_id=internal_id ).filter_by(deleted=_deleted(context) ).first() if not instance_ref: - raise exception.NotFound('Instance %s not found' % (ec2_id)) + raise exception.NotFound('Instance %s not found' % (internal_id)) return instance_ref -def instance_ec2_id_exists(context, ec2_id, session=None): +def instance_internal_id_exists(context, internal_id, session=None): if not session: session = get_session() - return session.query(exists().where(models.Instance.id==ec2_id)).one()[0] + return session.query(exists().where(models.Instance.id==internal_id) + ).one()[0] def instance_get_fixed_address(_context, instance_id): diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 6cb377476..5c93c92a0 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -197,7 +197,7 @@ class Instance(BASE, NovaBase): __tablename__ = 'instances' __prefix__ = 'i' id = Column(Integer, primary_key=True) - ec2_id = Column(String(10), unique=True) + internal_id = Column(Integer, unique=True) admin_pass = Column(String(255)) @@ -214,7 +214,7 @@ class Instance(BASE, NovaBase): @property def name(self): - return self.ec2_id + return self.internal_id image_id = Column(String(255)) kernel_id = Column(String(255)) -- cgit From 58773e16ddd6f3aaa4aafefde55a3ae631e806dd Mon Sep 17 00:00:00 2001 From: Michael Gundlach Date: Thu, 30 Sep 2010 21:59:52 -0400 Subject: Convert EC2 cloud.py from assuming that EC2 IDs are stored directly in the database, to assuming that EC2 IDs should be converted to internal IDs. The conversion between the internal ID and the EC2 ID is imperfect -- right now it turns internal IDs like 408 into EC2 IDs like i-408, and vice versa. Instead, EC2 IDs are supposed to be i-[base 36 of the integer]. --- nova/api/ec2/cloud.py | 58 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 79c95788b..2fec49da8 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -113,6 +113,16 @@ class CloudController(object): result[key] = [line] return result + def ec2_id_to_internal_id(ec2_id): + """Convert an ec2 ID (i-[base 36 number]) to an internal id (int)""" + # TODO(gundlach): Maybe this should actually work? + return ec2_id[2:] + + def internal_id_to_ec2_id(internal_id): + """Convert an internal ID (int) to an ec2 ID (i-[base 36 number])""" + # TODO(gundlach): Yo maybe this should actually convert to base 36 + return "i-%d" % internal_id + def get_metadata(self, address): instance_ref = db.fixed_ip_get_instance(None, address) if instance_ref is None: @@ -144,7 +154,7 @@ class CloudController(object): }, 'hostname': hostname, 'instance-action': 'none', - 'instance-id': instance_ref['ec2_id'], + 'instance-id': internal_id_to_ec2_id(instance_ref['internal_id']), 'instance-type': instance_ref['instance_type'], 'local-hostname': hostname, 'local-ipv4': address, @@ -244,9 +254,11 @@ class CloudController(object): def delete_security_group(self, context, group_name, **kwargs): return True - def get_console_output(self, context, instance_id, **kwargs): - # instance_id is passed in as a list of instances - instance_ref = db.instance_get_by_ec2_id(context, instance_id[0]) + def get_console_output(self, context, ec2_id_list, **kwargs): + # ec2_id_list is passed in as a list of instances + ec2_id = ec2_id_list[0] + internal_id = ec2_id_to_internal_id(ec2_id) + instance_ref = db.instance_get_by_ec2_id(context, internal_id) return rpc.call('%s.%s' % (FLAGS.compute_topic, instance_ref['host']), {"method": "get_console_output", @@ -326,7 +338,8 @@ class CloudController(object): raise exception.ApiError("Volume status must be available") if volume_ref['attach_status'] == "attached": raise exception.ApiError("Volume is already attached") - instance_ref = db.instance_get_by_ec2_id(context, instance_id) + internal_id = ec2_id_to_internal_id(instance_id) + instance_ref = db.instance_get_by_internal_id(context, internal_id) host = instance_ref['host'] rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host), {"method": "attach_volume", @@ -360,9 +373,11 @@ class CloudController(object): # If the instance doesn't exist anymore, # then we need to call detach blind db.volume_detached(context) + internal_id = instance_ref['internal_id'] + ec2_id = internal_id_to_ec2_id(internal_id) return {'attachTime': volume_ref['attach_time'], 'device': volume_ref['mountpoint'], - 'instanceId': instance_ref['ec2_id'], + 'instanceId': internal_id, 'requestId': context.request_id, 'status': volume_ref['attach_status'], 'volumeId': volume_ref['id']} @@ -411,7 +426,9 @@ class CloudController(object): if instance['image_id'] == FLAGS.vpn_image_id: continue i = {} - i['instanceId'] = instance['ec2_id'] + internal_id = instance['internal_id'] + ec2_id = internal_id_to_ec2_id(internal_id) + i['instanceId'] = ec2_id i['imageId'] = instance['image_id'] i['instanceState'] = { 'code': instance['state'], @@ -464,9 +481,10 @@ class CloudController(object): instance_id = None if (floating_ip_ref['fixed_ip'] and floating_ip_ref['fixed_ip']['instance']): - instance_id = floating_ip_ref['fixed_ip']['instance']['ec2_id'] + internal_id = floating_ip_ref['fixed_ip']['instance']['ec2_id'] + ec2_id = internal_id_to_ec2_id(internal_id) address_rv = {'public_ip': address, - 'instance_id': instance_id} + 'instance_id': ec2_id} if context.user.is_admin(): details = "%s (%s)" % (address_rv['instance_id'], floating_ip_ref['project_id']) @@ -498,8 +516,9 @@ class CloudController(object): "floating_address": floating_ip_ref['address']}}) return {'releaseResponse': ["Address released."]} - def associate_address(self, context, instance_id, public_ip, **kwargs): - instance_ref = db.instance_get_by_ec2_id(context, instance_id) + def associate_address(self, context, ec2_id, public_ip, **kwargs): + internal_id = ec2_id_to_internal_id(ec2_id) + instance_ref = db.instance_get_by_internal_id(context, internal_id) fixed_address = db.instance_get_fixed_address(context, instance_ref['id']) floating_ip_ref = db.floating_ip_get_by_address(context, public_ip) @@ -610,7 +629,9 @@ class CloudController(object): inst = {} inst['mac_address'] = utils.generate_mac() inst['launch_index'] = num - inst['hostname'] = instance_ref['ec2_id'] + internal_id = instance_ref['internal_id'] + ec2_id = internal_id_to_ec2_id(internal_id) + inst['hostname'] = ec2_id db.instance_update(context, inst_id, inst) address = self.network_manager.allocate_fixed_ip(context, inst_id, @@ -634,12 +655,14 @@ class CloudController(object): return self._format_run_instances(context, reservation_id) - def terminate_instances(self, context, instance_id, **kwargs): + def terminate_instances(self, context, ec2_id_list, **kwargs): logging.debug("Going to start terminating instances") - for id_str in instance_id: + for id_str in ec2_id_list: + internal_id = ec2_id_to_internal_id(id_str) logging.debug("Going to try and terminate %s" % id_str) try: - instance_ref = db.instance_get_by_ec2_id(context, id_str) + instance_ref = db.instance_get_by_internal_id(context, + internal_id) except exception.NotFound: logging.warning("Instance %s was not found during terminate" % id_str) @@ -688,7 +711,7 @@ class CloudController(object): cloud.reboot(id_str, context=context) return True - def update_instance(self, context, instance_id, **kwargs): + def update_instance(self, context, ec2_id, **kwargs): updatable_fields = ['display_name', 'display_description'] changes = {} for field in updatable_fields: @@ -696,7 +719,8 @@ class CloudController(object): changes[field] = kwargs[field] if changes: db_context = {} - inst = db.instance_get_by_ec2_id(db_context, instance_id) + internal_id = ec2_id_to_internal_id(ec2_id) + inst = db.instance_get_by_internal_id(db_context, internal_id) db.instance_update(db_context, inst['id'], kwargs) return True -- cgit From 39080e5f5000e0f401ff19f3fd9dd8cfbffffe69 Mon Sep 17 00:00:00 2001 From: Michael Gundlach Date: Thu, 30 Sep 2010 22:05:16 -0400 Subject: Find other places in the code that used ec2_id or get_instance_by_ec2_id and use internal_id as appropriate --- nova/compute/manager.py | 6 +++--- nova/tests/cloud_unittest.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/nova/compute/manager.py b/nova/compute/manager.py index f370ede8b..131fac406 100644 --- a/nova/compute/manager.py +++ b/nova/compute/manager.py @@ -67,7 +67,7 @@ class ComputeManager(manager.Manager): def run_instance(self, context, instance_id, **_kwargs): """Launch a new instance with specified options.""" instance_ref = self.db.instance_get(context, instance_id) - if instance_ref['ec2_id'] in self.driver.list_instances(): + if instance_ref['internal_id'] in self.driver.list_instances(): raise exception.Error("Instance has already been created") logging.debug("instance %s: starting...", instance_id) project_id = instance_ref['project_id'] @@ -129,7 +129,7 @@ class ComputeManager(manager.Manager): raise exception.Error( 'trying to reboot a non-running' 'instance: %s (state: %s excepted: %s)' % - (instance_ref['ec2_id'], + (instance_ref['internal_id'], instance_ref['state'], power_state.RUNNING)) @@ -151,7 +151,7 @@ class ComputeManager(manager.Manager): if FLAGS.connection_type == 'libvirt': fname = os.path.abspath(os.path.join(FLAGS.instances_path, - instance_ref['ec2_id'], + instance_ref['internal_id'], 'console.log')) with open(fname, 'r') as f: output = f.read() diff --git a/nova/tests/cloud_unittest.py b/nova/tests/cloud_unittest.py index ae7dea1db..d316db153 100644 --- a/nova/tests/cloud_unittest.py +++ b/nova/tests/cloud_unittest.py @@ -236,7 +236,7 @@ class CloudTestCase(test.TrialTestCase): def test_update_of_instance_display_fields(self): inst = db.instance_create({}, {}) - self.cloud.update_instance(self.context, inst['ec2_id'], + self.cloud.update_instance(self.context, inst['internal_id'], display_name='c00l 1m4g3') inst = db.instance_get({}, inst['id']) self.assertEqual('c00l 1m4g3', inst['display_name']) -- cgit From 06cdef056b508e15869623da28ad18cc817e6848 Mon Sep 17 00:00:00 2001 From: Michael Gundlach Date: Thu, 30 Sep 2010 22:09:46 -0400 Subject: First attempt at a uuid generator -- but we've lost a 'topic' input so i don't know what that did. --- nova/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nova/utils.py b/nova/utils.py index d18dd9843..86ff3d22e 100644 --- a/nova/utils.py +++ b/nova/utils.py @@ -126,7 +126,9 @@ def runthis(prompt, cmd, check_exit_code = True): def generate_uid(topic, size=8): - return '%s-%s' % (topic, ''.join([random.choice('01234567890abcdefghijklmnopqrstuvwxyz') for x in xrange(size)])) + #TODO(gundlach): we want internal ids to just be ints now. i just dropped + #off a topic prefix, so what have I broken? + return random.randint(0, 2**64-1) def generate_mac(): -- cgit From ddaaebb28649811d723f93a89ee46d69cc3ecabc Mon Sep 17 00:00:00 2001 From: Vishvananda Ishaya Date: Thu, 30 Sep 2010 20:24:42 -0700 Subject: show project ids for groups instead of user ids --- nova/api/ec2/cloud.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 6c67db28d..8aa76a787 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -260,7 +260,7 @@ class CloudController(object): g = {} g['groupDescription'] = group.description g['groupName'] = group.name - g['ownerId'] = context.user.id + g['ownerId'] = group.project_id g['ipPermissions'] = [] for rule in group.rules: r = {} @@ -272,7 +272,7 @@ class CloudController(object): if rule.group_id: source_group = db.security_group_get(context, rule.group_id) r['groups'] += [{'groupName': source_group.name, - 'userId': source_group.user_id}] + 'userId': source_group.project_id}] else: r['ipRanges'] += [{'cidrIp': rule.cidr}] g['ipPermissions'] += [r] -- cgit From c9e14d6257f0b488bd892c09d284091c0f612dd7 Mon Sep 17 00:00:00 2001 From: Devin Carlen Date: Fri, 1 Oct 2010 01:44:17 -0700 Subject: Locked down fixed ips and improved network tests --- nova/db/sqlalchemy/api.py | 98 ++++++++++++++++-------------------------- nova/tests/network_unittest.py | 44 ++++++++++--------- 2 files changed, 60 insertions(+), 82 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index fc5ee2235..860723516 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -78,7 +78,7 @@ def authorize_user_context(context, user_id): raise exception.NotAuthorized() -def use_deleted(context): +def can_read_deleted(context): """Indicates if the context has access to deleted objects.""" if not context: return False @@ -124,7 +124,7 @@ def service_get(context, service_id, session=None): result = session.query(models.Service ).filter_by(id=service_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() if not result: @@ -222,9 +222,8 @@ def service_get_by_args(context, host, binary): result = session.query(models.Service ).filter_by(host=host ).filter_by(binary=binary - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() - if not result: raise exception.NotFound('No service for %s, %s' % (host, binary)) @@ -256,7 +255,6 @@ def service_update(context, service_id, values): @require_context def floating_ip_allocate_address(context, host, project_id): authorize_project_context(context, project_id) - session = get_session() with session.begin(): floating_ip_ref = session.query(models.FloatingIp @@ -287,7 +285,6 @@ def floating_ip_create(context, values): @require_context def floating_ip_count_by_project(context, project_id): authorize_project_context(context, project_id) - session = get_session() return session.query(models.FloatingIp ).filter_by(project_id=project_id @@ -374,7 +371,6 @@ def floating_ip_get_all_by_host(context, host): @require_context def floating_ip_get_all_by_project(context, project_id): authorize_project_context(context, project_id) - session = get_session() return session.query(models.FloatingIp ).options(joinedload_all('fixed_ip.instance') @@ -391,7 +387,7 @@ def floating_ip_get_by_address(context, address, session=None): result = session.query(models.FloatingIp ).filter_by(address=address - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() if not result: raise exception.NotFound('No fixed ip for address %s' % address) @@ -406,6 +402,7 @@ def floating_ip_get_by_address(context, address, session=None): def fixed_ip_associate(context, address, instance_id): session = get_session() with session.begin(): + instance = instance_get(context, instance_id, session=session) fixed_ip_ref = session.query(models.FixedIp ).filter_by(address=address ).filter_by(deleted=False @@ -416,9 +413,7 @@ def fixed_ip_associate(context, address, instance_id): # then this has concurrency issues if not fixed_ip_ref: raise db.NoMoreAddresses() - fixed_ip_ref.instance = instance_get(context, - instance_id, - session=session) + fixed_ip_ref.instance = instance session.add(fixed_ip_ref) @@ -472,21 +467,21 @@ def fixed_ip_disassociate(context, address): @require_context def fixed_ip_get_by_address(context, address, session=None): - # TODO(devcamcar): Ensure floating ip belongs to user. - # Only possible if it is associated with an instance. - # May have to use system context for this always. if not session: session = get_session() result = session.query(models.FixedIp ).filter_by(address=address - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).options(joinedload('network') ).options(joinedload('instance') ).first() if not result: raise exception.NotFound('No floating ip for address %s' % address) + if is_user_context(context): + authorize_project_context(context, result.instance.project_id) + return result @@ -562,7 +557,7 @@ def instance_get(context, instance_id, session=None): if is_admin_context(context): result = session.query(models.Instance ).filter_by(id=instance_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Instance @@ -581,7 +576,7 @@ def instance_get_all(context): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() @@ -590,7 +585,7 @@ def instance_get_all_by_user(context, user_id): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).filter_by(user_id=user_id ).all() @@ -603,7 +598,7 @@ def instance_get_all_by_project(context, project_id): return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).filter_by(project_id=project_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() @@ -615,7 +610,7 @@ def instance_get_all_by_reservation(context, reservation_id): return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') ).filter_by(reservation_id=reservation_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() elif is_user_context(context): return session.query(models.Instance @@ -633,7 +628,7 @@ def instance_get_by_ec2_id(context, ec2_id): if is_admin_context(context): result = session.query(models.Instance ).filter_by(ec2_id=ec2_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Instance @@ -749,7 +744,7 @@ def key_pair_get(context, user_id, name, session=None): result = session.query(models.KeyPair ).filter_by(user_id=user_id ).filter_by(name=name - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() if not result: raise exception.NotFound('no keypair for user %s, name %s' % @@ -775,7 +770,7 @@ def key_pair_get_all_by_user(context, user_id): def network_count(context): session = get_session() return session.query(models.Network - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).count() @@ -847,7 +842,7 @@ def network_get(context, network_id, session=None): if is_admin_context(context): result = session.query(models.Network ).filter_by(id=network_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Network @@ -914,7 +909,7 @@ def network_get_index(context, network_id): def network_index_count(context): session = get_session() return session.query(models.NetworkIndex - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).count() @@ -992,7 +987,7 @@ def queue_get_for(_context, topic, physical_node_id): def export_device_count(context): session = get_session() return session.query(models.ExportDevice - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).count() @@ -1038,7 +1033,7 @@ def quota_get(context, project_id, session=None): result = session.query(models.Quota ).filter_by(project_id=project_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() if not result: raise exception.NotFound('No quota for project_id %s' % project_id) @@ -1167,7 +1162,7 @@ def volume_get(context, volume_id, session=None): if is_admin_context(context): result = session.query(models.Volume ).filter_by(id=volume_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Volume @@ -1184,7 +1179,7 @@ def volume_get(context, volume_id, session=None): @require_admin_context def volume_get_all(context): return session.query(models.Volume - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() @require_context @@ -1194,7 +1189,7 @@ def volume_get_all_by_project(context, project_id): session = get_session() return session.query(models.Volume ).filter_by(project_id=project_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() @@ -1206,7 +1201,7 @@ def volume_get_by_ec2_id(context, ec2_id): if is_admin_context(context): result = session.query(models.Volume ).filter_by(ec2_id=ec2_id - ).filter_by(deleted=use_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Volume @@ -1233,47 +1228,26 @@ def volume_ec2_id_exists(context, ec2_id, session=None): ).one()[0] -@require_context +@require_admin_context def volume_get_instance(context, volume_id): session = get_session() - result = None - - if is_admin_context(context): - result = session.query(models.Volume - ).filter_by(id=volume_id - ).filter_by(deleted=use_deleted(context) - ).options(joinedload('instance') - ).first() - elif is_user_context(context): - result = session.query(models.Volume - ).filter_by(project_id=context.project.id - ).filter_by(deleted=False - ).options(joinedload('instance') - ).first() - else: - raise exception.NotAuthorized() - + result = session.query(models.Volume + ).filter_by(id=volume_id + ).filter_by(deleted=can_read_deleted(context) + ).options(joinedload('instance') + ).first() if not result: raise exception.NotFound('Volume %s not found' % ec2_id) return result.instance -@require_context +@require_admin_context def volume_get_shelf_and_blade(context, volume_id): session = get_session() - result = None - - if is_admin_context(context): - result = session.query(models.ExportDevice - ).filter_by(volume_id=volume_id - ).first() - elif is_user_context(context): - result = session.query(models.ExportDevice - ).join(models.Volume - ).filter(models.Volume.project_id==context.project.id - ).filter_by(volume_id=volume_id - ).first() + result = session.query(models.ExportDevice + ).filter_by(volume_id=volume_id + ).first() if not result: raise exception.NotFound('No export device found for volume %s' % volume_id) diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index e01d7cff9..e601c480c 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -49,7 +49,6 @@ class NetworkTestCase(test.TrialTestCase): self.user = self.manager.create_user('netuser', 'netuser', 'netuser') self.projects = [] self.network = utils.import_object(FLAGS.network_manager) - # TODO(devcamcar): Passing project=None is Bad(tm). self.context = context.APIRequestContext(project=None, user=self.user) for i in range(5): name = 'project%s' % i @@ -60,11 +59,9 @@ class NetworkTestCase(test.TrialTestCase): user_context = context.APIRequestContext(project=self.projects[i], user=self.user) self.network.set_network_host(user_context, self.projects[i].id) - instance_ref = db.instance_create(None, - {'mac_address': utils.generate_mac()}) + instance_ref = self._create_instance(0) self.instance_id = instance_ref['id'] - instance_ref = db.instance_create(None, - {'mac_address': utils.generate_mac()}) + instance_ref = self._create_instance(1) self.instance2_id = instance_ref['id'] def tearDown(self): # pylint: disable-msg=C0103 @@ -77,6 +74,15 @@ class NetworkTestCase(test.TrialTestCase): self.manager.delete_project(project) self.manager.delete_user(self.user) + def _create_instance(self, project_num, mac=None): + if not mac: + mac = utils.generate_mac() + project = self.projects[project_num] + self.context.project = project + return db.instance_create(self.context, + {'project_id': project.id, + 'mac_address': mac}) + def _create_address(self, project_num, instance_id=None): """Create an address in given project num""" if instance_id is None: @@ -84,6 +90,11 @@ class NetworkTestCase(test.TrialTestCase): self.context.project = self.projects[project_num] return self.network.allocate_fixed_ip(self.context, instance_id) + def _deallocate_address(self, project_num, address): + self.context.project = self.projects[project_num] + self.network.deallocate_fixed_ip(self.context, address) + + def test_public_network_association(self): """Makes sure that we can allocaate a public ip""" # TODO(vish): better way of adding floating ips @@ -134,14 +145,14 @@ class NetworkTestCase(test.TrialTestCase): lease_ip(address) lease_ip(address2) - self.network.deallocate_fixed_ip(self.context, address) + self._deallocate_address(0, address) release_ip(address) self.assertFalse(is_allocated_in_project(address, self.projects[0].id)) # First address release shouldn't affect the second self.assertTrue(is_allocated_in_project(address2, self.projects[1].id)) - self.network.deallocate_fixed_ip(self.context, address2) + self._deallocate_address(1, address2) release_ip(address2) self.assertFalse(is_allocated_in_project(address2, self.projects[1].id)) @@ -152,24 +163,19 @@ class NetworkTestCase(test.TrialTestCase): lease_ip(first) instance_ids = [] for i in range(1, 5): - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address = self._create_address(i, instance_ref['id']) - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address2 = self._create_address(i, instance_ref['id']) - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(i, mac=utils.generate_mac()) instance_ids.append(instance_ref['id']) address3 = self._create_address(i, instance_ref['id']) lease_ip(address) lease_ip(address2) lease_ip(address3) + self.context.project = self.projects[i] self.assertFalse(is_allocated_in_project(address, self.projects[0].id)) self.assertFalse(is_allocated_in_project(address2, @@ -185,7 +191,7 @@ class NetworkTestCase(test.TrialTestCase): for instance_id in instance_ids: db.instance_destroy(None, instance_id) release_ip(first) - self.network.deallocate_fixed_ip(self.context, first) + self._deallocate_address(0, first) def test_vpn_ip_and_port_looks_valid(self): """Ensure the vpn ip and port are reasonable""" @@ -246,9 +252,7 @@ class NetworkTestCase(test.TrialTestCase): addresses = [] instance_ids = [] for i in range(num_available_ips): - mac = utils.generate_mac() - instance_ref = db.instance_create(None, - {'mac_address': mac}) + instance_ref = self._create_instance(0) instance_ids.append(instance_ref['id']) address = self._create_address(0, instance_ref['id']) addresses.append(address) -- cgit From 7e020e743c138d542e957c24ea53c1ca7fbc757c Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Fri, 1 Oct 2010 13:03:57 +0200 Subject: Address a few comments from Todd. --- nova/db/api.py | 5 ----- nova/db/sqlalchemy/api.py | 8 -------- nova/db/sqlalchemy/models.py | 2 +- 3 files changed, 1 insertion(+), 14 deletions(-) diff --git a/nova/db/api.py b/nova/db/api.py index 703936002..eb4eee782 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -640,11 +640,6 @@ def project_get(context, id): return IMPL.project_get(context, id) -#def project_get_by_uid(context, uid): -# """Get project by uid""" -# return IMPL.project_get_by_uid(context, uid) -# - def project_create(context, values): """Create a new project""" return IMPL.project_create(context, values) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index bd5c285d8..5cd2d6d51 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -1000,14 +1000,6 @@ def project_get(context, id): return result -def project_get_by_uid(context, uid): - session = get_session() - return session.query(models.Project - ).filter_by(uid=uid - ).filter_by(deleted=_deleted(context) - ).first() - - def project_get_all(context): session = get_session() return session.query(models.Project diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index b247eb416..92a68ab68 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -104,7 +104,7 @@ class NovaBase(object): session.flush() except IntegrityError, e: if str(e).endswith('is not unique'): - raise Exception.Duplicate(str(e)) + raise exception.Duplicate(str(e)) else: raise -- cgit From bf22bbd2d4f4364255a306e024d1a7d316b89014 Mon Sep 17 00:00:00 2001 From: "jaypipes@gmail.com" <> Date: Fri, 1 Oct 2010 14:02:51 -0400 Subject: Cleans up the unit tests that are meant to be run with nosetests * Renames all test modules to start with test_ so that nosetests does not need to be run with the --all-modules flag in order to pick them up * Renames test_helper to fakes and removes imports in unit tests that did not reference the fakes * Adds nose to pip-requires so that run_tests.sh -V will install nose into the virtualenv instead of having to manually install it after running into import errors :) --- nova/tests/api/__init__.py | 3 +- nova/tests/api/fakes.py | 8 + nova/tests/api/rackspace/__init__.py | 4 +- nova/tests/api/rackspace/auth.py | 108 ---------- nova/tests/api/rackspace/fakes.py | 148 ++++++++++++++ nova/tests/api/rackspace/flavors.py | 46 ----- nova/tests/api/rackspace/images.py | 40 ---- nova/tests/api/rackspace/servers.py | 245 ----------------------- nova/tests/api/rackspace/sharedipgroups.py | 41 ---- nova/tests/api/rackspace/test_auth.py | 108 ++++++++++ nova/tests/api/rackspace/test_faults.py | 40 ++++ nova/tests/api/rackspace/test_flavors.py | 48 +++++ nova/tests/api/rackspace/test_helper.py | 134 ------------- nova/tests/api/rackspace/test_images.py | 39 ++++ nova/tests/api/rackspace/test_servers.py | 250 ++++++++++++++++++++++++ nova/tests/api/rackspace/test_sharedipgroups.py | 39 ++++ nova/tests/api/rackspace/testfaults.py | 40 ---- nova/tests/api/test_helper.py | 8 - nova/tests/api/test_wsgi.py | 147 ++++++++++++++ nova/tests/api/wsgi_test.py | 147 -------------- tools/pip-requires | 1 + 21 files changed, 832 insertions(+), 812 deletions(-) create mode 100644 nova/tests/api/fakes.py delete mode 100644 nova/tests/api/rackspace/auth.py create mode 100644 nova/tests/api/rackspace/fakes.py delete mode 100644 nova/tests/api/rackspace/flavors.py delete mode 100644 nova/tests/api/rackspace/images.py delete mode 100644 nova/tests/api/rackspace/servers.py delete mode 100644 nova/tests/api/rackspace/sharedipgroups.py create mode 100644 nova/tests/api/rackspace/test_auth.py create mode 100644 nova/tests/api/rackspace/test_faults.py create mode 100644 nova/tests/api/rackspace/test_flavors.py delete mode 100644 nova/tests/api/rackspace/test_helper.py create mode 100644 nova/tests/api/rackspace/test_images.py create mode 100644 nova/tests/api/rackspace/test_servers.py create mode 100644 nova/tests/api/rackspace/test_sharedipgroups.py delete mode 100644 nova/tests/api/rackspace/testfaults.py delete mode 100644 nova/tests/api/test_helper.py create mode 100644 nova/tests/api/test_wsgi.py delete mode 100644 nova/tests/api/wsgi_test.py diff --git a/nova/tests/api/__init__.py b/nova/tests/api/__init__.py index fc1ab9ae2..ec76aa827 100644 --- a/nova/tests/api/__init__.py +++ b/nova/tests/api/__init__.py @@ -27,7 +27,8 @@ import webob.dec import nova.exception from nova import api -from nova.tests.api.test_helper import * +from nova.tests.api.fakes import APIStub + class Test(unittest.TestCase): diff --git a/nova/tests/api/fakes.py b/nova/tests/api/fakes.py new file mode 100644 index 000000000..d0a2cc027 --- /dev/null +++ b/nova/tests/api/fakes.py @@ -0,0 +1,8 @@ +import webob.dec +from nova import wsgi + +class APIStub(object): + """Class to verify request and mark it was called.""" + @webob.dec.wsgify + def __call__(self, req): + return req.path_info diff --git a/nova/tests/api/rackspace/__init__.py b/nova/tests/api/rackspace/__init__.py index bfd0f87a7..1834f91b1 100644 --- a/nova/tests/api/rackspace/__init__.py +++ b/nova/tests/api/rackspace/__init__.py @@ -19,7 +19,7 @@ import unittest from nova.api.rackspace import limited from nova.api.rackspace import RateLimitingMiddleware -from nova.tests.api.test_helper import * +from nova.tests.api.fakes import APIStub from webob import Request @@ -82,7 +82,7 @@ class RateLimitingMiddlewareTest(unittest.TestCase): class LimiterTest(unittest.TestCase): - def testLimiter(self): + def test_limiter(self): items = range(2000) req = Request.blank('/') self.assertEqual(limited(items, req), items[ :1000]) diff --git a/nova/tests/api/rackspace/auth.py b/nova/tests/api/rackspace/auth.py deleted file mode 100644 index 56677c2f4..000000000 --- a/nova/tests/api/rackspace/auth.py +++ /dev/null @@ -1,108 +0,0 @@ -import datetime -import unittest - -import stubout -import webob -import webob.dec - -import nova.api -import nova.api.rackspace.auth -from nova import auth -from nova.tests.api.rackspace import test_helper - -class Test(unittest.TestCase): - def setUp(self): - self.stubs = stubout.StubOutForTesting() - self.stubs.Set(nova.api.rackspace.auth.BasicApiAuthManager, - '__init__', test_helper.fake_auth_init) - test_helper.FakeAuthManager.auth_data = {} - test_helper.FakeAuthDatabase.data = {} - test_helper.stub_out_rate_limiting(self.stubs) - test_helper.stub_for_testing(self.stubs) - - def tearDown(self): - self.stubs.UnsetAll() - test_helper.fake_data_store = {} - - def test_authorize_user(self): - f = test_helper.FakeAuthManager() - f.add_user('derp', { 'uid': 1, 'name':'herp' } ) - - req = webob.Request.blank('/v1.0/') - req.headers['X-Auth-User'] = 'herp' - req.headers['X-Auth-Key'] = 'derp' - result = req.get_response(nova.api.API()) - self.assertEqual(result.status, '204 No Content') - self.assertEqual(len(result.headers['X-Auth-Token']), 40) - self.assertEqual(result.headers['X-CDN-Management-Url'], - "") - self.assertEqual(result.headers['X-Storage-Url'], "") - - def test_authorize_token(self): - f = test_helper.FakeAuthManager() - f.add_user('derp', { 'uid': 1, 'name':'herp' } ) - - req = webob.Request.blank('/v1.0/') - req.headers['X-Auth-User'] = 'herp' - req.headers['X-Auth-Key'] = 'derp' - result = req.get_response(nova.api.API()) - self.assertEqual(result.status, '204 No Content') - self.assertEqual(len(result.headers['X-Auth-Token']), 40) - self.assertEqual(result.headers['X-Server-Management-Url'], - "https://foo/v1.0/") - self.assertEqual(result.headers['X-CDN-Management-Url'], - "") - self.assertEqual(result.headers['X-Storage-Url'], "") - - token = result.headers['X-Auth-Token'] - self.stubs.Set(nova.api.rackspace, 'APIRouter', - test_helper.FakeRouter) - req = webob.Request.blank('/v1.0/fake') - req.headers['X-Auth-Token'] = token - result = req.get_response(nova.api.API()) - self.assertEqual(result.status, '200 OK') - self.assertEqual(result.headers['X-Test-Success'], 'True') - - def test_token_expiry(self): - self.destroy_called = False - token_hash = 'bacon' - - def destroy_token_mock(meh, context, token): - self.destroy_called = True - - def bad_token(meh, context, token_hash): - return { 'token_hash':token_hash, - 'created_at':datetime.datetime(1990, 1, 1) } - - self.stubs.Set(test_helper.FakeAuthDatabase, 'auth_destroy_token', - destroy_token_mock) - - self.stubs.Set(test_helper.FakeAuthDatabase, 'auth_get_token', - bad_token) - - req = webob.Request.blank('/v1.0/') - req.headers['X-Auth-Token'] = 'bacon' - result = req.get_response(nova.api.API()) - self.assertEqual(result.status, '401 Unauthorized') - self.assertEqual(self.destroy_called, True) - - def test_bad_user(self): - req = webob.Request.blank('/v1.0/') - req.headers['X-Auth-User'] = 'herp' - req.headers['X-Auth-Key'] = 'derp' - result = req.get_response(nova.api.API()) - self.assertEqual(result.status, '401 Unauthorized') - - def test_no_user(self): - req = webob.Request.blank('/v1.0/') - result = req.get_response(nova.api.API()) - self.assertEqual(result.status, '401 Unauthorized') - - def test_bad_token(self): - req = webob.Request.blank('/v1.0/') - req.headers['X-Auth-Token'] = 'baconbaconbacon' - result = req.get_response(nova.api.API()) - self.assertEqual(result.status, '401 Unauthorized') - -if __name__ == '__main__': - unittest.main() diff --git a/nova/tests/api/rackspace/fakes.py b/nova/tests/api/rackspace/fakes.py new file mode 100644 index 000000000..2c4447920 --- /dev/null +++ b/nova/tests/api/rackspace/fakes.py @@ -0,0 +1,148 @@ +import datetime +import json + +import webob +import webob.dec + +from nova import auth +from nova import utils +from nova import flags +import nova.api.rackspace.auth +import nova.api.rackspace._id_translator +from nova.image import service +from nova.wsgi import Router + + +FLAGS = flags.FLAGS + + +class Context(object): + pass + + +class FakeRouter(Router): + def __init__(self): + pass + + @webob.dec.wsgify + def __call__(self, req): + res = webob.Response() + res.status = '200' + res.headers['X-Test-Success'] = 'True' + return res + + +def fake_auth_init(self): + self.db = FakeAuthDatabase() + self.context = Context() + self.auth = FakeAuthManager() + self.host = 'foo' + + +@webob.dec.wsgify +def fake_wsgi(self, req): + req.environ['nova.context'] = dict(user=dict(id=1)) + if req.body: + req.environ['inst_dict'] = json.loads(req.body) + return self.application + + +def stub_out_key_pair_funcs(stubs): + def key_pair(context, user_id): + return [dict(name='key', public_key='public_key')] + stubs.Set(nova.db.api, 'key_pair_get_all_by_user', + key_pair) + + +def stub_out_image_service(stubs): + def fake_image_show(meh, id): + return dict(kernelId=1, ramdiskId=1) + + stubs.Set(nova.image.service.LocalImageService, 'show', fake_image_show) + + +def stub_out_id_translator(stubs): + class FakeTranslator(object): + def __init__(self, id_type, service_name): + pass + + def to_rs_id(self, id): + return id + + def from_rs_id(self, id): + return id + + stubs.Set(nova.api.rackspace._id_translator, + 'RackspaceAPIIdTranslator', FakeTranslator) + + +def stub_out_auth(stubs): + def fake_auth_init(self, app): + self.application = app + + stubs.Set(nova.api.rackspace.AuthMiddleware, + '__init__', fake_auth_init) + stubs.Set(nova.api.rackspace.AuthMiddleware, + '__call__', fake_wsgi) + + +def stub_out_rate_limiting(stubs): + def fake_rate_init(self, app): + super(nova.api.rackspace.RateLimitingMiddleware, self).__init__(app) + self.application = app + + stubs.Set(nova.api.rackspace.RateLimitingMiddleware, + '__init__', fake_rate_init) + + stubs.Set(nova.api.rackspace.RateLimitingMiddleware, + '__call__', fake_wsgi) + + +def stub_out_networking(stubs): + def get_my_ip(): + return '127.0.0.1' + stubs.Set(nova.utils, 'get_my_ip', get_my_ip) + FLAGS.FAKE_subdomain = 'rs' + + +class FakeAuthDatabase(object): + data = {} + + @staticmethod + def auth_get_token(context, token_hash): + return FakeAuthDatabase.data.get(token_hash, None) + + @staticmethod + def auth_create_token(context, token): + token['created_at'] = datetime.datetime.now() + FakeAuthDatabase.data[token['token_hash']] = token + + @staticmethod + def auth_destroy_token(context, token): + if FakeAuthDatabase.data.has_key(token['token_hash']): + del FakeAuthDatabase.data['token_hash'] + + +class FakeAuthManager(object): + auth_data = {} + + def add_user(self, key, user): + FakeAuthManager.auth_data[key] = user + + def get_user(self, uid): + for k, v in FakeAuthManager.auth_data.iteritems(): + if v['uid'] == uid: + return v + return None + + def get_user_from_access_key(self, key): + return FakeAuthManager.auth_data.get(key, None) + + +class FakeRateLimiter(object): + def __init__(self, application): + self.application = application + + @webob.dec.wsgify + def __call__(self, req): + return self.application diff --git a/nova/tests/api/rackspace/flavors.py b/nova/tests/api/rackspace/flavors.py deleted file mode 100644 index d25a2e2be..000000000 --- a/nova/tests/api/rackspace/flavors.py +++ /dev/null @@ -1,46 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import unittest -import stubout - -import nova.api -from nova.api.rackspace import flavors -from nova.tests.api.rackspace import test_helper -from nova.tests.api.test_helper import * - -class FlavorsTest(unittest.TestCase): - def setUp(self): - self.stubs = stubout.StubOutForTesting() - test_helper.FakeAuthManager.auth_data = {} - test_helper.FakeAuthDatabase.data = {} - test_helper.stub_for_testing(self.stubs) - test_helper.stub_out_rate_limiting(self.stubs) - test_helper.stub_out_auth(self.stubs) - - def tearDown(self): - self.stubs.UnsetAll() - - def test_get_flavor_list(self): - req = webob.Request.blank('/v1.0/flavors') - res = req.get_response(nova.api.API()) - - def test_get_flavor_by_id(self): - pass - -if __name__ == '__main__': - unittest.main() diff --git a/nova/tests/api/rackspace/images.py b/nova/tests/api/rackspace/images.py deleted file mode 100644 index 4c9987e8b..000000000 --- a/nova/tests/api/rackspace/images.py +++ /dev/null @@ -1,40 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import stubout -import unittest - -from nova.api.rackspace import images -from nova.tests.api.test_helper import * - -class ImagesTest(unittest.TestCase): - def setUp(self): - self.stubs = stubout.StubOutForTesting() - - def tearDown(self): - self.stubs.UnsetAll() - - def test_get_image_list(self): - pass - - def test_delete_image(self): - pass - - def test_create_image(self): - pass - - diff --git a/nova/tests/api/rackspace/servers.py b/nova/tests/api/rackspace/servers.py deleted file mode 100644 index 69ad2c1d3..000000000 --- a/nova/tests/api/rackspace/servers.py +++ /dev/null @@ -1,245 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import json -import unittest - -import stubout - -from nova import db -from nova import flags -import nova.api.rackspace -from nova.api.rackspace import servers -import nova.db.api -from nova.db.sqlalchemy.models import Instance -import nova.rpc -from nova.tests.api.test_helper import * -from nova.tests.api.rackspace import test_helper - -FLAGS = flags.FLAGS - -def return_server(context, id): - return stub_instance(id) - -def return_servers(context, user_id=1): - return [stub_instance(i, user_id) for i in xrange(5)] - - -def stub_instance(id, user_id=1): - return Instance( - id=id, state=0, image_id=10, server_name='server%s'%id, - user_id=user_id - ) - -class ServersTest(unittest.TestCase): - def setUp(self): - self.stubs = stubout.StubOutForTesting() - test_helper.FakeAuthManager.auth_data = {} - test_helper.FakeAuthDatabase.data = {} - test_helper.stub_for_testing(self.stubs) - test_helper.stub_out_rate_limiting(self.stubs) - test_helper.stub_out_auth(self.stubs) - test_helper.stub_out_id_translator(self.stubs) - test_helper.stub_out_key_pair_funcs(self.stubs) - test_helper.stub_out_image_service(self.stubs) - self.stubs.Set(nova.db.api, 'instance_get_all', return_servers) - self.stubs.Set(nova.db.api, 'instance_get_by_ec2_id', return_server) - self.stubs.Set(nova.db.api, 'instance_get_all_by_user', - return_servers) - - def tearDown(self): - self.stubs.UnsetAll() - - def test_get_server_by_id(self): - req = webob.Request.blank('/v1.0/servers/1') - res = req.get_response(nova.api.API()) - res_dict = json.loads(res.body) - self.assertEqual(res_dict['server']['id'], '1') - self.assertEqual(res_dict['server']['name'], 'server1') - - def test_get_server_list(self): - req = webob.Request.blank('/v1.0/servers') - res = req.get_response(nova.api.API()) - res_dict = json.loads(res.body) - - i = 0 - for s in res_dict['servers']: - self.assertEqual(s['id'], i) - self.assertEqual(s['name'], 'server%d'%i) - self.assertEqual(s.get('imageId', None), None) - i += 1 - - def test_create_instance(self): - def server_update(context, id, params): - pass - - def instance_create(context, inst): - class Foo(object): - ec2_id = 1 - return Foo() - - def fake_method(*args, **kwargs): - pass - - def project_get_network(context, user_id): - return dict(id='1', host='localhost') - - def queue_get_for(context, *args): - return 'network_topic' - - self.stubs.Set(nova.db.api, 'project_get_network', project_get_network) - self.stubs.Set(nova.db.api, 'instance_create', instance_create) - self.stubs.Set(nova.rpc, 'cast', fake_method) - self.stubs.Set(nova.rpc, 'call', fake_method) - self.stubs.Set(nova.db.api, 'instance_update', - server_update) - self.stubs.Set(nova.db.api, 'queue_get_for', queue_get_for) - self.stubs.Set(nova.network.manager.FlatManager, 'allocate_fixed_ip', - fake_method) - - test_helper.stub_out_id_translator(self.stubs) - body = dict(server=dict( - name='server_test', imageId=2, flavorId=2, metadata={}, - personality = {} - )) - req = webob.Request.blank('/v1.0/servers') - req.method = 'POST' - req.body = json.dumps(body) - - res = req.get_response(nova.api.API()) - - self.assertEqual(res.status_int, 200) - - def test_update_no_body(self): - req = webob.Request.blank('/v1.0/servers/1') - req.method = 'PUT' - res = req.get_response(nova.api.API()) - self.assertEqual(res.status_int, 422) - - def test_update_bad_params(self): - """ Confirm that update is filtering params """ - inst_dict = dict(cat='leopard', name='server_test', adminPass='bacon') - self.body = json.dumps(dict(server=inst_dict)) - - def server_update(context, id, params): - self.update_called = True - filtered_dict = dict(name='server_test', admin_pass='bacon') - self.assertEqual(params, filtered_dict) - - self.stubs.Set(nova.db.api, 'instance_update', - server_update) - - req = webob.Request.blank('/v1.0/servers/1') - req.method = 'PUT' - req.body = self.body - req.get_response(nova.api.API()) - - def test_update_server(self): - inst_dict = dict(name='server_test', adminPass='bacon') - self.body = json.dumps(dict(server=inst_dict)) - - def server_update(context, id, params): - filtered_dict = dict(name='server_test', admin_pass='bacon') - self.assertEqual(params, filtered_dict) - - self.stubs.Set(nova.db.api, 'instance_update', - server_update) - - req = webob.Request.blank('/v1.0/servers/1') - req.method = 'PUT' - req.body = self.body - req.get_response(nova.api.API()) - - def test_create_backup_schedules(self): - req = webob.Request.blank('/v1.0/servers/1/backup_schedules') - req.method = 'POST' - res = req.get_response(nova.api.API()) - self.assertEqual(res.status, '404 Not Found') - - def test_delete_backup_schedules(self): - req = webob.Request.blank('/v1.0/servers/1/backup_schedules') - req.method = 'DELETE' - res = req.get_response(nova.api.API()) - self.assertEqual(res.status, '404 Not Found') - - def test_get_server_backup_schedules(self): - req = webob.Request.blank('/v1.0/servers/1/backup_schedules') - res = req.get_response(nova.api.API()) - self.assertEqual(res.status, '404 Not Found') - - def test_get_all_server_details(self): - req = webob.Request.blank('/v1.0/servers/detail') - res = req.get_response(nova.api.API()) - res_dict = json.loads(res.body) - - i = 0 - for s in res_dict['servers']: - self.assertEqual(s['id'], i) - self.assertEqual(s['name'], 'server%d'%i) - self.assertEqual(s['imageId'], 10) - i += 1 - - def test_server_reboot(self): - body = dict(server=dict( - name='server_test', imageId=2, flavorId=2, metadata={}, - personality = {} - )) - req = webob.Request.blank('/v1.0/servers/1/action') - req.method = 'POST' - req.content_type= 'application/json' - req.body = json.dumps(body) - res = req.get_response(nova.api.API()) - - def test_server_rebuild(self): - body = dict(server=dict( - name='server_test', imageId=2, flavorId=2, metadata={}, - personality = {} - )) - req = webob.Request.blank('/v1.0/servers/1/action') - req.method = 'POST' - req.content_type= 'application/json' - req.body = json.dumps(body) - res = req.get_response(nova.api.API()) - - def test_server_resize(self): - body = dict(server=dict( - name='server_test', imageId=2, flavorId=2, metadata={}, - personality = {} - )) - req = webob.Request.blank('/v1.0/servers/1/action') - req.method = 'POST' - req.content_type= 'application/json' - req.body = json.dumps(body) - res = req.get_response(nova.api.API()) - - def test_delete_server_instance(self): - req = webob.Request.blank('/v1.0/servers/1') - req.method = 'DELETE' - - self.server_delete_called = False - def instance_destroy_mock(context, id): - self.server_delete_called = True - - self.stubs.Set(nova.db.api, 'instance_destroy', - instance_destroy_mock) - - res = req.get_response(nova.api.API()) - self.assertEqual(res.status, '202 Accepted') - self.assertEqual(self.server_delete_called, True) - -if __name__ == "__main__": - unittest.main() diff --git a/nova/tests/api/rackspace/sharedipgroups.py b/nova/tests/api/rackspace/sharedipgroups.py deleted file mode 100644 index 1906b54f5..000000000 --- a/nova/tests/api/rackspace/sharedipgroups.py +++ /dev/null @@ -1,41 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import stubout -import unittest - -from nova.api.rackspace import sharedipgroups -from nova.tests.api.test_helper import * - -class SharedIpGroupsTest(unittest.TestCase): - def setUp(self): - self.stubs = stubout.StubOutForTesting() - - def tearDown(self): - self.stubs.UnsetAll() - - def test_get_shared_ip_groups(self): - pass - - def test_create_shared_ip_group(self): - pass - - def test_delete_shared_ip_group(self): - pass - - - diff --git a/nova/tests/api/rackspace/test_auth.py b/nova/tests/api/rackspace/test_auth.py new file mode 100644 index 000000000..374cfe42b --- /dev/null +++ b/nova/tests/api/rackspace/test_auth.py @@ -0,0 +1,108 @@ +import datetime +import unittest + +import stubout +import webob +import webob.dec + +import nova.api +import nova.api.rackspace.auth +from nova import auth +from nova.tests.api.rackspace import fakes + +class Test(unittest.TestCase): + def setUp(self): + self.stubs = stubout.StubOutForTesting() + self.stubs.Set(nova.api.rackspace.auth.BasicApiAuthManager, + '__init__', fakes.fake_auth_init) + fakes.FakeAuthManager.auth_data = {} + fakes.FakeAuthDatabase.data = {} + fakes.stub_out_rate_limiting(self.stubs) + fakes.stub_out_networking(self.stubs) + + def tearDown(self): + self.stubs.UnsetAll() + fakes.fake_data_store = {} + + def test_authorize_user(self): + f = fakes.FakeAuthManager() + f.add_user('derp', { 'uid': 1, 'name':'herp' } ) + + req = webob.Request.blank('/v1.0/') + req.headers['X-Auth-User'] = 'herp' + req.headers['X-Auth-Key'] = 'derp' + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '204 No Content') + self.assertEqual(len(result.headers['X-Auth-Token']), 40) + self.assertEqual(result.headers['X-CDN-Management-Url'], + "") + self.assertEqual(result.headers['X-Storage-Url'], "") + + def test_authorize_token(self): + f = fakes.FakeAuthManager() + f.add_user('derp', { 'uid': 1, 'name':'herp' } ) + + req = webob.Request.blank('/v1.0/') + req.headers['X-Auth-User'] = 'herp' + req.headers['X-Auth-Key'] = 'derp' + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '204 No Content') + self.assertEqual(len(result.headers['X-Auth-Token']), 40) + self.assertEqual(result.headers['X-Server-Management-Url'], + "https://foo/v1.0/") + self.assertEqual(result.headers['X-CDN-Management-Url'], + "") + self.assertEqual(result.headers['X-Storage-Url'], "") + + token = result.headers['X-Auth-Token'] + self.stubs.Set(nova.api.rackspace, 'APIRouter', + fakes.FakeRouter) + req = webob.Request.blank('/v1.0/fake') + req.headers['X-Auth-Token'] = token + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '200 OK') + self.assertEqual(result.headers['X-Test-Success'], 'True') + + def test_token_expiry(self): + self.destroy_called = False + token_hash = 'bacon' + + def destroy_token_mock(meh, context, token): + self.destroy_called = True + + def bad_token(meh, context, token_hash): + return { 'token_hash':token_hash, + 'created_at':datetime.datetime(1990, 1, 1) } + + self.stubs.Set(fakes.FakeAuthDatabase, 'auth_destroy_token', + destroy_token_mock) + + self.stubs.Set(fakes.FakeAuthDatabase, 'auth_get_token', + bad_token) + + req = webob.Request.blank('/v1.0/') + req.headers['X-Auth-Token'] = 'bacon' + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '401 Unauthorized') + self.assertEqual(self.destroy_called, True) + + def test_bad_user(self): + req = webob.Request.blank('/v1.0/') + req.headers['X-Auth-User'] = 'herp' + req.headers['X-Auth-Key'] = 'derp' + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '401 Unauthorized') + + def test_no_user(self): + req = webob.Request.blank('/v1.0/') + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '401 Unauthorized') + + def test_bad_token(self): + req = webob.Request.blank('/v1.0/') + req.headers['X-Auth-Token'] = 'baconbaconbacon' + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '401 Unauthorized') + +if __name__ == '__main__': + unittest.main() diff --git a/nova/tests/api/rackspace/test_faults.py b/nova/tests/api/rackspace/test_faults.py new file mode 100644 index 000000000..b2931bc98 --- /dev/null +++ b/nova/tests/api/rackspace/test_faults.py @@ -0,0 +1,40 @@ +import unittest +import webob +import webob.dec +import webob.exc + +from nova.api.rackspace import faults + +class TestFaults(unittest.TestCase): + + def test_fault_parts(self): + req = webob.Request.blank('/.xml') + f = faults.Fault(webob.exc.HTTPBadRequest(explanation='scram')) + resp = req.get_response(f) + + first_two_words = resp.body.strip().split()[:2] + self.assertEqual(first_two_words, ['']) + body_without_spaces = ''.join(resp.body.split()) + self.assertTrue('scram' in body_without_spaces) + + def test_retry_header(self): + req = webob.Request.blank('/.xml') + exc = webob.exc.HTTPRequestEntityTooLarge(explanation='sorry', + headers={'Retry-After': 4}) + f = faults.Fault(exc) + resp = req.get_response(f) + first_two_words = resp.body.strip().split()[:2] + self.assertEqual(first_two_words, ['']) + body_sans_spaces = ''.join(resp.body.split()) + self.assertTrue('sorry' in body_sans_spaces) + self.assertTrue('4' in body_sans_spaces) + self.assertEqual(resp.headers['Retry-After'], 4) + + def test_raise(self): + @webob.dec.wsgify + def raiser(req): + raise faults.Fault(webob.exc.HTTPNotFound(explanation='whut?')) + req = webob.Request.blank('/.xml') + resp = req.get_response(raiser) + self.assertEqual(resp.status_int, 404) + self.assertTrue('whut?' in resp.body) diff --git a/nova/tests/api/rackspace/test_flavors.py b/nova/tests/api/rackspace/test_flavors.py new file mode 100644 index 000000000..affdd2406 --- /dev/null +++ b/nova/tests/api/rackspace/test_flavors.py @@ -0,0 +1,48 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import stubout +import webob + +import nova.api +from nova.api.rackspace import flavors +from nova.tests.api.rackspace import fakes + + +class FlavorsTest(unittest.TestCase): + def setUp(self): + self.stubs = stubout.StubOutForTesting() + fakes.FakeAuthManager.auth_data = {} + fakes.FakeAuthDatabase.data = {} + fakes.stub_out_networking(self.stubs) + fakes.stub_out_rate_limiting(self.stubs) + fakes.stub_out_auth(self.stubs) + + def tearDown(self): + self.stubs.UnsetAll() + + def test_get_flavor_list(self): + req = webob.Request.blank('/v1.0/flavors') + res = req.get_response(nova.api.API()) + + def test_get_flavor_by_id(self): + pass + +if __name__ == '__main__': + unittest.main() diff --git a/nova/tests/api/rackspace/test_helper.py b/nova/tests/api/rackspace/test_helper.py deleted file mode 100644 index 2cf154f63..000000000 --- a/nova/tests/api/rackspace/test_helper.py +++ /dev/null @@ -1,134 +0,0 @@ -import datetime -import json - -import webob -import webob.dec - -from nova import auth -from nova import utils -from nova import flags -import nova.api.rackspace.auth -import nova.api.rackspace._id_translator -from nova.image import service -from nova.wsgi import Router - -FLAGS = flags.FLAGS - -class Context(object): - pass - -class FakeRouter(Router): - def __init__(self): - pass - - @webob.dec.wsgify - def __call__(self, req): - res = webob.Response() - res.status = '200' - res.headers['X-Test-Success'] = 'True' - return res - -def fake_auth_init(self): - self.db = FakeAuthDatabase() - self.context = Context() - self.auth = FakeAuthManager() - self.host = 'foo' - -@webob.dec.wsgify -def fake_wsgi(self, req): - req.environ['nova.context'] = dict(user=dict(id=1)) - if req.body: - req.environ['inst_dict'] = json.loads(req.body) - return self.application - -def stub_out_key_pair_funcs(stubs): - def key_pair(context, user_id): - return [dict(name='key', public_key='public_key')] - stubs.Set(nova.db.api, 'key_pair_get_all_by_user', - key_pair) - -def stub_out_image_service(stubs): - def fake_image_show(meh, id): - return dict(kernelId=1, ramdiskId=1) - - stubs.Set(nova.image.service.LocalImageService, 'show', fake_image_show) - -def stub_out_id_translator(stubs): - class FakeTranslator(object): - def __init__(self, id_type, service_name): - pass - - def to_rs_id(self, id): - return id - - def from_rs_id(self, id): - return id - - stubs.Set(nova.api.rackspace._id_translator, - 'RackspaceAPIIdTranslator', FakeTranslator) - -def stub_out_auth(stubs): - def fake_auth_init(self, app): - self.application = app - - stubs.Set(nova.api.rackspace.AuthMiddleware, - '__init__', fake_auth_init) - stubs.Set(nova.api.rackspace.AuthMiddleware, - '__call__', fake_wsgi) - -def stub_out_rate_limiting(stubs): - def fake_rate_init(self, app): - super(nova.api.rackspace.RateLimitingMiddleware, self).__init__(app) - self.application = app - - stubs.Set(nova.api.rackspace.RateLimitingMiddleware, - '__init__', fake_rate_init) - - stubs.Set(nova.api.rackspace.RateLimitingMiddleware, - '__call__', fake_wsgi) - -def stub_for_testing(stubs): - def get_my_ip(): - return '127.0.0.1' - stubs.Set(nova.utils, 'get_my_ip', get_my_ip) - FLAGS.FAKE_subdomain = 'rs' - -class FakeAuthDatabase(object): - data = {} - - @staticmethod - def auth_get_token(context, token_hash): - return FakeAuthDatabase.data.get(token_hash, None) - - @staticmethod - def auth_create_token(context, token): - token['created_at'] = datetime.datetime.now() - FakeAuthDatabase.data[token['token_hash']] = token - - @staticmethod - def auth_destroy_token(context, token): - if FakeAuthDatabase.data.has_key(token['token_hash']): - del FakeAuthDatabase.data['token_hash'] - -class FakeAuthManager(object): - auth_data = {} - - def add_user(self, key, user): - FakeAuthManager.auth_data[key] = user - - def get_user(self, uid): - for k, v in FakeAuthManager.auth_data.iteritems(): - if v['uid'] == uid: - return v - return None - - def get_user_from_access_key(self, key): - return FakeAuthManager.auth_data.get(key, None) - -class FakeRateLimiter(object): - def __init__(self, application): - self.application = application - - @webob.dec.wsgify - def __call__(self, req): - return self.application diff --git a/nova/tests/api/rackspace/test_images.py b/nova/tests/api/rackspace/test_images.py new file mode 100644 index 000000000..489e35052 --- /dev/null +++ b/nova/tests/api/rackspace/test_images.py @@ -0,0 +1,39 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import stubout + +from nova.api.rackspace import images + + +class ImagesTest(unittest.TestCase): + def setUp(self): + self.stubs = stubout.StubOutForTesting() + + def tearDown(self): + self.stubs.UnsetAll() + + def test_get_image_list(self): + pass + + def test_delete_image(self): + pass + + def test_create_image(self): + pass diff --git a/nova/tests/api/rackspace/test_servers.py b/nova/tests/api/rackspace/test_servers.py new file mode 100644 index 000000000..9c1860879 --- /dev/null +++ b/nova/tests/api/rackspace/test_servers.py @@ -0,0 +1,250 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import json +import unittest + +import stubout +import webob + +from nova import db +from nova import flags +import nova.api.rackspace +from nova.api.rackspace import servers +import nova.db.api +from nova.db.sqlalchemy.models import Instance +import nova.rpc +from nova.tests.api.rackspace import fakes + + +FLAGS = flags.FLAGS + + +def return_server(context, id): + return stub_instance(id) + + +def return_servers(context, user_id=1): + return [stub_instance(i, user_id) for i in xrange(5)] + + +def stub_instance(id, user_id=1): + return Instance( + id=id, state=0, image_id=10, server_name='server%s'%id, + user_id=user_id + ) + + +class ServersTest(unittest.TestCase): + def setUp(self): + self.stubs = stubout.StubOutForTesting() + fakes.FakeAuthManager.auth_data = {} + fakes.FakeAuthDatabase.data = {} + fakes.stub_out_networking(self.stubs) + fakes.stub_out_rate_limiting(self.stubs) + fakes.stub_out_auth(self.stubs) + fakes.stub_out_id_translator(self.stubs) + fakes.stub_out_key_pair_funcs(self.stubs) + fakes.stub_out_image_service(self.stubs) + self.stubs.Set(nova.db.api, 'instance_get_all', return_servers) + self.stubs.Set(nova.db.api, 'instance_get_by_ec2_id', return_server) + self.stubs.Set(nova.db.api, 'instance_get_all_by_user', + return_servers) + + def tearDown(self): + self.stubs.UnsetAll() + + def test_get_server_by_id(self): + req = webob.Request.blank('/v1.0/servers/1') + res = req.get_response(nova.api.API()) + res_dict = json.loads(res.body) + self.assertEqual(res_dict['server']['id'], '1') + self.assertEqual(res_dict['server']['name'], 'server1') + + def test_get_server_list(self): + req = webob.Request.blank('/v1.0/servers') + res = req.get_response(nova.api.API()) + res_dict = json.loads(res.body) + + i = 0 + for s in res_dict['servers']: + self.assertEqual(s['id'], i) + self.assertEqual(s['name'], 'server%d'%i) + self.assertEqual(s.get('imageId', None), None) + i += 1 + + def test_create_instance(self): + def server_update(context, id, params): + pass + + def instance_create(context, inst): + class Foo(object): + ec2_id = 1 + return Foo() + + def fake_method(*args, **kwargs): + pass + + def project_get_network(context, user_id): + return dict(id='1', host='localhost') + + def queue_get_for(context, *args): + return 'network_topic' + + self.stubs.Set(nova.db.api, 'project_get_network', project_get_network) + self.stubs.Set(nova.db.api, 'instance_create', instance_create) + self.stubs.Set(nova.rpc, 'cast', fake_method) + self.stubs.Set(nova.rpc, 'call', fake_method) + self.stubs.Set(nova.db.api, 'instance_update', + server_update) + self.stubs.Set(nova.db.api, 'queue_get_for', queue_get_for) + self.stubs.Set(nova.network.manager.FlatManager, 'allocate_fixed_ip', + fake_method) + + fakes.stub_out_id_translator(self.stubs) + body = dict(server=dict( + name='server_test', imageId=2, flavorId=2, metadata={}, + personality = {} + )) + req = webob.Request.blank('/v1.0/servers') + req.method = 'POST' + req.body = json.dumps(body) + + res = req.get_response(nova.api.API()) + + self.assertEqual(res.status_int, 200) + + def test_update_no_body(self): + req = webob.Request.blank('/v1.0/servers/1') + req.method = 'PUT' + res = req.get_response(nova.api.API()) + self.assertEqual(res.status_int, 422) + + def test_update_bad_params(self): + """ Confirm that update is filtering params """ + inst_dict = dict(cat='leopard', name='server_test', adminPass='bacon') + self.body = json.dumps(dict(server=inst_dict)) + + def server_update(context, id, params): + self.update_called = True + filtered_dict = dict(name='server_test', admin_pass='bacon') + self.assertEqual(params, filtered_dict) + + self.stubs.Set(nova.db.api, 'instance_update', + server_update) + + req = webob.Request.blank('/v1.0/servers/1') + req.method = 'PUT' + req.body = self.body + req.get_response(nova.api.API()) + + def test_update_server(self): + inst_dict = dict(name='server_test', adminPass='bacon') + self.body = json.dumps(dict(server=inst_dict)) + + def server_update(context, id, params): + filtered_dict = dict(name='server_test', admin_pass='bacon') + self.assertEqual(params, filtered_dict) + + self.stubs.Set(nova.db.api, 'instance_update', + server_update) + + req = webob.Request.blank('/v1.0/servers/1') + req.method = 'PUT' + req.body = self.body + req.get_response(nova.api.API()) + + def test_create_backup_schedules(self): + req = webob.Request.blank('/v1.0/servers/1/backup_schedules') + req.method = 'POST' + res = req.get_response(nova.api.API()) + self.assertEqual(res.status, '404 Not Found') + + def test_delete_backup_schedules(self): + req = webob.Request.blank('/v1.0/servers/1/backup_schedules') + req.method = 'DELETE' + res = req.get_response(nova.api.API()) + self.assertEqual(res.status, '404 Not Found') + + def test_get_server_backup_schedules(self): + req = webob.Request.blank('/v1.0/servers/1/backup_schedules') + res = req.get_response(nova.api.API()) + self.assertEqual(res.status, '404 Not Found') + + def test_get_all_server_details(self): + req = webob.Request.blank('/v1.0/servers/detail') + res = req.get_response(nova.api.API()) + res_dict = json.loads(res.body) + + i = 0 + for s in res_dict['servers']: + self.assertEqual(s['id'], i) + self.assertEqual(s['name'], 'server%d'%i) + self.assertEqual(s['imageId'], 10) + i += 1 + + def test_server_reboot(self): + body = dict(server=dict( + name='server_test', imageId=2, flavorId=2, metadata={}, + personality = {} + )) + req = webob.Request.blank('/v1.0/servers/1/action') + req.method = 'POST' + req.content_type= 'application/json' + req.body = json.dumps(body) + res = req.get_response(nova.api.API()) + + def test_server_rebuild(self): + body = dict(server=dict( + name='server_test', imageId=2, flavorId=2, metadata={}, + personality = {} + )) + req = webob.Request.blank('/v1.0/servers/1/action') + req.method = 'POST' + req.content_type= 'application/json' + req.body = json.dumps(body) + res = req.get_response(nova.api.API()) + + def test_server_resize(self): + body = dict(server=dict( + name='server_test', imageId=2, flavorId=2, metadata={}, + personality = {} + )) + req = webob.Request.blank('/v1.0/servers/1/action') + req.method = 'POST' + req.content_type= 'application/json' + req.body = json.dumps(body) + res = req.get_response(nova.api.API()) + + def test_delete_server_instance(self): + req = webob.Request.blank('/v1.0/servers/1') + req.method = 'DELETE' + + self.server_delete_called = False + def instance_destroy_mock(context, id): + self.server_delete_called = True + + self.stubs.Set(nova.db.api, 'instance_destroy', + instance_destroy_mock) + + res = req.get_response(nova.api.API()) + self.assertEqual(res.status, '202 Accepted') + self.assertEqual(self.server_delete_called, True) + + +if __name__ == "__main__": + unittest.main() diff --git a/nova/tests/api/rackspace/test_sharedipgroups.py b/nova/tests/api/rackspace/test_sharedipgroups.py new file mode 100644 index 000000000..31ce967d0 --- /dev/null +++ b/nova/tests/api/rackspace/test_sharedipgroups.py @@ -0,0 +1,39 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import stubout + +from nova.api.rackspace import sharedipgroups + + +class SharedIpGroupsTest(unittest.TestCase): + def setUp(self): + self.stubs = stubout.StubOutForTesting() + + def tearDown(self): + self.stubs.UnsetAll() + + def test_get_shared_ip_groups(self): + pass + + def test_create_shared_ip_group(self): + pass + + def test_delete_shared_ip_group(self): + pass diff --git a/nova/tests/api/rackspace/testfaults.py b/nova/tests/api/rackspace/testfaults.py deleted file mode 100644 index b2931bc98..000000000 --- a/nova/tests/api/rackspace/testfaults.py +++ /dev/null @@ -1,40 +0,0 @@ -import unittest -import webob -import webob.dec -import webob.exc - -from nova.api.rackspace import faults - -class TestFaults(unittest.TestCase): - - def test_fault_parts(self): - req = webob.Request.blank('/.xml') - f = faults.Fault(webob.exc.HTTPBadRequest(explanation='scram')) - resp = req.get_response(f) - - first_two_words = resp.body.strip().split()[:2] - self.assertEqual(first_two_words, ['']) - body_without_spaces = ''.join(resp.body.split()) - self.assertTrue('scram' in body_without_spaces) - - def test_retry_header(self): - req = webob.Request.blank('/.xml') - exc = webob.exc.HTTPRequestEntityTooLarge(explanation='sorry', - headers={'Retry-After': 4}) - f = faults.Fault(exc) - resp = req.get_response(f) - first_two_words = resp.body.strip().split()[:2] - self.assertEqual(first_two_words, ['']) - body_sans_spaces = ''.join(resp.body.split()) - self.assertTrue('sorry' in body_sans_spaces) - self.assertTrue('4' in body_sans_spaces) - self.assertEqual(resp.headers['Retry-After'], 4) - - def test_raise(self): - @webob.dec.wsgify - def raiser(req): - raise faults.Fault(webob.exc.HTTPNotFound(explanation='whut?')) - req = webob.Request.blank('/.xml') - resp = req.get_response(raiser) - self.assertEqual(resp.status_int, 404) - self.assertTrue('whut?' in resp.body) diff --git a/nova/tests/api/test_helper.py b/nova/tests/api/test_helper.py deleted file mode 100644 index d0a2cc027..000000000 --- a/nova/tests/api/test_helper.py +++ /dev/null @@ -1,8 +0,0 @@ -import webob.dec -from nova import wsgi - -class APIStub(object): - """Class to verify request and mark it was called.""" - @webob.dec.wsgify - def __call__(self, req): - return req.path_info diff --git a/nova/tests/api/test_wsgi.py b/nova/tests/api/test_wsgi.py new file mode 100644 index 000000000..9425b01d0 --- /dev/null +++ b/nova/tests/api/test_wsgi.py @@ -0,0 +1,147 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Test WSGI basics and provide some helper functions for other WSGI tests. +""" + +import unittest + +import routes +import webob + +from nova import wsgi + + +class Test(unittest.TestCase): + + def test_debug(self): + + class Application(wsgi.Application): + """Dummy application to test debug.""" + + def __call__(self, environ, start_response): + start_response("200", [("X-Test", "checking")]) + return ['Test result'] + + application = wsgi.Debug(Application()) + result = webob.Request.blank('/').get_response(application) + self.assertEqual(result.body, "Test result") + + def test_router(self): + + class Application(wsgi.Application): + """Test application to call from router.""" + + def __call__(self, environ, start_response): + start_response("200", []) + return ['Router result'] + + class Router(wsgi.Router): + """Test router.""" + + def __init__(self): + mapper = routes.Mapper() + mapper.connect("/test", controller=Application()) + super(Router, self).__init__(mapper) + + result = webob.Request.blank('/test').get_response(Router()) + self.assertEqual(result.body, "Router result") + result = webob.Request.blank('/bad').get_response(Router()) + self.assertNotEqual(result.body, "Router result") + + def test_controller(self): + + class Controller(wsgi.Controller): + """Test controller to call from router.""" + test = self + + def show(self, req, id): # pylint: disable-msg=W0622,C0103 + """Default action called for requests with an ID.""" + self.test.assertEqual(req.path_info, '/tests/123') + self.test.assertEqual(id, '123') + return id + + class Router(wsgi.Router): + """Test router.""" + + def __init__(self): + mapper = routes.Mapper() + mapper.resource("test", "tests", controller=Controller()) + super(Router, self).__init__(mapper) + + result = webob.Request.blank('/tests/123').get_response(Router()) + self.assertEqual(result.body, "123") + result = webob.Request.blank('/test/123').get_response(Router()) + self.assertNotEqual(result.body, "123") + + +class SerializerTest(unittest.TestCase): + + def match(self, url, accept, expect): + input_dict = dict(servers=dict(a=(2,3))) + expected_xml = '(2,3)' + expected_json = '{"servers":{"a":[2,3]}}' + req = webob.Request.blank(url, headers=dict(Accept=accept)) + result = wsgi.Serializer(req.environ).to_content_type(input_dict) + result = result.replace('\n', '').replace(' ', '') + if expect == 'xml': + self.assertEqual(result, expected_xml) + elif expect == 'json': + self.assertEqual(result, expected_json) + else: + raise "Bad expect value" + + def test_basic(self): + self.match('/servers/4.json', None, expect='json') + self.match('/servers/4', 'application/json', expect='json') + self.match('/servers/4', 'application/xml', expect='xml') + self.match('/servers/4.xml', None, expect='xml') + + def test_defaults_to_json(self): + self.match('/servers/4', None, expect='json') + self.match('/servers/4', 'text/html', expect='json') + + def test_suffix_takes_precedence_over_accept_header(self): + self.match('/servers/4.xml', 'application/json', expect='xml') + self.match('/servers/4.xml.', 'application/json', expect='json') + + def test_deserialize(self): + xml = """ + + 123 + 1 + 1 + + """.strip() + as_dict = dict(a={ + 'a1': '1', + 'a2': '2', + 'bs': ['1', '2', '3', {'c': dict(c1='1')}], + 'd': {'e': '1'}, + 'f': '1'}) + metadata = {'application/xml': dict(plurals={'bs': 'b', 'ts': 't'})} + serializer = wsgi.Serializer({}, metadata) + self.assertEqual(serializer.deserialize(xml), as_dict) + + def test_deserialize_empty_xml(self): + xml = """""" + as_dict = {"a": {}} + serializer = wsgi.Serializer({}) + self.assertEqual(serializer.deserialize(xml), as_dict) diff --git a/nova/tests/api/wsgi_test.py b/nova/tests/api/wsgi_test.py deleted file mode 100644 index 9425b01d0..000000000 --- a/nova/tests/api/wsgi_test.py +++ /dev/null @@ -1,147 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 United States Government as represented by the -# Administrator of the National Aeronautics and Space Administration. -# Copyright 2010 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -""" -Test WSGI basics and provide some helper functions for other WSGI tests. -""" - -import unittest - -import routes -import webob - -from nova import wsgi - - -class Test(unittest.TestCase): - - def test_debug(self): - - class Application(wsgi.Application): - """Dummy application to test debug.""" - - def __call__(self, environ, start_response): - start_response("200", [("X-Test", "checking")]) - return ['Test result'] - - application = wsgi.Debug(Application()) - result = webob.Request.blank('/').get_response(application) - self.assertEqual(result.body, "Test result") - - def test_router(self): - - class Application(wsgi.Application): - """Test application to call from router.""" - - def __call__(self, environ, start_response): - start_response("200", []) - return ['Router result'] - - class Router(wsgi.Router): - """Test router.""" - - def __init__(self): - mapper = routes.Mapper() - mapper.connect("/test", controller=Application()) - super(Router, self).__init__(mapper) - - result = webob.Request.blank('/test').get_response(Router()) - self.assertEqual(result.body, "Router result") - result = webob.Request.blank('/bad').get_response(Router()) - self.assertNotEqual(result.body, "Router result") - - def test_controller(self): - - class Controller(wsgi.Controller): - """Test controller to call from router.""" - test = self - - def show(self, req, id): # pylint: disable-msg=W0622,C0103 - """Default action called for requests with an ID.""" - self.test.assertEqual(req.path_info, '/tests/123') - self.test.assertEqual(id, '123') - return id - - class Router(wsgi.Router): - """Test router.""" - - def __init__(self): - mapper = routes.Mapper() - mapper.resource("test", "tests", controller=Controller()) - super(Router, self).__init__(mapper) - - result = webob.Request.blank('/tests/123').get_response(Router()) - self.assertEqual(result.body, "123") - result = webob.Request.blank('/test/123').get_response(Router()) - self.assertNotEqual(result.body, "123") - - -class SerializerTest(unittest.TestCase): - - def match(self, url, accept, expect): - input_dict = dict(servers=dict(a=(2,3))) - expected_xml = '(2,3)' - expected_json = '{"servers":{"a":[2,3]}}' - req = webob.Request.blank(url, headers=dict(Accept=accept)) - result = wsgi.Serializer(req.environ).to_content_type(input_dict) - result = result.replace('\n', '').replace(' ', '') - if expect == 'xml': - self.assertEqual(result, expected_xml) - elif expect == 'json': - self.assertEqual(result, expected_json) - else: - raise "Bad expect value" - - def test_basic(self): - self.match('/servers/4.json', None, expect='json') - self.match('/servers/4', 'application/json', expect='json') - self.match('/servers/4', 'application/xml', expect='xml') - self.match('/servers/4.xml', None, expect='xml') - - def test_defaults_to_json(self): - self.match('/servers/4', None, expect='json') - self.match('/servers/4', 'text/html', expect='json') - - def test_suffix_takes_precedence_over_accept_header(self): - self.match('/servers/4.xml', 'application/json', expect='xml') - self.match('/servers/4.xml.', 'application/json', expect='json') - - def test_deserialize(self): - xml = """ - - 123 - 1 - 1 - - """.strip() - as_dict = dict(a={ - 'a1': '1', - 'a2': '2', - 'bs': ['1', '2', '3', {'c': dict(c1='1')}], - 'd': {'e': '1'}, - 'f': '1'}) - metadata = {'application/xml': dict(plurals={'bs': 'b', 'ts': 't'})} - serializer = wsgi.Serializer({}, metadata) - self.assertEqual(serializer.deserialize(xml), as_dict) - - def test_deserialize_empty_xml(self): - xml = """""" - as_dict = {"a": {}} - serializer = wsgi.Serializer({}) - self.assertEqual(serializer.deserialize(xml), as_dict) diff --git a/tools/pip-requires b/tools/pip-requires index 1e2707be7..6c3940372 100644 --- a/tools/pip-requires +++ b/tools/pip-requires @@ -20,3 +20,4 @@ zope.interface==3.6.1 mox==0.5.0 -f http://pymox.googlecode.com/files/mox-0.5.0.tar.gz greenlet==0.3.1 +nose -- cgit From 0ef621d47eeea421820a2191de53dee9e83d8c44 Mon Sep 17 00:00:00 2001 From: "jaypipes@gmail.com" <> Date: Fri, 1 Oct 2010 16:06:14 -0400 Subject: Adds BaseImageService and flag to control image service loading. Adds unit test for local image service. --- nova/api/rackspace/images.py | 7 +- nova/flags.py | 4 ++ nova/image/service.py | 114 ++++++++++++++++++++++++++++---- nova/tests/api/rackspace/test_images.py | 102 +++++++++++++++++++++++++--- 4 files changed, 206 insertions(+), 21 deletions(-) diff --git a/nova/api/rackspace/images.py b/nova/api/rackspace/images.py index 4a7dd489c..d4ab8ce3c 100644 --- a/nova/api/rackspace/images.py +++ b/nova/api/rackspace/images.py @@ -17,12 +17,17 @@ from webob import exc +from nova import flags +from nova import utils from nova import wsgi from nova.api.rackspace import _id_translator import nova.api.rackspace import nova.image.service from nova.api.rackspace import faults + +FLAGS = flags.FLAGS + class Controller(wsgi.Controller): _serialization_metadata = { @@ -35,7 +40,7 @@ class Controller(wsgi.Controller): } def __init__(self): - self._service = nova.image.service.ImageService.load() + self._service = utils.import_object(FLAGS.image_service) self._id_translator = _id_translator.RackspaceAPIIdTranslator( "image", self._service.__class__.__name__) diff --git a/nova/flags.py b/nova/flags.py index c32cdd7a4..ab80e83fb 100644 --- a/nova/flags.py +++ b/nova/flags.py @@ -222,6 +222,10 @@ DEFINE_string('volume_manager', 'nova.volume.manager.AOEManager', DEFINE_string('scheduler_manager', 'nova.scheduler.manager.SchedulerManager', 'Manager for scheduler') +# The service to use for image search and retrieval +DEFINE_string('image_service', 'nova.image.service.LocalImageService', + 'The service to use for retrieving and searching for images.') + DEFINE_string('host', socket.gethostname(), 'name of this node') diff --git a/nova/image/service.py b/nova/image/service.py index 1a7a258b7..4bceab6ee 100644 --- a/nova/image/service.py +++ b/nova/image/service.py @@ -20,34 +20,117 @@ import os.path import random import string -class ImageService(object): - """Provides storage and retrieval of disk image objects.""" +from nova import utils +from nova import flags - @staticmethod - def load(): - """Factory method to return image service.""" - #TODO(gundlach): read from config. - class_ = LocalImageService - return class_() + +FLAGS = flags.FLAGS + + +flags.DEFINE_string('glance_teller_address', '127.0.0.1', + 'IP address or URL where Glance\'s Teller service resides') +flags.DEFINE_string('glance_teller_port', '9191', + 'Port for Glance\'s Teller service') +flags.DEFINE_string('glance_parallax_address', '127.0.0.1', + 'IP address or URL where Glance\'s Parallax service resides') +flags.DEFINE_string('glance_parallax_port', '9191', + 'Port for Glance\'s Parallax service') + + +class BaseImageService(object): + + """Base class for providing image search and retrieval services""" def index(self): """ Return a dict from opaque image id to image data. """ + raise NotImplementedError def show(self, id): """ Returns a dict containing image data for the given opaque image id. """ + raise NotImplementedError + + def create(self, data): + """ + Store the image data and return the new image id. + + :raises AlreadyExists if the image already exist. + """ + raise NotImplementedError + + def update(self, image_id, data): + """Replace the contents of the given image with the new data. -class GlanceImageService(ImageService): + :raises NotFound if the image does not exist. + + """ + raise NotImplementedError + + def delete(self, image_id): + """ + Delete the given image. + + :raises NotFound if the image does not exist. + + """ + raise NotImplementedError + + +class GlanceImageService(BaseImageService): + """Provides storage and retrieval of disk image objects within Glance.""" - # TODO(gundlach): once Glance has an API, build this. - pass + def index(self): + """ + Calls out to Parallax for a list of images available + """ + raise NotImplementedError + + def show(self, id): + """ + Returns a dict containing image data for the given opaque image id. + """ + raise NotImplementedError + + def create(self, data): + """ + Store the image data and return the new image id. + + :raises AlreadyExists if the image already exist. + + """ + raise NotImplementedError + + def update(self, image_id, data): + """Replace the contents of the given image with the new data. + + :raises NotFound if the image does not exist. + + """ + raise NotImplementedError + + def delete(self, image_id): + """ + Delete the given image. + + :raises NotFound if the image does not exist. + + """ + raise NotImplementedError + + def delete_all(self): + """ + Clears out all images + """ + pass + + +class LocalImageService(BaseImageService): -class LocalImageService(ImageService): """Image service storing images to local disk.""" def __init__(self): @@ -88,3 +171,10 @@ class LocalImageService(ImageService): Delete the given image. Raises OSError if the image does not exist. """ os.unlink(self._path_to(image_id)) + + def delete_all(self): + """ + Clears out all images in local directory + """ + for f in os.listdir(self._path): + os.unlink(self._path_to(f)) diff --git a/nova/tests/api/rackspace/test_images.py b/nova/tests/api/rackspace/test_images.py index 489e35052..21dad7648 100644 --- a/nova/tests/api/rackspace/test_images.py +++ b/nova/tests/api/rackspace/test_images.py @@ -15,25 +15,111 @@ # License for the specific language governing permissions and limitations # under the License. +import logging import unittest import stubout +from nova import utils from nova.api.rackspace import images -class ImagesTest(unittest.TestCase): +#{ Fixtures + + +fixture_images = [ + { + 'name': 'image #1', + 'updated': None, + 'created': None, + 'status': None, + 'serverId': None, + 'progress': None}, + { + 'name': 'image #2', + 'updated': None, + 'created': None, + 'status': None, + 'serverId': None, + 'progress': None}, + { + 'name': 'image #3', + 'updated': None, + 'created': None, + 'status': None, + 'serverId': None, + 'progress': None}] + + +#} + + +class BaseImageServiceTests(): + + """Tasks to test for all image services""" + + def test_create_and_index(self): + for i in fixture_images: + self.service.create(i) + + self.assertEquals(len(fixture_images), len(self.service.index())) + + def test_create_and_update(self): + ids = {} + temp = 0 + for i in fixture_images: + ids[self.service.create(i)] = temp + temp += 1 + + self.assertEquals(len(fixture_images), len(self.service.index())) + + for image_id, num in ids.iteritems(): + new_data = fixture_images[num] + new_data['updated'] = 'test' + str(num) + self.service.update(image_id, new_data) + + images = self.service.index() + + for i in images: + self.assertEquals('test' + str(ids[i['id']]), + i['updated']) + + def test_create_and_show(self): + ids = {} + temp = 0 + for i in fixture_images: + ids[self.service.create(i)] = temp + temp += 1 + + for i in fixture_images: + image = self.service.show(i['id']) + index = ids[i['id']] + self.assertEquals(image, fixture_images[index]) + + +class LocalImageServiceTest(unittest.TestCase, + BaseImageServiceTests): + + """Tests the local image service""" + def setUp(self): self.stubs = stubout.StubOutForTesting() + self.service = utils.import_object('nova.image.service.LocalImageService') def tearDown(self): + self.service.delete_all() self.stubs.UnsetAll() - def test_get_image_list(self): - pass - def test_delete_image(self): - pass - - def test_create_image(self): - pass +#class GlanceImageServiceTest(unittest.TestCase, +# BaseImageServiceTests): +# +# """Tests the local image service""" +# +# def setUp(self): +# self.stubs = stubout.StubOutForTesting() +# self.service = utils.import_object('nova.image.service.GlanceImageService') +# +# def tearDown(self): +# self.service.delete_all() +# self.stubs.UnsetAll() -- cgit From 4b11351aba7e71154f82a6a76590c786b4d7a53a Mon Sep 17 00:00:00 2001 From: Todd Willey Date: Fri, 1 Oct 2010 21:46:36 -0400 Subject: Keep handles to loggers open after daemonizing. --- nova/server.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nova/server.py b/nova/server.py index d4563bfe0..c58a15041 100644 --- a/nova/server.py +++ b/nova/server.py @@ -106,6 +106,7 @@ def serve(name, main): def daemonize(args, name, main): """Does the work of daemonizing the process""" logging.getLogger('amqplib').setLevel(logging.WARN) + files_to_keep = [] if FLAGS.daemonize: logger = logging.getLogger() formatter = logging.Formatter( @@ -114,12 +115,14 @@ def daemonize(args, name, main): syslog = logging.handlers.SysLogHandler(address='/dev/log') syslog.setFormatter(formatter) logger.addHandler(syslog) + files_to_keep.append(syslog.socket) else: if not FLAGS.logfile: FLAGS.logfile = '%s.log' % name logfile = logging.FileHandler(FLAGS.logfile) logfile.setFormatter(formatter) logger.addHandler(logfile) + files_to_keep.append(logfile.stream) stdin, stdout, stderr = None, None, None else: stdin, stdout, stderr = sys.stdin, sys.stdout, sys.stderr @@ -139,6 +142,7 @@ def daemonize(args, name, main): stdout=stdout, stderr=stderr, uid=FLAGS.uid, - gid=FLAGS.gid + gid=FLAGS.gid, + files_preserve=files_to_keep ): main(args) -- cgit From 033c464882c3d74ecd863abde767f37e7ad6a956 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Sat, 2 Oct 2010 12:39:47 +0200 Subject: Make _dhcp_file ensure the existence of the directory containing the files it returns. --- nova/network/linux_net.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nova/network/linux_net.py b/nova/network/linux_net.py index 95f7fe2d0..37f9c8253 100644 --- a/nova/network/linux_net.py +++ b/nova/network/linux_net.py @@ -274,6 +274,9 @@ def _stop_dnsmasq(network): def _dhcp_file(vlan, kind): """Return path to a pid, leases or conf file for a vlan""" + if not os.path.exists(FLAGS.networks_path): + os.makedirs(FLAGS.networks_path) + return os.path.abspath("%s/nova-%s.%s" % (FLAGS.networks_path, vlan, kind)) -- cgit From 50fc372c1f4b5924b73de5c25100ce42166c4f12 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Sat, 2 Oct 2010 12:56:54 +0200 Subject: Adjust db api usage according to recent refactoring. --- nova/db/sqlalchemy/api.py | 89 +++++++++++++++++++++++++++++------------------ 1 file changed, 55 insertions(+), 34 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index b70e7dd4a..49d015716 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -1278,18 +1278,39 @@ def volume_update(context, volume_id, values): ################### -def user_get(context, id): - return models.User.find(id, deleted=_deleted(context)) +@require_admin_context +def user_get(context, id, session=None): + if not session: + session = get_session() + + result = session.query(models.User + ).filter_by(id=id + ).filter_by(deleted=can_read_deleted(context) + ).first() + if not result: + raise exception.NotFound('No user for id %s' % id) -def user_get_by_access_key(context, access_key): - session = get_session() - return session.query(models.User + return result + + +@require_admin_context +def user_get_by_access_key(context, access_key, session=None): + if not session: + session = get_session() + + result = session.query(models.User ).filter_by(access_key=access_key - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).first() + if not result: + raise exception.NotFound('No user for id %s' % id) + + return result + +@require_admin_context def user_create(_context, values): user_ref = models.User() for (key, value) in values.iteritems(): @@ -1298,6 +1319,7 @@ def user_create(_context, values): return user_ref +@require_admin_context def user_delete(context, id): session = get_session() with session.begin(): @@ -1307,14 +1329,14 @@ def user_delete(context, id): {'id': id}) session.execute('delete from user_project_role_association where user_id=:id', {'id': id}) - user_ref = models.User.find(id, session=session) + user_ref = user_get(context, id, session=session) session.delete(user_ref) def user_get_all(context): session = get_session() return session.query(models.User - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).all() @@ -1329,29 +1351,33 @@ def project_create(_context, values): def project_add_member(context, project_id, user_id): session = get_session() with session.begin(): - project_ref = models.Project.find(project_id, session=session) - user_ref = models.User.find(user_id, session=session) + project_ref = project_get(context, project_id, session=session) + user_ref = user_get(context, user_id, session=session) project_ref.members += [user_ref] project_ref.save(session=session) -def project_get(context, id): - session = get_session() +def project_get(context, id, session=None): + if not session: + session = get_session() + result = session.query(models.Project ).filter_by(deleted=False ).filter_by(id=id ).options(joinedload_all('members') ).first() + if not result: raise exception.NotFound("No project with id %s" % id) + return result def project_get_all(context): session = get_session() return session.query(models.Project - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).options(joinedload_all('members') ).all() @@ -1359,7 +1385,7 @@ def project_get_all(context): def project_get_by_user(context, user_id): session = get_session() user = session.query(models.User - ).filter_by(deleted=_deleted(context) + ).filter_by(deleted=can_read_deleted(context) ).options(joinedload_all('projects') ).first() return user.projects @@ -1367,32 +1393,27 @@ def project_get_by_user(context, user_id): def project_remove_member(context, project_id, user_id): session = get_session() - project = models.Project.find(project_id, session=session) - user = models.User.find(user_id, session=session) - if not project: - raise exception.NotFound('Project id "%s" not found' % (project_id,)) - - if not user: - raise exception.NotFound('User id "%s" not found' % (user_id,)) + project = project_get(context, project_id, session=session) + user = user_get(context, user_id, session=session) if user in project.members: project.members.remove(user) project.save(session=session) -def user_update(_context, user_id, values): +def user_update(context, user_id, values): session = get_session() with session.begin(): - user_ref = models.User.find(user_id, session=session) + user_ref = user_get(context, user_id, session=session) for (key, value) in values.iteritems(): user_ref[key] = value user_ref.save(session=session) -def project_update(_context, project_id, values): +def project_update(context, project_id, values): session = get_session() with session.begin(): - project_ref = models.Project.find(project_id, session=session) + project_ref = project_get(context, project_id, session=session) for (key, value) in values.iteritems(): project_ref[key] = value project_ref.save(session=session) @@ -1405,17 +1426,17 @@ def project_delete(context, id): {'id': id}) session.execute('delete from user_project_role_association where project_id=:id', {'id': id}) - project_ref = models.Project.find(id, session=session) + project_ref = project_get(context, id, session=session) session.delete(project_ref) def user_get_roles(context, user_id): session = get_session() with session.begin(): - user_ref = models.User.find(user_id, session=session) + user_ref = user_get(context, user_id, session=session) return [role.role for role in user_ref['roles']] - + def user_get_roles_for_project(context, user_id, project_id): session = get_session() with session.begin(): @@ -1434,7 +1455,7 @@ def user_remove_project_role(context, user_id, project_id, role): 'project_id' : project_id, 'role' : role }) - + def user_remove_role(context, user_id, role): session = get_session() with session.begin(): @@ -1449,18 +1470,18 @@ def user_remove_role(context, user_id, role): def user_add_role(context, user_id, role): session = get_session() with session.begin(): - user_ref = models.User.find(user_id, session=session) + user_ref = user_get(context, user_id, session=session) models.UserRoleAssociation(user=user_ref, role=role).save(session=session) - + def user_add_project_role(context, user_id, project_id, role): session = get_session() with session.begin(): - user_ref = models.User.find(user_id, session=session) - project_ref = models.Project.find(project_id, session=session) + user_ref = user_get(context, user_id, session=session) + project_ref = project_get(context, project_id, session=session) models.UserProjectRoleAssociation(user_id=user_ref['id'], project_id=project_ref['id'], role=role).save(session=session) - + ################### -- cgit From 5945291281f239bd928cea1833ee5a5b6c3df523 Mon Sep 17 00:00:00 2001 From: Ewan Mellor Date: Sat, 2 Oct 2010 12:42:09 +0100 Subject: Bug #653534: NameError on session_get in sqlalchemy.api.service_update Fix function call: session_get was meant to be service_get. --- nova/db/sqlalchemy/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 7f72f66b9..9a7c71a70 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -240,7 +240,7 @@ def service_create(context, values): def service_update(context, service_id, values): session = get_session() with session.begin(): - service_ref = session_get(context, service_id, session=session) + service_ref = service_get(context, service_id, session=session) for (key, value) in values.iteritems(): service_ref[key] = value service_ref.save(session=session) -- cgit From c66d550d208544799fdaf4646a846e9f9c0b6bc5 Mon Sep 17 00:00:00 2001 From: Ewan Mellor Date: Sat, 2 Oct 2010 13:11:33 +0100 Subject: Bug #653560: AttributeError in VlanManager.periodic_tasks Pass the correct context to db.fixed_ip_disassociate_all_by_timeout. --- nova/network/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nova/network/manager.py b/nova/network/manager.py index ef1d01138..9580479e5 100644 --- a/nova/network/manager.py +++ b/nova/network/manager.py @@ -230,7 +230,7 @@ class VlanManager(NetworkManager): now = datetime.datetime.utcnow() timeout = FLAGS.fixed_ip_disassociate_timeout time = now - datetime.timedelta(seconds=timeout) - num = self.db.fixed_ip_disassociate_all_by_timeout(self, + num = self.db.fixed_ip_disassociate_all_by_timeout(context, self.host, time) if num: -- cgit From 12e43d9deb3984d2b7ccc91490ffa4c13eedbe2b Mon Sep 17 00:00:00 2001 From: Ewan Mellor Date: Sat, 2 Oct 2010 16:55:57 +0100 Subject: Bug #653651: XenAPI support completely broken by orm-refactor merge Matches changes in the database / model layer with corresponding fixes to nova.virt.xenapi. --- nova/virt/xenapi.py | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/nova/virt/xenapi.py b/nova/virt/xenapi.py index 0d06b1fce..118e0b687 100644 --- a/nova/virt/xenapi.py +++ b/nova/virt/xenapi.py @@ -42,10 +42,12 @@ from twisted.internet import defer from twisted.internet import reactor from twisted.internet import task +from nova import db from nova import flags from nova import process from nova import utils from nova.auth.manager import AuthManager +from nova.compute import instance_types from nova.compute import power_state from nova.virt import images @@ -113,32 +115,24 @@ class XenAPIConnection(object): raise Exception('Attempted to create non-unique name %s' % instance.name) - if 'bridge_name' in instance.datamodel: - network_ref = \ - yield self._find_network_with_bridge( - instance.datamodel['bridge_name']) - else: - network_ref = None - - if 'mac_address' in instance.datamodel: - mac_address = instance.datamodel['mac_address'] - else: - mac_address = '' + network = db.project_get_network(None, instance.project_id) + network_ref = \ + yield self._find_network_with_bridge(network.bridge) - user = AuthManager().get_user(instance.datamodel['user_id']) - project = AuthManager().get_project(instance.datamodel['project_id']) + user = AuthManager().get_user(instance.user_id) + project = AuthManager().get_project(instance.project_id) vdi_uuid = yield self._fetch_image( - instance.datamodel['image_id'], user, project, True) + instance.image_id, user, project, True) kernel = yield self._fetch_image( - instance.datamodel['kernel_id'], user, project, False) + instance.kernel_id, user, project, False) ramdisk = yield self._fetch_image( - instance.datamodel['ramdisk_id'], user, project, False) + instance.ramdisk_id, user, project, False) vdi_ref = yield self._call_xenapi('VDI.get_by_uuid', vdi_uuid) vm_ref = yield self._create_vm(instance, kernel, ramdisk) yield self._create_vbd(vm_ref, vdi_ref, 0, True) if network_ref: - yield self._create_vif(vm_ref, network_ref, mac_address) + yield self._create_vif(vm_ref, network_ref, instance.mac_address) logging.debug('Starting VM %s...', vm_ref) yield self._call_xenapi('VM.start', vm_ref, False, False) logging.info('Spawning VM %s created %s.', instance.name, vm_ref) @@ -148,8 +142,9 @@ class XenAPIConnection(object): """Create a VM record. Returns a Deferred that gives the new VM reference.""" - mem = str(long(instance.datamodel['memory_kb']) * 1024) - vcpus = str(instance.datamodel['vcpus']) + instance_type = instance_types.INSTANCE_TYPES[instance.instance_type] + mem = str(long(instance_type['memory_mb']) * 1024 * 1024) + vcpus = str(instance_type['vcpus']) rec = { 'name_label': instance.name, 'name_description': '', -- cgit From 48ff601a3ab2d72275061135cac56557042e8e9d Mon Sep 17 00:00:00 2001 From: Vishvananda Ishaya Date: Sat, 2 Oct 2010 12:46:12 -0700 Subject: fix typo in setup_compute_network --- nova/compute/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nova/compute/manager.py b/nova/compute/manager.py index f370ede8b..4c6d2f06f 100644 --- a/nova/compute/manager.py +++ b/nova/compute/manager.py @@ -71,7 +71,7 @@ class ComputeManager(manager.Manager): raise exception.Error("Instance has already been created") logging.debug("instance %s: starting...", instance_id) project_id = instance_ref['project_id'] - self.network_manager.setup_compute_network(context, project_id) + self.network_manager.setup_compute_network(context, instance_id) self.db.instance_update(context, instance_id, {'host': self.host}) -- cgit From 65e2bbc31a7e4ea5d8f9456c2ea5b54715305d11 Mon Sep 17 00:00:00 2001 From: Ewan Mellor Date: Sun, 3 Oct 2010 12:41:07 +0100 Subject: Bug #654023: nova-manage vpn commands broken, resulting in erroneous "Wrong number of arguments supplied" message Add a context of None to the call to db.instance_get_all. This is deprecated, but it's what all the other calls in this file do, and it's better than exploding, so it will do for now. --- bin/nova-manage | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/nova-manage b/bin/nova-manage index bf3c67612..0b5869dfd 100755 --- a/bin/nova-manage +++ b/bin/nova-manage @@ -114,7 +114,7 @@ class VpnCommands(object): def _vpn_for(self, project_id): """Get the VPN instance for a project ID.""" - for instance in db.instance_get_all(): + for instance in db.instance_get_all(None): if (instance['image_id'] == FLAGS.vpn_image_id and not instance['state_description'] in ['shutting_down', 'shutdown'] -- cgit From 4fa2258af9fb130be1650372cf48be39e83451e5 Mon Sep 17 00:00:00 2001 From: Ewan Mellor Date: Sun, 3 Oct 2010 13:12:32 +0100 Subject: Bug #654025: nova-manage project zip and nova-manage vpn list broken by change in DB semantics when networks are missing Catch exception.NotFound when getting project VPN data. This is in two places: nova-manage as part of its vpn list command, and auth.manager.AuthManager.get_credentials. Also, document the behaviour of db.api.project_get_network. --- bin/nova-manage | 9 +++++++-- nova/auth/manager.py | 5 ++++- nova/db/api.py | 6 +++++- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/bin/nova-manage b/bin/nova-manage index bf3c67612..5293fc942 100755 --- a/bin/nova-manage +++ b/bin/nova-manage @@ -88,11 +88,16 @@ class VpnCommands(object): def list(self): """Print a listing of the VPNs for all projects.""" print "%-12s\t" % 'project', - print "%-12s\t" % 'ip:port', + print "%-20s\t" % 'ip:port', print "%s" % 'state' for project in self.manager.get_projects(): print "%-12s\t" % project.name, - print "%s:%s\t" % (project.vpn_ip, project.vpn_port), + + try: + s = "%s:%s" % (project.vpn_ip, project.vpn_port) + except exception.NotFound: + s = "None" + print "%-20s\t" % s, vpn = self._vpn_for(project.id) if vpn: diff --git a/nova/auth/manager.py b/nova/auth/manager.py index 0bc12c80f..c30192e20 100644 --- a/nova/auth/manager.py +++ b/nova/auth/manager.py @@ -653,7 +653,10 @@ class AuthManager(object): zippy.writestr(FLAGS.credential_key_file, private_key) zippy.writestr(FLAGS.credential_cert_file, signed_cert) - (vpn_ip, vpn_port) = self.get_project_vpn_data(project) + try: + (vpn_ip, vpn_port) = self.get_project_vpn_data(project) + except exception.NotFound: + vpn_ip = None if vpn_ip: configfile = open(FLAGS.vpn_client_template, "r") s = string.Template(configfile.read()) diff --git a/nova/db/api.py b/nova/db/api.py index 5c935b561..d34f1b2cb 100644 --- a/nova/db/api.py +++ b/nova/db/api.py @@ -432,7 +432,11 @@ def network_update(context, network_id, values): def project_get_network(context, project_id): - """Return the network associated with the project.""" + """Return the network associated with the project. + + Raises NotFound if no such network can be found. + + """ return IMPL.project_get_network(context, project_id) -- cgit From a0498717e470eb6fd52a4f26101c3513d90a3974 Mon Sep 17 00:00:00 2001 From: Ewan Mellor Date: Sun, 3 Oct 2010 13:17:20 +0100 Subject: Bug #654034: nova-manage doesn't honour --verbose flag Honour the --verbose flag by setting the logging level to DEBUG. --- bin/nova-manage | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bin/nova-manage b/bin/nova-manage index bf3c67612..ce87b9437 100755 --- a/bin/nova-manage +++ b/bin/nova-manage @@ -52,6 +52,7 @@ CLI interface for nova management. """ +import logging import os import sys import time @@ -417,6 +418,10 @@ def main(): """Parse options and call the appropriate class/method.""" utils.default_flagfile('/etc/nova/nova-manage.conf') argv = FLAGS(sys.argv) + + if FLAGS.verbose: + logging.getLogger().setLevel(logging.DEBUG) + script_name = argv.pop(0) if len(argv) < 1: print script_name + " category action []" -- cgit From 4e45f9472a95207153d32c88df8396c633c67a5d Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Sun, 3 Oct 2010 20:22:35 +0200 Subject: s/APIRequestContext/get_admin_context/ <-- sudo for request contexts. --- nova/tests/network_unittest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nova/tests/network_unittest.py b/nova/tests/network_unittest.py index 5370966d2..59b0a36e4 100644 --- a/nova/tests/network_unittest.py +++ b/nova/tests/network_unittest.py @@ -56,8 +56,8 @@ class NetworkTestCase(test.TrialTestCase): 'netuser', name)) # create the necessary network data for the project - user_context = context.APIRequestContext(project=self.projects[i], - user=self.user) + user_context = context.get_admin_context(user=self.user) + self.network.set_network_host(user_context, self.projects[i].id) instance_ref = self._create_instance(0) self.instance_id = instance_ref['id'] -- cgit From 077c008546123291dbc89ac31b492df6d176e339 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 4 Oct 2010 11:53:27 +0200 Subject: Move manager_class instantiation and db.service_* calls out of nova.service.Service.__init__ into a new nova.service.Service.startService method which gets called by twisted. This delays opening db connections (and thus sqlite file creation) until after privileges have been shed by twisted. --- nova/service.py | 12 +++++++++--- nova/tests/scheduler_unittest.py | 10 ++++++++++ nova/tests/service_unittest.py | 3 +++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/nova/service.py b/nova/service.py index a6c186896..dadef3c48 100644 --- a/nova/service.py +++ b/nova/service.py @@ -52,11 +52,17 @@ class Service(object, service.Service): self.host = host self.binary = binary self.topic = topic - manager_class = utils.import_class(manager) - self.manager = manager_class(host=host, *args, **kwargs) + self.manager_class_name = manager + super(Service, self).__init__(*args, **kwargs) + self.saved_args, self.saved_kwargs = args, kwargs + + + def startService(self): + manager_class = utils.import_class(self.manager_class_name) + self.manager = manager_class(host=self.host, *self.saved_args, + **self.saved_kwargs) self.manager.init_host() self.model_disconnected = False - super(Service, self).__init__(*args, **kwargs) try: service_ref = db.service_get_by_args(None, self.host, diff --git a/nova/tests/scheduler_unittest.py b/nova/tests/scheduler_unittest.py index fde30f81e..53a8be144 100644 --- a/nova/tests/scheduler_unittest.py +++ b/nova/tests/scheduler_unittest.py @@ -117,10 +117,12 @@ class SimpleDriverTestCase(test.TrialTestCase): 'nova-compute', 'compute', FLAGS.compute_manager) + compute1.startService() compute2 = service.Service('host2', 'nova-compute', 'compute', FLAGS.compute_manager) + compute2.startService() hosts = self.scheduler.driver.hosts_up(self.context, 'compute') self.assertEqual(len(hosts), 2) compute1.kill() @@ -132,10 +134,12 @@ class SimpleDriverTestCase(test.TrialTestCase): 'nova-compute', 'compute', FLAGS.compute_manager) + compute1.startService() compute2 = service.Service('host2', 'nova-compute', 'compute', FLAGS.compute_manager) + compute2.startService() instance_id1 = self._create_instance() compute1.run_instance(self.context, instance_id1) instance_id2 = self._create_instance() @@ -153,10 +157,12 @@ class SimpleDriverTestCase(test.TrialTestCase): 'nova-compute', 'compute', FLAGS.compute_manager) + compute1.startService() compute2 = service.Service('host2', 'nova-compute', 'compute', FLAGS.compute_manager) + compute2.startService() instance_ids1 = [] instance_ids2 = [] for index in xrange(FLAGS.max_cores): @@ -184,10 +190,12 @@ class SimpleDriverTestCase(test.TrialTestCase): 'nova-volume', 'volume', FLAGS.volume_manager) + volume1.startService() volume2 = service.Service('host2', 'nova-volume', 'volume', FLAGS.volume_manager) + volume2.startService() volume_id1 = self._create_volume() volume1.create_volume(self.context, volume_id1) volume_id2 = self._create_volume() @@ -205,10 +213,12 @@ class SimpleDriverTestCase(test.TrialTestCase): 'nova-volume', 'volume', FLAGS.volume_manager) + volume1.startService() volume2 = service.Service('host2', 'nova-volume', 'volume', FLAGS.volume_manager) + volume2.startService() volume_ids1 = [] volume_ids2 = [] for index in xrange(FLAGS.max_gigabytes): diff --git a/nova/tests/service_unittest.py b/nova/tests/service_unittest.py index 06f80e82c..6afeec377 100644 --- a/nova/tests/service_unittest.py +++ b/nova/tests/service_unittest.py @@ -22,6 +22,8 @@ Unit Tests for remote procedure calls using queue import mox +from twisted.application.app import startApplication + from nova import exception from nova import flags from nova import rpc @@ -96,6 +98,7 @@ class ServiceTestCase(test.BaseTestCase): self.mox.ReplayAll() app = service.Service.create(host=host, binary=binary) + startApplication(app, False) self.assert_(app) # We're testing sort of weird behavior in how report_state decides -- cgit From 5c4b1a38b8a82ee0a8f14f813f91d319a9715cc3 Mon Sep 17 00:00:00 2001 From: mdietz Date: Mon, 4 Oct 2010 16:01:44 +0000 Subject: More clean up and conflict resolution --- nova/api/ec2/cloud.py | 2 +- nova/compute/manager.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 2fec49da8..f43da42bd 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -719,7 +719,7 @@ class CloudController(object): changes[field] = kwargs[field] if changes: db_context = {} - internal_id = ec2_id_to_internal_id(ec2_id) + internal_id = self.ec2_id_to_internal_id(ec2_id) inst = db.instance_get_by_internal_id(db_context, internal_id) db.instance_update(db_context, inst['id'], kwargs) return True diff --git a/nova/compute/manager.py b/nova/compute/manager.py index 131fac406..fb0876578 100644 --- a/nova/compute/manager.py +++ b/nova/compute/manager.py @@ -174,7 +174,7 @@ class ComputeManager(manager.Manager): instance_ref = self.db.instance_get(context, instance_id) dev_path = yield self.volume_manager.setup_compute_volume(context, volume_id) - yield self.driver.attach_volume(instance_ref['ec2_id'], + yield self.driver.attach_volume(instance_ref['internal_id'], dev_path, mountpoint) self.db.volume_attached(context, volume_id, instance_id, mountpoint) @@ -189,7 +189,7 @@ class ComputeManager(manager.Manager): volume_id) instance_ref = self.db.instance_get(context, instance_id) volume_ref = self.db.volume_get(context, volume_id) - yield self.driver.detach_volume(instance_ref['ec2_id'], + yield self.driver.detach_volume(instance_ref['internal_id'], volume_ref['mountpoint']) self.db.volume_detached(context, volume_id) defer.returnValue(True) -- cgit From 7e66ee636910763630fcf5e6ff23848389713c81 Mon Sep 17 00:00:00 2001 From: mdietz Date: Mon, 4 Oct 2010 17:52:08 +0000 Subject: Accidentally renamed volume related stuff --- nova/compute/manager.py | 4 ++-- nova/db/sqlalchemy/api.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/nova/compute/manager.py b/nova/compute/manager.py index fb0876578..131fac406 100644 --- a/nova/compute/manager.py +++ b/nova/compute/manager.py @@ -174,7 +174,7 @@ class ComputeManager(manager.Manager): instance_ref = self.db.instance_get(context, instance_id) dev_path = yield self.volume_manager.setup_compute_volume(context, volume_id) - yield self.driver.attach_volume(instance_ref['internal_id'], + yield self.driver.attach_volume(instance_ref['ec2_id'], dev_path, mountpoint) self.db.volume_attached(context, volume_id, instance_id, mountpoint) @@ -189,7 +189,7 @@ class ComputeManager(manager.Manager): volume_id) instance_ref = self.db.instance_get(context, instance_id) volume_ref = self.db.volume_get(context, volume_id) - yield self.driver.detach_volume(instance_ref['internal_id'], + yield self.driver.detach_volume(instance_ref['ec2_id'], volume_ref['mountpoint']) self.db.volume_detached(context, volume_id) defer.returnValue(True) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 6bf020aee..f79473c00 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -463,8 +463,8 @@ def instance_get_by_internal_id(context, internal_id): def instance_internal_id_exists(context, internal_id, session=None): if not session: session = get_session() - return session.query(exists().where(models.Instance.id==internal_id) - ).one()[0] + return session.query(exists().where + (models.Instance.internal_id==internal_id)).one()[0] def instance_get_fixed_address(_context, instance_id): -- cgit From 2a8e4a3e818f1d279a886e2e5f5ae49f3de26a4d Mon Sep 17 00:00:00 2001 From: Michael Gundlach Date: Mon, 4 Oct 2010 14:26:55 -0400 Subject: Revert r312 --- nova/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nova/utils.py b/nova/utils.py index 86ff3d22e..5f64b13c4 100644 --- a/nova/utils.py +++ b/nova/utils.py @@ -126,9 +126,8 @@ def runthis(prompt, cmd, check_exit_code = True): def generate_uid(topic, size=8): - #TODO(gundlach): we want internal ids to just be ints now. i just dropped - #off a topic prefix, so what have I broken? - return random.randint(0, 2**64-1) + return '%s-%s' % (topic, ''.join([random.choice('01234567890abcdefghijklmnopqrstuvwxyz') for x in xrange(size)])) + def generate_mac(): -- cgit From 3fe309b6f1e8a592d7b2948f4c1cdc51a62d0ff4 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 4 Oct 2010 21:01:31 +0200 Subject: Add pylint thingamajig for startService (name defined by Twisted). --- nova/service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nova/service.py b/nova/service.py index dadef3c48..115e0ff32 100644 --- a/nova/service.py +++ b/nova/service.py @@ -57,7 +57,7 @@ class Service(object, service.Service): self.saved_args, self.saved_kwargs = args, kwargs - def startService(self): + def startService(self): # pylint: disable-msg C0103 manager_class = utils.import_class(self.manager_class_name) self.manager = manager_class(host=self.host, *self.saved_args, **self.saved_kwargs) -- cgit From 8fb9f78a313a43f333d20c7cc600a5085eb68915 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 4 Oct 2010 21:53:22 +0200 Subject: Replace the embarrasingly crude string based tests for to_xml with some more sensible ElementTree based stuff. --- nova/tests/virt_unittest.py | 63 ++++++++++++++++++++++++++++++--------------- 1 file changed, 42 insertions(+), 21 deletions(-) diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py index 2aab16809..998cc07db 100644 --- a/nova/tests/virt_unittest.py +++ b/nova/tests/virt_unittest.py @@ -14,36 +14,49 @@ # License for the specific language governing permissions and limitations # under the License. +from xml.etree.ElementTree import fromstring as parseXml + from nova import flags from nova import test from nova.virt import libvirt_conn FLAGS = flags.FLAGS - class LibvirtConnTestCase(test.TrialTestCase): def test_get_uri_and_template(self): - class MockDataModel(object): - def __init__(self): - self.datamodel = { 'name' : 'i-cafebabe', - 'memory_kb' : '1024000', - 'basepath' : '/some/path', - 'bridge_name' : 'br100', - 'mac_address' : '02:12:34:46:56:67', - 'vcpus' : 2 } + instance = { 'name' : 'i-cafebabe', + 'id' : 'i-cafebabe', + 'memory_kb' : '1024000', + 'basepath' : '/some/path', + 'bridge_name' : 'br100', + 'mac_address' : '02:12:34:46:56:67', + 'vcpus' : 2, + 'project_id' : 'fake', + 'ip_address' : '10.11.12.13', + 'bridge' : 'br101', + 'instance_type' : 'm1.small'} type_uri_map = { 'qemu' : ('qemu:///system', - [lambda s: '' in s, - lambda s: 'type>hvm/usr/bin/kvm' not in s]), + [(lambda t: t.find('.').tag, 'domain'), + (lambda t: t.find('.').get('type'), 'qemu'), + (lambda t: t.find('./os/type').text, 'hvm'), + (lambda t: t.find('./devices/emulator'), None)]), 'kvm' : ('qemu:///system', - [lambda s: '' in s, - lambda s: 'type>hvm/usr/bin/qemu<' not in s]), + [(lambda t: t.find('.').tag, 'domain'), + (lambda t: t.find('.').get('type'), 'kvm'), + (lambda t: t.find('./os/type').text, 'hvm'), + (lambda t: t.find('./devices/emulator'), None)]), 'uml' : ('uml:///system', - [lambda s: '' in s, - lambda s: 'type>uml Date: Mon, 4 Oct 2010 21:58:22 +0200 Subject: Merge security group related changes from lp:~anso/nova/deploy --- nova/api/ec2/cloud.py | 31 ++++++++++--- nova/db/sqlalchemy/api.py | 105 +++++++++++++++++++++++++++++++++----------- nova/tests/virt_unittest.py | 33 +++++++++----- nova/virt/libvirt_conn.py | 39 ++++++++++++++-- 4 files changed, 162 insertions(+), 46 deletions(-) diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 839b84b4e..4cd4c78ae 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -327,6 +327,26 @@ class CloudController(object): return values + + def _security_group_rule_exists(self, security_group, values): + """Indicates whether the specified rule values are already + defined in the given security group. + """ + for rule in security_group.rules: + if 'group_id' in values: + if rule['group_id'] == values['group_id']: + return True + else: + is_duplicate = True + for key in ('cidr', 'from_port', 'to_port', 'protocol'): + if rule[key] != values[key]: + is_duplicate = False + break + if is_duplicate: + return True + return False + + def revoke_security_group_ingress(self, context, group_name, **kwargs): self._ensure_default_security_group(context) security_group = db.security_group_get_by_name(context, @@ -348,9 +368,6 @@ class CloudController(object): return True raise exception.ApiError("No rule for the specified parameters.") - # TODO(soren): Dupe detection. Adding the same rule twice actually - # adds the same rule twice to the rule set, which is - # pointless. # TODO(soren): This has only been tested with Boto as the client. # Unfortunately, it seems Boto is using an old API # for these operations, so support for newer API versions @@ -364,6 +381,10 @@ class CloudController(object): values = self._authorize_revoke_rule_args_to_dict(context, **kwargs) values['parent_group_id'] = security_group.id + if self._security_group_rule_exists(security_group, values): + raise exception.ApiError('This rule already exists in group %s' % + group_name) + security_group_rule = db.security_group_rule_create(context, values) self._trigger_refresh_security_group(security_group) @@ -709,7 +730,7 @@ class CloudController(object): 'description' : 'default', 'user_id' : context.user.id, 'project_id' : context.project.id } - group = db.security_group_create({}, values) + group = db.security_group_create(context, values) def run_instances(self, context, **kwargs): instance_type = kwargs.get('instance_type', 'm1.small') @@ -797,7 +818,7 @@ class CloudController(object): inst_id = instance_ref['id'] for security_group_id in security_groups: - db.instance_add_security_group(context, inst_id, + db.instance_add_security_group(context.admin(), inst_id, security_group_id) inst = {} diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index d395b7e2c..bc5ef5a9b 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -572,11 +572,13 @@ def instance_get(context, instance_id, session=None): if is_admin_context(context): result = session.query(models.Instance + ).options(joinedload('security_groups') ).filter_by(id=instance_id ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Instance + ).options(joinedload('security_groups') ).filter_by(project_id=context.project.id ).filter_by(id=instance_id ).filter_by(deleted=False @@ -592,6 +594,7 @@ def instance_get_all(context): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') + ).options(joinedload('security_groups') ).filter_by(deleted=can_read_deleted(context) ).all() @@ -601,6 +604,7 @@ def instance_get_all_by_user(context, user_id): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') + ).options(joinedload('security_groups') ).filter_by(deleted=can_read_deleted(context) ).filter_by(user_id=user_id ).all() @@ -613,6 +617,7 @@ def instance_get_all_by_project(context, project_id): session = get_session() return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') + ).options(joinedload('security_groups') ).filter_by(project_id=project_id ).filter_by(deleted=can_read_deleted(context) ).all() @@ -625,12 +630,14 @@ def instance_get_all_by_reservation(context, reservation_id): if is_admin_context(context): return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') + ).options(joinedload('security_groups') ).filter_by(reservation_id=reservation_id ).filter_by(deleted=can_read_deleted(context) ).all() elif is_user_context(context): return session.query(models.Instance ).options(joinedload_all('fixed_ip.floating_ips') + ).options(joinedload('security_groups') ).filter_by(project_id=context.project.id ).filter_by(reservation_id=reservation_id ).filter_by(deleted=False @@ -643,11 +650,13 @@ def instance_get_by_ec2_id(context, ec2_id): if is_admin_context(context): result = session.query(models.Instance + ).options(joinedload('security_groups') ).filter_by(ec2_id=ec2_id ).filter_by(deleted=can_read_deleted(context) ).first() elif is_user_context(context): result = session.query(models.Instance + ).options(joinedload('security_groups') ).filter_by(project_id=context.project.id ).filter_by(ec2_id=ec2_id ).filter_by(deleted=False @@ -721,9 +730,10 @@ def instance_add_security_group(context, instance_id, security_group_id): """Associate the given security group with the given instance""" session = get_session() with session.begin(): - instance_ref = models.Instance.find(instance_id, session=session) - security_group_ref = models.SecurityGroup.find(security_group_id, - session=session) + instance_ref = instance_get(context, instance_id, session=session) + security_group_ref = security_group_get(context, + security_group_id, + session=session) instance_ref.security_groups += [security_group_ref] instance_ref.save(session=session) @@ -1202,6 +1212,7 @@ def volume_get(context, volume_id, session=None): @require_admin_context def volume_get_all(context): + session = get_session() return session.query(models.Volume ).filter_by(deleted=can_read_deleted(context) ).all() @@ -1292,27 +1303,39 @@ def volume_update(context, volume_id, values): ################### -def security_group_get_all(_context): +@require_context +def security_group_get_all(context): session = get_session() return session.query(models.SecurityGroup - ).filter_by(deleted=False + ).filter_by(deleted=can_read_deleted(context) ).options(joinedload_all('rules') ).all() -def security_group_get(_context, security_group_id): - session = get_session() - result = session.query(models.SecurityGroup - ).filter_by(deleted=False - ).filter_by(id=security_group_id - ).options(joinedload_all('rules') - ).first() +@require_context +def security_group_get(context, security_group_id, session=None): + if not session: + session = get_session() + if is_admin_context(context): + result = session.query(models.SecurityGroup + ).filter_by(deleted=can_read_deleted(context), + ).filter_by(id=security_group_id + ).options(joinedload_all('rules') + ).first() + else: + result = session.query(models.SecurityGroup + ).filter_by(deleted=False + ).filter_by(id=security_group_id + ).filter_by(project_id=context.project_id + ).options(joinedload_all('rules') + ).first() if not result: raise exception.NotFound("No secuity group with id %s" % security_group_id) return result +@require_context def security_group_get_by_name(context, project_id, group_name): session = get_session() result = session.query(models.SecurityGroup @@ -1329,7 +1352,8 @@ def security_group_get_by_name(context, project_id, group_name): return result -def security_group_get_by_project(_context, project_id): +@require_context +def security_group_get_by_project(context, project_id): session = get_session() return session.query(models.SecurityGroup ).filter_by(project_id=project_id @@ -1338,7 +1362,8 @@ def security_group_get_by_project(_context, project_id): ).all() -def security_group_get_by_instance(_context, instance_id): +@require_context +def security_group_get_by_instance(context, instance_id): session = get_session() return session.query(models.SecurityGroup ).filter_by(deleted=False @@ -1349,15 +1374,17 @@ def security_group_get_by_instance(_context, instance_id): ).all() -def security_group_exists(_context, project_id, group_name): +@require_context +def security_group_exists(context, project_id, group_name): try: - group = security_group_get_by_name(_context, project_id, group_name) + group = security_group_get_by_name(context, project_id, group_name) return group != None except exception.NotFound: return False -def security_group_create(_context, values): +@require_context +def security_group_create(context, values): security_group_ref = models.SecurityGroup() # FIXME(devcamcar): Unless I do this, rules fails with lazy load exception # once save() is called. This will get cleaned up in next orm pass. @@ -1368,7 +1395,8 @@ def security_group_create(_context, values): return security_group_ref -def security_group_destroy(_context, security_group_id): +@require_context +def security_group_destroy(context, security_group_id): session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? @@ -1378,35 +1406,62 @@ def security_group_destroy(_context, security_group_id): 'where group_id=:id', {'id': security_group_id}) -def security_group_destroy_all(_context): - session = get_session() +@require_context +def security_group_destroy_all(context, session=None): + if not session: + session = get_session() with session.begin(): # TODO(vish): do we have to use sql here? session.execute('update security_groups set deleted=1') session.execute('update security_group_rules set deleted=1') + ################### -def security_group_rule_create(_context, values): +@require_context +def security_group_rule_get(context, security_group_rule_id, session=None): + if not session: + session = get_session() + if is_admin_context(context): + result = session.query(models.SecurityGroupIngressRule + ).filter_by(deleted=can_read_deleted(context) + ).filter_by(id=security_group_rule_id + ).first() + else: + # TODO(vish): Join to group and check for project_id + result = session.query(models.SecurityGroupIngressRule + ).filter_by(deleted=False + ).filter_by(id=security_group_rule_id + ).first() + if not result: + raise exception.NotFound("No secuity group rule with id %s" % + security_group_rule_id) + return result + + +@require_context +def security_group_rule_create(context, values): security_group_rule_ref = models.SecurityGroupIngressRule() for (key, value) in values.iteritems(): security_group_rule_ref[key] = value security_group_rule_ref.save() return security_group_rule_ref -def security_group_rule_destroy(_context, security_group_rule_id): +@require_context +def security_group_rule_destroy(context, security_group_rule_id): session = get_session() with session.begin(): - model = models.SecurityGroupIngressRule - security_group_rule = model.find(security_group_rule_id, - session=session) + security_group_rule = security_group_rule_get(context, + security_group_rule_id, + session=session) security_group_rule.delete(session=session) ################### +@require_admin_context def host_get_networks(context, host): session = get_session() with session.begin(): diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py index 7fa8e52ac..8b0de6c29 100644 --- a/nova/tests/virt_unittest.py +++ b/nova/tests/virt_unittest.py @@ -19,7 +19,9 @@ from xml.dom.minidom import parseString from nova import db from nova import flags from nova import test +from nova.api import context from nova.api.ec2 import cloud +from nova.auth import manager from nova.virt import libvirt_conn FLAGS = flags.FLAGS @@ -83,17 +85,20 @@ class NWFilterTestCase(test.TrialTestCase): class Mock(object): pass - self.context = Mock() - self.context.user = Mock() - self.context.user.id = 'fake' - self.context.user.is_superuser = lambda:True - self.context.project = Mock() - self.context.project.id = 'fake' + self.manager = manager.AuthManager() + self.user = self.manager.create_user('fake', 'fake', 'fake', admin=True) + self.project = self.manager.create_project('fake', 'fake', 'fake') + self.context = context.APIRequestContext(self.user, self.project) self.fake_libvirt_connection = Mock() self.fw = libvirt_conn.NWFilterFirewall(self.fake_libvirt_connection) + def tearDown(self): + self.manager.delete_project(self.project) + self.manager.delete_user(self.user) + + def test_cidr_rule_nwfilter_xml(self): cloud_controller = cloud.CloudController() cloud_controller.create_security_group(self.context, @@ -107,7 +112,9 @@ class NWFilterTestCase(test.TrialTestCase): cidr_ip='0.0.0.0/0') - security_group = db.security_group_get_by_name({}, 'fake', 'testgroup') + security_group = db.security_group_get_by_name(self.context, + 'fake', + 'testgroup') xml = self.fw.security_group_to_nwfilter_xml(security_group.id) @@ -126,7 +133,8 @@ class NWFilterTestCase(test.TrialTestCase): ip_conditions = rules[0].getElementsByTagName('tcp') self.assertEqual(len(ip_conditions), 1) - self.assertEqual(ip_conditions[0].getAttribute('srcipaddr'), '0.0.0.0/0') + self.assertEqual(ip_conditions[0].getAttribute('srcipaddr'), '0.0.0.0') + self.assertEqual(ip_conditions[0].getAttribute('srcipmask'), '0.0.0.0') self.assertEqual(ip_conditions[0].getAttribute('dstportstart'), '80') self.assertEqual(ip_conditions[0].getAttribute('dstportend'), '81') @@ -150,7 +158,7 @@ class NWFilterTestCase(test.TrialTestCase): ip_protocol='tcp', cidr_ip='0.0.0.0/0') - return db.security_group_get_by_name({}, 'fake', 'testgroup') + return db.security_group_get_by_name(self.context, 'fake', 'testgroup') def test_creates_base_rule_first(self): # These come pre-defined by libvirt @@ -180,7 +188,8 @@ class NWFilterTestCase(test.TrialTestCase): self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock - instance_ref = db.instance_create({}, {'user_id': 'fake', + instance_ref = db.instance_create(self.context, + {'user_id': 'fake', 'project_id': 'fake'}) inst_id = instance_ref['id'] @@ -195,8 +204,8 @@ class NWFilterTestCase(test.TrialTestCase): self.security_group = self.setup_and_return_security_group() - db.instance_add_security_group({}, inst_id, self.security_group.id) - instance = db.instance_get({}, inst_id) + db.instance_add_security_group(self.context, inst_id, self.security_group.id) + instance = db.instance_get(self.context, inst_id) d = self.fw.setup_nwfilters_for_instance(instance) d.addCallback(_ensure_all_called) diff --git a/nova/virt/libvirt_conn.py b/nova/virt/libvirt_conn.py index 9d889cf29..319f7d2af 100644 --- a/nova/virt/libvirt_conn.py +++ b/nova/virt/libvirt_conn.py @@ -25,6 +25,7 @@ import logging import os import shutil +import IPy from twisted.internet import defer from twisted.internet import task from twisted.internet import threads @@ -34,6 +35,7 @@ from nova import exception from nova import flags from nova import process from nova import utils +#from nova.api import context from nova.auth import manager from nova.compute import disk from nova.compute import instance_types @@ -61,6 +63,9 @@ flags.DEFINE_string('libvirt_uri', '', 'Override the default libvirt URI (which is dependent' ' on libvirt_type)') +flags.DEFINE_bool('allow_project_net_traffic', + True, + 'Whether to allow in project network traffic') def get_connection(read_only): @@ -135,7 +140,7 @@ class LibvirtConnection(object): d.addCallback(lambda _: self._cleanup(instance)) # FIXME: What does this comment mean? # TODO(termie): short-circuit me for tests - # WE'LL save this for when we do shutdown, + # WE'LL save this for when we do shutdown, # instead of destroy - but destroy returns immediately timer = task.LoopingCall(f=None) def _wait_for_shutdown(): @@ -550,6 +555,16 @@ class NWFilterFirewall(object): return retval + def nova_project_filter(self, project, net, mask): + retval = "" % project + for protocol in ['tcp', 'udp', 'icmp']: + retval += """ + <%s srcipaddr='%s' srcipmask='%s' /> + """ % (protocol, net, mask) + retval += '' + return retval + + def _define_filter(self, xml): if callable(xml): xml = xml() @@ -557,6 +572,11 @@ class NWFilterFirewall(object): return d + @staticmethod + def _get_net_and_mask(cidr): + net = IPy.IP(cidr) + return str(net.net()), str(net.netmask()) + @defer.inlineCallbacks def setup_nwfilters_for_instance(self, instance): """ @@ -570,9 +590,19 @@ class NWFilterFirewall(object): yield self._define_filter(self.nova_dhcp_filter) yield self._define_filter(self.nova_base_filter) - nwfilter_xml = ("\n" + + nwfilter_xml = ("\n" + " \n" - ) % instance['name'] + ) % instance['name'] + + if FLAGS.allow_project_net_traffic: + network_ref = db.project_get_network({}, instance['project_id']) + net, mask = self._get_net_and_mask(network_ref['cidr']) + project_filter = self.nova_project_filter(instance['project_id'], + net, mask) + yield self._define_filter(project_filter) + + nwfilter_xml += (" \n" + ) % instance['project_id'] for security_group in instance.security_groups: yield self.ensure_security_group_filter(security_group['id']) @@ -595,7 +625,8 @@ class NWFilterFirewall(object): for rule in security_group.rules: rule_xml += "" if rule.cidr: - rule_xml += "<%s srcipaddr='%s' " % (rule.protocol, rule.cidr) + net, mask = self._get_net_and_mask(rule.cidr) + rule_xml += "<%s srcipaddr='%s' srcipmask='%s' " % (rule.protocol, net, mask) if rule.protocol in ['tcp', 'udp']: rule_xml += "dstportstart='%s' dstportend='%s' " % \ (rule.from_port, rule.to_port) -- cgit From dd0f365c98ae68afff9a0fbc75e7d5b88499b282 Mon Sep 17 00:00:00 2001 From: Michael Gundlach Date: Mon, 4 Oct 2010 16:39:05 -0400 Subject: Fix broken unit tests --- nova/api/ec2/cloud.py | 24 ++++++++++++++---------- nova/db/sqlalchemy/api.py | 3 +++ nova/tests/cloud_unittest.py | 3 ++- nova/utils.py | 11 +++++++++-- 4 files changed, 28 insertions(+), 13 deletions(-) diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 2fec49da8..7f5f4c4e9 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -72,6 +72,20 @@ def _gen_key(context, user_id, key_name): return {'private_key': private_key, 'fingerprint': fingerprint} +def ec2_id_to_internal_id(ec2_id): + """Convert an ec2 ID (i-[base 36 number]) to an internal id (int)""" + return int(ec2_id[2:], 36) + + +def internal_id_to_ec2_id(internal_id): + """Convert an internal ID (int) to an ec2 ID (i-[base 36 number])""" + digits = [] + while internal_id != 0: + internal_id, remainder = divmod(internal_id, 36) + digits.append('0123456789abcdefghijklmnopqrstuvwxyz'[remainder]) + return "i-%s" % ''.join(reversed(digits)) + + class CloudController(object): """ CloudController provides the critical dispatch between inbound API calls through the endpoint and messages @@ -113,16 +127,6 @@ class CloudController(object): result[key] = [line] return result - def ec2_id_to_internal_id(ec2_id): - """Convert an ec2 ID (i-[base 36 number]) to an internal id (int)""" - # TODO(gundlach): Maybe this should actually work? - return ec2_id[2:] - - def internal_id_to_ec2_id(internal_id): - """Convert an internal ID (int) to an ec2 ID (i-[base 36 number])""" - # TODO(gundlach): Yo maybe this should actually convert to base 36 - return "i-%d" % internal_id - def get_metadata(self, address): instance_ref = db.fixed_ip_get_instance(None, address) if instance_ref is None: diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 6dd6b545a..9d43da3ba 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -377,6 +377,9 @@ def fixed_ip_update(_context, address, values): ################### +#TODO(gundlach): instance_create and volume_create are nearly identical +#and should be refactored. I expect there are other copy-and-paste +#functions between the two of them as well. def instance_create(_context, values): instance_ref = models.Instance() for (key, value) in values.iteritems(): diff --git a/nova/tests/cloud_unittest.py b/nova/tests/cloud_unittest.py index d316db153..615e589cf 100644 --- a/nova/tests/cloud_unittest.py +++ b/nova/tests/cloud_unittest.py @@ -236,7 +236,8 @@ class CloudTestCase(test.TrialTestCase): def test_update_of_instance_display_fields(self): inst = db.instance_create({}, {}) - self.cloud.update_instance(self.context, inst['internal_id'], + ec2_id = cloud.internal_id_to_ec2_id(inst['internal_id']) + self.cloud.update_instance(self.context, ec2_id, display_name='c00l 1m4g3') inst = db.instance_get({}, inst['id']) self.assertEqual('c00l 1m4g3', inst['display_name']) diff --git a/nova/utils.py b/nova/utils.py index 5f64b13c4..b1699bda8 100644 --- a/nova/utils.py +++ b/nova/utils.py @@ -126,8 +126,15 @@ def runthis(prompt, cmd, check_exit_code = True): def generate_uid(topic, size=8): - return '%s-%s' % (topic, ''.join([random.choice('01234567890abcdefghijklmnopqrstuvwxyz') for x in xrange(size)])) - + if topic == "i": + # Instances have integer internal ids. + #TODO(gundlach): We should make this more than 32 bits, but we need to + #figure out how to make the DB happy with 64 bit integers. + return random.randint(0, 2**32-1) + else: + characters = '01234567890abcdefghijklmnopqrstuvwxyz' + choices = [random.choice(characters) for x in xrange(size)] + return '%s-%s' % (topic, ''.join(choices)) def generate_mac(): -- cgit From 32bd6c198a4ed96768649f58628e22fb25a95855 Mon Sep 17 00:00:00 2001 From: "jaypipes@gmail.com" <> Date: Mon, 4 Oct 2010 16:47:08 -0400 Subject: Adds ParallaxClient and TellerClient plumbing for GlanceImageService. Adds stubs FakeParallaxClient and unit tests for LocalImageService and GlanceImageService. --- nova/image/service.py | 99 ++++++++++++++++++++++-- nova/tests/api/rackspace/fakes.py | 74 ++++++++++++++++++ nova/tests/api/rackspace/test_images.py | 132 ++++++++++++++++---------------- 3 files changed, 230 insertions(+), 75 deletions(-) diff --git a/nova/image/service.py b/nova/image/service.py index 4bceab6ee..3b6d3b6e3 100644 --- a/nova/image/service.py +++ b/nova/image/service.py @@ -16,9 +16,14 @@ # under the License. import cPickle as pickle +import httplib +import json import os.path import random import string +import urlparse + +import webob.exc from nova import utils from nova import flags @@ -27,11 +32,11 @@ from nova import flags FLAGS = flags.FLAGS -flags.DEFINE_string('glance_teller_address', '127.0.0.1', +flags.DEFINE_string('glance_teller_address', 'http://127.0.0.1', 'IP address or URL where Glance\'s Teller service resides') flags.DEFINE_string('glance_teller_port', '9191', 'Port for Glance\'s Teller service') -flags.DEFINE_string('glance_parallax_address', '127.0.0.1', +flags.DEFINE_string('glance_parallax_address', 'http://127.0.0.1', 'IP address or URL where Glance\'s Parallax service resides') flags.DEFINE_string('glance_parallax_port', '9191', 'Port for Glance\'s Parallax service') @@ -80,21 +85,101 @@ class BaseImageService(object): raise NotImplementedError +class TellerClient(object): + + def __init__(self): + self.address = FLAGS.glance_teller_address + self.port = FLAGS.glance_teller_port + url = urlparse.urlparse(self.address) + self.netloc = url.netloc + self.connection_type = {'http': httplib.HTTPConnection, + 'https': httplib.HTTPSConnection}[url.scheme] + + +class ParallaxClient(object): + + def __init__(self): + self.address = FLAGS.glance_parallax_address + self.port = FLAGS.glance_parallax_port + url = urlparse.urlparse(self.address) + self.netloc = url.netloc + self.connection_type = {'http': httplib.HTTPConnection, + 'https': httplib.HTTPSConnection}[url.scheme] + + def get_images(self): + """ + Returns a list of image data mappings from Parallax + """ + try: + c = self.connection_type(self.netloc, self.port) + c.request("GET", "images") + res = c.getresponse() + if res.status == 200: + data = json.loads(res.read()) + return data + else: + # TODO(jaypipes): return or raise HTTP error? + return [] + finally: + c.close() + + def get_image_metadata(self, image_id): + """ + Returns a mapping of image metadata from Parallax + """ + try: + c = self.connection_type(self.netloc, self.port) + c.request("GET", "images/%s" % image_id) + res = c.getresponse() + if res.status == 200: + data = json.loads(res.read()) + return data + else: + # TODO(jaypipes): return or raise HTTP error? + return [] + finally: + c.close() + + def add_image_metadata(self, image_metadata): + """ + Tells parallax about an image's metadata + """ + pass + + def update_image_metadata(self, image_id, image_metadata): + """ + Updates Parallax's information about an image + """ + pass + + def delete_image_metadata(self, image_id): + """ + Deletes Parallax's information about an image + """ + pass + + class GlanceImageService(BaseImageService): """Provides storage and retrieval of disk image objects within Glance.""" + def __init__(self): + self.teller = TellerClient() + self.parallax = ParallaxClient() + def index(self): """ Calls out to Parallax for a list of images available """ - raise NotImplementedError + images = self.parallax.get_images() + return images def show(self, id): """ Returns a dict containing image data for the given opaque image id. """ - raise NotImplementedError + image = self.parallax.get_image_metadata(id) + return image def create(self, data): """ @@ -103,7 +188,7 @@ class GlanceImageService(BaseImageService): :raises AlreadyExists if the image already exist. """ - raise NotImplementedError + return self.parallax.add_image_metadata(data) def update(self, image_id, data): """Replace the contents of the given image with the new data. @@ -111,7 +196,7 @@ class GlanceImageService(BaseImageService): :raises NotFound if the image does not exist. """ - raise NotImplementedError + self.parallax.update_image_metadata(image_id, data) def delete(self, image_id): """ @@ -120,7 +205,7 @@ class GlanceImageService(BaseImageService): :raises NotFound if the image does not exist. """ - raise NotImplementedError + self.parallax.delete_image_metadata(image_id) def delete_all(self): """ diff --git a/nova/tests/api/rackspace/fakes.py b/nova/tests/api/rackspace/fakes.py index 2c4447920..3765b859e 100644 --- a/nova/tests/api/rackspace/fakes.py +++ b/nova/tests/api/rackspace/fakes.py @@ -1,5 +1,24 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + import datetime import json +import random +import string import webob import webob.dec @@ -7,6 +26,7 @@ import webob.dec from nova import auth from nova import utils from nova import flags +from nova import exception as exc import nova.api.rackspace.auth import nova.api.rackspace._id_translator from nova.image import service @@ -105,6 +125,60 @@ def stub_out_networking(stubs): FLAGS.FAKE_subdomain = 'rs' +def stub_out_glance(stubs): + + class FakeParallaxClient: + + def __init__(self): + self.fixtures = {} + + def fake_get_images(self): + return self.fixtures + + def fake_get_image_metadata(self, image_id): + for k, f in self.fixtures.iteritems(): + if k == image_id: + return f + raise exc.NotFound + + def fake_add_image_metadata(self, image_data): + id = ''.join(random.choice(string.letters) for _ in range(20)) + image_data['id'] = id + self.fixtures[id] = image_data + return id + + def fake_update_image_metadata(self, image_id, image_data): + + if image_id not in self.fixtures.keys(): + raise exc.NotFound + + self.fixtures[image_id].update(image_data) + + def fake_delete_image_metadata(self, image_id): + + if image_id not in self.fixtures.keys(): + raise exc.NotFound + + del self.fixtures[image_id] + + def fake_delete_all(self): + self.fixtures = {} + + fake_parallax_client = FakeParallaxClient() + stubs.Set(nova.image.service.ParallaxClient, 'get_images', + fake_parallax_client.fake_get_images) + stubs.Set(nova.image.service.ParallaxClient, 'get_image_metadata', + fake_parallax_client.fake_get_image_metadata) + stubs.Set(nova.image.service.ParallaxClient, 'add_image_metadata', + fake_parallax_client.fake_add_image_metadata) + stubs.Set(nova.image.service.ParallaxClient, 'update_image_metadata', + fake_parallax_client.fake_update_image_metadata) + stubs.Set(nova.image.service.ParallaxClient, 'delete_image_metadata', + fake_parallax_client.fake_delete_image_metadata) + stubs.Set(nova.image.service.GlanceImageService, 'delete_all', + fake_parallax_client.fake_delete_all) + + class FakeAuthDatabase(object): data = {} diff --git a/nova/tests/api/rackspace/test_images.py b/nova/tests/api/rackspace/test_images.py index 21dad7648..669346680 100644 --- a/nova/tests/api/rackspace/test_images.py +++ b/nova/tests/api/rackspace/test_images.py @@ -22,79 +22,74 @@ import stubout from nova import utils from nova.api.rackspace import images +from nova.tests.api.rackspace import fakes -#{ Fixtures +class BaseImageServiceTests(): + """Tasks to test for all image services""" -fixture_images = [ - { - 'name': 'image #1', - 'updated': None, - 'created': None, - 'status': None, - 'serverId': None, - 'progress': None}, - { - 'name': 'image #2', - 'updated': None, - 'created': None, - 'status': None, - 'serverId': None, - 'progress': None}, - { - 'name': 'image #3', - 'updated': None, - 'created': None, - 'status': None, - 'serverId': None, - 'progress': None}] + def test_create(self): + fixture = {'name': 'test image', + 'updated': None, + 'created': None, + 'status': None, + 'serverId': None, + 'progress': None} -#} + num_images = len(self.service.index()) + id = self.service.create(fixture) -class BaseImageServiceTests(): - - """Tasks to test for all image services""" + self.assertNotEquals(None, id) + self.assertEquals(num_images + 1, len(self.service.index())) - def test_create_and_index(self): - for i in fixture_images: - self.service.create(i) + def test_update(self): - self.assertEquals(len(fixture_images), len(self.service.index())) + fixture = {'name': 'test image', + 'updated': None, + 'created': None, + 'status': None, + 'serverId': None, + 'progress': None} - def test_create_and_update(self): - ids = {} - temp = 0 - for i in fixture_images: - ids[self.service.create(i)] = temp - temp += 1 + id = self.service.create(fixture) - self.assertEquals(len(fixture_images), len(self.service.index())) + fixture['status'] = 'in progress' + + self.service.update(id, fixture) + new_image_data = self.service.show(id) + self.assertEquals('in progress', new_image_data['status']) - for image_id, num in ids.iteritems(): - new_data = fixture_images[num] - new_data['updated'] = 'test' + str(num) - self.service.update(image_id, new_data) + def test_delete(self): - images = self.service.index() + fixtures = [ + {'name': 'test image 1', + 'updated': None, + 'created': None, + 'status': None, + 'serverId': None, + 'progress': None}, + {'name': 'test image 2', + 'updated': None, + 'created': None, + 'status': None, + 'serverId': None, + 'progress': None}] - for i in images: - self.assertEquals('test' + str(ids[i['id']]), - i['updated']) + ids = [] + for fixture in fixtures: + new_id = self.service.create(fixture) + ids.append(new_id) - def test_create_and_show(self): - ids = {} - temp = 0 - for i in fixture_images: - ids[self.service.create(i)] = temp - temp += 1 + num_images = len(self.service.index()) + self.assertEquals(2, num_images) + + self.service.delete(ids[0]) - for i in fixture_images: - image = self.service.show(i['id']) - index = ids[i['id']] - self.assertEquals(image, fixture_images[index]) + num_images = len(self.service.index()) + self.assertEquals(1, num_images) class LocalImageServiceTest(unittest.TestCase, @@ -111,15 +106,16 @@ class LocalImageServiceTest(unittest.TestCase, self.stubs.UnsetAll() -#class GlanceImageServiceTest(unittest.TestCase, -# BaseImageServiceTests): -# -# """Tests the local image service""" -# -# def setUp(self): -# self.stubs = stubout.StubOutForTesting() -# self.service = utils.import_object('nova.image.service.GlanceImageService') -# -# def tearDown(self): -# self.service.delete_all() -# self.stubs.UnsetAll() +class GlanceImageServiceTest(unittest.TestCase, + BaseImageServiceTests): + + """Tests the local image service""" + + def setUp(self): + self.stubs = stubout.StubOutForTesting() + fakes.stub_out_glance(self.stubs) + self.service = utils.import_object('nova.image.service.GlanceImageService') + + def tearDown(self): + self.service.delete_all() + self.stubs.UnsetAll() -- cgit From 6bdbb567f1a9e0a8b980ff916183d47375fe11bf Mon Sep 17 00:00:00 2001 From: mdietz Date: Mon, 4 Oct 2010 21:20:33 +0000 Subject: One last bad line --- nova/api/ec2/cloud.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 3f440c85c..7f5f4c4e9 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -723,7 +723,7 @@ class CloudController(object): changes[field] = kwargs[field] if changes: db_context = {} - internal_id = self.ec2_id_to_internal_id(ec2_id) + internal_id = ec2_id_to_internal_id(ec2_id) inst = db.instance_get_by_internal_id(db_context, internal_id) db.instance_update(db_context, inst['id'], kwargs) return True -- cgit From bf727292794026694c37b84201172b933b41ad2d Mon Sep 17 00:00:00 2001 From: "jaypipes@gmail.com" <> Date: Mon, 4 Oct 2010 17:32:01 -0400 Subject: Update Parallax default port number to match Glance --- nova/image/service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nova/image/service.py b/nova/image/service.py index 3b6d3b6e3..66be669dd 100644 --- a/nova/image/service.py +++ b/nova/image/service.py @@ -38,7 +38,7 @@ flags.DEFINE_string('glance_teller_port', '9191', 'Port for Glance\'s Teller service') flags.DEFINE_string('glance_parallax_address', 'http://127.0.0.1', 'IP address or URL where Glance\'s Parallax service resides') -flags.DEFINE_string('glance_parallax_port', '9191', +flags.DEFINE_string('glance_parallax_port', '9292', 'Port for Glance\'s Parallax service') -- cgit From a374efd4cc3d27c9b5389009818e45efe2f35b12 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 5 Oct 2010 10:06:54 +0200 Subject: Run the virt tests by default. --- run_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/run_tests.py b/run_tests.py index 4121f4c06..fa1e6f15b 100644 --- a/run_tests.py +++ b/run_tests.py @@ -63,6 +63,7 @@ from nova.tests.rpc_unittest import * from nova.tests.scheduler_unittest import * from nova.tests.service_unittest import * from nova.tests.validator_unittest import * +from nova.tests.virt_unittest import * from nova.tests.volume_unittest import * -- cgit From e0b255140a2bb7125bde89c6732d440cef37096b Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 5 Oct 2010 10:07:37 +0200 Subject: Create and destroy user appropriately. Remove security group related tests (since they haven't been merged yet). --- nova/tests/virt_unittest.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/nova/tests/virt_unittest.py b/nova/tests/virt_unittest.py index 998cc07db..730928f39 100644 --- a/nova/tests/virt_unittest.py +++ b/nova/tests/virt_unittest.py @@ -18,11 +18,20 @@ from xml.etree.ElementTree import fromstring as parseXml from nova import flags from nova import test +from nova.auth import manager +# Needed to get FLAGS.instances_path defined: +from nova.compute import manager as compute_manager from nova.virt import libvirt_conn FLAGS = flags.FLAGS class LibvirtConnTestCase(test.TrialTestCase): + def setUp(self): + self.manager = manager.AuthManager() + self.user = self.manager.create_user('fake', 'fake', 'fake', admin=True) + self.project = self.manager.create_project('fake', 'fake', 'fake') + FLAGS.instances_path = '' + def test_get_uri_and_template(self): instance = { 'name' : 'i-cafebabe', 'id' : 'i-cafebabe', @@ -51,12 +60,6 @@ class LibvirtConnTestCase(test.TrialTestCase): (lambda t: t.find('.').get('type'), 'uml'), (lambda t: t.find('./os/type').text, 'uml')]), } - common_checks = [(lambda t: \ - t.find('./devices/interface/filterref/parameter') \ - .get('name'), 'IP'), - (lambda t: \ - t.find('./devices/interface/filterref/parameter') \ - .get('value'), '10.11.12.13')] for (libvirt_type,(expected_uri, checks)) in type_uri_map.iteritems(): FLAGS.libvirt_type = libvirt_type @@ -72,11 +75,6 @@ class LibvirtConnTestCase(test.TrialTestCase): expected_result, '%s failed check %d' % (xml, i)) - for i, (check, expected_result) in enumerate(common_checks): - self.assertEqual(check(tree), - expected_result, - '%s failed common check %d' % (xml, i)) - # Deliberately not just assigning this string to FLAGS.libvirt_uri and # checking against that later on. This way we make sure the # implementation doesn't fiddle around with the FLAGS. @@ -88,3 +86,7 @@ class LibvirtConnTestCase(test.TrialTestCase): uri, template = conn.get_uri_and_template() self.assertEquals(uri, testuri) + + def tearDown(self): + self.manager.delete_project(self.project) + self.manager.delete_user(self.user) -- cgit From b61f4ceff6ea5dbb4d9c63b9f7345c0b31785984 Mon Sep 17 00:00:00 2001 From: "jaypipes@gmail.com" <> Date: Tue, 5 Oct 2010 13:29:27 -0400 Subject: Adds unit test for calling show() on a non-existing image. Changes return from real Parallax service per sirp's recommendation for actual returned dict() values. --- nova/image/service.py | 36 ++++++++++++++++++++++++-------- nova/tests/api/rackspace/fakes.py | 2 +- nova/tests/api/rackspace/test_images.py | 20 ++++++++++++++++++ nova/tests/api/rackspace/test_servers.py | 1 + 4 files changed, 49 insertions(+), 10 deletions(-) diff --git a/nova/image/service.py b/nova/image/service.py index 66be669dd..2e570e8a4 100644 --- a/nova/image/service.py +++ b/nova/image/service.py @@ -18,6 +18,7 @@ import cPickle as pickle import httplib import json +import logging import os.path import random import string @@ -27,6 +28,7 @@ import webob.exc from nova import utils from nova import flags +from nova import exception FLAGS = flags.FLAGS @@ -55,6 +57,8 @@ class BaseImageService(object): def show(self, id): """ Returns a dict containing image data for the given opaque image id. + + :raises NotFound if the image does not exist """ raise NotImplementedError @@ -115,10 +119,12 @@ class ParallaxClient(object): c.request("GET", "images") res = c.getresponse() if res.status == 200: - data = json.loads(res.read()) + # Parallax returns a JSONified dict(images=image_list) + data = json.loads(res.read())['images'] return data else: - # TODO(jaypipes): return or raise HTTP error? + logging.warn("Parallax returned HTTP error %d from " + "request for /images", res.status_int) return [] finally: c.close() @@ -132,11 +138,12 @@ class ParallaxClient(object): c.request("GET", "images/%s" % image_id) res = c.getresponse() if res.status == 200: - data = json.loads(res.read()) + # Parallax returns a JSONified dict(image=image_info) + data = json.loads(res.read())['image'] return data else: - # TODO(jaypipes): return or raise HTTP error? - return [] + # TODO(jaypipes): log the error? + return None finally: c.close() @@ -179,7 +186,9 @@ class GlanceImageService(BaseImageService): Returns a dict containing image data for the given opaque image id. """ image = self.parallax.get_image_metadata(id) - return image + if image: + return image + raise exception.NotFound def create(self, data): """ @@ -236,7 +245,10 @@ class LocalImageService(BaseImageService): return [ self.show(id) for id in self._ids() ] def show(self, id): - return pickle.load(open(self._path_to(id))) + try: + return pickle.load(open(self._path_to(id))) + except IOError: + raise exception.NotFound def create(self, data): """ @@ -249,13 +261,19 @@ class LocalImageService(BaseImageService): def update(self, image_id, data): """Replace the contents of the given image with the new data.""" - pickle.dump(data, open(self._path_to(image_id), 'w')) + try: + pickle.dump(data, open(self._path_to(image_id), 'w')) + except IOError: + raise exception.NotFound def delete(self, image_id): """ Delete the given image. Raises OSError if the image does not exist. """ - os.unlink(self._path_to(image_id)) + try: + os.unlink(self._path_to(image_id)) + except IOError: + raise exception.NotFound def delete_all(self): """ diff --git a/nova/tests/api/rackspace/fakes.py b/nova/tests/api/rackspace/fakes.py index 3765b859e..c7d9216c8 100644 --- a/nova/tests/api/rackspace/fakes.py +++ b/nova/tests/api/rackspace/fakes.py @@ -139,7 +139,7 @@ def stub_out_glance(stubs): for k, f in self.fixtures.iteritems(): if k == image_id: return f - raise exc.NotFound + return None def fake_add_image_metadata(self, image_data): id = ''.join(random.choice(string.letters) for _ in range(20)) diff --git a/nova/tests/api/rackspace/test_images.py b/nova/tests/api/rackspace/test_images.py index 669346680..a7f320b46 100644 --- a/nova/tests/api/rackspace/test_images.py +++ b/nova/tests/api/rackspace/test_images.py @@ -20,6 +20,7 @@ import unittest import stubout +from nova import exception from nova import utils from nova.api.rackspace import images from nova.tests.api.rackspace import fakes @@ -45,6 +46,25 @@ class BaseImageServiceTests(): self.assertNotEquals(None, id) self.assertEquals(num_images + 1, len(self.service.index())) + def test_create_and_show_non_existing_image(self): + + fixture = {'name': 'test image', + 'updated': None, + 'created': None, + 'status': None, + 'serverId': None, + 'progress': None} + + num_images = len(self.service.index()) + + id = self.service.create(fixture) + + self.assertNotEquals(None, id) + + self.assertRaises(exception.NotFound, + self.service.show, + 'bad image id') + def test_update(self): fixture = {'name': 'test image', diff --git a/nova/tests/api/rackspace/test_servers.py b/nova/tests/api/rackspace/test_servers.py index 9c1860879..b20a8d432 100644 --- a/nova/tests/api/rackspace/test_servers.py +++ b/nova/tests/api/rackspace/test_servers.py @@ -33,6 +33,7 @@ from nova.tests.api.rackspace import fakes FLAGS = flags.FLAGS +FLAGS.verbose = True def return_server(context, id): return stub_instance(id) -- cgit From 8a2d7efa542e168fda81f703fa8e8c19467bf800 Mon Sep 17 00:00:00 2001 From: Michael Gundlach Date: Tue, 5 Oct 2010 13:40:17 -0400 Subject: Fix clause comparing id to internal_id --- nova/db/sqlalchemy/api.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 61feed9d0..870e7b1a5 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -663,8 +663,9 @@ def instance_get_by_internal_id(context, internal_id): def instance_internal_id_exists(context, internal_id, session=None): if not session: session = get_session() - return session.query(exists().where(models.Instance.id==internal_id) - ).one()[0] + return session.query( + exists().where(models.Instance.internal_id==internal_id) + ).one()[0] @require_context -- cgit From 091cf4ec5851e87bf722ed0bbbbfdf64dd599389 Mon Sep 17 00:00:00 2001 From: mdietz Date: Tue, 5 Oct 2010 19:52:12 +0000 Subject: A little more clean up --- nova/db/sqlalchemy/api.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 7429057ca..6f1ea7c23 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -29,8 +29,7 @@ from nova.db.sqlalchemy import models from nova.db.sqlalchemy.session import get_session from sqlalchemy import or_ from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import joinedload_all -from sqlalchemy.orm.exc import NoResultFound +from sqlalchemy.orm import joinedload, joinedload_all from sqlalchemy.sql import exists, func FLAGS = flags.FLAGS @@ -451,7 +450,6 @@ def fixed_ip_create(_context, values): fixed_ip_ref.save() return fixed_ip_ref['address'] - @require_context def fixed_ip_disassociate(context, address): session = get_session() -- cgit From c86462d11a6709bf9f2130056bf04712fe3db2d9 Mon Sep 17 00:00:00 2001 From: mdietz Date: Tue, 5 Oct 2010 20:07:11 +0000 Subject: merge prop fixes --- nova/api/rackspace/servers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/nova/api/rackspace/servers.py b/nova/api/rackspace/servers.py index 868b697e0..5cfb7a431 100644 --- a/nova/api/rackspace/servers.py +++ b/nova/api/rackspace/servers.py @@ -129,7 +129,7 @@ class Controller(wsgi.Controller): def show(self, req, id): """ Returns server details by server id """ user_id = req.environ['nova.context']['user']['id'] - inst = self.db_driver.instance_get_by_internal_id(None, id) + inst = self.db_driver.instance_get_by_internal_id(None, int(id)) if inst: if inst.user_id == user_id: return _entity_detail(inst) @@ -138,7 +138,7 @@ class Controller(wsgi.Controller): def delete(self, req, id): """ Destroys a server """ user_id = req.environ['nova.context']['user']['id'] - instance = self.db_driver.instance_get_by_internal_id(None, id) + instance = self.db_driver.instance_get_by_internal_id(None, int(id)) if instance and instance['user_id'] == user_id: self.db_driver.instance_destroy(None, id) return faults.Fault(exc.HTTPAccepted()) @@ -171,11 +171,11 @@ class Controller(wsgi.Controller): if not inst_dict: return faults.Fault(exc.HTTPUnprocessableEntity()) - instance = self.db_driver.instance_get_by_internal_id(None, id) + instance = self.db_driver.instance_get_by_internal_id(None, int(id)) if not instance or instance.user_id != user_id: return faults.Fault(exc.HTTPNotFound()) - self.db_driver.instance_update(None, id, + self.db_driver.instance_update(None, int(id), _filter_params(inst_dict['server'])) return faults.Fault(exc.HTTPNoContent()) @@ -187,7 +187,7 @@ class Controller(wsgi.Controller): reboot_type = input_dict['reboot']['type'] except Exception: raise faults.Fault(webob.exc.HTTPNotImplemented()) - opaque_id = _instance_id_translator().from_rs_id(id) + opaque_id = _instance_id_translator().from_rs_id(int(id)) cloud.reboot(opaque_id) def _build_server_instance(self, req, env): @@ -257,7 +257,7 @@ class Controller(wsgi.Controller): #TODO(dietz) is this necessary? inst['launch_index'] = 0 - inst['hostname'] = ref.internal_id + inst['hostname'] = str(ref.internal_id) self.db_driver.instance_update(None, inst['id'], inst) network_manager = utils.import_object(FLAGS.network_manager) -- cgit From db620f323c2fc5e65a722a33ae8a42b54817dae1 Mon Sep 17 00:00:00 2001 From: Michael Gundlach Date: Tue, 5 Oct 2010 16:16:42 -0400 Subject: Missed an ec2_id conversion to internal_id --- nova/api/ec2/cloud.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 7f5f4c4e9..175bb493c 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -262,7 +262,7 @@ class CloudController(object): # ec2_id_list is passed in as a list of instances ec2_id = ec2_id_list[0] internal_id = ec2_id_to_internal_id(ec2_id) - instance_ref = db.instance_get_by_ec2_id(context, internal_id) + instance_ref = db.instance_get_by_internal_id(context, internal_id) return rpc.call('%s.%s' % (FLAGS.compute_topic, instance_ref['host']), {"method": "get_console_output", -- cgit From fbd1bc015bd5615963b9073eefb895ea04c55a3e Mon Sep 17 00:00:00 2001 From: "jaypipes@gmail.com" <> Date: Tue, 5 Oct 2010 16:19:55 -0400 Subject: Merge overwrote import_object() load of image service. --- nova/api/rackspace/servers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nova/api/rackspace/servers.py b/nova/api/rackspace/servers.py index 5cfb7a431..b23867bbf 100644 --- a/nova/api/rackspace/servers.py +++ b/nova/api/rackspace/servers.py @@ -42,7 +42,7 @@ def _instance_id_translator(): def _image_service(): """ Helper method for initializing the image id translator """ - service = nova.image.service.ImageService.load() + service = utils.import_object(FLAGS.image_service) return (service, _id_translator.RackspaceAPIIdTranslator( "image", service.__class__.__name__)) -- cgit From 684c1ed50aebaed07cf89e6f1f7ee189a1b79b9b Mon Sep 17 00:00:00 2001 From: mdietz Date: Tue, 5 Oct 2010 20:43:23 +0000 Subject: Huge sweeping changes --- nova/tests/api/rackspace/test_servers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nova/tests/api/rackspace/test_servers.py b/nova/tests/api/rackspace/test_servers.py index 5a21356eb..92377538b 100644 --- a/nova/tests/api/rackspace/test_servers.py +++ b/nova/tests/api/rackspace/test_servers.py @@ -72,7 +72,7 @@ class ServersTest(unittest.TestCase): req = webob.Request.blank('/v1.0/servers/1') res = req.get_response(nova.api.API()) res_dict = json.loads(res.body) - self.assertEqual(res_dict['server']['id'], '1') + self.assertEqual(res_dict['server']['id'], 1) self.assertEqual(res_dict['server']['name'], 'server1') def test_get_server_list(self): @@ -93,7 +93,7 @@ class ServersTest(unittest.TestCase): def instance_create(context, inst): class Foo(object): - internal_id = 1 + internal_id = '1' return Foo() def fake_method(*args, **kwargs): -- cgit From 5f40379b407301c0907a72cde988197f3d18ea56 Mon Sep 17 00:00:00 2001 From: mdietz Date: Tue, 5 Oct 2010 21:00:05 +0000 Subject: Merge prop suggestions --- nova/tests/api/rackspace/fakes.py | 16 ---------------- nova/tests/api/rackspace/test_servers.py | 4 +--- 2 files changed, 1 insertion(+), 19 deletions(-) diff --git a/nova/tests/api/rackspace/fakes.py b/nova/tests/api/rackspace/fakes.py index 2c4447920..f623524ed 100644 --- a/nova/tests/api/rackspace/fakes.py +++ b/nova/tests/api/rackspace/fakes.py @@ -60,22 +60,6 @@ def stub_out_image_service(stubs): stubs.Set(nova.image.service.LocalImageService, 'show', fake_image_show) - -def stub_out_id_translator(stubs): - class FakeTranslator(object): - def __init__(self, id_type, service_name): - pass - - def to_rs_id(self, id): - return id - - def from_rs_id(self, id): - return id - - stubs.Set(nova.api.rackspace._id_translator, - 'RackspaceAPIIdTranslator', FakeTranslator) - - def stub_out_auth(stubs): def fake_auth_init(self, app): self.application = app diff --git a/nova/tests/api/rackspace/test_servers.py b/nova/tests/api/rackspace/test_servers.py index 92377538b..7fed0cc26 100644 --- a/nova/tests/api/rackspace/test_servers.py +++ b/nova/tests/api/rackspace/test_servers.py @@ -57,7 +57,6 @@ class ServersTest(unittest.TestCase): fakes.stub_out_networking(self.stubs) fakes.stub_out_rate_limiting(self.stubs) fakes.stub_out_auth(self.stubs) - fakes.stub_out_id_translator(self.stubs) fakes.stub_out_key_pair_funcs(self.stubs) fakes.stub_out_image_service(self.stubs) self.stubs.Set(nova.db.api, 'instance_get_all', return_servers) @@ -93,7 +92,7 @@ class ServersTest(unittest.TestCase): def instance_create(context, inst): class Foo(object): - internal_id = '1' + internal_id = 1 return Foo() def fake_method(*args, **kwargs): @@ -115,7 +114,6 @@ class ServersTest(unittest.TestCase): self.stubs.Set(nova.network.manager.VlanManager, 'allocate_fixed_ip', fake_method) - fakes.stub_out_id_translator(self.stubs) body = dict(server=dict( name='server_test', imageId=2, flavorId=2, metadata={}, personality = {} -- cgit From 8f524607856dbf4cecf7c7503e53e14c42888307 Mon Sep 17 00:00:00 2001 From: Hisaki Ohara Date: Wed, 6 Oct 2010 18:04:18 +0900 Subject: Defined images_path for nova-compute. Without its setting, it fails to launch instances by exception at _fetch_local_image. --- nova/virt/images.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nova/virt/images.py b/nova/virt/images.py index a60bcc4c1..9ba5b7890 100644 --- a/nova/virt/images.py +++ b/nova/virt/images.py @@ -26,6 +26,7 @@ import time import urlparse from nova import flags +from nova import utils from nova import process from nova.auth import manager from nova.auth import signer @@ -34,6 +35,8 @@ from nova.auth import signer FLAGS = flags.FLAGS flags.DEFINE_bool('use_s3', True, 'whether to get images from s3 or use local copy') +flags.DEFINE_string('images_path', utils.abspath('../images'), + 'path to decrypted images') def fetch(image, path, user, project): -- cgit From b7028c0d0262d3d4395077a8bd2d95664c6bf16e Mon Sep 17 00:00:00 2001 From: Hisaki Ohara Date: Thu, 7 Oct 2010 23:03:43 +0900 Subject: Imported images_path from nova.objectstore for nova-compute. Without its setting, it fails to launch instances by exception at _fetch_local_image. --- nova/virt/images.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/nova/virt/images.py b/nova/virt/images.py index 9ba5b7890..dc50764d9 100644 --- a/nova/virt/images.py +++ b/nova/virt/images.py @@ -26,17 +26,15 @@ import time import urlparse from nova import flags -from nova import utils from nova import process from nova.auth import manager from nova.auth import signer +from nova.objectstore import image FLAGS = flags.FLAGS flags.DEFINE_bool('use_s3', True, 'whether to get images from s3 or use local copy') -flags.DEFINE_string('images_path', utils.abspath('../images'), - 'path to decrypted images') def fetch(image, path, user, project): -- cgit From db87fd5a8145d045c4767a8d02cde5a0750113f8 Mon Sep 17 00:00:00 2001 From: Michael Gundlach Date: Fri, 8 Oct 2010 12:21:26 -0400 Subject: Remove redis dependency from Images controller. LocalImageService now works with integer ids, so there's no need for the translator. Once Glance exists we'll have to revisit this. --- nova/api/rackspace/backup_schedules.py | 1 - nova/api/rackspace/images.py | 10 +--------- nova/image/service.py | 16 +++++++++------- nova/tests/api/rackspace/fakes.py | 1 - 4 files changed, 10 insertions(+), 18 deletions(-) diff --git a/nova/api/rackspace/backup_schedules.py b/nova/api/rackspace/backup_schedules.py index cb83023bc..9c0d41fa0 100644 --- a/nova/api/rackspace/backup_schedules.py +++ b/nova/api/rackspace/backup_schedules.py @@ -19,7 +19,6 @@ import time from webob import exc from nova import wsgi -from nova.api.rackspace import _id_translator from nova.api.rackspace import faults import nova.image.service diff --git a/nova/api/rackspace/images.py b/nova/api/rackspace/images.py index d4ab8ce3c..82dcd2049 100644 --- a/nova/api/rackspace/images.py +++ b/nova/api/rackspace/images.py @@ -20,7 +20,6 @@ from webob import exc from nova import flags from nova import utils from nova import wsgi -from nova.api.rackspace import _id_translator import nova.api.rackspace import nova.image.service from nova.api.rackspace import faults @@ -41,8 +40,6 @@ class Controller(wsgi.Controller): def __init__(self): self._service = utils.import_object(FLAGS.image_service) - self._id_translator = _id_translator.RackspaceAPIIdTranslator( - "image", self._service.__class__.__name__) def index(self, req): """Return all public images in brief.""" @@ -53,16 +50,11 @@ class Controller(wsgi.Controller): """Return all public images in detail.""" data = self._service.index() data = nova.api.rackspace.limited(data, req) - for img in data: - img['id'] = self._id_translator.to_rs_id(img['id']) return dict(images=data) def show(self, req, id): """Return data about the given image id.""" - opaque_id = self._id_translator.from_rs_id(id) - img = self._service.show(opaque_id) - img['id'] = id - return dict(image=img) + return dict(image=self._service.show(id)) def delete(self, req, id): # Only public images are supported for now. diff --git a/nova/image/service.py b/nova/image/service.py index 2e570e8a4..5276e1312 100644 --- a/nova/image/service.py +++ b/nova/image/service.py @@ -225,7 +225,9 @@ class GlanceImageService(BaseImageService): class LocalImageService(BaseImageService): - """Image service storing images to local disk.""" + """Image service storing images to local disk. + + It assumes that image_ids are integers.""" def __init__(self): self._path = "/tmp/nova/images" @@ -234,12 +236,12 @@ class LocalImageService(BaseImageService): except OSError: # exists pass - def _path_to(self, image_id=''): - return os.path.join(self._path, image_id) + def _path_to(self, image_id): + return os.path.join(self._path, str(image_id)) def _ids(self): """The list of all image ids.""" - return os.listdir(self._path) + return [int(i) for i in os.listdir(self._path)] def index(self): return [ self.show(id) for id in self._ids() ] @@ -254,7 +256,7 @@ class LocalImageService(BaseImageService): """ Store the image data and return the new image id. """ - id = ''.join(random.choice(string.letters) for _ in range(20)) + id = random.randint(0, 2**32-1) data['id'] = id self.update(id, data) return id @@ -279,5 +281,5 @@ class LocalImageService(BaseImageService): """ Clears out all images in local directory """ - for f in os.listdir(self._path): - os.unlink(self._path_to(f)) + for id in self._ids(): + os.unlink(self._path_to(id)) diff --git a/nova/tests/api/rackspace/fakes.py b/nova/tests/api/rackspace/fakes.py index b5fba2dfa..6a25720a9 100644 --- a/nova/tests/api/rackspace/fakes.py +++ b/nova/tests/api/rackspace/fakes.py @@ -28,7 +28,6 @@ from nova import utils from nova import flags from nova import exception as exc import nova.api.rackspace.auth -import nova.api.rackspace._id_translator from nova.image import service from nova.wsgi import Router -- cgit From f1a48207dfc1948ba847f262d5a4ff825b02202c Mon Sep 17 00:00:00 2001 From: mdietz Date: Fri, 8 Oct 2010 18:56:32 +0000 Subject: Start stripping out the translators --- nova/api/rackspace/servers.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/nova/api/rackspace/servers.py b/nova/api/rackspace/servers.py index b23867bbf..8c489ed83 100644 --- a/nova/api/rackspace/servers.py +++ b/nova/api/rackspace/servers.py @@ -25,7 +25,6 @@ from nova import rpc from nova import utils from nova import wsgi from nova.api import cloud -from nova.api.rackspace import _id_translator from nova.api.rackspace import context from nova.api.rackspace import faults from nova.compute import instance_types @@ -35,11 +34,6 @@ import nova.image.service FLAGS = flags.FLAGS -def _instance_id_translator(): - """ Helper method for initializing an id translator for Rackspace instance - ids """ - return _id_translator.RackspaceAPIIdTranslator( "instance", 'nova') - def _image_service(): """ Helper method for initializing the image id translator """ service = utils.import_object(FLAGS.image_service) @@ -182,13 +176,16 @@ class Controller(wsgi.Controller): def action(self, req, id): """ multi-purpose method used to reboot, rebuild, and resize a server """ + user_id = req.environ['nova.context']['user']['id'] input_dict = self._deserialize(req.body, req) try: reboot_type = input_dict['reboot']['type'] except Exception: raise faults.Fault(webob.exc.HTTPNotImplemented()) - opaque_id = _instance_id_translator().from_rs_id(int(id)) - cloud.reboot(opaque_id) + inst_ref = self.db.instance_get_by_internal_id(None, int(id)) + if not inst_ref or (inst_ref and not inst_ref.user_id == user_id): + return faults.Fault(exc.HTTPUnprocessableEntity()) + cloud.reboot(id) def _build_server_instance(self, req, env): """Build instance data structure and save it to the data store.""" -- cgit From 90f38090ecd586a39257b3efd2c86c2c60b7fdb9 Mon Sep 17 00:00:00 2001 From: mdietz Date: Fri, 8 Oct 2010 20:39:00 +0000 Subject: Mass renaming --- nova/api/__init__.py | 4 +- nova/api/openstack/__init__.py | 190 ++++++++++++++++ nova/api/openstack/_id_translator.py | 42 ++++ nova/api/openstack/auth.py | 101 +++++++++ nova/api/openstack/backup_schedules.py | 38 ++++ nova/api/openstack/context.py | 33 +++ nova/api/openstack/faults.py | 62 ++++++ nova/api/openstack/flavors.py | 58 +++++ nova/api/openstack/images.py | 71 ++++++ nova/api/openstack/notes.txt | 23 ++ nova/api/openstack/ratelimiting/__init__.py | 122 ++++++++++ nova/api/openstack/servers.py | 276 +++++++++++++++++++++++ nova/api/openstack/sharedipgroups.py | 20 ++ nova/api/rackspace/__init__.py | 190 ---------------- nova/api/rackspace/_id_translator.py | 42 ---- nova/api/rackspace/auth.py | 101 --------- nova/api/rackspace/backup_schedules.py | 38 ---- nova/api/rackspace/context.py | 33 --- nova/api/rackspace/faults.py | 62 ------ nova/api/rackspace/flavors.py | 58 ----- nova/api/rackspace/images.py | 71 ------ nova/api/rackspace/notes.txt | 23 -- nova/api/rackspace/ratelimiting/__init__.py | 122 ---------- nova/api/rackspace/ratelimiting/tests.py | 237 -------------------- nova/api/rackspace/servers.py | 283 ------------------------ nova/api/rackspace/sharedipgroups.py | 20 -- nova/tests/api/__init__.py | 6 +- nova/tests/api/openstack/__init__.py | 108 +++++++++ nova/tests/api/openstack/fakes.py | 205 +++++++++++++++++ nova/tests/api/openstack/test_auth.py | 108 +++++++++ nova/tests/api/openstack/test_faults.py | 40 ++++ nova/tests/api/openstack/test_flavors.py | 48 ++++ nova/tests/api/openstack/test_images.py | 141 ++++++++++++ nova/tests/api/openstack/test_ratelimiting.py | 237 ++++++++++++++++++++ nova/tests/api/openstack/test_servers.py | 249 +++++++++++++++++++++ nova/tests/api/openstack/test_sharedipgroups.py | 39 ++++ nova/tests/api/rackspace/__init__.py | 108 --------- nova/tests/api/rackspace/fakes.py | 205 ----------------- nova/tests/api/rackspace/test_auth.py | 108 --------- nova/tests/api/rackspace/test_faults.py | 40 ---- nova/tests/api/rackspace/test_flavors.py | 48 ---- nova/tests/api/rackspace/test_images.py | 141 ------------ nova/tests/api/rackspace/test_servers.py | 249 --------------------- nova/tests/api/rackspace/test_sharedipgroups.py | 39 ---- 44 files changed, 2216 insertions(+), 2223 deletions(-) create mode 100644 nova/api/openstack/__init__.py create mode 100644 nova/api/openstack/_id_translator.py create mode 100644 nova/api/openstack/auth.py create mode 100644 nova/api/openstack/backup_schedules.py create mode 100644 nova/api/openstack/context.py create mode 100644 nova/api/openstack/faults.py create mode 100644 nova/api/openstack/flavors.py create mode 100644 nova/api/openstack/images.py create mode 100644 nova/api/openstack/notes.txt create mode 100644 nova/api/openstack/ratelimiting/__init__.py create mode 100644 nova/api/openstack/servers.py create mode 100644 nova/api/openstack/sharedipgroups.py delete mode 100644 nova/api/rackspace/__init__.py delete mode 100644 nova/api/rackspace/_id_translator.py delete mode 100644 nova/api/rackspace/auth.py delete mode 100644 nova/api/rackspace/backup_schedules.py delete mode 100644 nova/api/rackspace/context.py delete mode 100644 nova/api/rackspace/faults.py delete mode 100644 nova/api/rackspace/flavors.py delete mode 100644 nova/api/rackspace/images.py delete mode 100644 nova/api/rackspace/notes.txt delete mode 100644 nova/api/rackspace/ratelimiting/__init__.py delete mode 100644 nova/api/rackspace/ratelimiting/tests.py delete mode 100644 nova/api/rackspace/servers.py delete mode 100644 nova/api/rackspace/sharedipgroups.py create mode 100644 nova/tests/api/openstack/__init__.py create mode 100644 nova/tests/api/openstack/fakes.py create mode 100644 nova/tests/api/openstack/test_auth.py create mode 100644 nova/tests/api/openstack/test_faults.py create mode 100644 nova/tests/api/openstack/test_flavors.py create mode 100644 nova/tests/api/openstack/test_images.py create mode 100644 nova/tests/api/openstack/test_ratelimiting.py create mode 100644 nova/tests/api/openstack/test_servers.py create mode 100644 nova/tests/api/openstack/test_sharedipgroups.py delete mode 100644 nova/tests/api/rackspace/__init__.py delete mode 100644 nova/tests/api/rackspace/fakes.py delete mode 100644 nova/tests/api/rackspace/test_auth.py delete mode 100644 nova/tests/api/rackspace/test_faults.py delete mode 100644 nova/tests/api/rackspace/test_flavors.py delete mode 100644 nova/tests/api/rackspace/test_images.py delete mode 100644 nova/tests/api/rackspace/test_servers.py delete mode 100644 nova/tests/api/rackspace/test_sharedipgroups.py diff --git a/nova/api/__init__.py b/nova/api/__init__.py index 744abd621..627883018 100644 --- a/nova/api/__init__.py +++ b/nova/api/__init__.py @@ -27,7 +27,7 @@ from nova import flags from nova import wsgi from nova.api import cloudpipe from nova.api import ec2 -from nova.api import rackspace +from nova.api import openstack from nova.api.ec2 import metadatarequesthandler @@ -57,7 +57,7 @@ class API(wsgi.Router): mapper.sub_domains = True mapper.connect("/", controller=self.rsapi_versions, conditions=rsdomain) - mapper.connect("/v1.0/{path_info:.*}", controller=rackspace.API(), + mapper.connect("/v1.0/{path_info:.*}", controller=openstack.API(), conditions=rsdomain) mapper.connect("/", controller=self.ec2api_versions, diff --git a/nova/api/openstack/__init__.py b/nova/api/openstack/__init__.py new file mode 100644 index 000000000..5e81ba2bd --- /dev/null +++ b/nova/api/openstack/__init__.py @@ -0,0 +1,190 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +WSGI middleware for OpenStack API controllers. +""" + +import json +import time + +import routes +import webob.dec +import webob.exc +import webob + +from nova import flags +from nova import utils +from nova import wsgi +from nova.api.openstack import faults +from nova.api.openstack import backup_schedules +from nova.api.openstack import flavors +from nova.api.openstack import images +from nova.api.openstack import ratelimiting +from nova.api.openstack import servers +from nova.api.openstack import sharedipgroups +from nova.auth import manager + + +FLAGS = flags.FLAGS +flags.DEFINE_string('nova_api_auth', + 'nova.api.openstack.auth.BasicApiAuthManager', + 'The auth mechanism to use for the OpenStack API implemenation') + +class API(wsgi.Middleware): + """WSGI entry point for all OpenStack API requests.""" + + def __init__(self): + app = AuthMiddleware(RateLimitingMiddleware(APIRouter())) + super(API, self).__init__(app) + +class AuthMiddleware(wsgi.Middleware): + """Authorize the openstack API request or return an HTTP Forbidden.""" + + def __init__(self, application): + self.auth_driver = utils.import_class(FLAGS.nova_api_auth)() + super(AuthMiddleware, self).__init__(application) + + @webob.dec.wsgify + def __call__(self, req): + if not req.headers.has_key("X-Auth-Token"): + return self.auth_driver.authenticate(req) + + user = self.auth_driver.authorize_token(req.headers["X-Auth-Token"]) + + if not user: + return faults.Fault(webob.exc.HTTPUnauthorized()) + + if not req.environ.has_key('nova.context'): + req.environ['nova.context'] = {} + req.environ['nova.context']['user'] = user + return self.application + +class RateLimitingMiddleware(wsgi.Middleware): + """Rate limit incoming requests according to the OpenStack rate limits.""" + + def __init__(self, application, service_host=None): + """Create a rate limiting middleware that wraps the given application. + + By default, rate counters are stored in memory. If service_host is + specified, the middleware instead relies on the ratelimiting.WSGIApp + at the given host+port to keep rate counters. + """ + super(RateLimitingMiddleware, self).__init__(application) + if not service_host: + #TODO(gundlach): These limits were based on limitations of Cloud + #Servers. We should revisit them in Nova. + self.limiter = ratelimiting.Limiter(limits={ + 'DELETE': (100, ratelimiting.PER_MINUTE), + 'PUT': (10, ratelimiting.PER_MINUTE), + 'POST': (10, ratelimiting.PER_MINUTE), + 'POST servers': (50, ratelimiting.PER_DAY), + 'GET changes-since': (3, ratelimiting.PER_MINUTE), + }) + else: + self.limiter = ratelimiting.WSGIAppProxy(service_host) + + @webob.dec.wsgify + def __call__(self, req): + """Rate limit the request. + + If the request should be rate limited, return a 413 status with a + Retry-After header giving the time when the request would succeed. + """ + username = req.headers['X-Auth-User'] + action_name = self.get_action_name(req) + if not action_name: # not rate limited + return self.application + delay = self.get_delay(action_name, username) + if delay: + # TODO(gundlach): Get the retry-after format correct. + exc = webob.exc.HTTPRequestEntityTooLarge( + explanation='Too many requests.', + headers={'Retry-After': time.time() + delay}) + raise faults.Fault(exc) + return self.application + + def get_delay(self, action_name, username): + """Return the delay for the given action and username, or None if + the action would not be rate limited. + """ + if action_name == 'POST servers': + # "POST servers" is a POST, so it counts against "POST" too. + # Attempt the "POST" first, lest we are rate limited by "POST" but + # use up a precious "POST servers" call. + delay = self.limiter.perform("POST", username=username) + if delay: + return delay + return self.limiter.perform(action_name, username=username) + + def get_action_name(self, req): + """Return the action name for this request.""" + if req.method == 'GET' and 'changes-since' in req.GET: + return 'GET changes-since' + if req.method == 'POST' and req.path_info.startswith('/servers'): + return 'POST servers' + if req.method in ['PUT', 'POST', 'DELETE']: + return req.method + return None + + +class APIRouter(wsgi.Router): + """ + Routes requests on the OpenStack API to the appropriate controller + and method. + """ + + def __init__(self): + mapper = routes.Mapper() + mapper.resource("server", "servers", controller=servers.Controller(), + collection={ 'detail': 'GET'}, + member={'action':'POST'}) + + mapper.resource("backup_schedule", "backup_schedules", + controller=backup_schedules.Controller(), + parent_resource=dict(member_name='server', + collection_name = 'servers')) + + mapper.resource("image", "images", controller=images.Controller(), + collection={'detail': 'GET'}) + mapper.resource("flavor", "flavors", controller=flavors.Controller(), + collection={'detail': 'GET'}) + mapper.resource("sharedipgroup", "sharedipgroups", + controller=sharedipgroups.Controller()) + + super(APIRouter, self).__init__(mapper) + + +def limited(items, req): + """Return a slice of items according to requested offset and limit. + + items - a sliceable + req - wobob.Request possibly containing offset and limit GET variables. + offset is where to start in the list, and limit is the maximum number + of items to return. + + If limit is not specified, 0, or > 1000, defaults to 1000. + """ + offset = int(req.GET.get('offset', 0)) + limit = int(req.GET.get('limit', 0)) + if not limit: + limit = 1000 + limit = min(1000, limit) + range_end = offset + limit + return items[offset:range_end] + diff --git a/nova/api/openstack/_id_translator.py b/nova/api/openstack/_id_translator.py new file mode 100644 index 000000000..333aa8434 --- /dev/null +++ b/nova/api/openstack/_id_translator.py @@ -0,0 +1,42 @@ +from nova import datastore + +class RackspaceAPIIdTranslator(object): + """ + Converts Rackspace API ids to and from the id format for a given + strategy. + """ + + def __init__(self, id_type, service_name): + """ + Creates a translator for ids of the given type (e.g. 'flavor'), for the + given storage service backend class name (e.g. 'LocalFlavorService'). + """ + + self._store = datastore.Redis.instance() + key_prefix = "rsapi.idtranslator.%s.%s" % (id_type, service_name) + # Forward (strategy format -> RS format) and reverse translation keys + self._fwd_key = "%s.fwd" % key_prefix + self._rev_key = "%s.rev" % key_prefix + + def to_rs_id(self, opaque_id): + """Convert an id from a strategy-specific one to a Rackspace one.""" + result = self._store.hget(self._fwd_key, str(opaque_id)) + if result: # we have a mapping from opaque to RS for this strategy + return int(result) + else: + # Store the mapping. + nextid = self._store.incr("%s.lastid" % self._fwd_key) + if self._store.hsetnx(self._fwd_key, str(opaque_id), nextid): + # If someone else didn't beat us to it, store the reverse + # mapping as well. + self._store.hset(self._rev_key, nextid, str(opaque_id)) + return nextid + else: + # Someone beat us to it; use their number instead, and + # discard nextid (which is OK -- we don't require that + # every int id be used.) + return int(self._store.hget(self._fwd_key, str(opaque_id))) + + def from_rs_id(self, rs_id): + """Convert a Rackspace id to a strategy-specific one.""" + return self._store.hget(self._rev_key, rs_id) diff --git a/nova/api/openstack/auth.py b/nova/api/openstack/auth.py new file mode 100644 index 000000000..4c909293e --- /dev/null +++ b/nova/api/openstack/auth.py @@ -0,0 +1,101 @@ +import datetime +import hashlib +import json +import time + +import webob.exc +import webob.dec + +from nova import auth +from nova import db +from nova import flags +from nova import manager +from nova import utils +from nova.api.openstack import faults + +FLAGS = flags.FLAGS + +class Context(object): + pass + +class BasicApiAuthManager(object): + """ Implements a somewhat rudimentary version of OpenStack Auth""" + + def __init__(self, host=None, db_driver=None): + if not host: + host = FLAGS.host + self.host = host + if not db_driver: + db_driver = FLAGS.db_driver + self.db = utils.import_object(db_driver) + self.auth = auth.manager.AuthManager() + self.context = Context() + super(BasicApiAuthManager, self).__init__() + + def authenticate(self, req): + # Unless the request is explicitly made against // don't + # honor it + path_info = req.path_info + if len(path_info) > 1: + return faults.Fault(webob.exc.HTTPUnauthorized()) + + try: + username, key = req.headers['X-Auth-User'], \ + req.headers['X-Auth-Key'] + except KeyError: + return faults.Fault(webob.exc.HTTPUnauthorized()) + + username, key = req.headers['X-Auth-User'], req.headers['X-Auth-Key'] + token, user = self._authorize_user(username, key) + if user and token: + res = webob.Response() + res.headers['X-Auth-Token'] = token['token_hash'] + res.headers['X-Server-Management-Url'] = \ + token['server_management_url'] + res.headers['X-Storage-Url'] = token['storage_url'] + res.headers['X-CDN-Management-Url'] = token['cdn_management_url'] + res.content_type = 'text/plain' + res.status = '204' + return res + else: + return faults.Fault(webob.exc.HTTPUnauthorized()) + + def authorize_token(self, token_hash): + """ retrieves user information from the datastore given a token + + If the token has expired, returns None + If the token is not found, returns None + Otherwise returns the token + + This method will also remove the token if the timestamp is older than + 2 days ago. + """ + token = self.db.auth_get_token(self.context, token_hash) + if token: + delta = datetime.datetime.now() - token['created_at'] + if delta.days >= 2: + self.db.auth_destroy_token(self.context, token) + else: + user = self.auth.get_user(token['user_id']) + return { 'id':user['uid'] } + return None + + def _authorize_user(self, username, key): + """ Generates a new token and assigns it to a user """ + user = self.auth.get_user_from_access_key(key) + if user and user['name'] == username: + token_hash = hashlib.sha1('%s%s%f' % (username, key, + time.time())).hexdigest() + token = {} + token['token_hash'] = token_hash + token['cdn_management_url'] = '' + token['server_management_url'] = self._get_server_mgmt_url() + token['storage_url'] = '' + token['user_id'] = user['uid'] + self.db.auth_create_token(self.context, token) + return token, user + return None, None + + def _get_server_mgmt_url(self): + return 'https://%s/v1.0/' % self.host + diff --git a/nova/api/openstack/backup_schedules.py b/nova/api/openstack/backup_schedules.py new file mode 100644 index 000000000..76ad6ef87 --- /dev/null +++ b/nova/api/openstack/backup_schedules.py @@ -0,0 +1,38 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import time +from webob import exc + +from nova import wsgi +from nova.api.openstack import faults +import nova.image.service + +class Controller(wsgi.Controller): + def __init__(self): + pass + + def index(self, req, server_id): + return faults.Fault(exc.HTTPNotFound()) + + def create(self, req, server_id): + """ No actual update method required, since the existing API allows + both create and update through a POST """ + return faults.Fault(exc.HTTPNotFound()) + + def delete(self, req, server_id): + return faults.Fault(exc.HTTPNotFound()) diff --git a/nova/api/openstack/context.py b/nova/api/openstack/context.py new file mode 100644 index 000000000..77394615b --- /dev/null +++ b/nova/api/openstack/context.py @@ -0,0 +1,33 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +APIRequestContext +""" + +import random + +class Project(object): + def __init__(self, user_id): + self.id = user_id + +class APIRequestContext(object): + """ This is an adapter class to get around all of the assumptions made in + the FlatNetworking """ + def __init__(self, user_id): + self.user_id = user_id + self.project = Project(user_id) diff --git a/nova/api/openstack/faults.py b/nova/api/openstack/faults.py new file mode 100644 index 000000000..32e5c866f --- /dev/null +++ b/nova/api/openstack/faults.py @@ -0,0 +1,62 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import webob.dec +import webob.exc + +from nova import wsgi + + +class Fault(webob.exc.HTTPException): + + """An RS API fault response.""" + + _fault_names = { + 400: "badRequest", + 401: "unauthorized", + 403: "resizeNotAllowed", + 404: "itemNotFound", + 405: "badMethod", + 409: "inProgress", + 413: "overLimit", + 415: "badMediaType", + 501: "notImplemented", + 503: "serviceUnavailable"} + + def __init__(self, exception): + """Create a Fault for the given webob.exc.exception.""" + self.wrapped_exc = exception + + @webob.dec.wsgify + def __call__(self, req): + """Generate a WSGI response based on the exception passed to ctor.""" + # Replace the body with fault details. + code = self.wrapped_exc.status_int + fault_name = self._fault_names.get(code, "cloudServersFault") + fault_data = { + fault_name: { + 'code': code, + 'message': self.wrapped_exc.explanation}} + if code == 413: + retry = self.wrapped_exc.headers['Retry-After'] + fault_data[fault_name]['retryAfter'] = retry + # 'code' is an attribute on the fault tag itself + metadata = {'application/xml': {'attributes': {fault_name: 'code'}}} + serializer = wsgi.Serializer(req.environ, metadata) + self.wrapped_exc.body = serializer.to_content_type(fault_data) + return self.wrapped_exc diff --git a/nova/api/openstack/flavors.py b/nova/api/openstack/flavors.py new file mode 100644 index 000000000..793984a5d --- /dev/null +++ b/nova/api/openstack/flavors.py @@ -0,0 +1,58 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from webob import exc + +from nova.api.openstack import faults +from nova.compute import instance_types +from nova import wsgi +import nova.api.openstack + +class Controller(wsgi.Controller): + """Flavor controller for the OpenStack API.""" + + _serialization_metadata = { + 'application/xml': { + "attributes": { + "flavor": [ "id", "name", "ram", "disk" ] + } + } + } + + def index(self, req): + """Return all flavors in brief.""" + return dict(flavors=[dict(id=flavor['id'], name=flavor['name']) + for flavor in self.detail(req)['flavors']]) + + def detail(self, req): + """Return all flavors in detail.""" + items = [self.show(req, id)['flavor'] for id in self._all_ids()] + items = nova.api.openstack.limited(items, req) + return dict(flavors=items) + + def show(self, req, id): + """Return data about the given flavor id.""" + for name, val in instance_types.INSTANCE_TYPES.iteritems(): + if val['flavorid'] == int(id): + item = dict(ram=val['memory_mb'], disk=val['local_gb'], + id=val['flavorid'], name=name) + return dict(flavor=item) + raise faults.Fault(exc.HTTPNotFound()) + + def _all_ids(self): + """Return the list of all flavorids.""" + return [i['flavorid'] for i in instance_types.INSTANCE_TYPES.values()] diff --git a/nova/api/openstack/images.py b/nova/api/openstack/images.py new file mode 100644 index 000000000..aa438739c --- /dev/null +++ b/nova/api/openstack/images.py @@ -0,0 +1,71 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from webob import exc + +from nova import flags +from nova import utils +from nova import wsgi +import nova.api.openstack +import nova.image.service +from nova.api.openstack import faults + + +FLAGS = flags.FLAGS + +class Controller(wsgi.Controller): + + _serialization_metadata = { + 'application/xml': { + "attributes": { + "image": [ "id", "name", "updated", "created", "status", + "serverId", "progress" ] + } + } + } + + def __init__(self): + self._service = utils.import_object(FLAGS.image_service) + + def index(self, req): + """Return all public images in brief.""" + return dict(images=[dict(id=img['id'], name=img['name']) + for img in self.detail(req)['images']]) + + def detail(self, req): + """Return all public images in detail.""" + data = self._service.index() + data = nova.api.openstack.limited(data, req) + return dict(images=data) + + def show(self, req, id): + """Return data about the given image id.""" + return dict(image=self._service.show(id)) + + def delete(self, req, id): + # Only public images are supported for now. + raise faults.Fault(exc.HTTPNotFound()) + + def create(self, req): + # Only public images are supported for now, so a request to + # make a backup of a server cannot be supproted. + raise faults.Fault(exc.HTTPNotFound()) + + def update(self, req, id): + # Users may not modify public images, and that's all that + # we support for now. + raise faults.Fault(exc.HTTPNotFound()) diff --git a/nova/api/openstack/notes.txt b/nova/api/openstack/notes.txt new file mode 100644 index 000000000..2330f1002 --- /dev/null +++ b/nova/api/openstack/notes.txt @@ -0,0 +1,23 @@ +We will need: + +ImageService +a service that can do crud on image information. not user-specific. opaque +image ids. + +GlanceImageService(ImageService): +image ids are URIs. + +LocalImageService(ImageService): +image ids are random strings. + +OpenstackAPITranslationStore: +translates RS server/images/flavor/etc ids into formats required +by a given ImageService strategy. + +api.openstack.images.Controller: +uses an ImageService strategy behind the scenes to do its fetching; it just +converts int image id into a strategy-specific image id. + +who maintains the mapping from user to [images he owns]? nobody, because +we have no way of enforcing access to his images, without kryptex which +won't be in Austin. diff --git a/nova/api/openstack/ratelimiting/__init__.py b/nova/api/openstack/ratelimiting/__init__.py new file mode 100644 index 000000000..f843bac0f --- /dev/null +++ b/nova/api/openstack/ratelimiting/__init__.py @@ -0,0 +1,122 @@ +"""Rate limiting of arbitrary actions.""" + +import httplib +import time +import urllib +import webob.dec +import webob.exc + + +# Convenience constants for the limits dictionary passed to Limiter(). +PER_SECOND = 1 +PER_MINUTE = 60 +PER_HOUR = 60 * 60 +PER_DAY = 60 * 60 * 24 + +class Limiter(object): + + """Class providing rate limiting of arbitrary actions.""" + + def __init__(self, limits): + """Create a rate limiter. + + limits: a dict mapping from action name to a tuple. The tuple contains + the number of times the action may be performed, and the time period + (in seconds) during which the number must not be exceeded for this + action. Example: dict(reboot=(10, ratelimiting.PER_MINUTE)) would + allow 10 'reboot' actions per minute. + """ + self.limits = limits + self._levels = {} + + def perform(self, action_name, username='nobody'): + """Attempt to perform an action by the given username. + + action_name: the string name of the action to perform. This must + be a key in the limits dict passed to the ctor. + + username: an optional string name of the user performing the action. + Each user has her own set of rate limiting counters. Defaults to + 'nobody' (so that if you never specify a username when calling + perform(), a single set of counters will be used.) + + Return None if the action may proceed. If the action may not proceed + because it has been rate limited, return the float number of seconds + until the action would succeed. + """ + # Think of rate limiting as a bucket leaking water at 1cc/second. The + # bucket can hold as many ccs as there are seconds in the rate + # limiting period (e.g. 3600 for per-hour ratelimits), and if you can + # perform N actions in that time, each action fills the bucket by + # 1/Nth of its volume. You may only perform an action if the bucket + # would not overflow. + now = time.time() + key = '%s:%s' % (username, action_name) + last_time_performed, water_level = self._levels.get(key, (now, 0)) + # The bucket leaks 1cc/second. + water_level -= (now - last_time_performed) + if water_level < 0: + water_level = 0 + num_allowed_per_period, period_in_secs = self.limits[action_name] + # Fill the bucket by 1/Nth its capacity, and hope it doesn't overflow. + capacity = period_in_secs + new_level = water_level + (capacity * 1.0 / num_allowed_per_period) + if new_level > capacity: + # Delay this many seconds. + return new_level - capacity + self._levels[key] = (now, new_level) + return None + + +# If one instance of this WSGIApps is unable to handle your load, put a +# sharding app in front that shards by username to one of many backends. + +class WSGIApp(object): + + """Application that tracks rate limits in memory. Send requests to it of + this form: + + POST /limiter// + + and receive a 200 OK, or a 403 Forbidden with an X-Wait-Seconds header + containing the number of seconds to wait before the action would succeed. + """ + + def __init__(self, limiter): + """Create the WSGI application using the given Limiter instance.""" + self.limiter = limiter + + @webob.dec.wsgify + def __call__(self, req): + parts = req.path_info.split('/') + # format: /limiter// + if req.method != 'POST': + raise webob.exc.HTTPMethodNotAllowed() + if len(parts) != 4 or parts[1] != 'limiter': + raise webob.exc.HTTPNotFound() + username = parts[2] + action_name = urllib.unquote(parts[3]) + delay = self.limiter.perform(action_name, username) + if delay: + return webob.exc.HTTPForbidden( + headers={'X-Wait-Seconds': "%.2f" % delay}) + else: + return '' # 200 OK + + +class WSGIAppProxy(object): + + """Limiter lookalike that proxies to a ratelimiting.WSGIApp.""" + + def __init__(self, service_host): + """Creates a proxy pointing to a ratelimiting.WSGIApp at the given + host.""" + self.service_host = service_host + + def perform(self, action, username='nobody'): + conn = httplib.HTTPConnection(self.service_host) + conn.request('POST', '/limiter/%s/%s' % (username, action)) + resp = conn.getresponse() + if resp.status == 200: + return None # no delay + return float(resp.getheader('X-Wait-Seconds')) diff --git a/nova/api/openstack/servers.py b/nova/api/openstack/servers.py new file mode 100644 index 000000000..f234af7de --- /dev/null +++ b/nova/api/openstack/servers.py @@ -0,0 +1,276 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import time + +import webob +from webob import exc + +from nova import flags +from nova import rpc +from nova import utils +from nova import wsgi +from nova.api import cloud +from nova.api.openstack import context +from nova.api.openstack import faults +from nova.compute import instance_types +from nova.compute import power_state +import nova.api.openstack +import nova.image.service + +FLAGS = flags.FLAGS + +def _filter_params(inst_dict): + """ Extracts all updatable parameters for a server update request """ + keys = dict(name='name', admin_pass='adminPass') + new_attrs = {} + for k, v in keys.items(): + if inst_dict.has_key(v): + new_attrs[k] = inst_dict[v] + return new_attrs + +def _entity_list(entities): + """ Coerces a list of servers into proper dictionary format """ + return dict(servers=entities) + +def _entity_detail(inst): + """ Maps everything to valid attributes for return""" + power_mapping = { + power_state.NOSTATE: 'build', + power_state.RUNNING: 'active', + power_state.BLOCKED: 'active', + power_state.PAUSED: 'suspended', + power_state.SHUTDOWN: 'active', + power_state.SHUTOFF: 'active', + power_state.CRASHED: 'error' + } + inst_dict = {} + + mapped_keys = dict(status='state', imageId='image_id', + flavorId='instance_type', name='server_name', id='id') + + for k, v in mapped_keys.iteritems(): + inst_dict[k] = inst[v] + + inst_dict['status'] = power_mapping[inst_dict['status']] + inst_dict['addresses'] = dict(public=[], private=[]) + inst_dict['metadata'] = {} + inst_dict['hostId'] = '' + + return dict(server=inst_dict) + +def _entity_inst(inst): + """ Filters all model attributes save for id and name """ + return dict(server=dict(id=inst['id'], name=inst['server_name'])) + +class Controller(wsgi.Controller): + """ The Server API controller for the OpenStack API """ + + _serialization_metadata = { + 'application/xml': { + "attributes": { + "server": [ "id", "imageId", "name", "flavorId", "hostId", + "status", "progress", "progress" ] + } + } + } + + def __init__(self, db_driver=None): + if not db_driver: + db_driver = FLAGS.db_driver + self.db_driver = utils.import_object(db_driver) + super(Controller, self).__init__() + + def index(self, req): + """ Returns a list of server names and ids for a given user """ + return self._items(req, entity_maker=_entity_inst) + + def detail(self, req): + """ Returns a list of server details for a given user """ + return self._items(req, entity_maker=_entity_detail) + + def _items(self, req, entity_maker): + """Returns a list of servers for a given user. + + entity_maker - either _entity_detail or _entity_inst + """ + user_id = req.environ['nova.context']['user']['id'] + instance_list = self.db_driver.instance_get_all_by_user(None, user_id) + limited_list = nova.api.openstack.limited(instance_list, req) + res = [entity_maker(inst)['server'] for inst in limited_list] + return _entity_list(res) + + def show(self, req, id): + """ Returns server details by server id """ + user_id = req.environ['nova.context']['user']['id'] + inst = self.db_driver.instance_get_by_internal_id(None, int(id)) + if inst: + if inst.user_id == user_id: + return _entity_detail(inst) + raise faults.Fault(exc.HTTPNotFound()) + + def delete(self, req, id): + """ Destroys a server """ + user_id = req.environ['nova.context']['user']['id'] + instance = self.db_driver.instance_get_by_internal_id(None, int(id)) + if instance and instance['user_id'] == user_id: + self.db_driver.instance_destroy(None, id) + return faults.Fault(exc.HTTPAccepted()) + return faults.Fault(exc.HTTPNotFound()) + + def create(self, req): + """ Creates a new server for a given user """ + + env = self._deserialize(req.body, req) + if not env: + return faults.Fault(exc.HTTPUnprocessableEntity()) + + #try: + inst = self._build_server_instance(req, env) + #except Exception, e: + # return faults.Fault(exc.HTTPUnprocessableEntity()) + + rpc.cast( + FLAGS.compute_topic, { + "method": "run_instance", + "args": {"instance_id": inst['id']}}) + return _entity_inst(inst) + + def update(self, req, id): + """ Updates the server name or password """ + user_id = req.environ['nova.context']['user']['id'] + + inst_dict = self._deserialize(req.body, req) + + if not inst_dict: + return faults.Fault(exc.HTTPUnprocessableEntity()) + + instance = self.db_driver.instance_get_by_internal_id(None, int(id)) + if not instance or instance.user_id != user_id: + return faults.Fault(exc.HTTPNotFound()) + + self.db_driver.instance_update(None, int(id), + _filter_params(inst_dict['server'])) + return faults.Fault(exc.HTTPNoContent()) + + def action(self, req, id): + """ multi-purpose method used to reboot, rebuild, and + resize a server """ + user_id = req.environ['nova.context']['user']['id'] + input_dict = self._deserialize(req.body, req) + try: + reboot_type = input_dict['reboot']['type'] + except Exception: + raise faults.Fault(webob.exc.HTTPNotImplemented()) + inst_ref = self.db.instance_get_by_internal_id(None, int(id)) + if not inst_ref or (inst_ref and not inst_ref.user_id == user_id): + return faults.Fault(exc.HTTPUnprocessableEntity()) + cloud.reboot(id) + + def _build_server_instance(self, req, env): + """Build instance data structure and save it to the data store.""" + ltime = time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime()) + inst = {} + + user_id = req.environ['nova.context']['user']['id'] + + flavor_id = env['server']['flavorId'] + + instance_type, flavor = [(k, v) for k, v in + instance_types.INSTANCE_TYPES.iteritems() + if v['flavorid'] == flavor_id][0] + + image_id = env['server']['imageId'] + + img_service = utils.import_object(FLAGS.image_service) + + image = img_service.show(image_id) + + if not image: + raise Exception, "Image not found" + + inst['server_name'] = env['server']['name'] + inst['image_id'] = image_id + inst['user_id'] = user_id + inst['launch_time'] = ltime + inst['mac_address'] = utils.generate_mac() + inst['project_id'] = user_id + + inst['state_description'] = 'scheduling' + inst['kernel_id'] = image.get('kernelId', FLAGS.default_kernel) + inst['ramdisk_id'] = image.get('ramdiskId', FLAGS.default_ramdisk) + inst['reservation_id'] = utils.generate_uid('r') + + inst['display_name'] = env['server']['name'] + inst['display_description'] = env['server']['name'] + + #TODO(dietz) this may be ill advised + key_pair_ref = self.db_driver.key_pair_get_all_by_user( + None, user_id)[0] + + inst['key_data'] = key_pair_ref['public_key'] + inst['key_name'] = key_pair_ref['name'] + + #TODO(dietz) stolen from ec2 api, see TODO there + inst['security_group'] = 'default' + + # Flavor related attributes + inst['instance_type'] = instance_type + inst['memory_mb'] = flavor['memory_mb'] + inst['vcpus'] = flavor['vcpus'] + inst['local_gb'] = flavor['local_gb'] + + ref = self.db_driver.instance_create(None, inst) + inst['id'] = ref.internal_id + + # TODO(dietz): this isn't explicitly necessary, but the networking + # calls depend on an object with a project_id property, and therefore + # should be cleaned up later + api_context = context.APIRequestContext(user_id) + + inst['mac_address'] = utils.generate_mac() + + #TODO(dietz) is this necessary? + inst['launch_index'] = 0 + + inst['hostname'] = str(ref.internal_id) + self.db_driver.instance_update(None, inst['id'], inst) + + network_manager = utils.import_object(FLAGS.network_manager) + address = network_manager.allocate_fixed_ip(api_context, + inst['id']) + + # TODO(vish): This probably should be done in the scheduler + # network is setup when host is assigned + network_topic = self._get_network_topic(user_id) + rpc.call(network_topic, + {"method": "setup_fixed_ip", + "args": {"context": None, + "address": address}}) + return inst + + def _get_network_topic(self, user_id): + """Retrieves the network host for a project""" + network_ref = self.db_driver.project_get_network(None, + user_id) + host = network_ref['host'] + if not host: + host = rpc.call(FLAGS.network_topic, + {"method": "set_network_host", + "args": {"context": None, + "project_id": user_id}}) + return self.db_driver.queue_get_for(None, FLAGS.network_topic, host) diff --git a/nova/api/openstack/sharedipgroups.py b/nova/api/openstack/sharedipgroups.py new file mode 100644 index 000000000..4d2d0ede1 --- /dev/null +++ b/nova/api/openstack/sharedipgroups.py @@ -0,0 +1,20 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from nova import wsgi + +class Controller(wsgi.Controller): pass diff --git a/nova/api/rackspace/__init__.py b/nova/api/rackspace/__init__.py deleted file mode 100644 index 89a4693ad..000000000 --- a/nova/api/rackspace/__init__.py +++ /dev/null @@ -1,190 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 United States Government as represented by the -# Administrator of the National Aeronautics and Space Administration. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -""" -WSGI middleware for Rackspace API controllers. -""" - -import json -import time - -import routes -import webob.dec -import webob.exc -import webob - -from nova import flags -from nova import utils -from nova import wsgi -from nova.api.rackspace import faults -from nova.api.rackspace import backup_schedules -from nova.api.rackspace import flavors -from nova.api.rackspace import images -from nova.api.rackspace import ratelimiting -from nova.api.rackspace import servers -from nova.api.rackspace import sharedipgroups -from nova.auth import manager - - -FLAGS = flags.FLAGS -flags.DEFINE_string('nova_api_auth', - 'nova.api.rackspace.auth.BasicApiAuthManager', - 'The auth mechanism to use for the Rackspace API implemenation') - -class API(wsgi.Middleware): - """WSGI entry point for all Rackspace API requests.""" - - def __init__(self): - app = AuthMiddleware(RateLimitingMiddleware(APIRouter())) - super(API, self).__init__(app) - -class AuthMiddleware(wsgi.Middleware): - """Authorize the rackspace API request or return an HTTP Forbidden.""" - - def __init__(self, application): - self.auth_driver = utils.import_class(FLAGS.nova_api_auth)() - super(AuthMiddleware, self).__init__(application) - - @webob.dec.wsgify - def __call__(self, req): - if not req.headers.has_key("X-Auth-Token"): - return self.auth_driver.authenticate(req) - - user = self.auth_driver.authorize_token(req.headers["X-Auth-Token"]) - - if not user: - return faults.Fault(webob.exc.HTTPUnauthorized()) - - if not req.environ.has_key('nova.context'): - req.environ['nova.context'] = {} - req.environ['nova.context']['user'] = user - return self.application - -class RateLimitingMiddleware(wsgi.Middleware): - """Rate limit incoming requests according to the OpenStack rate limits.""" - - def __init__(self, application, service_host=None): - """Create a rate limiting middleware that wraps the given application. - - By default, rate counters are stored in memory. If service_host is - specified, the middleware instead relies on the ratelimiting.WSGIApp - at the given host+port to keep rate counters. - """ - super(RateLimitingMiddleware, self).__init__(application) - if not service_host: - #TODO(gundlach): These limits were based on limitations of Cloud - #Servers. We should revisit them in Nova. - self.limiter = ratelimiting.Limiter(limits={ - 'DELETE': (100, ratelimiting.PER_MINUTE), - 'PUT': (10, ratelimiting.PER_MINUTE), - 'POST': (10, ratelimiting.PER_MINUTE), - 'POST servers': (50, ratelimiting.PER_DAY), - 'GET changes-since': (3, ratelimiting.PER_MINUTE), - }) - else: - self.limiter = ratelimiting.WSGIAppProxy(service_host) - - @webob.dec.wsgify - def __call__(self, req): - """Rate limit the request. - - If the request should be rate limited, return a 413 status with a - Retry-After header giving the time when the request would succeed. - """ - username = req.headers['X-Auth-User'] - action_name = self.get_action_name(req) - if not action_name: # not rate limited - return self.application - delay = self.get_delay(action_name, username) - if delay: - # TODO(gundlach): Get the retry-after format correct. - exc = webob.exc.HTTPRequestEntityTooLarge( - explanation='Too many requests.', - headers={'Retry-After': time.time() + delay}) - raise faults.Fault(exc) - return self.application - - def get_delay(self, action_name, username): - """Return the delay for the given action and username, or None if - the action would not be rate limited. - """ - if action_name == 'POST servers': - # "POST servers" is a POST, so it counts against "POST" too. - # Attempt the "POST" first, lest we are rate limited by "POST" but - # use up a precious "POST servers" call. - delay = self.limiter.perform("POST", username=username) - if delay: - return delay - return self.limiter.perform(action_name, username=username) - - def get_action_name(self, req): - """Return the action name for this request.""" - if req.method == 'GET' and 'changes-since' in req.GET: - return 'GET changes-since' - if req.method == 'POST' and req.path_info.startswith('/servers'): - return 'POST servers' - if req.method in ['PUT', 'POST', 'DELETE']: - return req.method - return None - - -class APIRouter(wsgi.Router): - """ - Routes requests on the Rackspace API to the appropriate controller - and method. - """ - - def __init__(self): - mapper = routes.Mapper() - mapper.resource("server", "servers", controller=servers.Controller(), - collection={ 'detail': 'GET'}, - member={'action':'POST'}) - - mapper.resource("backup_schedule", "backup_schedules", - controller=backup_schedules.Controller(), - parent_resource=dict(member_name='server', - collection_name = 'servers')) - - mapper.resource("image", "images", controller=images.Controller(), - collection={'detail': 'GET'}) - mapper.resource("flavor", "flavors", controller=flavors.Controller(), - collection={'detail': 'GET'}) - mapper.resource("sharedipgroup", "sharedipgroups", - controller=sharedipgroups.Controller()) - - super(APIRouter, self).__init__(mapper) - - -def limited(items, req): - """Return a slice of items according to requested offset and limit. - - items - a sliceable - req - wobob.Request possibly containing offset and limit GET variables. - offset is where to start in the list, and limit is the maximum number - of items to return. - - If limit is not specified, 0, or > 1000, defaults to 1000. - """ - offset = int(req.GET.get('offset', 0)) - limit = int(req.GET.get('limit', 0)) - if not limit: - limit = 1000 - limit = min(1000, limit) - range_end = offset + limit - return items[offset:range_end] - diff --git a/nova/api/rackspace/_id_translator.py b/nova/api/rackspace/_id_translator.py deleted file mode 100644 index 333aa8434..000000000 --- a/nova/api/rackspace/_id_translator.py +++ /dev/null @@ -1,42 +0,0 @@ -from nova import datastore - -class RackspaceAPIIdTranslator(object): - """ - Converts Rackspace API ids to and from the id format for a given - strategy. - """ - - def __init__(self, id_type, service_name): - """ - Creates a translator for ids of the given type (e.g. 'flavor'), for the - given storage service backend class name (e.g. 'LocalFlavorService'). - """ - - self._store = datastore.Redis.instance() - key_prefix = "rsapi.idtranslator.%s.%s" % (id_type, service_name) - # Forward (strategy format -> RS format) and reverse translation keys - self._fwd_key = "%s.fwd" % key_prefix - self._rev_key = "%s.rev" % key_prefix - - def to_rs_id(self, opaque_id): - """Convert an id from a strategy-specific one to a Rackspace one.""" - result = self._store.hget(self._fwd_key, str(opaque_id)) - if result: # we have a mapping from opaque to RS for this strategy - return int(result) - else: - # Store the mapping. - nextid = self._store.incr("%s.lastid" % self._fwd_key) - if self._store.hsetnx(self._fwd_key, str(opaque_id), nextid): - # If someone else didn't beat us to it, store the reverse - # mapping as well. - self._store.hset(self._rev_key, nextid, str(opaque_id)) - return nextid - else: - # Someone beat us to it; use their number instead, and - # discard nextid (which is OK -- we don't require that - # every int id be used.) - return int(self._store.hget(self._fwd_key, str(opaque_id))) - - def from_rs_id(self, rs_id): - """Convert a Rackspace id to a strategy-specific one.""" - return self._store.hget(self._rev_key, rs_id) diff --git a/nova/api/rackspace/auth.py b/nova/api/rackspace/auth.py deleted file mode 100644 index c45156ebd..000000000 --- a/nova/api/rackspace/auth.py +++ /dev/null @@ -1,101 +0,0 @@ -import datetime -import hashlib -import json -import time - -import webob.exc -import webob.dec - -from nova import auth -from nova import db -from nova import flags -from nova import manager -from nova import utils -from nova.api.rackspace import faults - -FLAGS = flags.FLAGS - -class Context(object): - pass - -class BasicApiAuthManager(object): - """ Implements a somewhat rudimentary version of Rackspace Auth""" - - def __init__(self, host=None, db_driver=None): - if not host: - host = FLAGS.host - self.host = host - if not db_driver: - db_driver = FLAGS.db_driver - self.db = utils.import_object(db_driver) - self.auth = auth.manager.AuthManager() - self.context = Context() - super(BasicApiAuthManager, self).__init__() - - def authenticate(self, req): - # Unless the request is explicitly made against // don't - # honor it - path_info = req.path_info - if len(path_info) > 1: - return faults.Fault(webob.exc.HTTPUnauthorized()) - - try: - username, key = req.headers['X-Auth-User'], \ - req.headers['X-Auth-Key'] - except KeyError: - return faults.Fault(webob.exc.HTTPUnauthorized()) - - username, key = req.headers['X-Auth-User'], req.headers['X-Auth-Key'] - token, user = self._authorize_user(username, key) - if user and token: - res = webob.Response() - res.headers['X-Auth-Token'] = token['token_hash'] - res.headers['X-Server-Management-Url'] = \ - token['server_management_url'] - res.headers['X-Storage-Url'] = token['storage_url'] - res.headers['X-CDN-Management-Url'] = token['cdn_management_url'] - res.content_type = 'text/plain' - res.status = '204' - return res - else: - return faults.Fault(webob.exc.HTTPUnauthorized()) - - def authorize_token(self, token_hash): - """ retrieves user information from the datastore given a token - - If the token has expired, returns None - If the token is not found, returns None - Otherwise returns the token - - This method will also remove the token if the timestamp is older than - 2 days ago. - """ - token = self.db.auth_get_token(self.context, token_hash) - if token: - delta = datetime.datetime.now() - token['created_at'] - if delta.days >= 2: - self.db.auth_destroy_token(self.context, token) - else: - user = self.auth.get_user(token['user_id']) - return { 'id':user['uid'] } - return None - - def _authorize_user(self, username, key): - """ Generates a new token and assigns it to a user """ - user = self.auth.get_user_from_access_key(key) - if user and user['name'] == username: - token_hash = hashlib.sha1('%s%s%f' % (username, key, - time.time())).hexdigest() - token = {} - token['token_hash'] = token_hash - token['cdn_management_url'] = '' - token['server_management_url'] = self._get_server_mgmt_url() - token['storage_url'] = '' - token['user_id'] = user['uid'] - self.db.auth_create_token(self.context, token) - return token, user - return None, None - - def _get_server_mgmt_url(self): - return 'https://%s/v1.0/' % self.host - diff --git a/nova/api/rackspace/backup_schedules.py b/nova/api/rackspace/backup_schedules.py deleted file mode 100644 index 9c0d41fa0..000000000 --- a/nova/api/rackspace/backup_schedules.py +++ /dev/null @@ -1,38 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import time -from webob import exc - -from nova import wsgi -from nova.api.rackspace import faults -import nova.image.service - -class Controller(wsgi.Controller): - def __init__(self): - pass - - def index(self, req, server_id): - return faults.Fault(exc.HTTPNotFound()) - - def create(self, req, server_id): - """ No actual update method required, since the existing API allows - both create and update through a POST """ - return faults.Fault(exc.HTTPNotFound()) - - def delete(self, req, server_id): - return faults.Fault(exc.HTTPNotFound()) diff --git a/nova/api/rackspace/context.py b/nova/api/rackspace/context.py deleted file mode 100644 index 77394615b..000000000 --- a/nova/api/rackspace/context.py +++ /dev/null @@ -1,33 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -""" -APIRequestContext -""" - -import random - -class Project(object): - def __init__(self, user_id): - self.id = user_id - -class APIRequestContext(object): - """ This is an adapter class to get around all of the assumptions made in - the FlatNetworking """ - def __init__(self, user_id): - self.user_id = user_id - self.project = Project(user_id) diff --git a/nova/api/rackspace/faults.py b/nova/api/rackspace/faults.py deleted file mode 100644 index 32e5c866f..000000000 --- a/nova/api/rackspace/faults.py +++ /dev/null @@ -1,62 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - - -import webob.dec -import webob.exc - -from nova import wsgi - - -class Fault(webob.exc.HTTPException): - - """An RS API fault response.""" - - _fault_names = { - 400: "badRequest", - 401: "unauthorized", - 403: "resizeNotAllowed", - 404: "itemNotFound", - 405: "badMethod", - 409: "inProgress", - 413: "overLimit", - 415: "badMediaType", - 501: "notImplemented", - 503: "serviceUnavailable"} - - def __init__(self, exception): - """Create a Fault for the given webob.exc.exception.""" - self.wrapped_exc = exception - - @webob.dec.wsgify - def __call__(self, req): - """Generate a WSGI response based on the exception passed to ctor.""" - # Replace the body with fault details. - code = self.wrapped_exc.status_int - fault_name = self._fault_names.get(code, "cloudServersFault") - fault_data = { - fault_name: { - 'code': code, - 'message': self.wrapped_exc.explanation}} - if code == 413: - retry = self.wrapped_exc.headers['Retry-After'] - fault_data[fault_name]['retryAfter'] = retry - # 'code' is an attribute on the fault tag itself - metadata = {'application/xml': {'attributes': {fault_name: 'code'}}} - serializer = wsgi.Serializer(req.environ, metadata) - self.wrapped_exc.body = serializer.to_content_type(fault_data) - return self.wrapped_exc diff --git a/nova/api/rackspace/flavors.py b/nova/api/rackspace/flavors.py deleted file mode 100644 index 916449854..000000000 --- a/nova/api/rackspace/flavors.py +++ /dev/null @@ -1,58 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -from webob import exc - -from nova.api.rackspace import faults -from nova.compute import instance_types -from nova import wsgi -import nova.api.rackspace - -class Controller(wsgi.Controller): - """Flavor controller for the Rackspace API.""" - - _serialization_metadata = { - 'application/xml': { - "attributes": { - "flavor": [ "id", "name", "ram", "disk" ] - } - } - } - - def index(self, req): - """Return all flavors in brief.""" - return dict(flavors=[dict(id=flavor['id'], name=flavor['name']) - for flavor in self.detail(req)['flavors']]) - - def detail(self, req): - """Return all flavors in detail.""" - items = [self.show(req, id)['flavor'] for id in self._all_ids()] - items = nova.api.rackspace.limited(items, req) - return dict(flavors=items) - - def show(self, req, id): - """Return data about the given flavor id.""" - for name, val in instance_types.INSTANCE_TYPES.iteritems(): - if val['flavorid'] == int(id): - item = dict(ram=val['memory_mb'], disk=val['local_gb'], - id=val['flavorid'], name=name) - return dict(flavor=item) - raise faults.Fault(exc.HTTPNotFound()) - - def _all_ids(self): - """Return the list of all flavorids.""" - return [i['flavorid'] for i in instance_types.INSTANCE_TYPES.values()] diff --git a/nova/api/rackspace/images.py b/nova/api/rackspace/images.py deleted file mode 100644 index 82dcd2049..000000000 --- a/nova/api/rackspace/images.py +++ /dev/null @@ -1,71 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -from webob import exc - -from nova import flags -from nova import utils -from nova import wsgi -import nova.api.rackspace -import nova.image.service -from nova.api.rackspace import faults - - -FLAGS = flags.FLAGS - -class Controller(wsgi.Controller): - - _serialization_metadata = { - 'application/xml': { - "attributes": { - "image": [ "id", "name", "updated", "created", "status", - "serverId", "progress" ] - } - } - } - - def __init__(self): - self._service = utils.import_object(FLAGS.image_service) - - def index(self, req): - """Return all public images in brief.""" - return dict(images=[dict(id=img['id'], name=img['name']) - for img in self.detail(req)['images']]) - - def detail(self, req): - """Return all public images in detail.""" - data = self._service.index() - data = nova.api.rackspace.limited(data, req) - return dict(images=data) - - def show(self, req, id): - """Return data about the given image id.""" - return dict(image=self._service.show(id)) - - def delete(self, req, id): - # Only public images are supported for now. - raise faults.Fault(exc.HTTPNotFound()) - - def create(self, req): - # Only public images are supported for now, so a request to - # make a backup of a server cannot be supproted. - raise faults.Fault(exc.HTTPNotFound()) - - def update(self, req, id): - # Users may not modify public images, and that's all that - # we support for now. - raise faults.Fault(exc.HTTPNotFound()) diff --git a/nova/api/rackspace/notes.txt b/nova/api/rackspace/notes.txt deleted file mode 100644 index e133bf5ea..000000000 --- a/nova/api/rackspace/notes.txt +++ /dev/null @@ -1,23 +0,0 @@ -We will need: - -ImageService -a service that can do crud on image information. not user-specific. opaque -image ids. - -GlanceImageService(ImageService): -image ids are URIs. - -LocalImageService(ImageService): -image ids are random strings. - -RackspaceAPITranslationStore: -translates RS server/images/flavor/etc ids into formats required -by a given ImageService strategy. - -api.rackspace.images.Controller: -uses an ImageService strategy behind the scenes to do its fetching; it just -converts int image id into a strategy-specific image id. - -who maintains the mapping from user to [images he owns]? nobody, because -we have no way of enforcing access to his images, without kryptex which -won't be in Austin. diff --git a/nova/api/rackspace/ratelimiting/__init__.py b/nova/api/rackspace/ratelimiting/__init__.py deleted file mode 100644 index f843bac0f..000000000 --- a/nova/api/rackspace/ratelimiting/__init__.py +++ /dev/null @@ -1,122 +0,0 @@ -"""Rate limiting of arbitrary actions.""" - -import httplib -import time -import urllib -import webob.dec -import webob.exc - - -# Convenience constants for the limits dictionary passed to Limiter(). -PER_SECOND = 1 -PER_MINUTE = 60 -PER_HOUR = 60 * 60 -PER_DAY = 60 * 60 * 24 - -class Limiter(object): - - """Class providing rate limiting of arbitrary actions.""" - - def __init__(self, limits): - """Create a rate limiter. - - limits: a dict mapping from action name to a tuple. The tuple contains - the number of times the action may be performed, and the time period - (in seconds) during which the number must not be exceeded for this - action. Example: dict(reboot=(10, ratelimiting.PER_MINUTE)) would - allow 10 'reboot' actions per minute. - """ - self.limits = limits - self._levels = {} - - def perform(self, action_name, username='nobody'): - """Attempt to perform an action by the given username. - - action_name: the string name of the action to perform. This must - be a key in the limits dict passed to the ctor. - - username: an optional string name of the user performing the action. - Each user has her own set of rate limiting counters. Defaults to - 'nobody' (so that if you never specify a username when calling - perform(), a single set of counters will be used.) - - Return None if the action may proceed. If the action may not proceed - because it has been rate limited, return the float number of seconds - until the action would succeed. - """ - # Think of rate limiting as a bucket leaking water at 1cc/second. The - # bucket can hold as many ccs as there are seconds in the rate - # limiting period (e.g. 3600 for per-hour ratelimits), and if you can - # perform N actions in that time, each action fills the bucket by - # 1/Nth of its volume. You may only perform an action if the bucket - # would not overflow. - now = time.time() - key = '%s:%s' % (username, action_name) - last_time_performed, water_level = self._levels.get(key, (now, 0)) - # The bucket leaks 1cc/second. - water_level -= (now - last_time_performed) - if water_level < 0: - water_level = 0 - num_allowed_per_period, period_in_secs = self.limits[action_name] - # Fill the bucket by 1/Nth its capacity, and hope it doesn't overflow. - capacity = period_in_secs - new_level = water_level + (capacity * 1.0 / num_allowed_per_period) - if new_level > capacity: - # Delay this many seconds. - return new_level - capacity - self._levels[key] = (now, new_level) - return None - - -# If one instance of this WSGIApps is unable to handle your load, put a -# sharding app in front that shards by username to one of many backends. - -class WSGIApp(object): - - """Application that tracks rate limits in memory. Send requests to it of - this form: - - POST /limiter// - - and receive a 200 OK, or a 403 Forbidden with an X-Wait-Seconds header - containing the number of seconds to wait before the action would succeed. - """ - - def __init__(self, limiter): - """Create the WSGI application using the given Limiter instance.""" - self.limiter = limiter - - @webob.dec.wsgify - def __call__(self, req): - parts = req.path_info.split('/') - # format: /limiter// - if req.method != 'POST': - raise webob.exc.HTTPMethodNotAllowed() - if len(parts) != 4 or parts[1] != 'limiter': - raise webob.exc.HTTPNotFound() - username = parts[2] - action_name = urllib.unquote(parts[3]) - delay = self.limiter.perform(action_name, username) - if delay: - return webob.exc.HTTPForbidden( - headers={'X-Wait-Seconds': "%.2f" % delay}) - else: - return '' # 200 OK - - -class WSGIAppProxy(object): - - """Limiter lookalike that proxies to a ratelimiting.WSGIApp.""" - - def __init__(self, service_host): - """Creates a proxy pointing to a ratelimiting.WSGIApp at the given - host.""" - self.service_host = service_host - - def perform(self, action, username='nobody'): - conn = httplib.HTTPConnection(self.service_host) - conn.request('POST', '/limiter/%s/%s' % (username, action)) - resp = conn.getresponse() - if resp.status == 200: - return None # no delay - return float(resp.getheader('X-Wait-Seconds')) diff --git a/nova/api/rackspace/ratelimiting/tests.py b/nova/api/rackspace/ratelimiting/tests.py deleted file mode 100644 index 4c9510917..000000000 --- a/nova/api/rackspace/ratelimiting/tests.py +++ /dev/null @@ -1,237 +0,0 @@ -import httplib -import StringIO -import time -import unittest -import webob - -import nova.api.rackspace.ratelimiting as ratelimiting - -class LimiterTest(unittest.TestCase): - - def setUp(self): - self.limits = { - 'a': (5, ratelimiting.PER_SECOND), - 'b': (5, ratelimiting.PER_MINUTE), - 'c': (5, ratelimiting.PER_HOUR), - 'd': (1, ratelimiting.PER_SECOND), - 'e': (100, ratelimiting.PER_SECOND)} - self.rl = ratelimiting.Limiter(self.limits) - - def exhaust(self, action, times_until_exhausted, **kwargs): - for i in range(times_until_exhausted): - when = self.rl.perform(action, **kwargs) - self.assertEqual(when, None) - num, period = self.limits[action] - delay = period * 1.0 / num - # Verify that we are now thoroughly delayed - for i in range(10): - when = self.rl.perform(action, **kwargs) - self.assertAlmostEqual(when, delay, 2) - - def test_second(self): - self.exhaust('a', 5) - time.sleep(0.2) - self.exhaust('a', 1) - time.sleep(1) - self.exhaust('a', 5) - - def test_minute(self): - self.exhaust('b', 5) - - def test_one_per_period(self): - def allow_once_and_deny_once(): - when = self.rl.perform('d') - self.assertEqual(when, None) - when = self.rl.perform('d') - self.assertAlmostEqual(when, 1, 2) - return when - time.sleep(allow_once_and_deny_once()) - time.sleep(allow_once_and_deny_once()) - allow_once_and_deny_once() - - def test_we_can_go_indefinitely_if_we_spread_out_requests(self): - for i in range(200): - when = self.rl.perform('e') - self.assertEqual(when, None) - time.sleep(0.01) - - def test_users_get_separate_buckets(self): - self.exhaust('c', 5, username='alice') - self.exhaust('c', 5, username='bob') - self.exhaust('c', 5, username='chuck') - self.exhaust('c', 0, username='chuck') - self.exhaust('c', 0, username='bob') - self.exhaust('c', 0, username='alice') - - -class FakeLimiter(object): - """Fake Limiter class that you can tell how to behave.""" - def __init__(self, test): - self._action = self._username = self._delay = None - self.test = test - def mock(self, action, username, delay): - self._action = action - self._username = username - self._delay = delay - def perform(self, action, username): - self.test.assertEqual(action, self._action) - self.test.assertEqual(username, self._username) - return self._delay - - -class WSGIAppTest(unittest.TestCase): - - def setUp(self): - self.limiter = FakeLimiter(self) - self.app = ratelimiting.WSGIApp(self.limiter) - - def test_invalid_methods(self): - requests = [] - for method in ['GET', 'PUT', 'DELETE']: - req = webob.Request.blank('/limits/michael/breakdance', - dict(REQUEST_METHOD=method)) - requests.append(req) - for req in requests: - self.assertEqual(req.get_response(self.app).status_int, 405) - - def test_invalid_urls(self): - requests = [] - for prefix in ['limit', '', 'limiter2', 'limiter/limits', 'limiter/1']: - req = webob.Request.blank('/%s/michael/breakdance' % prefix, - dict(REQUEST_METHOD='POST')) - requests.append(req) - for req in requests: - self.assertEqual(req.get_response(self.app).status_int, 404) - - def verify(self, url, username, action, delay=None): - """Make sure that POSTing to the given url causes the given username - to perform the given action. Make the internal rate limiter return - delay and make sure that the WSGI app returns the correct response. - """ - req = webob.Request.blank(url, dict(REQUEST_METHOD='POST')) - self.limiter.mock(action, username, delay) - resp = req.get_response(self.app) - if not delay: - self.assertEqual(resp.status_int, 200) - else: - self.assertEqual(resp.status_int, 403) - self.assertEqual(resp.headers['X-Wait-Seconds'], "%.2f" % delay) - - def test_good_urls(self): - self.verify('/limiter/michael/hoot', 'michael', 'hoot') - - def test_escaping(self): - self.verify('/limiter/michael/jump%20up', 'michael', 'jump up') - - def test_response_to_delays(self): - self.verify('/limiter/michael/hoot', 'michael', 'hoot', 1) - self.verify('/limiter/michael/hoot', 'michael', 'hoot', 1.56) - self.verify('/limiter/michael/hoot', 'michael', 'hoot', 1000) - - -class FakeHttplibSocket(object): - """a fake socket implementation for httplib.HTTPResponse, trivial""" - - def __init__(self, response_string): - self._buffer = StringIO.StringIO(response_string) - - def makefile(self, _mode, _other): - """Returns the socket's internal buffer""" - return self._buffer - - -class FakeHttplibConnection(object): - """A fake httplib.HTTPConnection - - Requests made via this connection actually get translated and routed into - our WSGI app, we then wait for the response and turn it back into - an httplib.HTTPResponse. - """ - def __init__(self, app, host, is_secure=False): - self.app = app - self.host = host - - def request(self, method, path, data='', headers={}): - req = webob.Request.blank(path) - req.method = method - req.body = data - req.headers = headers - req.host = self.host - # Call the WSGI app, get the HTTP response - resp = str(req.get_response(self.app)) - # For some reason, the response doesn't have "HTTP/1.0 " prepended; I - # guess that's a function the web server usually provides. - resp = "HTTP/1.0 %s" % resp - sock = FakeHttplibSocket(resp) - self.http_response = httplib.HTTPResponse(sock) - self.http_response.begin() - - def getresponse(self): - return self.http_response - - -def wire_HTTPConnection_to_WSGI(host, app): - """Monkeypatches HTTPConnection so that if you try to connect to host, you - are instead routed straight to the given WSGI app. - - After calling this method, when any code calls - - httplib.HTTPConnection(host) - - the connection object will be a fake. Its requests will be sent directly - to the given WSGI app rather than through a socket. - - Code connecting to hosts other than host will not be affected. - - This method may be called multiple times to map different hosts to - different apps. - """ - class HTTPConnectionDecorator(object): - """Wraps the real HTTPConnection class so that when you instantiate - the class you might instead get a fake instance.""" - def __init__(self, wrapped): - self.wrapped = wrapped - def __call__(self, connection_host, *args, **kwargs): - if connection_host == host: - return FakeHttplibConnection(app, host) - else: - return self.wrapped(connection_host, *args, **kwargs) - httplib.HTTPConnection = HTTPConnectionDecorator(httplib.HTTPConnection) - - -class WSGIAppProxyTest(unittest.TestCase): - - def setUp(self): - """Our WSGIAppProxy is going to call across an HTTPConnection to a - WSGIApp running a limiter. The proxy will send input, and the proxy - should receive that same input, pass it to the limiter who gives a - result, and send the expected result back. - - The HTTPConnection isn't real -- it's monkeypatched to point straight - at the WSGIApp. And the limiter isn't real -- it's a fake that - behaves the way we tell it to. - """ - self.limiter = FakeLimiter(self) - app = ratelimiting.WSGIApp(self.limiter) - wire_HTTPConnection_to_WSGI('100.100.100.100:80', app) - self.proxy = ratelimiting.WSGIAppProxy('100.100.100.100:80') - - def test_200(self): - self.limiter.mock('conquer', 'caesar', None) - when = self.proxy.perform('conquer', 'caesar') - self.assertEqual(when, None) - - def test_403(self): - self.limiter.mock('grumble', 'proletariat', 1.5) - when = self.proxy.perform('grumble', 'proletariat') - self.assertEqual(when, 1.5) - - def test_failure(self): - def shouldRaise(): - self.limiter.mock('murder', 'brutus', None) - self.proxy.perform('stab', 'brutus') - self.assertRaises(AssertionError, shouldRaise) - - -if __name__ == '__main__': - unittest.main() diff --git a/nova/api/rackspace/servers.py b/nova/api/rackspace/servers.py deleted file mode 100644 index 8c489ed83..000000000 --- a/nova/api/rackspace/servers.py +++ /dev/null @@ -1,283 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import time - -import webob -from webob import exc - -from nova import flags -from nova import rpc -from nova import utils -from nova import wsgi -from nova.api import cloud -from nova.api.rackspace import context -from nova.api.rackspace import faults -from nova.compute import instance_types -from nova.compute import power_state -import nova.api.rackspace -import nova.image.service - -FLAGS = flags.FLAGS - -def _image_service(): - """ Helper method for initializing the image id translator """ - service = utils.import_object(FLAGS.image_service) - return (service, _id_translator.RackspaceAPIIdTranslator( - "image", service.__class__.__name__)) - -def _filter_params(inst_dict): - """ Extracts all updatable parameters for a server update request """ - keys = dict(name='name', admin_pass='adminPass') - new_attrs = {} - for k, v in keys.items(): - if inst_dict.has_key(v): - new_attrs[k] = inst_dict[v] - return new_attrs - -def _entity_list(entities): - """ Coerces a list of servers into proper dictionary format """ - return dict(servers=entities) - -def _entity_detail(inst): - """ Maps everything to Rackspace-like attributes for return""" - power_mapping = { - power_state.NOSTATE: 'build', - power_state.RUNNING: 'active', - power_state.BLOCKED: 'active', - power_state.PAUSED: 'suspended', - power_state.SHUTDOWN: 'active', - power_state.SHUTOFF: 'active', - power_state.CRASHED: 'error' - } - inst_dict = {} - - mapped_keys = dict(status='state', imageId='image_id', - flavorId='instance_type', name='server_name', id='id') - - for k, v in mapped_keys.iteritems(): - inst_dict[k] = inst[v] - - inst_dict['status'] = power_mapping[inst_dict['status']] - inst_dict['addresses'] = dict(public=[], private=[]) - inst_dict['metadata'] = {} - inst_dict['hostId'] = '' - - return dict(server=inst_dict) - -def _entity_inst(inst): - """ Filters all model attributes save for id and name """ - return dict(server=dict(id=inst['id'], name=inst['server_name'])) - -class Controller(wsgi.Controller): - """ The Server API controller for the Openstack API """ - - _serialization_metadata = { - 'application/xml': { - "attributes": { - "server": [ "id", "imageId", "name", "flavorId", "hostId", - "status", "progress", "progress" ] - } - } - } - - def __init__(self, db_driver=None): - if not db_driver: - db_driver = FLAGS.db_driver - self.db_driver = utils.import_object(db_driver) - super(Controller, self).__init__() - - def index(self, req): - """ Returns a list of server names and ids for a given user """ - return self._items(req, entity_maker=_entity_inst) - - def detail(self, req): - """ Returns a list of server details for a given user """ - return self._items(req, entity_maker=_entity_detail) - - def _items(self, req, entity_maker): - """Returns a list of servers for a given user. - - entity_maker - either _entity_detail or _entity_inst - """ - user_id = req.environ['nova.context']['user']['id'] - instance_list = self.db_driver.instance_get_all_by_user(None, user_id) - limited_list = nova.api.rackspace.limited(instance_list, req) - res = [entity_maker(inst)['server'] for inst in limited_list] - return _entity_list(res) - - def show(self, req, id): - """ Returns server details by server id """ - user_id = req.environ['nova.context']['user']['id'] - inst = self.db_driver.instance_get_by_internal_id(None, int(id)) - if inst: - if inst.user_id == user_id: - return _entity_detail(inst) - raise faults.Fault(exc.HTTPNotFound()) - - def delete(self, req, id): - """ Destroys a server """ - user_id = req.environ['nova.context']['user']['id'] - instance = self.db_driver.instance_get_by_internal_id(None, int(id)) - if instance and instance['user_id'] == user_id: - self.db_driver.instance_destroy(None, id) - return faults.Fault(exc.HTTPAccepted()) - return faults.Fault(exc.HTTPNotFound()) - - def create(self, req): - """ Creates a new server for a given user """ - - env = self._deserialize(req.body, req) - if not env: - return faults.Fault(exc.HTTPUnprocessableEntity()) - - #try: - inst = self._build_server_instance(req, env) - #except Exception, e: - # return faults.Fault(exc.HTTPUnprocessableEntity()) - - rpc.cast( - FLAGS.compute_topic, { - "method": "run_instance", - "args": {"instance_id": inst['id']}}) - return _entity_inst(inst) - - def update(self, req, id): - """ Updates the server name or password """ - user_id = req.environ['nova.context']['user']['id'] - - inst_dict = self._deserialize(req.body, req) - - if not inst_dict: - return faults.Fault(exc.HTTPUnprocessableEntity()) - - instance = self.db_driver.instance_get_by_internal_id(None, int(id)) - if not instance or instance.user_id != user_id: - return faults.Fault(exc.HTTPNotFound()) - - self.db_driver.instance_update(None, int(id), - _filter_params(inst_dict['server'])) - return faults.Fault(exc.HTTPNoContent()) - - def action(self, req, id): - """ multi-purpose method used to reboot, rebuild, and - resize a server """ - user_id = req.environ['nova.context']['user']['id'] - input_dict = self._deserialize(req.body, req) - try: - reboot_type = input_dict['reboot']['type'] - except Exception: - raise faults.Fault(webob.exc.HTTPNotImplemented()) - inst_ref = self.db.instance_get_by_internal_id(None, int(id)) - if not inst_ref or (inst_ref and not inst_ref.user_id == user_id): - return faults.Fault(exc.HTTPUnprocessableEntity()) - cloud.reboot(id) - - def _build_server_instance(self, req, env): - """Build instance data structure and save it to the data store.""" - ltime = time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime()) - inst = {} - - user_id = req.environ['nova.context']['user']['id'] - - flavor_id = env['server']['flavorId'] - - instance_type, flavor = [(k, v) for k, v in - instance_types.INSTANCE_TYPES.iteritems() - if v['flavorid'] == flavor_id][0] - - image_id = env['server']['imageId'] - - img_service, image_id_trans = _image_service() - - opaque_image_id = image_id_trans.to_rs_id(image_id) - image = img_service.show(opaque_image_id) - - if not image: - raise Exception, "Image not found" - - inst['server_name'] = env['server']['name'] - inst['image_id'] = opaque_image_id - inst['user_id'] = user_id - inst['launch_time'] = ltime - inst['mac_address'] = utils.generate_mac() - inst['project_id'] = user_id - - inst['state_description'] = 'scheduling' - inst['kernel_id'] = image.get('kernelId', FLAGS.default_kernel) - inst['ramdisk_id'] = image.get('ramdiskId', FLAGS.default_ramdisk) - inst['reservation_id'] = utils.generate_uid('r') - - inst['display_name'] = env['server']['name'] - inst['display_description'] = env['server']['name'] - - #TODO(dietz) this may be ill advised - key_pair_ref = self.db_driver.key_pair_get_all_by_user( - None, user_id)[0] - - inst['key_data'] = key_pair_ref['public_key'] - inst['key_name'] = key_pair_ref['name'] - - #TODO(dietz) stolen from ec2 api, see TODO there - inst['security_group'] = 'default' - - # Flavor related attributes - inst['instance_type'] = instance_type - inst['memory_mb'] = flavor['memory_mb'] - inst['vcpus'] = flavor['vcpus'] - inst['local_gb'] = flavor['local_gb'] - - ref = self.db_driver.instance_create(None, inst) - inst['id'] = ref.internal_id - - # TODO(dietz): this isn't explicitly necessary, but the networking - # calls depend on an object with a project_id property, and therefore - # should be cleaned up later - api_context = context.APIRequestContext(user_id) - - inst['mac_address'] = utils.generate_mac() - - #TODO(dietz) is this necessary? - inst['launch_index'] = 0 - - inst['hostname'] = str(ref.internal_id) - self.db_driver.instance_update(None, inst['id'], inst) - - network_manager = utils.import_object(FLAGS.network_manager) - address = network_manager.allocate_fixed_ip(api_context, - inst['id']) - - # TODO(vish): This probably should be done in the scheduler - # network is setup when host is assigned - network_topic = self._get_network_topic(user_id) - rpc.call(network_topic, - {"method": "setup_fixed_ip", - "args": {"context": None, - "address": address}}) - return inst - - def _get_network_topic(self, user_id): - """Retrieves the network host for a project""" - network_ref = self.db_driver.project_get_network(None, - user_id) - host = network_ref['host'] - if not host: - host = rpc.call(FLAGS.network_topic, - {"method": "set_network_host", - "args": {"context": None, - "project_id": user_id}}) - return self.db_driver.queue_get_for(None, FLAGS.network_topic, host) diff --git a/nova/api/rackspace/sharedipgroups.py b/nova/api/rackspace/sharedipgroups.py deleted file mode 100644 index 4d2d0ede1..000000000 --- a/nova/api/rackspace/sharedipgroups.py +++ /dev/null @@ -1,20 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -from nova import wsgi - -class Controller(wsgi.Controller): pass diff --git a/nova/tests/api/__init__.py b/nova/tests/api/__init__.py index ec76aa827..2c7f7fd3e 100644 --- a/nova/tests/api/__init__.py +++ b/nova/tests/api/__init__.py @@ -44,8 +44,8 @@ class Test(unittest.TestCase): req = webob.Request.blank(url, environ_keys) return req.get_response(api.API()) - def test_rackspace(self): - self.stubs.Set(api.rackspace, 'API', APIStub) + def test_openstack(self): + self.stubs.Set(api.openstack, 'API', APIStub) result = self._request('/v1.0/cloud', 'rs') self.assertEqual(result.body, "/cloud") @@ -56,7 +56,7 @@ class Test(unittest.TestCase): def test_not_found(self): self.stubs.Set(api.ec2, 'API', APIStub) - self.stubs.Set(api.rackspace, 'API', APIStub) + self.stubs.Set(api.openstack, 'API', APIStub) result = self._request('/test/cloud', 'ec2') self.assertNotEqual(result.body, "/cloud") diff --git a/nova/tests/api/openstack/__init__.py b/nova/tests/api/openstack/__init__.py new file mode 100644 index 000000000..b534897f5 --- /dev/null +++ b/nova/tests/api/openstack/__init__.py @@ -0,0 +1,108 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +from nova.api.openstack import limited +from nova.api.openstack import RateLimitingMiddleware +from nova.tests.api.fakes import APIStub +from webob import Request + + +class RateLimitingMiddlewareTest(unittest.TestCase): + + def test_get_action_name(self): + middleware = RateLimitingMiddleware(APIStub()) + def verify(method, url, action_name): + req = Request.blank(url) + req.method = method + action = middleware.get_action_name(req) + self.assertEqual(action, action_name) + verify('PUT', '/servers/4', 'PUT') + verify('DELETE', '/servers/4', 'DELETE') + verify('POST', '/images/4', 'POST') + verify('POST', '/servers/4', 'POST servers') + verify('GET', '/foo?a=4&changes-since=never&b=5', 'GET changes-since') + verify('GET', '/foo?a=4&monkeys-since=never&b=5', None) + verify('GET', '/servers/4', None) + verify('HEAD', '/servers/4', None) + + def exhaust(self, middleware, method, url, username, times): + req = Request.blank(url, dict(REQUEST_METHOD=method), + headers={'X-Auth-User': username}) + for i in range(times): + resp = req.get_response(middleware) + self.assertEqual(resp.status_int, 200) + resp = req.get_response(middleware) + self.assertEqual(resp.status_int, 413) + self.assertTrue('Retry-After' in resp.headers) + + def test_single_action(self): + middleware = RateLimitingMiddleware(APIStub()) + self.exhaust(middleware, 'DELETE', '/servers/4', 'usr1', 100) + self.exhaust(middleware, 'DELETE', '/servers/4', 'usr2', 100) + + def test_POST_servers_action_implies_POST_action(self): + middleware = RateLimitingMiddleware(APIStub()) + self.exhaust(middleware, 'POST', '/servers/4', 'usr1', 10) + self.exhaust(middleware, 'POST', '/images/4', 'usr2', 10) + self.assertTrue(set(middleware.limiter._levels) == + set(['usr1:POST', 'usr1:POST servers', 'usr2:POST'])) + + def test_POST_servers_action_correctly_ratelimited(self): + middleware = RateLimitingMiddleware(APIStub()) + # Use up all of our "POST" allowance for the minute, 5 times + for i in range(5): + self.exhaust(middleware, 'POST', '/servers/4', 'usr1', 10) + # Reset the 'POST' action counter. + del middleware.limiter._levels['usr1:POST'] + # All 50 daily "POST servers" actions should be all used up + self.exhaust(middleware, 'POST', '/servers/4', 'usr1', 0) + + def test_proxy_ctor_works(self): + middleware = RateLimitingMiddleware(APIStub()) + self.assertEqual(middleware.limiter.__class__.__name__, "Limiter") + middleware = RateLimitingMiddleware(APIStub(), service_host='foobar') + self.assertEqual(middleware.limiter.__class__.__name__, "WSGIAppProxy") + + +class LimiterTest(unittest.TestCase): + + def test_limiter(self): + items = range(2000) + req = Request.blank('/') + self.assertEqual(limited(items, req), items[ :1000]) + req = Request.blank('/?offset=0') + self.assertEqual(limited(items, req), items[ :1000]) + req = Request.blank('/?offset=3') + self.assertEqual(limited(items, req), items[3:1003]) + req = Request.blank('/?offset=2005') + self.assertEqual(limited(items, req), []) + req = Request.blank('/?limit=10') + self.assertEqual(limited(items, req), items[ :10]) + req = Request.blank('/?limit=0') + self.assertEqual(limited(items, req), items[ :1000]) + req = Request.blank('/?limit=3000') + self.assertEqual(limited(items, req), items[ :1000]) + req = Request.blank('/?offset=1&limit=3') + self.assertEqual(limited(items, req), items[1:4]) + req = Request.blank('/?offset=3&limit=0') + self.assertEqual(limited(items, req), items[3:1003]) + req = Request.blank('/?offset=3&limit=1500') + self.assertEqual(limited(items, req), items[3:1003]) + req = Request.blank('/?offset=3000&limit=10') + self.assertEqual(limited(items, req), []) diff --git a/nova/tests/api/openstack/fakes.py b/nova/tests/api/openstack/fakes.py new file mode 100644 index 000000000..1119fa714 --- /dev/null +++ b/nova/tests/api/openstack/fakes.py @@ -0,0 +1,205 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import datetime +import json +import random +import string + +import webob +import webob.dec + +from nova import auth +from nova import utils +from nova import flags +from nova import exception as exc +import nova.api.openstack.auth +from nova.image import service +from nova.wsgi import Router + + +FLAGS = flags.FLAGS + + +class Context(object): + pass + + +class FakeRouter(Router): + def __init__(self): + pass + + @webob.dec.wsgify + def __call__(self, req): + res = webob.Response() + res.status = '200' + res.headers['X-Test-Success'] = 'True' + return res + + +def fake_auth_init(self): + self.db = FakeAuthDatabase() + self.context = Context() + self.auth = FakeAuthManager() + self.host = 'foo' + + +@webob.dec.wsgify +def fake_wsgi(self, req): + req.environ['nova.context'] = dict(user=dict(id=1)) + if req.body: + req.environ['inst_dict'] = json.loads(req.body) + return self.application + + +def stub_out_key_pair_funcs(stubs): + def key_pair(context, user_id): + return [dict(name='key', public_key='public_key')] + stubs.Set(nova.db.api, 'key_pair_get_all_by_user', + key_pair) + + +def stub_out_image_service(stubs): + def fake_image_show(meh, id): + return dict(kernelId=1, ramdiskId=1) + + stubs.Set(nova.image.service.LocalImageService, 'show', fake_image_show) + +def stub_out_auth(stubs): + def fake_auth_init(self, app): + self.application = app + + stubs.Set(nova.api.openstack.AuthMiddleware, + '__init__', fake_auth_init) + stubs.Set(nova.api.openstack.AuthMiddleware, + '__call__', fake_wsgi) + + +def stub_out_rate_limiting(stubs): + def fake_rate_init(self, app): + super(nova.api.openstack.RateLimitingMiddleware, self).__init__(app) + self.application = app + + stubs.Set(nova.api.openstack.RateLimitingMiddleware, + '__init__', fake_rate_init) + + stubs.Set(nova.api.openstack.RateLimitingMiddleware, + '__call__', fake_wsgi) + + +def stub_out_networking(stubs): + def get_my_ip(): + return '127.0.0.1' + stubs.Set(nova.utils, 'get_my_ip', get_my_ip) + FLAGS.FAKE_subdomain = 'rs' + + +def stub_out_glance(stubs): + + class FakeParallaxClient: + + def __init__(self): + self.fixtures = {} + + def fake_get_images(self): + return self.fixtures + + def fake_get_image_metadata(self, image_id): + for k, f in self.fixtures.iteritems(): + if k == image_id: + return f + return None + + def fake_add_image_metadata(self, image_data): + id = ''.join(random.choice(string.letters) for _ in range(20)) + image_data['id'] = id + self.fixtures[id] = image_data + return id + + def fake_update_image_metadata(self, image_id, image_data): + + if image_id not in self.fixtures.keys(): + raise exc.NotFound + + self.fixtures[image_id].update(image_data) + + def fake_delete_image_metadata(self, image_id): + + if image_id not in self.fixtures.keys(): + raise exc.NotFound + + del self.fixtures[image_id] + + def fake_delete_all(self): + self.fixtures = {} + + fake_parallax_client = FakeParallaxClient() + stubs.Set(nova.image.service.ParallaxClient, 'get_images', + fake_parallax_client.fake_get_images) + stubs.Set(nova.image.service.ParallaxClient, 'get_image_metadata', + fake_parallax_client.fake_get_image_metadata) + stubs.Set(nova.image.service.ParallaxClient, 'add_image_metadata', + fake_parallax_client.fake_add_image_metadata) + stubs.Set(nova.image.service.ParallaxClient, 'update_image_metadata', + fake_parallax_client.fake_update_image_metadata) + stubs.Set(nova.image.service.ParallaxClient, 'delete_image_metadata', + fake_parallax_client.fake_delete_image_metadata) + stubs.Set(nova.image.service.GlanceImageService, 'delete_all', + fake_parallax_client.fake_delete_all) + + +class FakeAuthDatabase(object): + data = {} + + @staticmethod + def auth_get_token(context, token_hash): + return FakeAuthDatabase.data.get(token_hash, None) + + @staticmethod + def auth_create_token(context, token): + token['created_at'] = datetime.datetime.now() + FakeAuthDatabase.data[token['token_hash']] = token + + @staticmethod + def auth_destroy_token(context, token): + if FakeAuthDatabase.data.has_key(token['token_hash']): + del FakeAuthDatabase.data['token_hash'] + + +class FakeAuthManager(object): + auth_data = {} + + def add_user(self, key, user): + FakeAuthManager.auth_data[key] = user + + def get_user(self, uid): + for k, v in FakeAuthManager.auth_data.iteritems(): + if v['uid'] == uid: + return v + return None + + def get_user_from_access_key(self, key): + return FakeAuthManager.auth_data.get(key, None) + + +class FakeRateLimiter(object): + def __init__(self, application): + self.application = application + + @webob.dec.wsgify + def __call__(self, req): + return self.application diff --git a/nova/tests/api/openstack/test_auth.py b/nova/tests/api/openstack/test_auth.py new file mode 100644 index 000000000..d2ba80243 --- /dev/null +++ b/nova/tests/api/openstack/test_auth.py @@ -0,0 +1,108 @@ +import datetime +import unittest + +import stubout +import webob +import webob.dec + +import nova.api +import nova.api.openstack.auth +from nova import auth +from nova.tests.api.openstack import fakes + +class Test(unittest.TestCase): + def setUp(self): + self.stubs = stubout.StubOutForTesting() + self.stubs.Set(nova.api.openstack.auth.BasicApiAuthManager, + '__init__', fakes.fake_auth_init) + fakes.FakeAuthManager.auth_data = {} + fakes.FakeAuthDatabase.data = {} + fakes.stub_out_rate_limiting(self.stubs) + fakes.stub_out_networking(self.stubs) + + def tearDown(self): + self.stubs.UnsetAll() + fakes.fake_data_store = {} + + def test_authorize_user(self): + f = fakes.FakeAuthManager() + f.add_user('derp', { 'uid': 1, 'name':'herp' } ) + + req = webob.Request.blank('/v1.0/') + req.headers['X-Auth-User'] = 'herp' + req.headers['X-Auth-Key'] = 'derp' + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '204 No Content') + self.assertEqual(len(result.headers['X-Auth-Token']), 40) + self.assertEqual(result.headers['X-CDN-Management-Url'], + "") + self.assertEqual(result.headers['X-Storage-Url'], "") + + def test_authorize_token(self): + f = fakes.FakeAuthManager() + f.add_user('derp', { 'uid': 1, 'name':'herp' } ) + + req = webob.Request.blank('/v1.0/') + req.headers['X-Auth-User'] = 'herp' + req.headers['X-Auth-Key'] = 'derp' + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '204 No Content') + self.assertEqual(len(result.headers['X-Auth-Token']), 40) + self.assertEqual(result.headers['X-Server-Management-Url'], + "https://foo/v1.0/") + self.assertEqual(result.headers['X-CDN-Management-Url'], + "") + self.assertEqual(result.headers['X-Storage-Url'], "") + + token = result.headers['X-Auth-Token'] + self.stubs.Set(nova.api.openstack, 'APIRouter', + fakes.FakeRouter) + req = webob.Request.blank('/v1.0/fake') + req.headers['X-Auth-Token'] = token + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '200 OK') + self.assertEqual(result.headers['X-Test-Success'], 'True') + + def test_token_expiry(self): + self.destroy_called = False + token_hash = 'bacon' + + def destroy_token_mock(meh, context, token): + self.destroy_called = True + + def bad_token(meh, context, token_hash): + return { 'token_hash':token_hash, + 'created_at':datetime.datetime(1990, 1, 1) } + + self.stubs.Set(fakes.FakeAuthDatabase, 'auth_destroy_token', + destroy_token_mock) + + self.stubs.Set(fakes.FakeAuthDatabase, 'auth_get_token', + bad_token) + + req = webob.Request.blank('/v1.0/') + req.headers['X-Auth-Token'] = 'bacon' + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '401 Unauthorized') + self.assertEqual(self.destroy_called, True) + + def test_bad_user(self): + req = webob.Request.blank('/v1.0/') + req.headers['X-Auth-User'] = 'herp' + req.headers['X-Auth-Key'] = 'derp' + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '401 Unauthorized') + + def test_no_user(self): + req = webob.Request.blank('/v1.0/') + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '401 Unauthorized') + + def test_bad_token(self): + req = webob.Request.blank('/v1.0/') + req.headers['X-Auth-Token'] = 'baconbaconbacon' + result = req.get_response(nova.api.API()) + self.assertEqual(result.status, '401 Unauthorized') + +if __name__ == '__main__': + unittest.main() diff --git a/nova/tests/api/openstack/test_faults.py b/nova/tests/api/openstack/test_faults.py new file mode 100644 index 000000000..70a811469 --- /dev/null +++ b/nova/tests/api/openstack/test_faults.py @@ -0,0 +1,40 @@ +import unittest +import webob +import webob.dec +import webob.exc + +from nova.api.openstack import faults + +class TestFaults(unittest.TestCase): + + def test_fault_parts(self): + req = webob.Request.blank('/.xml') + f = faults.Fault(webob.exc.HTTPBadRequest(explanation='scram')) + resp = req.get_response(f) + + first_two_words = resp.body.strip().split()[:2] + self.assertEqual(first_two_words, ['']) + body_without_spaces = ''.join(resp.body.split()) + self.assertTrue('scram' in body_without_spaces) + + def test_retry_header(self): + req = webob.Request.blank('/.xml') + exc = webob.exc.HTTPRequestEntityTooLarge(explanation='sorry', + headers={'Retry-After': 4}) + f = faults.Fault(exc) + resp = req.get_response(f) + first_two_words = resp.body.strip().split()[:2] + self.assertEqual(first_two_words, ['']) + body_sans_spaces = ''.join(resp.body.split()) + self.assertTrue('sorry' in body_sans_spaces) + self.assertTrue('4' in body_sans_spaces) + self.assertEqual(resp.headers['Retry-After'], 4) + + def test_raise(self): + @webob.dec.wsgify + def raiser(req): + raise faults.Fault(webob.exc.HTTPNotFound(explanation='whut?')) + req = webob.Request.blank('/.xml') + resp = req.get_response(raiser) + self.assertEqual(resp.status_int, 404) + self.assertTrue('whut?' in resp.body) diff --git a/nova/tests/api/openstack/test_flavors.py b/nova/tests/api/openstack/test_flavors.py new file mode 100644 index 000000000..8dd4d1f29 --- /dev/null +++ b/nova/tests/api/openstack/test_flavors.py @@ -0,0 +1,48 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import stubout +import webob + +import nova.api +from nova.api.openstack import flavors +from nova.tests.api.openstack import fakes + + +class FlavorsTest(unittest.TestCase): + def setUp(self): + self.stubs = stubout.StubOutForTesting() + fakes.FakeAuthManager.auth_data = {} + fakes.FakeAuthDatabase.data = {} + fakes.stub_out_networking(self.stubs) + fakes.stub_out_rate_limiting(self.stubs) + fakes.stub_out_auth(self.stubs) + + def tearDown(self): + self.stubs.UnsetAll() + + def test_get_flavor_list(self): + req = webob.Request.blank('/v1.0/flavors') + res = req.get_response(nova.api.API()) + + def test_get_flavor_by_id(self): + pass + +if __name__ == '__main__': + unittest.main() diff --git a/nova/tests/api/openstack/test_images.py b/nova/tests/api/openstack/test_images.py new file mode 100644 index 000000000..505fea3e2 --- /dev/null +++ b/nova/tests/api/openstack/test_images.py @@ -0,0 +1,141 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import logging +import unittest + +import stubout + +from nova import exception +from nova import utils +from nova.api.openstack import images +from nova.tests.api.openstack import fakes + + +class BaseImageServiceTests(): + + """Tasks to test for all image services""" + + def test_create(self): + + fixture = {'name': 'test image', + 'updated': None, + 'created': None, + 'status': None, + 'serverId': None, + 'progress': None} + + num_images = len(self.service.index()) + + id = self.service.create(fixture) + + self.assertNotEquals(None, id) + self.assertEquals(num_images + 1, len(self.service.index())) + + def test_create_and_show_non_existing_image(self): + + fixture = {'name': 'test image', + 'updated': None, + 'created': None, + 'status': None, + 'serverId': None, + 'progress': None} + + num_images = len(self.service.index()) + + id = self.service.create(fixture) + + self.assertNotEquals(None, id) + + self.assertRaises(exception.NotFound, + self.service.show, + 'bad image id') + + def test_update(self): + + fixture = {'name': 'test image', + 'updated': None, + 'created': None, + 'status': None, + 'serverId': None, + 'progress': None} + + id = self.service.create(fixture) + + fixture['status'] = 'in progress' + + self.service.update(id, fixture) + new_image_data = self.service.show(id) + self.assertEquals('in progress', new_image_data['status']) + + def test_delete(self): + + fixtures = [ + {'name': 'test image 1', + 'updated': None, + 'created': None, + 'status': None, + 'serverId': None, + 'progress': None}, + {'name': 'test image 2', + 'updated': None, + 'created': None, + 'status': None, + 'serverId': None, + 'progress': None}] + + ids = [] + for fixture in fixtures: + new_id = self.service.create(fixture) + ids.append(new_id) + + num_images = len(self.service.index()) + self.assertEquals(2, num_images) + + self.service.delete(ids[0]) + + num_images = len(self.service.index()) + self.assertEquals(1, num_images) + + +class LocalImageServiceTest(unittest.TestCase, + BaseImageServiceTests): + + """Tests the local image service""" + + def setUp(self): + self.stubs = stubout.StubOutForTesting() + self.service = utils.import_object('nova.image.service.LocalImageService') + + def tearDown(self): + self.service.delete_all() + self.stubs.UnsetAll() + + +class GlanceImageServiceTest(unittest.TestCase, + BaseImageServiceTests): + + """Tests the local image service""" + + def setUp(self): + self.stubs = stubout.StubOutForTesting() + fakes.stub_out_glance(self.stubs) + self.service = utils.import_object('nova.image.service.GlanceImageService') + + def tearDown(self): + self.service.delete_all() + self.stubs.UnsetAll() diff --git a/nova/tests/api/openstack/test_ratelimiting.py b/nova/tests/api/openstack/test_ratelimiting.py new file mode 100644 index 000000000..ad9e67454 --- /dev/null +++ b/nova/tests/api/openstack/test_ratelimiting.py @@ -0,0 +1,237 @@ +import httplib +import StringIO +import time +import unittest +import webob + +import nova.api.openstack.ratelimiting as ratelimiting + +class LimiterTest(unittest.TestCase): + + def setUp(self): + self.limits = { + 'a': (5, ratelimiting.PER_SECOND), + 'b': (5, ratelimiting.PER_MINUTE), + 'c': (5, ratelimiting.PER_HOUR), + 'd': (1, ratelimiting.PER_SECOND), + 'e': (100, ratelimiting.PER_SECOND)} + self.rl = ratelimiting.Limiter(self.limits) + + def exhaust(self, action, times_until_exhausted, **kwargs): + for i in range(times_until_exhausted): + when = self.rl.perform(action, **kwargs) + self.assertEqual(when, None) + num, period = self.limits[action] + delay = period * 1.0 / num + # Verify that we are now thoroughly delayed + for i in range(10): + when = self.rl.perform(action, **kwargs) + self.assertAlmostEqual(when, delay, 2) + + def test_second(self): + self.exhaust('a', 5) + time.sleep(0.2) + self.exhaust('a', 1) + time.sleep(1) + self.exhaust('a', 5) + + def test_minute(self): + self.exhaust('b', 5) + + def test_one_per_period(self): + def allow_once_and_deny_once(): + when = self.rl.perform('d') + self.assertEqual(when, None) + when = self.rl.perform('d') + self.assertAlmostEqual(when, 1, 2) + return when + time.sleep(allow_once_and_deny_once()) + time.sleep(allow_once_and_deny_once()) + allow_once_and_deny_once() + + def test_we_can_go_indefinitely_if_we_spread_out_requests(self): + for i in range(200): + when = self.rl.perform('e') + self.assertEqual(when, None) + time.sleep(0.01) + + def test_users_get_separate_buckets(self): + self.exhaust('c', 5, username='alice') + self.exhaust('c', 5, username='bob') + self.exhaust('c', 5, username='chuck') + self.exhaust('c', 0, username='chuck') + self.exhaust('c', 0, username='bob') + self.exhaust('c', 0, username='alice') + + +class FakeLimiter(object): + """Fake Limiter class that you can tell how to behave.""" + def __init__(self, test): + self._action = self._username = self._delay = None + self.test = test + def mock(self, action, username, delay): + self._action = action + self._username = username + self._delay = delay + def perform(self, action, username): + self.test.assertEqual(action, self._action) + self.test.assertEqual(username, self._username) + return self._delay + + +class WSGIAppTest(unittest.TestCase): + + def setUp(self): + self.limiter = FakeLimiter(self) + self.app = ratelimiting.WSGIApp(self.limiter) + + def test_invalid_methods(self): + requests = [] + for method in ['GET', 'PUT', 'DELETE']: + req = webob.Request.blank('/limits/michael/breakdance', + dict(REQUEST_METHOD=method)) + requests.append(req) + for req in requests: + self.assertEqual(req.get_response(self.app).status_int, 405) + + def test_invalid_urls(self): + requests = [] + for prefix in ['limit', '', 'limiter2', 'limiter/limits', 'limiter/1']: + req = webob.Request.blank('/%s/michael/breakdance' % prefix, + dict(REQUEST_METHOD='POST')) + requests.append(req) + for req in requests: + self.assertEqual(req.get_response(self.app).status_int, 404) + + def verify(self, url, username, action, delay=None): + """Make sure that POSTing to the given url causes the given username + to perform the given action. Make the internal rate limiter return + delay and make sure that the WSGI app returns the correct response. + """ + req = webob.Request.blank(url, dict(REQUEST_METHOD='POST')) + self.limiter.mock(action, username, delay) + resp = req.get_response(self.app) + if not delay: + self.assertEqual(resp.status_int, 200) + else: + self.assertEqual(resp.status_int, 403) + self.assertEqual(resp.headers['X-Wait-Seconds'], "%.2f" % delay) + + def test_good_urls(self): + self.verify('/limiter/michael/hoot', 'michael', 'hoot') + + def test_escaping(self): + self.verify('/limiter/michael/jump%20up', 'michael', 'jump up') + + def test_response_to_delays(self): + self.verify('/limiter/michael/hoot', 'michael', 'hoot', 1) + self.verify('/limiter/michael/hoot', 'michael', 'hoot', 1.56) + self.verify('/limiter/michael/hoot', 'michael', 'hoot', 1000) + + +class FakeHttplibSocket(object): + """a fake socket implementation for httplib.HTTPResponse, trivial""" + + def __init__(self, response_string): + self._buffer = StringIO.StringIO(response_string) + + def makefile(self, _mode, _other): + """Returns the socket's internal buffer""" + return self._buffer + + +class FakeHttplibConnection(object): + """A fake httplib.HTTPConnection + + Requests made via this connection actually get translated and routed into + our WSGI app, we then wait for the response and turn it back into + an httplib.HTTPResponse. + """ + def __init__(self, app, host, is_secure=False): + self.app = app + self.host = host + + def request(self, method, path, data='', headers={}): + req = webob.Request.blank(path) + req.method = method + req.body = data + req.headers = headers + req.host = self.host + # Call the WSGI app, get the HTTP response + resp = str(req.get_response(self.app)) + # For some reason, the response doesn't have "HTTP/1.0 " prepended; I + # guess that's a function the web server usually provides. + resp = "HTTP/1.0 %s" % resp + sock = FakeHttplibSocket(resp) + self.http_response = httplib.HTTPResponse(sock) + self.http_response.begin() + + def getresponse(self): + return self.http_response + + +def wire_HTTPConnection_to_WSGI(host, app): + """Monkeypatches HTTPConnection so that if you try to connect to host, you + are instead routed straight to the given WSGI app. + + After calling this method, when any code calls + + httplib.HTTPConnection(host) + + the connection object will be a fake. Its requests will be sent directly + to the given WSGI app rather than through a socket. + + Code connecting to hosts other than host will not be affected. + + This method may be called multiple times to map different hosts to + different apps. + """ + class HTTPConnectionDecorator(object): + """Wraps the real HTTPConnection class so that when you instantiate + the class you might instead get a fake instance.""" + def __init__(self, wrapped): + self.wrapped = wrapped + def __call__(self, connection_host, *args, **kwargs): + if connection_host == host: + return FakeHttplibConnection(app, host) + else: + return self.wrapped(connection_host, *args, **kwargs) + httplib.HTTPConnection = HTTPConnectionDecorator(httplib.HTTPConnection) + + +class WSGIAppProxyTest(unittest.TestCase): + + def setUp(self): + """Our WSGIAppProxy is going to call across an HTTPConnection to a + WSGIApp running a limiter. The proxy will send input, and the proxy + should receive that same input, pass it to the limiter who gives a + result, and send the expected result back. + + The HTTPConnection isn't real -- it's monkeypatched to point straight + at the WSGIApp. And the limiter isn't real -- it's a fake that + behaves the way we tell it to. + """ + self.limiter = FakeLimiter(self) + app = ratelimiting.WSGIApp(self.limiter) + wire_HTTPConnection_to_WSGI('100.100.100.100:80', app) + self.proxy = ratelimiting.WSGIAppProxy('100.100.100.100:80') + + def test_200(self): + self.limiter.mock('conquer', 'caesar', None) + when = self.proxy.perform('conquer', 'caesar') + self.assertEqual(when, None) + + def test_403(self): + self.limiter.mock('grumble', 'proletariat', 1.5) + when = self.proxy.perform('grumble', 'proletariat') + self.assertEqual(when, 1.5) + + def test_failure(self): + def shouldRaise(): + self.limiter.mock('murder', 'brutus', None) + self.proxy.perform('stab', 'brutus') + self.assertRaises(AssertionError, shouldRaise) + + +if __name__ == '__main__': + unittest.main() diff --git a/nova/tests/api/openstack/test_servers.py b/nova/tests/api/openstack/test_servers.py new file mode 100644 index 000000000..d1ee533b6 --- /dev/null +++ b/nova/tests/api/openstack/test_servers.py @@ -0,0 +1,249 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import json +import unittest + +import stubout +import webob + +from nova import db +from nova import flags +import nova.api.openstack +from nova.api.openstack import servers +import nova.db.api +from nova.db.sqlalchemy.models import Instance +import nova.rpc +from nova.tests.api.openstack import fakes + + +FLAGS = flags.FLAGS + +FLAGS.verbose = True + +def return_server(context, id): + return stub_instance(id) + + +def return_servers(context, user_id=1): + return [stub_instance(i, user_id) for i in xrange(5)] + + +def stub_instance(id, user_id=1): + return Instance( + id=id, state=0, image_id=10, server_name='server%s'%id, + user_id=user_id + ) + + +class ServersTest(unittest.TestCase): + def setUp(self): + self.stubs = stubout.StubOutForTesting() + fakes.FakeAuthManager.auth_data = {} + fakes.FakeAuthDatabase.data = {} + fakes.stub_out_networking(self.stubs) + fakes.stub_out_rate_limiting(self.stubs) + fakes.stub_out_auth(self.stubs) + fakes.stub_out_key_pair_funcs(self.stubs) + fakes.stub_out_image_service(self.stubs) + self.stubs.Set(nova.db.api, 'instance_get_all', return_servers) + self.stubs.Set(nova.db.api, 'instance_get_by_internal_id', return_server) + self.stubs.Set(nova.db.api, 'instance_get_all_by_user', + return_servers) + + def tearDown(self): + self.stubs.UnsetAll() + + def test_get_server_by_id(self): + req = webob.Request.blank('/v1.0/servers/1') + res = req.get_response(nova.api.API()) + res_dict = json.loads(res.body) + self.assertEqual(res_dict['server']['id'], 1) + self.assertEqual(res_dict['server']['name'], 'server1') + + def test_get_server_list(self): + req = webob.Request.blank('/v1.0/servers') + res = req.get_response(nova.api.API()) + res_dict = json.loads(res.body) + + i = 0 + for s in res_dict['servers']: + self.assertEqual(s['id'], i) + self.assertEqual(s['name'], 'server%d'%i) + self.assertEqual(s.get('imageId', None), None) + i += 1 + + def test_create_instance(self): + def server_update(context, id, params): + pass + + def instance_create(context, inst): + class Foo(object): + internal_id = 1 + return Foo() + + def fake_method(*args, **kwargs): + pass + + def project_get_network(context, user_id): + return dict(id='1', host='localhost') + + def queue_get_for(context, *args): + return 'network_topic' + + self.stubs.Set(nova.db.api, 'project_get_network', project_get_network) + self.stubs.Set(nova.db.api, 'instance_create', instance_create) + self.stubs.Set(nova.rpc, 'cast', fake_method) + self.stubs.Set(nova.rpc, 'call', fake_method) + self.stubs.Set(nova.db.api, 'instance_update', + server_update) + self.stubs.Set(nova.db.api, 'queue_get_for', queue_get_for) + self.stubs.Set(nova.network.manager.VlanManager, 'allocate_fixed_ip', + fake_method) + + body = dict(server=dict( + name='server_test', imageId=2, flavorId=2, metadata={}, + personality = {} + )) + req = webob.Request.blank('/v1.0/servers') + req.method = 'POST' + req.body = json.dumps(body) + + res = req.get_response(nova.api.API()) + + self.assertEqual(res.status_int, 200) + + def test_update_no_body(self): + req = webob.Request.blank('/v1.0/servers/1') + req.method = 'PUT' + res = req.get_response(nova.api.API()) + self.assertEqual(res.status_int, 422) + + def test_update_bad_params(self): + """ Confirm that update is filtering params """ + inst_dict = dict(cat='leopard', name='server_test', adminPass='bacon') + self.body = json.dumps(dict(server=inst_dict)) + + def server_update(context, id, params): + self.update_called = True + filtered_dict = dict(name='server_test', admin_pass='bacon') + self.assertEqual(params, filtered_dict) + + self.stubs.Set(nova.db.api, 'instance_update', + server_update) + + req = webob.Request.blank('/v1.0/servers/1') + req.method = 'PUT' + req.body = self.body + req.get_response(nova.api.API()) + + def test_update_server(self): + inst_dict = dict(name='server_test', adminPass='bacon') + self.body = json.dumps(dict(server=inst_dict)) + + def server_update(context, id, params): + filtered_dict = dict(name='server_test', admin_pass='bacon') + self.assertEqual(params, filtered_dict) + + self.stubs.Set(nova.db.api, 'instance_update', + server_update) + + req = webob.Request.blank('/v1.0/servers/1') + req.method = 'PUT' + req.body = self.body + req.get_response(nova.api.API()) + + def test_create_backup_schedules(self): + req = webob.Request.blank('/v1.0/servers/1/backup_schedules') + req.method = 'POST' + res = req.get_response(nova.api.API()) + self.assertEqual(res.status, '404 Not Found') + + def test_delete_backup_schedules(self): + req = webob.Request.blank('/v1.0/servers/1/backup_schedules') + req.method = 'DELETE' + res = req.get_response(nova.api.API()) + self.assertEqual(res.status, '404 Not Found') + + def test_get_server_backup_schedules(self): + req = webob.Request.blank('/v1.0/servers/1/backup_schedules') + res = req.get_response(nova.api.API()) + self.assertEqual(res.status, '404 Not Found') + + def test_get_all_server_details(self): + req = webob.Request.blank('/v1.0/servers/detail') + res = req.get_response(nova.api.API()) + res_dict = json.loads(res.body) + + i = 0 + for s in res_dict['servers']: + self.assertEqual(s['id'], i) + self.assertEqual(s['name'], 'server%d'%i) + self.assertEqual(s['imageId'], 10) + i += 1 + + def test_server_reboot(self): + body = dict(server=dict( + name='server_test', imageId=2, flavorId=2, metadata={}, + personality = {} + )) + req = webob.Request.blank('/v1.0/servers/1/action') + req.method = 'POST' + req.content_type= 'application/json' + req.body = json.dumps(body) + res = req.get_response(nova.api.API()) + + def test_server_rebuild(self): + body = dict(server=dict( + name='server_test', imageId=2, flavorId=2, metadata={}, + personality = {} + )) + req = webob.Request.blank('/v1.0/servers/1/action') + req.method = 'POST' + req.content_type= 'application/json' + req.body = json.dumps(body) + res = req.get_response(nova.api.API()) + + def test_server_resize(self): + body = dict(server=dict( + name='server_test', imageId=2, flavorId=2, metadata={}, + personality = {} + )) + req = webob.Request.blank('/v1.0/servers/1/action') + req.method = 'POST' + req.content_type= 'application/json' + req.body = json.dumps(body) + res = req.get_response(nova.api.API()) + + def test_delete_server_instance(self): + req = webob.Request.blank('/v1.0/servers/1') + req.method = 'DELETE' + + self.server_delete_called = False + def instance_destroy_mock(context, id): + self.server_delete_called = True + + self.stubs.Set(nova.db.api, 'instance_destroy', + instance_destroy_mock) + + res = req.get_response(nova.api.API()) + self.assertEqual(res.status, '202 Accepted') + self.assertEqual(self.server_delete_called, True) + + +if __name__ == "__main__": + unittest.main() diff --git a/nova/tests/api/openstack/test_sharedipgroups.py b/nova/tests/api/openstack/test_sharedipgroups.py new file mode 100644 index 000000000..d199951d8 --- /dev/null +++ b/nova/tests/api/openstack/test_sharedipgroups.py @@ -0,0 +1,39 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import stubout + +from nova.api.openstack import sharedipgroups + + +class SharedIpGroupsTest(unittest.TestCase): + def setUp(self): + self.stubs = stubout.StubOutForTesting() + + def tearDown(self): + self.stubs.UnsetAll() + + def test_get_shared_ip_groups(self): + pass + + def test_create_shared_ip_group(self): + pass + + def test_delete_shared_ip_group(self): + pass diff --git a/nova/tests/api/rackspace/__init__.py b/nova/tests/api/rackspace/__init__.py deleted file mode 100644 index 1834f91b1..000000000 --- a/nova/tests/api/rackspace/__init__.py +++ /dev/null @@ -1,108 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import unittest - -from nova.api.rackspace import limited -from nova.api.rackspace import RateLimitingMiddleware -from nova.tests.api.fakes import APIStub -from webob import Request - - -class RateLimitingMiddlewareTest(unittest.TestCase): - - def test_get_action_name(self): - middleware = RateLimitingMiddleware(APIStub()) - def verify(method, url, action_name): - req = Request.blank(url) - req.method = method - action = middleware.get_action_name(req) - self.assertEqual(action, action_name) - verify('PUT', '/servers/4', 'PUT') - verify('DELETE', '/servers/4', 'DELETE') - verify('POST', '/images/4', 'POST') - verify('POST', '/servers/4', 'POST servers') - verify('GET', '/foo?a=4&changes-since=never&b=5', 'GET changes-since') - verify('GET', '/foo?a=4&monkeys-since=never&b=5', None) - verify('GET', '/servers/4', None) - verify('HEAD', '/servers/4', None) - - def exhaust(self, middleware, method, url, username, times): - req = Request.blank(url, dict(REQUEST_METHOD=method), - headers={'X-Auth-User': username}) - for i in range(times): - resp = req.get_response(middleware) - self.assertEqual(resp.status_int, 200) - resp = req.get_response(middleware) - self.assertEqual(resp.status_int, 413) - self.assertTrue('Retry-After' in resp.headers) - - def test_single_action(self): - middleware = RateLimitingMiddleware(APIStub()) - self.exhaust(middleware, 'DELETE', '/servers/4', 'usr1', 100) - self.exhaust(middleware, 'DELETE', '/servers/4', 'usr2', 100) - - def test_POST_servers_action_implies_POST_action(self): - middleware = RateLimitingMiddleware(APIStub()) - self.exhaust(middleware, 'POST', '/servers/4', 'usr1', 10) - self.exhaust(middleware, 'POST', '/images/4', 'usr2', 10) - self.assertTrue(set(middleware.limiter._levels) == - set(['usr1:POST', 'usr1:POST servers', 'usr2:POST'])) - - def test_POST_servers_action_correctly_ratelimited(self): - middleware = RateLimitingMiddleware(APIStub()) - # Use up all of our "POST" allowance for the minute, 5 times - for i in range(5): - self.exhaust(middleware, 'POST', '/servers/4', 'usr1', 10) - # Reset the 'POST' action counter. - del middleware.limiter._levels['usr1:POST'] - # All 50 daily "POST servers" actions should be all used up - self.exhaust(middleware, 'POST', '/servers/4', 'usr1', 0) - - def test_proxy_ctor_works(self): - middleware = RateLimitingMiddleware(APIStub()) - self.assertEqual(middleware.limiter.__class__.__name__, "Limiter") - middleware = RateLimitingMiddleware(APIStub(), service_host='foobar') - self.assertEqual(middleware.limiter.__class__.__name__, "WSGIAppProxy") - - -class LimiterTest(unittest.TestCase): - - def test_limiter(self): - items = range(2000) - req = Request.blank('/') - self.assertEqual(limited(items, req), items[ :1000]) - req = Request.blank('/?offset=0') - self.assertEqual(limited(items, req), items[ :1000]) - req = Request.blank('/?offset=3') - self.assertEqual(limited(items, req), items[3:1003]) - req = Request.blank('/?offset=2005') - self.assertEqual(limited(items, req), []) - req = Request.blank('/?limit=10') - self.assertEqual(limited(items, req), items[ :10]) - req = Request.blank('/?limit=0') - self.assertEqual(limited(items, req), items[ :1000]) - req = Request.blank('/?limit=3000') - self.assertEqual(limited(items, req), items[ :1000]) - req = Request.blank('/?offset=1&limit=3') - self.assertEqual(limited(items, req), items[1:4]) - req = Request.blank('/?offset=3&limit=0') - self.assertEqual(limited(items, req), items[3:1003]) - req = Request.blank('/?offset=3&limit=1500') - self.assertEqual(limited(items, req), items[3:1003]) - req = Request.blank('/?offset=3000&limit=10') - self.assertEqual(limited(items, req), []) diff --git a/nova/tests/api/rackspace/fakes.py b/nova/tests/api/rackspace/fakes.py deleted file mode 100644 index 6a25720a9..000000000 --- a/nova/tests/api/rackspace/fakes.py +++ /dev/null @@ -1,205 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import datetime -import json -import random -import string - -import webob -import webob.dec - -from nova import auth -from nova import utils -from nova import flags -from nova import exception as exc -import nova.api.rackspace.auth -from nova.image import service -from nova.wsgi import Router - - -FLAGS = flags.FLAGS - - -class Context(object): - pass - - -class FakeRouter(Router): - def __init__(self): - pass - - @webob.dec.wsgify - def __call__(self, req): - res = webob.Response() - res.status = '200' - res.headers['X-Test-Success'] = 'True' - return res - - -def fake_auth_init(self): - self.db = FakeAuthDatabase() - self.context = Context() - self.auth = FakeAuthManager() - self.host = 'foo' - - -@webob.dec.wsgify -def fake_wsgi(self, req): - req.environ['nova.context'] = dict(user=dict(id=1)) - if req.body: - req.environ['inst_dict'] = json.loads(req.body) - return self.application - - -def stub_out_key_pair_funcs(stubs): - def key_pair(context, user_id): - return [dict(name='key', public_key='public_key')] - stubs.Set(nova.db.api, 'key_pair_get_all_by_user', - key_pair) - - -def stub_out_image_service(stubs): - def fake_image_show(meh, id): - return dict(kernelId=1, ramdiskId=1) - - stubs.Set(nova.image.service.LocalImageService, 'show', fake_image_show) - -def stub_out_auth(stubs): - def fake_auth_init(self, app): - self.application = app - - stubs.Set(nova.api.rackspace.AuthMiddleware, - '__init__', fake_auth_init) - stubs.Set(nova.api.rackspace.AuthMiddleware, - '__call__', fake_wsgi) - - -def stub_out_rate_limiting(stubs): - def fake_rate_init(self, app): - super(nova.api.rackspace.RateLimitingMiddleware, self).__init__(app) - self.application = app - - stubs.Set(nova.api.rackspace.RateLimitingMiddleware, - '__init__', fake_rate_init) - - stubs.Set(nova.api.rackspace.RateLimitingMiddleware, - '__call__', fake_wsgi) - - -def stub_out_networking(stubs): - def get_my_ip(): - return '127.0.0.1' - stubs.Set(nova.utils, 'get_my_ip', get_my_ip) - FLAGS.FAKE_subdomain = 'rs' - - -def stub_out_glance(stubs): - - class FakeParallaxClient: - - def __init__(self): - self.fixtures = {} - - def fake_get_images(self): - return self.fixtures - - def fake_get_image_metadata(self, image_id): - for k, f in self.fixtures.iteritems(): - if k == image_id: - return f - return None - - def fake_add_image_metadata(self, image_data): - id = ''.join(random.choice(string.letters) for _ in range(20)) - image_data['id'] = id - self.fixtures[id] = image_data - return id - - def fake_update_image_metadata(self, image_id, image_data): - - if image_id not in self.fixtures.keys(): - raise exc.NotFound - - self.fixtures[image_id].update(image_data) - - def fake_delete_image_metadata(self, image_id): - - if image_id not in self.fixtures.keys(): - raise exc.NotFound - - del self.fixtures[image_id] - - def fake_delete_all(self): - self.fixtures = {} - - fake_parallax_client = FakeParallaxClient() - stubs.Set(nova.image.service.ParallaxClient, 'get_images', - fake_parallax_client.fake_get_images) - stubs.Set(nova.image.service.ParallaxClient, 'get_image_metadata', - fake_parallax_client.fake_get_image_metadata) - stubs.Set(nova.image.service.ParallaxClient, 'add_image_metadata', - fake_parallax_client.fake_add_image_metadata) - stubs.Set(nova.image.service.ParallaxClient, 'update_image_metadata', - fake_parallax_client.fake_update_image_metadata) - stubs.Set(nova.image.service.ParallaxClient, 'delete_image_metadata', - fake_parallax_client.fake_delete_image_metadata) - stubs.Set(nova.image.service.GlanceImageService, 'delete_all', - fake_parallax_client.fake_delete_all) - - -class FakeAuthDatabase(object): - data = {} - - @staticmethod - def auth_get_token(context, token_hash): - return FakeAuthDatabase.data.get(token_hash, None) - - @staticmethod - def auth_create_token(context, token): - token['created_at'] = datetime.datetime.now() - FakeAuthDatabase.data[token['token_hash']] = token - - @staticmethod - def auth_destroy_token(context, token): - if FakeAuthDatabase.data.has_key(token['token_hash']): - del FakeAuthDatabase.data['token_hash'] - - -class FakeAuthManager(object): - auth_data = {} - - def add_user(self, key, user): - FakeAuthManager.auth_data[key] = user - - def get_user(self, uid): - for k, v in FakeAuthManager.auth_data.iteritems(): - if v['uid'] == uid: - return v - return None - - def get_user_from_access_key(self, key): - return FakeAuthManager.auth_data.get(key, None) - - -class FakeRateLimiter(object): - def __init__(self, application): - self.application = application - - @webob.dec.wsgify - def __call__(self, req): - return self.application diff --git a/nova/tests/api/rackspace/test_auth.py b/nova/tests/api/rackspace/test_auth.py deleted file mode 100644 index 374cfe42b..000000000 --- a/nova/tests/api/rackspace/test_auth.py +++ /dev/null @@ -1,108 +0,0 @@ -import datetime -import unittest - -import stubout -import webob -import webob.dec - -import nova.api -import nova.api.rackspace.auth -from nova import auth -from nova.tests.api.rackspace import fakes - -class Test(unittest.TestCase): - def setUp(self): - self.stubs = stubout.StubOutForTesting() - self.stubs.Set(nova.api.rackspace.auth.BasicApiAuthManager, - '__init__', fakes.fake_auth_init) - fakes.FakeAuthManager.auth_data = {} - fakes.FakeAuthDatabase.data = {} - fakes.stub_out_rate_limiting(self.stubs) - fakes.stub_out_networking(self.stubs) - - def tearDown(self): - self.stubs.UnsetAll() - fakes.fake_data_store = {} - - def test_authorize_user(self): - f = fakes.FakeAuthManager() - f.add_user('derp', { 'uid': 1, 'name':'herp' } ) - - req = webob.Request.blank('/v1.0/') - req.headers['X-Auth-User'] = 'herp' - req.headers['X-Auth-Key'] = 'derp' - result = req.get_response(nova.api.API()) - self.assertEqual(result.status, '204 No Content') - self.assertEqual(len(result.headers['X-Auth-Token']), 40) - self.assertEqual(result.headers['X-CDN-Management-Url'], - "") - self.assertEqual(result.headers['X-Storage-Url'], "") - - def test_authorize_token(self): - f = fakes.FakeAuthManager() - f.add_user('derp', { 'uid': 1, 'name':'herp' } ) - - req = webob.Request.blank('/v1.0/') - req.headers['X-Auth-User'] = 'herp' - req.headers['X-Auth-Key'] = 'derp' - result = req.get_response(nova.api.API()) - self.assertEqual(result.status, '204 No Content') - self.assertEqual(len(result.headers['X-Auth-Token']), 40) - self.assertEqual(result.headers['X-Server-Management-Url'], - "https://foo/v1.0/") - self.assertEqual(result.headers['X-CDN-Management-Url'], - "") - self.assertEqual(result.headers['X-Storage-Url'], "") - - token = result.headers['X-Auth-Token'] - self.stubs.Set(nova.api.rackspace, 'APIRouter', - fakes.FakeRouter) - req = webob.Request.blank('/v1.0/fake') - req.headers['X-Auth-Token'] = token - result = req.get_response(nova.api.API()) - self.assertEqual(result.status, '200 OK') - self.assertEqual(result.headers['X-Test-Success'], 'True') - - def test_token_expiry(self): - self.destroy_called = False - token_hash = 'bacon' - - def destroy_token_mock(meh, context, token): - self.destroy_called = True - - def bad_token(meh, context, token_hash): - return { 'token_hash':token_hash, - 'created_at':datetime.datetime(1990, 1, 1) } - - self.stubs.Set(fakes.FakeAuthDatabase, 'auth_destroy_token', - destroy_token_mock) - - self.stubs.Set(fakes.FakeAuthDatabase, 'auth_get_token', - bad_token) - - req = webob.Request.blank('/v1.0/') - req.headers['X-Auth-Token'] = 'bacon' - result = req.get_response(nova.api.API()) - self.assertEqual(result.status, '401 Unauthorized') - self.assertEqual(self.destroy_called, True) - - def test_bad_user(self): - req = webob.Request.blank('/v1.0/') - req.headers['X-Auth-User'] = 'herp' - req.headers['X-Auth-Key'] = 'derp' - result = req.get_response(nova.api.API()) - self.assertEqual(result.status, '401 Unauthorized') - - def test_no_user(self): - req = webob.Request.blank('/v1.0/') - result = req.get_response(nova.api.API()) - self.assertEqual(result.status, '401 Unauthorized') - - def test_bad_token(self): - req = webob.Request.blank('/v1.0/') - req.headers['X-Auth-Token'] = 'baconbaconbacon' - result = req.get_response(nova.api.API()) - self.assertEqual(result.status, '401 Unauthorized') - -if __name__ == '__main__': - unittest.main() diff --git a/nova/tests/api/rackspace/test_faults.py b/nova/tests/api/rackspace/test_faults.py deleted file mode 100644 index b2931bc98..000000000 --- a/nova/tests/api/rackspace/test_faults.py +++ /dev/null @@ -1,40 +0,0 @@ -import unittest -import webob -import webob.dec -import webob.exc - -from nova.api.rackspace import faults - -class TestFaults(unittest.TestCase): - - def test_fault_parts(self): - req = webob.Request.blank('/.xml') - f = faults.Fault(webob.exc.HTTPBadRequest(explanation='scram')) - resp = req.get_response(f) - - first_two_words = resp.body.strip().split()[:2] - self.assertEqual(first_two_words, ['']) - body_without_spaces = ''.join(resp.body.split()) - self.assertTrue('scram' in body_without_spaces) - - def test_retry_header(self): - req = webob.Request.blank('/.xml') - exc = webob.exc.HTTPRequestEntityTooLarge(explanation='sorry', - headers={'Retry-After': 4}) - f = faults.Fault(exc) - resp = req.get_response(f) - first_two_words = resp.body.strip().split()[:2] - self.assertEqual(first_two_words, ['']) - body_sans_spaces = ''.join(resp.body.split()) - self.assertTrue('sorry' in body_sans_spaces) - self.assertTrue('4' in body_sans_spaces) - self.assertEqual(resp.headers['Retry-After'], 4) - - def test_raise(self): - @webob.dec.wsgify - def raiser(req): - raise faults.Fault(webob.exc.HTTPNotFound(explanation='whut?')) - req = webob.Request.blank('/.xml') - resp = req.get_response(raiser) - self.assertEqual(resp.status_int, 404) - self.assertTrue('whut?' in resp.body) diff --git a/nova/tests/api/rackspace/test_flavors.py b/nova/tests/api/rackspace/test_flavors.py deleted file mode 100644 index affdd2406..000000000 --- a/nova/tests/api/rackspace/test_flavors.py +++ /dev/null @@ -1,48 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import unittest - -import stubout -import webob - -import nova.api -from nova.api.rackspace import flavors -from nova.tests.api.rackspace import fakes - - -class FlavorsTest(unittest.TestCase): - def setUp(self): - self.stubs = stubout.StubOutForTesting() - fakes.FakeAuthManager.auth_data = {} - fakes.FakeAuthDatabase.data = {} - fakes.stub_out_networking(self.stubs) - fakes.stub_out_rate_limiting(self.stubs) - fakes.stub_out_auth(self.stubs) - - def tearDown(self): - self.stubs.UnsetAll() - - def test_get_flavor_list(self): - req = webob.Request.blank('/v1.0/flavors') - res = req.get_response(nova.api.API()) - - def test_get_flavor_by_id(self): - pass - -if __name__ == '__main__': - unittest.main() diff --git a/nova/tests/api/rackspace/test_images.py b/nova/tests/api/rackspace/test_images.py deleted file mode 100644 index a7f320b46..000000000 --- a/nova/tests/api/rackspace/test_images.py +++ /dev/null @@ -1,141 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import logging -import unittest - -import stubout - -from nova import exception -from nova import utils -from nova.api.rackspace import images -from nova.tests.api.rackspace import fakes - - -class BaseImageServiceTests(): - - """Tasks to test for all image services""" - - def test_create(self): - - fixture = {'name': 'test image', - 'updated': None, - 'created': None, - 'status': None, - 'serverId': None, - 'progress': None} - - num_images = len(self.service.index()) - - id = self.service.create(fixture) - - self.assertNotEquals(None, id) - self.assertEquals(num_images + 1, len(self.service.index())) - - def test_create_and_show_non_existing_image(self): - - fixture = {'name': 'test image', - 'updated': None, - 'created': None, - 'status': None, - 'serverId': None, - 'progress': None} - - num_images = len(self.service.index()) - - id = self.service.create(fixture) - - self.assertNotEquals(None, id) - - self.assertRaises(exception.NotFound, - self.service.show, - 'bad image id') - - def test_update(self): - - fixture = {'name': 'test image', - 'updated': None, - 'created': None, - 'status': None, - 'serverId': None, - 'progress': None} - - id = self.service.create(fixture) - - fixture['status'] = 'in progress' - - self.service.update(id, fixture) - new_image_data = self.service.show(id) - self.assertEquals('in progress', new_image_data['status']) - - def test_delete(self): - - fixtures = [ - {'name': 'test image 1', - 'updated': None, - 'created': None, - 'status': None, - 'serverId': None, - 'progress': None}, - {'name': 'test image 2', - 'updated': None, - 'created': None, - 'status': None, - 'serverId': None, - 'progress': None}] - - ids = [] - for fixture in fixtures: - new_id = self.service.create(fixture) - ids.append(new_id) - - num_images = len(self.service.index()) - self.assertEquals(2, num_images) - - self.service.delete(ids[0]) - - num_images = len(self.service.index()) - self.assertEquals(1, num_images) - - -class LocalImageServiceTest(unittest.TestCase, - BaseImageServiceTests): - - """Tests the local image service""" - - def setUp(self): - self.stubs = stubout.StubOutForTesting() - self.service = utils.import_object('nova.image.service.LocalImageService') - - def tearDown(self): - self.service.delete_all() - self.stubs.UnsetAll() - - -class GlanceImageServiceTest(unittest.TestCase, - BaseImageServiceTests): - - """Tests the local image service""" - - def setUp(self): - self.stubs = stubout.StubOutForTesting() - fakes.stub_out_glance(self.stubs) - self.service = utils.import_object('nova.image.service.GlanceImageService') - - def tearDown(self): - self.service.delete_all() - self.stubs.UnsetAll() diff --git a/nova/tests/api/rackspace/test_servers.py b/nova/tests/api/rackspace/test_servers.py deleted file mode 100644 index 57040621b..000000000 --- a/nova/tests/api/rackspace/test_servers.py +++ /dev/null @@ -1,249 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import json -import unittest - -import stubout -import webob - -from nova import db -from nova import flags -import nova.api.rackspace -from nova.api.rackspace import servers -import nova.db.api -from nova.db.sqlalchemy.models import Instance -import nova.rpc -from nova.tests.api.rackspace import fakes - - -FLAGS = flags.FLAGS - -FLAGS.verbose = True - -def return_server(context, id): - return stub_instance(id) - - -def return_servers(context, user_id=1): - return [stub_instance(i, user_id) for i in xrange(5)] - - -def stub_instance(id, user_id=1): - return Instance( - id=id, state=0, image_id=10, server_name='server%s'%id, - user_id=user_id - ) - - -class ServersTest(unittest.TestCase): - def setUp(self): - self.stubs = stubout.StubOutForTesting() - fakes.FakeAuthManager.auth_data = {} - fakes.FakeAuthDatabase.data = {} - fakes.stub_out_networking(self.stubs) - fakes.stub_out_rate_limiting(self.stubs) - fakes.stub_out_auth(self.stubs) - fakes.stub_out_key_pair_funcs(self.stubs) - fakes.stub_out_image_service(self.stubs) - self.stubs.Set(nova.db.api, 'instance_get_all', return_servers) - self.stubs.Set(nova.db.api, 'instance_get_by_internal_id', return_server) - self.stubs.Set(nova.db.api, 'instance_get_all_by_user', - return_servers) - - def tearDown(self): - self.stubs.UnsetAll() - - def test_get_server_by_id(self): - req = webob.Request.blank('/v1.0/servers/1') - res = req.get_response(nova.api.API()) - res_dict = json.loads(res.body) - self.assertEqual(res_dict['server']['id'], 1) - self.assertEqual(res_dict['server']['name'], 'server1') - - def test_get_server_list(self): - req = webob.Request.blank('/v1.0/servers') - res = req.get_response(nova.api.API()) - res_dict = json.loads(res.body) - - i = 0 - for s in res_dict['servers']: - self.assertEqual(s['id'], i) - self.assertEqual(s['name'], 'server%d'%i) - self.assertEqual(s.get('imageId', None), None) - i += 1 - - def test_create_instance(self): - def server_update(context, id, params): - pass - - def instance_create(context, inst): - class Foo(object): - internal_id = 1 - return Foo() - - def fake_method(*args, **kwargs): - pass - - def project_get_network(context, user_id): - return dict(id='1', host='localhost') - - def queue_get_for(context, *args): - return 'network_topic' - - self.stubs.Set(nova.db.api, 'project_get_network', project_get_network) - self.stubs.Set(nova.db.api, 'instance_create', instance_create) - self.stubs.Set(nova.rpc, 'cast', fake_method) - self.stubs.Set(nova.rpc, 'call', fake_method) - self.stubs.Set(nova.db.api, 'instance_update', - server_update) - self.stubs.Set(nova.db.api, 'queue_get_for', queue_get_for) - self.stubs.Set(nova.network.manager.VlanManager, 'allocate_fixed_ip', - fake_method) - - body = dict(server=dict( - name='server_test', imageId=2, flavorId=2, metadata={}, - personality = {} - )) - req = webob.Request.blank('/v1.0/servers') - req.method = 'POST' - req.body = json.dumps(body) - - res = req.get_response(nova.api.API()) - - self.assertEqual(res.status_int, 200) - - def test_update_no_body(self): - req = webob.Request.blank('/v1.0/servers/1') - req.method = 'PUT' - res = req.get_response(nova.api.API()) - self.assertEqual(res.status_int, 422) - - def test_update_bad_params(self): - """ Confirm that update is filtering params """ - inst_dict = dict(cat='leopard', name='server_test', adminPass='bacon') - self.body = json.dumps(dict(server=inst_dict)) - - def server_update(context, id, params): - self.update_called = True - filtered_dict = dict(name='server_test', admin_pass='bacon') - self.assertEqual(params, filtered_dict) - - self.stubs.Set(nova.db.api, 'instance_update', - server_update) - - req = webob.Request.blank('/v1.0/servers/1') - req.method = 'PUT' - req.body = self.body - req.get_response(nova.api.API()) - - def test_update_server(self): - inst_dict = dict(name='server_test', adminPass='bacon') - self.body = json.dumps(dict(server=inst_dict)) - - def server_update(context, id, params): - filtered_dict = dict(name='server_test', admin_pass='bacon') - self.assertEqual(params, filtered_dict) - - self.stubs.Set(nova.db.api, 'instance_update', - server_update) - - req = webob.Request.blank('/v1.0/servers/1') - req.method = 'PUT' - req.body = self.body - req.get_response(nova.api.API()) - - def test_create_backup_schedules(self): - req = webob.Request.blank('/v1.0/servers/1/backup_schedules') - req.method = 'POST' - res = req.get_response(nova.api.API()) - self.assertEqual(res.status, '404 Not Found') - - def test_delete_backup_schedules(self): - req = webob.Request.blank('/v1.0/servers/1/backup_schedules') - req.method = 'DELETE' - res = req.get_response(nova.api.API()) - self.assertEqual(res.status, '404 Not Found') - - def test_get_server_backup_schedules(self): - req = webob.Request.blank('/v1.0/servers/1/backup_schedules') - res = req.get_response(nova.api.API()) - self.assertEqual(res.status, '404 Not Found') - - def test_get_all_server_details(self): - req = webob.Request.blank('/v1.0/servers/detail') - res = req.get_response(nova.api.API()) - res_dict = json.loads(res.body) - - i = 0 - for s in res_dict['servers']: - self.assertEqual(s['id'], i) - self.assertEqual(s['name'], 'server%d'%i) - self.assertEqual(s['imageId'], 10) - i += 1 - - def test_server_reboot(self): - body = dict(server=dict( - name='server_test', imageId=2, flavorId=2, metadata={}, - personality = {} - )) - req = webob.Request.blank('/v1.0/servers/1/action') - req.method = 'POST' - req.content_type= 'application/json' - req.body = json.dumps(body) - res = req.get_response(nova.api.API()) - - def test_server_rebuild(self): - body = dict(server=dict( - name='server_test', imageId=2, flavorId=2, metadata={}, - personality = {} - )) - req = webob.Request.blank('/v1.0/servers/1/action') - req.method = 'POST' - req.content_type= 'application/json' - req.body = json.dumps(body) - res = req.get_response(nova.api.API()) - - def test_server_resize(self): - body = dict(server=dict( - name='server_test', imageId=2, flavorId=2, metadata={}, - personality = {} - )) - req = webob.Request.blank('/v1.0/servers/1/action') - req.method = 'POST' - req.content_type= 'application/json' - req.body = json.dumps(body) - res = req.get_response(nova.api.API()) - - def test_delete_server_instance(self): - req = webob.Request.blank('/v1.0/servers/1') - req.method = 'DELETE' - - self.server_delete_called = False - def instance_destroy_mock(context, id): - self.server_delete_called = True - - self.stubs.Set(nova.db.api, 'instance_destroy', - instance_destroy_mock) - - res = req.get_response(nova.api.API()) - self.assertEqual(res.status, '202 Accepted') - self.assertEqual(self.server_delete_called, True) - - -if __name__ == "__main__": - unittest.main() diff --git a/nova/tests/api/rackspace/test_sharedipgroups.py b/nova/tests/api/rackspace/test_sharedipgroups.py deleted file mode 100644 index 31ce967d0..000000000 --- a/nova/tests/api/rackspace/test_sharedipgroups.py +++ /dev/null @@ -1,39 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2010 OpenStack LLC. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import unittest - -import stubout - -from nova.api.rackspace import sharedipgroups - - -class SharedIpGroupsTest(unittest.TestCase): - def setUp(self): - self.stubs = stubout.StubOutForTesting() - - def tearDown(self): - self.stubs.UnsetAll() - - def test_get_shared_ip_groups(self): - pass - - def test_create_shared_ip_group(self): - pass - - def test_delete_shared_ip_group(self): - pass -- cgit From c1190d55e130a80ac831ce15e6e30c28c5621aff Mon Sep 17 00:00:00 2001 From: mdietz Date: Fri, 8 Oct 2010 21:08:48 +0000 Subject: That's what I get for not using a good vimrc --- nova/api/openstack/servers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nova/api/openstack/servers.py b/nova/api/openstack/servers.py index f234af7de..5d1ed9822 100644 --- a/nova/api/openstack/servers.py +++ b/nova/api/openstack/servers.py @@ -170,14 +170,14 @@ class Controller(wsgi.Controller): def action(self, req, id): """ multi-purpose method used to reboot, rebuild, and resize a server """ - user_id = req.environ['nova.context']['user']['id'] + user_id = req.environ['nova.context']['user']['id'] input_dict = self._deserialize(req.body, req) try: reboot_type = input_dict['reboot']['type'] except Exception: raise faults.Fault(webob.exc.HTTPNotImplemented()) - inst_ref = self.db.instance_get_by_internal_id(None, int(id)) - if not inst_ref or (inst_ref and not inst_ref.user_id == user_id): + inst_ref = self.db.instance_get_by_internal_id(None, int(id)) + if not inst_ref or (inst_ref and not inst_ref.user_id == user_id): return faults.Fault(exc.HTTPUnprocessableEntity()) cloud.reboot(id) -- cgit From f447e1a3a2234e0ab3a5e281442659626f8d99bd Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 11 Oct 2010 13:39:33 +0200 Subject: Rename ec2 get_console_output's instance ID argument to 'instance_id'. It's passed as a kwarg, based on key in the http query, so it must be named this way. --- nova/api/ec2/cloud.py | 6 +++--- nova/api/ec2/images.py | 3 +++ nova/fakerabbit.py | 14 ++++++++++++++ nova/rpc.py | 9 +++++++++ nova/tests/cloud_unittest.py | 31 ++++++++++++++++++++----------- 5 files changed, 49 insertions(+), 14 deletions(-) diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 175bb493c..11e54d2b5 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -258,9 +258,9 @@ class CloudController(object): def delete_security_group(self, context, group_name, **kwargs): return True - def get_console_output(self, context, ec2_id_list, **kwargs): - # ec2_id_list is passed in as a list of instances - ec2_id = ec2_id_list[0] + def get_console_output(self, context, instance_id, **kwargs): + # instance_id is passed in as a list of instances + ec2_id = instance_id[0] internal_id = ec2_id_to_internal_id(ec2_id) instance_ref = db.instance_get_by_internal_id(context, internal_id) return rpc.call('%s.%s' % (FLAGS.compute_topic, diff --git a/nova/api/ec2/images.py b/nova/api/ec2/images.py index cb54cdda2..f0a43dad6 100644 --- a/nova/api/ec2/images.py +++ b/nova/api/ec2/images.py @@ -69,6 +69,9 @@ def list(context, filter_list=[]): optionally filtered by a list of image_id """ + if FLAGS.connection_type == 'fake': + return [{ 'imageId' : 'bar'}] + # FIXME: send along the list of only_images to check for response = conn(context).make_request( method='GET', diff --git a/nova/fakerabbit.py b/nova/fakerabbit.py index 068025249..835973810 100644 --- a/nova/fakerabbit.py +++ b/nova/fakerabbit.py @@ -22,6 +22,7 @@ import logging import Queue as queue from carrot.backends import base +from eventlet import greenthread class Message(base.BaseMessage): @@ -38,6 +39,7 @@ class Exchange(object): def publish(self, message, routing_key=None): logging.debug('(%s) publish (key: %s) %s', self.name, routing_key, message) + routing_key = routing_key.split('.')[0] if routing_key in self._routes: for f in self._routes[routing_key]: logging.debug('Publishing to route %s', f) @@ -94,6 +96,18 @@ class Backend(object): self._exchanges[exchange].bind(self._queues[queue].push, routing_key) + def declare_consumer(self, queue, callback, *args, **kwargs): + self.current_queue = queue + self.current_callback = callback + + def consume(self, *args, **kwargs): + while True: + item = self.get(self.current_queue) + if item: + self.current_callback(item) + raise StopIteration() + greenthread.sleep(0) + def get(self, queue, no_ack=False): if not queue in self._queues or not self._queues[queue].size(): return None diff --git a/nova/rpc.py b/nova/rpc.py index fe52ad35f..447ad3b93 100644 --- a/nova/rpc.py +++ b/nova/rpc.py @@ -28,6 +28,7 @@ import uuid from carrot import connection as carrot_connection from carrot import messaging +from eventlet import greenthread from twisted.internet import defer from twisted.internet import task @@ -107,6 +108,14 @@ class Consumer(messaging.Consumer): logging.exception("Failed to fetch message from queue") self.failed_connection = True + def attach_to_eventlet(self): + """Only needed for unit tests!""" + def fetch_repeatedly(): + while True: + self.fetch(enable_callbacks=True) + greenthread.sleep(0.1) + greenthread.spawn(fetch_repeatedly) + def attach_to_twisted(self): """Attach a callback to twisted that fires 10 times a second""" loop = task.LoopingCall(self.fetch, enable_callbacks=True) diff --git a/nova/tests/cloud_unittest.py b/nova/tests/cloud_unittest.py index 615e589cf..8e5881edb 100644 --- a/nova/tests/cloud_unittest.py +++ b/nova/tests/cloud_unittest.py @@ -16,6 +16,7 @@ # License for the specific language governing permissions and limitations # under the License. +from base64 import b64decode import json import logging from M2Crypto import BIO @@ -63,11 +64,17 @@ class CloudTestCase(test.TrialTestCase): self.cloud = cloud.CloudController() # set up a service - self.compute = utils.import_class(FLAGS.compute_manager) + self.compute = utils.import_class(FLAGS.compute_manager)() self.compute_consumer = rpc.AdapterConsumer(connection=self.conn, topic=FLAGS.compute_topic, proxy=self.compute) - self.compute_consumer.attach_to_twisted() + self.compute_consumer.attach_to_eventlet() + self.network = utils.import_class(FLAGS.network_manager)() + self.network_consumer = rpc.AdapterConsumer(connection=self.conn, + topic=FLAGS.network_topic, + proxy=self.network) + self.network_consumer.attach_to_eventlet() + self.manager = manager.AuthManager() self.user = self.manager.create_user('admin', 'admin', 'admin', True) @@ -85,15 +92,17 @@ class CloudTestCase(test.TrialTestCase): return cloud._gen_key(self.context, self.context.user.id, name) def test_console_output(self): - if FLAGS.connection_type == 'fake': - logging.debug("Can't test instances without a real virtual env.") - return - instance_id = 'foo' - inst = yield self.compute.run_instance(instance_id) - output = yield self.cloud.get_console_output(self.context, [instance_id]) - logging.debug(output) - self.assert_(output) - rv = yield self.compute.terminate_instance(instance_id) + image_id = FLAGS.default_image + instance_type = FLAGS.default_instance_type + max_count = 1 + kwargs = {'image_id': image_id, + 'instance_type': instance_type, + 'max_count': max_count } + rv = yield self.cloud.run_instances(self.context, **kwargs) + instance_id = rv['instancesSet'][0]['instanceId'] + output = yield self.cloud.get_console_output(context=self.context, instance_id=[instance_id]) + self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE OUTPUT') + rv = yield self.cloud.terminate_instances(self.context, [instance_id]) def test_key_generation(self): -- cgit From 3ca549942e96e4ff769e914f227919f3a4d98686 Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Mon, 11 Oct 2010 14:09:24 +0200 Subject: If machine manifest includes a kernel and/or ramdisk id, include it in the image's metadata. --- nova/objectstore/image.py | 10 ++++++-- nova/tests/bundle/1mb.manifest.xml | 2 +- .../bundle/1mb.no_kernel_or_ramdisk.manifest.xml | 1 + nova/tests/objectstore_unittest.py | 29 ++++++++++++++++++---- 4 files changed, 34 insertions(+), 8 deletions(-) create mode 100644 nova/tests/bundle/1mb.no_kernel_or_ramdisk.manifest.xml diff --git a/nova/objectstore/image.py b/nova/objectstore/image.py index def1b8167..c01b041bb 100644 --- a/nova/objectstore/image.py +++ b/nova/objectstore/image.py @@ -191,14 +191,14 @@ class Image(object): if kernel_id == 'true': image_type = 'kernel' except: - pass + kernel_id = None try: ramdisk_id = manifest.find("machine_configuration/ramdisk_id").text if ramdisk_id == 'true': image_type = 'ramdisk' except: - pass + ramdisk_id = None info = { 'imageId': image_id, @@ -209,6 +209,12 @@ class Image(object): 'imageType' : image_type } + if kernel_id: + info['kernelId'] = kernel_id + + if ramdisk_id: + info['ramdiskId'] = ramdisk_id + def write_state(state): info['imageState'] = state with open(os.path.join(image_path, 'info.json'), "w") as f: diff --git a/nova/tests/bundle/1mb.manifest.xml b/nova/tests/bundle/1mb.manifest.xml index dc3315957..01648a544 100644 --- a/nova/tests/bundle/1mb.manifest.xml +++ b/nova/tests/bundle/1mb.manifest.xml @@ -1 +1 @@ -2007-10-10euca-tools1.231337x86_641mb42machineda39a3ee5e6b4b0d3255bfef95601890afd807091048576113633a2ea00dc64083dd9a10eb5e233635b42a7beb1670ab75452087d9de74c60aba1cd27c136fda56f62beb581de128fb1f10d072b9e556fd25e903107a57827c21f6ee8a93a4ff55b11311fcef217e3eefb07e81f71e88216f43b4b54029c1f2549f2925a839a73947d2d5aeecec4a62ece4af9156d557ae907978298296d99154c11147fd8caf92447e90ce339928933d7579244c2f8ffb07cc0ea35f8738da8b90eff6c7a49671a84500e993e9462e4c36d5c19c0b3a2b397d035b4c0cce742b58e12552175d81d129b0425e9f71ebacb9aeb539fa9dd2ac36749fb82876f6902e5fb24b6ec19f35ec4c20acd50437fd30966e99c4d9a0647577970a8fa302314bd082c9715f071160c69bbfb070f51d2ba1076775f1d988ccde150e515088156b248e4b5a64e46c4fe064feeeedfe14511f7fde478a51acb89f9b2f6c84b60593e5c3f792ba6b01fed9bf2158fdac03086374883b39d13a3ca74497eeaaf579fc3f26effc73bfd9446a2a8c4061f0874bfaca058905180e22d3d8881551cb38f7606f19f00e4e19535dd234b66b31b77e9c7bad3885d9c9efa75c863631fd4f82a009e17d789066d9cc6032a436f05384832f6d9a3283d3e63eab04fa0da5c8c87db9b17e854e842c3fb416507d067a266b44538125ce732e486098e8ebd1ca91fa3079f007fce7d14957a9b7e57282407ead3c6eb68fe975df3d83190021b1mb.part.0c4413423cf7a57e71187e19bfd5cd4b514a642831mb.part.19d4262e6589393d09a11a0332af169887bc2e57d4e00b5ba28114dda4a9df7eeae94be847ec46117a09a1cbe41e578660642f0660dda1776b39fb3bf826b6cfec019e2a5e9c566728d186b7400ebc989a30670eb1db26ce01e68bd9d3f31290370077a85b81c66b63c1e0d5499bac115c06c17a21a81b6d3a67ebbce6c17019095af7ab07f3796c708cc843e58efc12ddc788c5e \ No newline at end of file +2007-10-10euca-tools1.231337x86_64aki-testari-test1mb42machineda39a3ee5e6b4b0d3255bfef95601890afd807091048576113633a2ea00dc64083dd9a10eb5e233635b42a7beb1670ab75452087d9de74c60aba1cd27c136fda56f62beb581de128fb1f10d072b9e556fd25e903107a57827c21f6ee8a93a4ff55b11311fcef217e3eefb07e81f71e88216f43b4b54029c1f2549f2925a839a73947d2d5aeecec4a62ece4af9156d557ae907978298296d99154c11147fd8caf92447e90ce339928933d7579244c2f8ffb07cc0ea35f8738da8b90eff6c7a49671a84500e993e9462e4c36d5c19c0b3a2b397d035b4c0cce742b58e12552175d81d129b0425e9f71ebacb9aeb539fa9dd2ac36749fb82876f6902e5fb24b6ec19f35ec4c20acd50437fd30966e99c4d9a0647577970a8fa302314bd082c9715f071160c69bbfb070f51d2ba1076775f1d988ccde150e515088156b248e4b5a64e46c4fe064feeeedfe14511f7fde478a51acb89f9b2f6c84b60593e5c3f792ba6b01fed9bf2158fdac03086374883b39d13a3ca74497eeaaf579fc3f26effc73bfd9446a2a8c4061f0874bfaca058905180e22d3d8881551cb38f7606f19f00e4e19535dd234b66b31b77e9c7bad3885d9c9efa75c863631fd4f82a009e17d789066d9cc6032a436f05384832f6d9a3283d3e63eab04fa0da5c8c87db9b17e854e842c3fb416507d067a266b44538125ce732e486098e8ebd1ca91fa3079f007fce7d14957a9b7e57282407ead3c6eb68fe975df3d83190021b1mb.part.0c4413423cf7a57e71187e19bfd5cd4b514a642831mb.part.19d4262e6589393d09a11a0332af169887bc2e57d4e00b5ba28114dda4a9df7eeae94be847ec46117a09a1cbe41e578660642f0660dda1776b39fb3bf826b6cfec019e2a5e9c566728d186b7400ebc989a30670eb1db26ce01e68bd9d3f31290370077a85b81c66b63c1e0d5499bac115c06c17a21a81b6d3a67ebbce6c17019095af7ab07f3796c708cc843e58efc12ddc788c5e diff --git a/nova/tests/bundle/1mb.no_kernel_or_ramdisk.manifest.xml b/nova/tests/bundle/1mb.no_kernel_or_ramdisk.manifest.xml new file mode 100644 index 000000000..73d7ace00 --- /dev/null +++ b/nova/tests/bundle/1mb.no_kernel_or_ramdisk.manifest.xml @@ -0,0 +1 @@ +2007-10-10euca-tools1.231337x86_641mb42machineda39a3ee5e6b4b0d3255bfef95601890afd807091048576113633a2ea00dc64083dd9a10eb5e233635b42a7beb1670ab75452087d9de74c60aba1cd27c136fda56f62beb581de128fb1f10d072b9e556fd25e903107a57827c21f6ee8a93a4ff55b11311fcef217e3eefb07e81f71e88216f43b4b54029c1f2549f2925a839a73947d2d5aeecec4a62ece4af9156d557ae907978298296d99154c11147fd8caf92447e90ce339928933d7579244c2f8ffb07cc0ea35f8738da8b90eff6c7a49671a84500e993e9462e4c36d5c19c0b3a2b397d035b4c0cce742b58e12552175d81d129b0425e9f71ebacb9aeb539fa9dd2ac36749fb82876f6902e5fb24b6ec19f35ec4c20acd50437fd30966e99c4d9a0647577970a8fa302314bd082c9715f071160c69bbfb070f51d2ba1076775f1d988ccde150e515088156b248e4b5a64e46c4fe064feeeedfe14511f7fde478a51acb89f9b2f6c84b60593e5c3f792ba6b01fed9bf2158fdac03086374883b39d13a3ca74497eeaaf579fc3f26effc73bfd9446a2a8c4061f0874bfaca058905180e22d3d8881551cb38f7606f19f00e4e19535dd234b66b31b77e9c7bad3885d9c9efa75c863631fd4f82a009e17d789066d9cc6032a436f05384832f6d9a3283d3e63eab04fa0da5c8c87db9b17e854e842c3fb416507d067a266b44538125ce732e486098e8ebd1ca91fa3079f007fce7d14957a9b7e57282407ead3c6eb68fe975df3d83190021b1mb.part.0c4413423cf7a57e71187e19bfd5cd4b514a642831mb.part.19d4262e6589393d09a11a0332af169887bc2e57d4e00b5ba28114dda4a9df7eeae94be847ec46117a09a1cbe41e578660642f0660dda1776b39fb3bf826b6cfec019e2a5e9c566728d186b7400ebc989a30670eb1db26ce01e68bd9d3f31290370077a85b81c66b63c1e0d5499bac115c06c17a21a81b6d3a67ebbce6c17019095af7ab07f3796c708cc843e58efc12ddc788c5e diff --git a/nova/tests/objectstore_unittest.py b/nova/tests/objectstore_unittest.py index 5a599ff3a..eb2ee0406 100644 --- a/nova/tests/objectstore_unittest.py +++ b/nova/tests/objectstore_unittest.py @@ -133,13 +133,22 @@ class ObjectStoreTestCase(test.TrialTestCase): self.assertRaises(NotFound, objectstore.bucket.Bucket, 'new_bucket') def test_images(self): + self.do_test_images('1mb.manifest.xml', True, + 'image_bucket1', 'i-testing1') + + def test_images_no_kernel_or_ramdisk(self): + self.do_test_images('1mb.no_kernel_or_ramdisk.manifest.xml', + False, 'image_bucket2', 'i-testing2') + + def do_test_images(self, manifest_file, expect_kernel_and_ramdisk, + image_bucket, image_name): "Test the image API." self.context.user = self.auth_manager.get_user('user1') self.context.project = self.auth_manager.get_project('proj1') # create a bucket for our bundle - objectstore.bucket.Bucket.create('image_bucket', self.context) - bucket = objectstore.bucket.Bucket('image_bucket') + objectstore.bucket.Bucket.create(image_bucket, self.context) + bucket = objectstore.bucket.Bucket(image_bucket) # upload an image manifest/parts bundle_path = os.path.join(os.path.dirname(__file__), 'bundle') @@ -147,18 +156,28 @@ class ObjectStoreTestCase(test.TrialTestCase): bucket[os.path.basename(path)] = open(path, 'rb').read() # register an image - image.Image.register_aws_image('i-testing', - 'image_bucket/1mb.manifest.xml', + image.Image.register_aws_image(image_name, + '%s/%s' % (image_bucket, manifest_file), self.context) # verify image - my_img = image.Image('i-testing') + my_img = image.Image(image_name) result_image_file = os.path.join(my_img.path, 'image') self.assertEqual(os.stat(result_image_file).st_size, 1048576) sha = hashlib.sha1(open(result_image_file).read()).hexdigest() self.assertEqual(sha, '3b71f43ff30f4b15b5cd85dd9e95ebc7e84eb5a3') + if expect_kernel_and_ramdisk: + # Verify the default kernel and ramdisk are set + self.assertEqual(my_img.metadata['kernelId'], 'aki-test') + self.assertEqual(my_img.metadata['ramdiskId'], 'ari-test') + else: + # Verify that the default kernel and ramdisk (the one from FLAGS) + # doesn't get embedded in the metadata + self.assertFalse('kernelId' in my_img.metadata) + self.assertFalse('ramdiskId' in my_img.metadata) + # verify image permissions self.context.user = self.auth_manager.get_user('user2') self.context.project = self.auth_manager.get_project('proj2') -- cgit From 76a76244ccee2502903a67f3f17dda97664e6687 Mon Sep 17 00:00:00 2001 From: Michael Gundlach Date: Mon, 11 Oct 2010 11:21:26 -0400 Subject: Fix bug 658444 --- nova/api/ec2/cloud.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 175bb493c..ca0f64e99 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -659,7 +659,12 @@ class CloudController(object): return self._format_run_instances(context, reservation_id) - def terminate_instances(self, context, ec2_id_list, **kwargs): + def terminate_instances(self, context, instance_id, **kwargs): + """Terminate each instance in instance_id, which is a list of ec2 ids. + + instance_id is a kwarg so its name cannot be modified. + """ + ec2_id_list = instance_id logging.debug("Going to start terminating instances") for id_str in ec2_id_list: internal_id = ec2_id_to_internal_id(id_str) -- cgit From f9b2f70f22bdc8a9cf08ada5f7ec45eea6060866 Mon Sep 17 00:00:00 2001 From: Michael Gundlach Date: Mon, 11 Oct 2010 11:43:58 -0400 Subject: Rename rsapi to osapi, and make the default subdomain for OpenStack API calls be 'api' instead of 'rs'. --- nova/api/__init__.py | 26 +++++++++++++------------- nova/tests/api/__init__.py | 4 ++-- nova/tests/api/openstack/fakes.py | 2 +- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/nova/api/__init__.py b/nova/api/__init__.py index 627883018..8ec7094d7 100644 --- a/nova/api/__init__.py +++ b/nova/api/__init__.py @@ -31,12 +31,12 @@ from nova.api import openstack from nova.api.ec2 import metadatarequesthandler -flags.DEFINE_string('rsapi_subdomain', 'rs', - 'subdomain running the RS API') +flags.DEFINE_string('osapi_subdomain', 'api', + 'subdomain running the OpenStack API') flags.DEFINE_string('ec2api_subdomain', 'ec2', 'subdomain running the EC2 API') flags.DEFINE_string('FAKE_subdomain', None, - 'set to rs or ec2 to fake the subdomain of the host for testing') + 'set to api or ec2 to fake the subdomain of the host for testing') FLAGS = flags.FLAGS @@ -44,21 +44,21 @@ class API(wsgi.Router): """Routes top-level requests to the appropriate controller.""" def __init__(self): - rsdomain = {'sub_domain': [FLAGS.rsapi_subdomain]} + osapidomain = {'sub_domain': [FLAGS.osapi_subdomain]} ec2domain = {'sub_domain': [FLAGS.ec2api_subdomain]} - # If someone wants to pretend they're hitting the RS subdomain - # on their local box, they can set FAKE_subdomain to 'rs', which - # removes subdomain restrictions from the RS routes below. - if FLAGS.FAKE_subdomain == 'rs': - rsdomain = {} + # If someone wants to pretend they're hitting the OSAPI subdomain + # on their local box, they can set FAKE_subdomain to 'api', which + # removes subdomain restrictions from the OpenStack API routes below. + if FLAGS.FAKE_subdomain == 'api': + osapidomain = {} elif FLAGS.FAKE_subdomain == 'ec2': ec2domain = {} mapper = routes.Mapper() mapper.sub_domains = True - mapper.connect("/", controller=self.rsapi_versions, - conditions=rsdomain) + mapper.connect("/", controller=self.osapi_versions, + conditions=osapidomain) mapper.connect("/v1.0/{path_info:.*}", controller=openstack.API(), - conditions=rsdomain) + conditions=osapidomain) mapper.connect("/", controller=self.ec2api_versions, conditions=ec2domain) @@ -81,7 +81,7 @@ class API(wsgi.Router): super(API, self).__init__(mapper) @webob.dec.wsgify - def rsapi_versions(self, req): + def osapi_versions(self, req): """Respond to a request for all OpenStack API versions.""" response = { "versions": [ diff --git a/nova/tests/api/__init__.py b/nova/tests/api/__init__.py index 2c7f7fd3e..f051e2390 100644 --- a/nova/tests/api/__init__.py +++ b/nova/tests/api/__init__.py @@ -46,7 +46,7 @@ class Test(unittest.TestCase): def test_openstack(self): self.stubs.Set(api.openstack, 'API', APIStub) - result = self._request('/v1.0/cloud', 'rs') + result = self._request('/v1.0/cloud', 'api') self.assertEqual(result.body, "/cloud") def test_ec2(self): @@ -61,7 +61,7 @@ class Test(unittest.TestCase): self.assertNotEqual(result.body, "/cloud") def test_query_api_versions(self): - result = self._request('/', 'rs') + result = self._request('/', 'api') self.assertTrue('CURRENT' in result.body) def test_metadata(self): diff --git a/nova/tests/api/openstack/fakes.py b/nova/tests/api/openstack/fakes.py index 1119fa714..34bc1f2a9 100644 --- a/nova/tests/api/openstack/fakes.py +++ b/nova/tests/api/openstack/fakes.py @@ -105,7 +105,7 @@ def stub_out_networking(stubs): def get_my_ip(): return '127.0.0.1' stubs.Set(nova.utils, 'get_my_ip', get_my_ip) - FLAGS.FAKE_subdomain = 'rs' + FLAGS.FAKE_subdomain = 'api' def stub_out_glance(stubs): -- cgit From da7fa3f388a45b3afca16dba6a59b68ea8804f7a Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 12 Oct 2010 09:24:33 +0200 Subject: APIRequestContext.admin is no more.. --- nova/api/ec2/cloud.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nova/api/ec2/cloud.py b/nova/api/ec2/cloud.py index 555518448..7839dc92c 100644 --- a/nova/api/ec2/cloud.py +++ b/nova/api/ec2/cloud.py @@ -841,7 +841,7 @@ class CloudController(object): inst_id = instance_ref['id'] for security_group_id in security_groups: - db.instance_add_security_group(context.admin(), inst_id, + db.instance_add_security_group(context, inst_id, security_group_id) inst = {} -- cgit From ac1dfd25c4b356c1725339709e535d4147feda3c Mon Sep 17 00:00:00 2001 From: Soren Hansen Date: Tue, 12 Oct 2010 14:29:57 +0200 Subject: Remove spurious project_id addition to KeyPair model. --- nova/db/sqlalchemy/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 584214deb..85b7c0aae 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -351,7 +351,6 @@ class KeyPair(BASE, NovaBase): name = Column(String(255)) user_id = Column(String(255)) - project_id = Column(String(255)) fingerprint = Column(String(255)) public_key = Column(Text) -- cgit From 3894e22d517447fb3d5e9c367ffd2e67162f4b0f Mon Sep 17 00:00:00 2001 From: Michael Gundlach Date: Tue, 12 Oct 2010 13:09:35 -0400 Subject: Fix bug 659330 --- nova/db/sqlalchemy/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index ebcb73413..fc99a535d 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -169,7 +169,7 @@ class Instance(BASE, NovaBase): @property def name(self): - return self.internal_id + return "instance-%d" % self.internal_id image_id = Column(String(255)) kernel_id = Column(String(255)) -- cgit From 32ea289d13a7ec9d273a57d2bf30484b80bfebec Mon Sep 17 00:00:00 2001 From: Michael Gundlach Date: Tue, 12 Oct 2010 13:42:43 -0400 Subject: Now that the ec2 id is not the same as the name of the instance, don't compare internal_id [nee ec2_id] to instance names provided by the virtualization driver. Compare names directly instead. --- nova/compute/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nova/compute/manager.py b/nova/compute/manager.py index 131fac406..99705d3a9 100644 --- a/nova/compute/manager.py +++ b/nova/compute/manager.py @@ -67,7 +67,7 @@ class ComputeManager(manager.Manager): def run_instance(self, context, instance_id, **_kwargs): """Launch a new instance with specified options.""" instance_ref = self.db.instance_get(context, instance_id) - if instance_ref['internal_id'] in self.driver.list_instances(): + if instance_ref['name'] in self.driver.list_instances(): raise exception.Error("Instance has already been created") logging.debug("instance %s: starting...", instance_id) project_id = instance_ref['project_id'] -- cgit From aa92c017ab91d7fb0ec9c2cd5fd420e625ce2dbd Mon Sep 17 00:00:00 2001 From: Michael Gundlach Date: Tue, 12 Oct 2010 18:27:59 -0400 Subject: Revert 64 bit storage and use 32 bit again. I didn't notice that we verify that randomly created uids don't already exist in the DB, so the chance of collision isn't really an issue until we get to tens of thousands of machines. Even then we should only expect a few retries before finding a free ID. --- nova/db/sqlalchemy/models.py | 4 ++-- nova/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 9809eb7a7..fc99a535d 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -25,7 +25,7 @@ import datetime # TODO(vish): clean up these imports from sqlalchemy.orm import relationship, backref, exc, object_mapper -from sqlalchemy import Column, PickleType, Integer, String +from sqlalchemy import Column, Integer, String from sqlalchemy import ForeignKey, DateTime, Boolean, Text from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.declarative import declarative_base @@ -152,7 +152,7 @@ class Instance(BASE, NovaBase): __tablename__ = 'instances' __prefix__ = 'i' id = Column(Integer, primary_key=True) - internal_id = Column(PickleType(mutable=False), unique=True) + internal_id = Column(Integer, unique=True) admin_pass = Column(String(255)) diff --git a/nova/utils.py b/nova/utils.py index 12afd388f..10b27ffec 100644 --- a/nova/utils.py +++ b/nova/utils.py @@ -128,7 +128,7 @@ def runthis(prompt, cmd, check_exit_code = True): def generate_uid(topic, size=8): if topic == "i": # Instances have integer internal ids. - return random.randint(0, 2**64-1) + return random.randint(0, 2**32-1) else: characters = '01234567890abcdefghijklmnopqrstuvwxyz' choices = [random.choice(characters) for x in xrange(size)] -- cgit