From 202415829afcbcb48827e15db725050d19eb1b84 Mon Sep 17 00:00:00 2001 From: Rick Harris Date: Thu, 6 Jun 2013 21:48:21 +0000 Subject: More KeypairAPI cleanups * DRY up quota checking for new keys * DRY up formatting of get_key/get_all_keys results * Add tests for get all keypairs Change-Id: I14a97c0f3cb3aa9b827d14002c21076a41942023 --- nova/compute/api.py | 45 ++++++++++++++----------------------- nova/tests/compute/test_keypairs.py | 15 ++++++++++--- 2 files changed, 29 insertions(+), 31 deletions(-) diff --git a/nova/compute/api.py b/nova/compute/api.py index 4e49b5186..c8df86f6d 100644 --- a/nova/compute/api.py +++ b/nova/compute/api.py @@ -2752,13 +2752,11 @@ class AggregateAPI(base.Base): class KeypairAPI(base.Base): - """Sub-set of the Compute Manager API for managing key pairs.""" - def __init__(self, **kwargs): - super(KeypairAPI, self).__init__(**kwargs) + """Subset of the Compute Manager API for managing key pairs.""" - def _validate_keypair_name(self, context, user_id, key_name): - safechars = "_- " + string.digits + string.ascii_letters - clean_value = "".join(x for x in key_name if x in safechars) + def _validate_new_key_pair(self, context, user_id, key_name): + safe_chars = "_- " + string.digits + string.ascii_letters + clean_value = "".join(x for x in key_name if x in safe_chars) if clean_value != key_name: raise exception.InvalidKeypair( _("Keypair name contains unsafe characters")) @@ -2767,16 +2765,16 @@ class KeypairAPI(base.Base): raise exception.InvalidKeypair( _('Keypair name must be between 1 and 255 characters long')) - def import_key_pair(self, context, user_id, key_name, public_key): - """Import a key pair using an existing public key.""" - self._validate_keypair_name(context, user_id, key_name) - count = QUOTAS.count(context, 'key_pairs', user_id) try: QUOTAS.limit_check(context, key_pairs=count + 1) except exception.OverQuota: raise exception.KeypairLimitExceeded() + def import_key_pair(self, context, user_id, key_name, public_key): + """Import a key pair using an existing public key.""" + self._validate_new_key_pair(context, user_id, key_name) + fingerprint = crypto.generate_fingerprint(public_key) keypair = {'user_id': user_id, @@ -2789,13 +2787,7 @@ class KeypairAPI(base.Base): def create_key_pair(self, context, user_id, key_name): """Create a new key pair.""" - self._validate_keypair_name(context, user_id, key_name) - - count = QUOTAS.count(context, 'key_pairs', user_id) - try: - QUOTAS.limit_check(context, key_pairs=count + 1) - except exception.OverQuota: - raise exception.KeypairLimitExceeded() + self._validate_new_key_pair(context, user_id, key_name) private_key, public_key, fingerprint = crypto.generate_key_pair() @@ -2804,6 +2796,7 @@ class KeypairAPI(base.Base): 'fingerprint': fingerprint, 'public_key': public_key, 'private_key': private_key} + self.db.key_pair_create(context, keypair) return keypair @@ -2811,24 +2804,20 @@ class KeypairAPI(base.Base): """Delete a keypair by name.""" self.db.key_pair_destroy(context, user_id, key_name) + def _get_key_pair(self, key_pair): + return {'name': key_pair['name'], + 'public_key': key_pair['public_key'], + 'fingerprint': key_pair['fingerprint']} + def get_key_pairs(self, context, user_id): """List key pairs.""" key_pairs = self.db.key_pair_get_all_by_user(context, user_id) - rval = [] - for key_pair in key_pairs: - rval.append({ - 'name': key_pair['name'], - 'public_key': key_pair['public_key'], - 'fingerprint': key_pair['fingerprint'], - }) - return rval + return [self._get_key_pair(k) for k in key_pairs] def get_key_pair(self, context, user_id, key_name): """Get a keypair by name.""" key_pair = self.db.key_pair_get(context, user_id, key_name) - return {'name': key_pair['name'], - 'public_key': key_pair['public_key'], - 'fingerprint': key_pair['fingerprint']} + return self._get_key_pair(key_pair) class SecurityGroupAPI(base.Base, security_group_base.SecurityGroupBase): diff --git a/nova/tests/compute/test_keypairs.py b/nova/tests/compute/test_keypairs.py index fcb21b3e6..f82d69ccb 100644 --- a/nova/tests/compute/test_keypairs.py +++ b/nova/tests/compute/test_keypairs.py @@ -49,10 +49,12 @@ class KeypairAPITestCase(test_compute.BaseTestCase): def _keypair_db_call_stubs(self): - def db_key_pair_get_all_by_user(self, user_id): - return [] + def db_key_pair_get_all_by_user(context, user_id): + return [{'name': self.existing_key_name, + 'public_key': self.pub_key, + 'fingerprint': self.fingerprint}] - def db_key_pair_create(self, keypair): + def db_key_pair_create(context, keypair): pass def db_key_pair_destroy(context, user_id, name): @@ -163,3 +165,10 @@ class GetKeypairTestCase(KeypairAPITestCase): self.ctxt.user_id, self.existing_key_name) self.assertEqual(self.existing_key_name, keypair['name']) + + +class GetKeypairsTestCase(KeypairAPITestCase): + def test_success(self): + keypairs = self.keypair_api.get_key_pairs(self.ctxt, self.ctxt.user_id) + self.assertEqual([self.existing_key_name], + [k['name'] for k in keypairs]) -- cgit