diff options
-rw-r--r-- | ipalib/__init__.py | 3 | ||||
-rw-r--r-- | ipalib/backend.py | 1 | ||||
-rw-r--r-- | ipalib/cli.py | 24 | ||||
-rw-r--r-- | tests/test_cmdline/test_cli.py | 184 |
4 files changed, 207 insertions, 5 deletions
diff --git a/ipalib/__init__.py b/ipalib/__init__.py index 1efeeab4a..dd861a826 100644 --- a/ipalib/__init__.py +++ b/ipalib/__init__.py @@ -916,5 +916,8 @@ def create_api(mode='dummy'): api = create_api(mode=None) if os.environ.get('IPA_UNIT_TEST_MODE', None) == 'cli_test': + from cli import cli_plugins + for klass in cli_plugins: + api.register(klass) api.bootstrap(context='cli', in_server=False, in_tree=True) api.finalize() diff --git a/ipalib/backend.py b/ipalib/backend.py index 0232fa536..7be38ecc8 100644 --- a/ipalib/backend.py +++ b/ipalib/backend.py @@ -102,7 +102,6 @@ class Connectible(Backend): class Executioner(Backend): - def create_context(self, ccache=None, client_ip=None): """ client_ip: The IP address of the remote client. diff --git a/ipalib/cli.py b/ipalib/cli.py index f72cca587..8279345a9 100644 --- a/ipalib/cli.py +++ b/ipalib/cli.py @@ -123,7 +123,7 @@ class textui(backend.Backend): def __get_encoding(self, stream): assert stream in (sys.stdin, sys.stdout) - if stream.encoding is None: + if getattr(stream, 'encoding', None) is None: return 'UTF-8' return stream.encoding @@ -1007,7 +1007,11 @@ class cli(backend.Executioner): Backend plugin for executing from command line interface. """ - def run(self, argv): + def get_command(self, argv): + """Given CLI arguments, return the Command to use + + On incorrect invocation, prints out a help message and returns None + """ if len(argv) == 0: self.Command.help() return @@ -1022,15 +1026,27 @@ class cli(backend.Executioner): if name not in self.Command or self.Command[name].NO_CLI: raise CommandError(name=key) cmd = self.Command[name] - if not isinstance(cmd, frontend.Local): - self.create_context() + return cmd + + def argv_to_keyword_arguments(self, cmd, argv): + """Get the keyword arguments for a Command""" kw = self.parse(cmd, argv) if self.env.interactive: self.prompt_interactively(cmd, kw) kw = cmd.split_csv(**kw) kw['version'] = API_VERSION self.load_files(cmd, kw) + return kw + + def run(self, argv): + cmd = self.get_command(argv) + if cmd is None: + return + name = cmd.name + if not isinstance(cmd, frontend.Local): + self.create_context() try: + kw = self.argv_to_keyword_arguments(cmd, argv[1:]) result = self.execute(name, **kw) if callable(cmd.output_for_cli): for param in cmd.params(): diff --git a/tests/test_cmdline/test_cli.py b/tests/test_cmdline/test_cli.py new file mode 100644 index 000000000..889aae413 --- /dev/null +++ b/tests/test_cmdline/test_cli.py @@ -0,0 +1,184 @@ +import shlex +import sys +import contextlib +import StringIO + +import nose + +from tests import util +from ipalib import api, errors +from ipapython.version import API_VERSION + + +class TestCLIParsing(object): + """Tests that commandlines are correctly parsed to Command keyword args + """ + def check_command(self, commandline, expected_command_name, **kw_expected): + argv = shlex.split(commandline) + executioner = api.Backend.cli + + cmd = executioner.get_command(argv) + kw_got = executioner.argv_to_keyword_arguments(cmd, argv[1:]) + util.assert_deepequal(expected_command_name, cmd.name, 'Command name') + util.assert_deepequal(kw_expected, kw_got) + + def run_command(self, command_name, **kw): + """Run a command on the server""" + if not api.Backend.xmlclient.isconnected(): + api.Backend.xmlclient.connect(fallback=False) + api.Command[command_name](**kw) + + @contextlib.contextmanager + def fake_stdin(self, string_in): + """Context manager that temporarily replaces stdin to read a string""" + old_stdin = sys.stdin + sys.stdin = StringIO.StringIO(string_in) + yield + sys.stdin = old_stdin + + def test_ping(self): + self.check_command('ping', 'ping', + version=API_VERSION) + + def test_user_show(self): + self.check_command('user-show admin', 'user_show', + uid=u'admin', + rights=False, + raw=False, + all=False, + version=API_VERSION) + + def test_user_show_underscore(self): + self.check_command('user_show admin', 'user_show', + uid=u'admin', + rights=False, + raw=False, + all=False, + version=API_VERSION) + + def test_group_add(self): + self.check_command('group-add tgroup1 --desc="Test group"', + 'group_add', + cn=u'tgroup1', + description=u'Test group', + nonposix=False, + raw=False, + all=False, + version=API_VERSION) + + def test_sudocmdgroup_add_member(self): + # Test CSV splitting is done correctly + self.check_command( + # The following is as it would appear on the command line: + r'sudocmdgroup-add-member tcmdgroup1 --sudocmds=abc,\"de,f\",\\,g', + 'sudocmdgroup_add_member', + cn=u'tcmdgroup1', + sudocmd=[u'abc', u'de,f', u'\\', u'g'], + raw=False, + all=False, + version=API_VERSION) + + def test_group_add_nonposix(self): + self.check_command('group-add tgroup1 --desc="Test group" --nonposix', + 'group_add', + cn=u'tgroup1', + description=u'Test group', + nonposix=True, + raw=False, + all=False, + version=API_VERSION) + + def test_group_add_gid(self): + self.check_command('group-add tgroup1 --desc="Test group" --gid=1234', + 'group_add', + cn=u'tgroup1', + description=u'Test group', + gidnumber=u'1234', + nonposix=False, + raw=False, + all=False, + version=API_VERSION) + + def test_group_add_interactive(self): + with self.fake_stdin('Test group\n'): + self.check_command('group-add tgroup1', 'group_add', + cn=u'tgroup1', + description=u'Test group', + nonposix=False, + raw=False, + all=False, + version=API_VERSION) + + def test_dnsrecord_add(self): + self.check_command('dnsrecord-add test-example.com ns --a-rec=1.2.3.4', + 'dnsrecord_add', + dnszoneidnsname=u'test-example.com', + idnsname=u'ns', + arecord=[u'1.2.3.4'], + structured=False, + force=False, + raw=False, + all=False, + version=API_VERSION) + + def test_dnsrecord_del_all(self): + try: + self.run_command('dnszone_add', idnsname=u'test-example.com', + idnssoamname=u'ns.test-example.com', + admin_email=u'devnull@test-example.com', force=True) + except errors.NotFound: + raise nose.SkipTest('DNS is not configured') + try: + self.run_command('dnsrecord_add', + dnszoneidnsname=u'test-example.com', + idnsname=u'ns', arecord=u'1.2.3.4') + with self.fake_stdin('yes\n'): + self.check_command('dnsrecord_del test-example.com ns', + 'dnsrecord_del', + dnszoneidnsname=u'test-example.com', + idnsname=u'ns', + del_all=True, + structured=False, + raw=False, + all=False, + version=API_VERSION) + with self.fake_stdin('YeS\n'): + self.check_command('dnsrecord_del test-example.com ns', + 'dnsrecord_del', + dnszoneidnsname=u'test-example.com', + idnsname=u'ns', + del_all=True, + structured=False, + raw=False, + all=False, + version=API_VERSION) + finally: + self.run_command('dnszone_del', idnsname=u'test-example.com') + + def test_dnsrecord_del_one_by_one(self): + try: + self.run_command('dnszone_add', idnsname=u'test-example.com', + idnssoamname=u'ns.test-example.com', + admin_email=u'devnull@test-example.com', force=True) + except errors.NotFound: + raise nose.SkipTest('DNS is not configured') + try: + records = (u'1 1 E3B72BA346B90570EED94BE9334E34AA795CED23', + u'2 1 FD2693C1EFFC11A8D2BE57229212A04B45663791') + for record in records: + self.run_command('dnsrecord_add', + dnszoneidnsname=u'test-example.com', idnsname=u'ns', + sshfprecord=record) + with self.fake_stdin('no\nyes\nyes\n'): + self.check_command('dnsrecord_del test-example.com ns', + 'dnsrecord_del', + dnszoneidnsname=u'test-example.com', + idnsname=u'ns', + del_all=False, + sshfprecord=records, + structured=False, + raw=False, + all=False, + version=API_VERSION) + finally: + self.run_command('dnszone_del', idnsname=u'test-example.com') |