diff options
| author | Mark McLoughlin <markmc@redhat.com> | 2011-10-14 12:25:30 +0100 |
|---|---|---|
| committer | Mark McLoughlin <markmc@redhat.com> | 2011-10-19 07:36:02 +0100 |
| commit | 5ee3e31eb189b7bc46bb009b99c12b8e58417a0d (patch) | |
| tree | 8cf2ca5204e9799777330a47b2023894dc86a901 | |
| parent | 2431b7848d633dc67ad684b4d1cc79468df24568 (diff) | |
| download | nova-5ee3e31eb189b7bc46bb009b99c12b8e58417a0d.tar.gz nova-5ee3e31eb189b7bc46bb009b99c12b8e58417a0d.tar.xz nova-5ee3e31eb189b7bc46bb009b99c12b8e58417a0d.zip | |
Start switching from gflags to optparse
Re-write the nova.flags module to use optparse instead of gflags.
This provides an easier path to switching completely to optparse.
Next steps are to:
1) Gradually switch each of the individual flags to optparse
2) Re-use config code from other projects via openstack-common
optparse was chosen instead of argparse purely because that's what
the other projects use and that's what makes most sense for
openstack-common. Switching to argparse is something that can be
done later in openstack-common.
Change-Id: Ia49d42b4c7cc208fba140db6b8fd8f33c0f89e04
| -rw-r--r-- | nova/flags.py | 337 | ||||
| -rw-r--r-- | nova/tests/test_flags.py | 90 |
2 files changed, 269 insertions, 158 deletions
diff --git a/nova/flags.py b/nova/flags.py index 58c37f939..a70a361a8 100644 --- a/nova/flags.py +++ b/nova/flags.py @@ -3,6 +3,7 @@ # Copyright 2010 United States Government as represented by the # Administrator of the National Aeronautics and Space Administration. # All Rights Reserved. +# Copyright 2011 Red Hat, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain @@ -18,14 +19,14 @@ """Command-line flag library. -Wraps gflags. +Emulates gflags by wrapping optparse. -Package-level global flags are defined here, the rest are defined -where they're used. +The idea is to move to optparse eventually, and this wrapper is a +stepping stone. """ -import getopt +import optparse import os import socket import string @@ -34,120 +35,165 @@ import sys import gflags -class FlagValues(gflags.FlagValues): - """Extension of gflags.FlagValues that allows undefined and runtime flags. +class FlagValues(object): + class Flag: + def __init__(self, name, value, update_default=None): + self.name = name + self.value = value + self._update_default = update_default - Unknown flags will be ignored when parsing the command line, but the - command line will be kept so that it can be replayed if new flags are - defined after the initial parsing. + def SetDefault(self, default): + if self._update_default: + self._update_default(self.name, default) - """ - - def __init__(self, extra_context=None): - gflags.FlagValues.__init__(self) - self.__dict__['__dirty'] = [] - self.__dict__['__was_already_parsed'] = False - self.__dict__['__stored_argv'] = [] - self.__dict__['__extra_context'] = extra_context - - def __call__(self, argv): - # We're doing some hacky stuff here so that we don't have to copy - # out all the code of the original verbatim and then tweak a few lines. - # We're hijacking the output of getopt so we can still return the - # leftover args at the end - sneaky_unparsed_args = {"value": None} - original_argv = list(argv) - - if self.IsGnuGetOpt(): - orig_getopt = getattr(getopt, 'gnu_getopt') - orig_name = 'gnu_getopt' - else: - orig_getopt = getattr(getopt, 'getopt') - orig_name = 'getopt' + class ErrorCatcher: + def __init__(self, orig_error): + self.orig_error = orig_error + self.reset() - def _sneaky(*args, **kw): - optlist, unparsed_args = orig_getopt(*args, **kw) - sneaky_unparsed_args['value'] = unparsed_args - return optlist, unparsed_args + def reset(self): + self._error_msg = None - try: - setattr(getopt, orig_name, _sneaky) - args = gflags.FlagValues.__call__(self, argv) - except gflags.UnrecognizedFlagError: - # Undefined args were found, for now we don't care so just - # act like everything went well - # (these three lines are copied pretty much verbatim from the end - # of the __call__ function we are wrapping) - unparsed_args = sneaky_unparsed_args['value'] - if unparsed_args: - if self.IsGnuGetOpt(): - args = argv[:1] + unparsed_args - else: - args = argv[:1] + original_argv[-len(unparsed_args):] + def catch(self, msg): + if ": --" in msg: + self._error_msg = msg else: - args = argv[:1] - finally: - setattr(getopt, orig_name, orig_getopt) + self.orig_error(msg) - # Store the arguments for later, we'll need them for new flags - # added at runtime - self.__dict__['__stored_argv'] = original_argv - self.__dict__['__was_already_parsed'] = True - self.ClearDirty() - return args + def get_unknown_arg(self, args): + if not self._error_msg: + return None + # Error message is e.g. "no such option: --runtime_answer" + a = self._error_msg[self._error_msg.rindex(": --") + 2:] + return filter(lambda i: i == a or i.startswith(a + "="), args)[0] - def Reset(self): - gflags.FlagValues.Reset(self) - self.__dict__['__dirty'] = [] - self.__dict__['__was_already_parsed'] = False - self.__dict__['__stored_argv'] = [] + def __init__(self, extra_context=None): + self._parser = optparse.OptionParser() + self._parser.disable_interspersed_args() + self._extra_context = extra_context + self.Reset() - def SetDirty(self, name): - """Mark a flag as dirty so that accessing it will case a reparse.""" - self.__dict__['__dirty'].append(name) + def _parse(self): + if not self._values is None: + return - def IsDirty(self, name): - return name in self.__dict__['__dirty'] + args = gflags.FlagValues().ReadFlagsFromFiles(self._args) - def ClearDirty(self): - self.__dict__['__dirty'] = [] + values = extra = None - def WasAlreadyParsed(self): - return self.__dict__['__was_already_parsed'] + # + # This horrendous hack allows us to stop optparse + # exiting when it encounters an unknown option + # + error_catcher = self.ErrorCatcher(self._parser.error) + self._parser.error = error_catcher.catch + try: + while True: + error_catcher.reset() - def ParseNewFlags(self): - if '__stored_argv' not in self.__dict__: - return - new_flags = FlagValues(self) - for k in self.FlagDict().iterkeys(): - new_flags[k] = gflags.FlagValues.__getitem__(self, k) + (values, extra) = self._parser.parse_args(args) + + unknown = error_catcher.get_unknown_arg(args) + if not unknown: + break - new_flags.Reset() - new_flags(self.__dict__['__stored_argv']) - for k in new_flags.FlagDict().iterkeys(): - setattr(self, k, getattr(new_flags, k)) - self.ClearDirty() + args.remove(unknown) + finally: + self._parser.error = error_catcher.orig_error - def __setitem__(self, name, flag): - gflags.FlagValues.__setitem__(self, name, flag) - if self.WasAlreadyParsed(): - self.SetDirty(name) + (self._values, self._extra) = (values, extra) - def __getitem__(self, name): - if self.IsDirty(name): - self.ParseNewFlags() - return gflags.FlagValues.__getitem__(self, name) + def __call__(self, argv): + self._args = argv[1:] + self._values = None + self._parse() + return [argv[0]] + self._extra def __getattr__(self, name): - if self.IsDirty(name): - self.ParseNewFlags() - val = gflags.FlagValues.__getattr__(self, name) + self._parse() + val = getattr(self._values, name) if type(val) is str: tmpl = string.Template(val) - context = [self, self.__dict__['__extra_context']] + context = [self, self._extra_context] return tmpl.substitute(StrWrapper(context)) return val + def get(self, name, default): + value = getattr(self, name) + if value is not None: # value might be '0' or "" + return value + else: + return default + + def __contains__(self, name): + self._parse() + return hasattr(self._values, name) + + def _update_default(self, name, default): + self._parser.set_default(name, default) + self._values = None + + def __iter__(self): + return self.FlagValuesDict().iterkeys() + + def __getitem__(self, name): + self._parse() + if not self.__contains__(name): + return None + return self.Flag(name, getattr(self, name), self._update_default) + + def Reset(self): + self._args = [] + self._values = None + self._extra = None + + def ParseNewFlags(self): + pass + + def FlagValuesDict(self): + ret = {} + for opt in self._parser.option_list: + if opt.dest: + ret[opt.dest] = getattr(self, opt.dest) + return ret + + def _add_option(self, name, default, help, prefix='--', **kwargs): + prefixed_name = prefix + name + for opt in self._parser.option_list: + if prefixed_name == opt.get_opt_string(): + return + self._parser.add_option(prefixed_name, dest=name, + default=default, help=help, **kwargs) + self._values = None + + def define_string(self, name, default, help): + self._add_option(name, default, help) + + def define_integer(self, name, default, help): + self._add_option(name, default, help, type='int') + + def define_float(self, name, default, help): + self._add_option(name, default, help, type='float') + + def define_bool(self, name, default, help): + # + # FIXME(markmc): this doesn't support --boolflag=true/false/t/f/1/0 + # + self._add_option(name, default, help, action='store_true') + self._add_option(name, default, help, + prefix="--no", action='store_false') + + def define_list(self, name, default, help): + def parse_list(option, opt, value, parser): + setattr(self._parser.values, name, value.split(',')) + self._add_option(name, default, help, type='string', + action='callback', callback=parse_list) + + def define_multistring(self, name, default, help): + self._add_option(name, default, help, action='append') + +FLAGS = FlagValues() + class StrWrapper(object): """Wrapper around FlagValues objects. @@ -167,85 +213,60 @@ class StrWrapper(object): raise KeyError(name) -# Copied from gflags with small mods to get the naming correct. -# Originally gflags checks for the first module that is not gflags that is -# in the call chain, we want to check for the first module that is not gflags -# and not this module. -def _GetCallingModule(): - """Returns the name of the module that's calling into this module. +def DEFINE_string(name, default, help, flag_values=FLAGS): + flag_values.define_string(name, default, help) - We generally use this function to get the name of the module calling a - DEFINE_foo... function. - """ - # Walk down the stack to find the first globals dict that's not ours. - for depth in range(1, sys.getrecursionlimit()): - if not sys._getframe(depth).f_globals is globals(): - module_name = __GetModuleName(sys._getframe(depth).f_globals) - if module_name == 'gflags': - continue - if module_name is not None: - return module_name - raise AssertionError("No module was found") +def DEFINE_integer(name, default, help, lower_bound=None, flag_values=FLAGS): + # FIXME(markmc): ignoring lower_bound + flag_values.define_integer(name, default, help) -# Copied from gflags because it is a private function -def __GetModuleName(globals_dict): - """Given a globals dict, returns the name of the module that defines it. +def DEFINE_bool(name, default, help, flag_values=FLAGS): + flag_values.define_bool(name, default, help) - Args: - globals_dict: A dictionary that should correspond to an environment - providing the values of the globals. - Returns: - A string (the name of the module) or None (if the module could not - be identified. +def DEFINE_boolean(name, default, help, flag_values=FLAGS): + DEFINE_bool(name, default, help, flag_values) - """ - for name, module in sys.modules.iteritems(): - if getattr(module, '__dict__', None) is globals_dict: - if name == '__main__': - return sys.argv[0] - return name - return None +def DEFINE_list(name, default, help, flag_values=FLAGS): + flag_values.define_list(name, default, help) -def _wrapper(func): - def _wrapped(*args, **kw): - kw.setdefault('flag_values', FLAGS) - func(*args, **kw) - _wrapped.func_name = func.func_name - return _wrapped +def DEFINE_float(name, default, help, flag_values=FLAGS): + flag_values.define_float(name, default, help) -FLAGS = FlagValues() -gflags.FLAGS = FLAGS -gflags._GetCallingModule = _GetCallingModule - - -DEFINE = _wrapper(gflags.DEFINE) -DEFINE_string = _wrapper(gflags.DEFINE_string) -DEFINE_integer = _wrapper(gflags.DEFINE_integer) -DEFINE_bool = _wrapper(gflags.DEFINE_bool) -DEFINE_boolean = _wrapper(gflags.DEFINE_boolean) -DEFINE_float = _wrapper(gflags.DEFINE_float) -DEFINE_enum = _wrapper(gflags.DEFINE_enum) -DEFINE_list = _wrapper(gflags.DEFINE_list) -DEFINE_spaceseplist = _wrapper(gflags.DEFINE_spaceseplist) -DEFINE_multistring = _wrapper(gflags.DEFINE_multistring) -DEFINE_multi_int = _wrapper(gflags.DEFINE_multi_int) -DEFINE_flag = _wrapper(gflags.DEFINE_flag) -HelpFlag = gflags.HelpFlag -HelpshortFlag = gflags.HelpshortFlag -HelpXMLFlag = gflags.HelpXMLFlag + +def DEFINE_multistring(name, default, help, flag_values=FLAGS): + flag_values.define_multistring(name, default, help) + + +class UnrecognizedFlag(Exception): + pass def DECLARE(name, module_string, flag_values=FLAGS): if module_string not in sys.modules: __import__(module_string, globals(), locals()) if name not in flag_values: - raise gflags.UnrecognizedFlag( - "%s not defined by %s" % (name, module_string)) + raise UnrecognizedFlag('%s not defined by %s' % (name, module_string)) + + +def DEFINE_flag(flag): + pass + + +class HelpFlag: + pass + + +class HelpshortFlag: + pass + + +class HelpXMLFlag: + pass def _get_my_ip(): diff --git a/nova/tests/test_flags.py b/nova/tests/test_flags.py index 05319d91f..dab11c5e0 100644 --- a/nova/tests/test_flags.py +++ b/nova/tests/test_flags.py @@ -3,6 +3,7 @@ # Copyright 2010 United States Government as represented by the # Administrator of the National Aeronautics and Space Administration. # All Rights Reserved. +# Copyright 2011 Red Hat, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain @@ -16,6 +17,10 @@ # License for the specific language governing permissions and limitations # under the License. +import exceptions +import os +import tempfile + from nova import exception from nova import flags from nova import test @@ -64,6 +69,37 @@ class FlagsTestCase(test.TestCase): self.assertEqual(self.FLAGS.false, True) self.assertEqual(self.FLAGS.true, False) + def test_define_float(self): + flags.DEFINE_float('float', 6.66, 'desc', flag_values=self.FLAGS) + self.assertEqual(self.FLAGS.float, 6.66) + + def test_define_multistring(self): + flags.DEFINE_multistring('multi', [], 'desc', flag_values=self.FLAGS) + + argv = ['flags_test', '--multi', 'foo', '--multi', 'bar'] + self.FLAGS(argv) + + self.assertEqual(self.FLAGS.multi, ['foo', 'bar']) + + def test_define_list(self): + flags.DEFINE_list('list', ['foo'], 'desc', flag_values=self.FLAGS) + + self.assert_(self.FLAGS['list']) + self.assertEqual(self.FLAGS.list, ['foo']) + + argv = ['flags_test', '--list=a,b,c,d'] + self.FLAGS(argv) + + self.assertEqual(self.FLAGS.list, ['a', 'b', 'c', 'd']) + + def test_error(self): + flags.DEFINE_integer('error', 1, 'desc', flag_values=self.FLAGS) + + self.assertEqual(self.FLAGS.error, 1) + + argv = ['flags_test', '--error=foo'] + self.assertRaises(exceptions.SystemExit, self.FLAGS, argv) + def test_declare(self): self.assert_('answer' not in self.global_FLAGS) flags.DECLARE('answer', 'nova.tests.declare_flags') @@ -76,6 +112,14 @@ class FlagsTestCase(test.TestCase): flags.DECLARE('answer', 'nova.tests.declare_flags') self.assertEqual(self.global_FLAGS.answer, 256) + def test_getopt_non_interspersed_args(self): + self.assert_('runtime_answer' not in self.global_FLAGS) + + argv = ['flags_test', 'extra_arg', '--runtime_answer=60'] + args = self.global_FLAGS(argv) + self.assertEqual(len(args), 3) + self.assertEqual(argv, args) + def test_runtime_and_unknown_flags(self): self.assert_('runtime_answer' not in self.global_FLAGS) @@ -114,3 +158,49 @@ class FlagsTestCase(test.TestCase): self.assertEqual(FLAGS.flags_unittest, 'foo') FLAGS.flags_unittest = 'bar' self.assertEqual(FLAGS.flags_unittest, 'bar') + + def test_flag_overrides(self): + self.assertEqual(FLAGS.flags_unittest, 'foo') + self.flags(flags_unittest='bar') + self.assertEqual(FLAGS.flags_unittest, 'bar') + self.assertEqual(FLAGS['flags_unittest'].value, 'bar') + self.assertEqual(FLAGS.FlagValuesDict()['flags_unittest'], 'bar') + self.reset_flags() + self.assertEqual(FLAGS.flags_unittest, 'foo') + self.assertEqual(FLAGS['flags_unittest'].value, 'foo') + self.assertEqual(FLAGS.FlagValuesDict()['flags_unittest'], 'foo') + + def test_flagfile(self): + flags.DEFINE_string('string', 'default', 'desc', + flag_values=self.FLAGS) + flags.DEFINE_integer('int', 1, 'desc', flag_values=self.FLAGS) + flags.DEFINE_bool('false', False, 'desc', flag_values=self.FLAGS) + flags.DEFINE_bool('true', True, 'desc', flag_values=self.FLAGS) + + (fd, path) = tempfile.mkstemp(prefix='nova', suffix='.flags') + + try: + os.write(fd, '--string=foo\n--int=2\n--false\n--notrue\n') + os.close(fd) + + self.FLAGS(['flags_test', '--flagfile=' + path]) + + self.assertEqual(self.FLAGS.string, 'foo') + self.assertEqual(self.FLAGS.int, 2) + self.assertEqual(self.FLAGS.false, True) + self.assertEqual(self.FLAGS.true, False) + finally: + os.remove(path) + + def test_defaults(self): + flags.DEFINE_string('foo', 'bar', 'help', flag_values=self.FLAGS) + self.assertEqual(self.FLAGS.foo, 'bar') + + self.FLAGS['foo'].SetDefault('blaa') + self.assertEqual(self.FLAGS.foo, 'blaa') + + def test_templated_values(self): + flags.DEFINE_string('foo', 'foo', 'help', flag_values=self.FLAGS) + flags.DEFINE_string('bar', 'bar', 'help', flag_values=self.FLAGS) + flags.DEFINE_string('blaa', '$foo$bar', 'help', flag_values=self.FLAGS) + self.assertEqual(self.FLAGS.blaa, 'foobar') |
