summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--ipalib/__init__.py3
-rw-r--r--ipalib/backend.py1
-rw-r--r--ipalib/cli.py24
-rw-r--r--tests/test_cmdline/test_cli.py184
4 files changed, 207 insertions, 5 deletions
diff --git a/ipalib/__init__.py b/ipalib/__init__.py
index 1efeeab4a..dd861a826 100644
--- a/ipalib/__init__.py
+++ b/ipalib/__init__.py
@@ -916,5 +916,8 @@ def create_api(mode='dummy'):
api = create_api(mode=None)
if os.environ.get('IPA_UNIT_TEST_MODE', None) == 'cli_test':
+ from cli import cli_plugins
+ for klass in cli_plugins:
+ api.register(klass)
api.bootstrap(context='cli', in_server=False, in_tree=True)
api.finalize()
diff --git a/ipalib/backend.py b/ipalib/backend.py
index 0232fa536..7be38ecc8 100644
--- a/ipalib/backend.py
+++ b/ipalib/backend.py
@@ -102,7 +102,6 @@ class Connectible(Backend):
class Executioner(Backend):
-
def create_context(self, ccache=None, client_ip=None):
"""
client_ip: The IP address of the remote client.
diff --git a/ipalib/cli.py b/ipalib/cli.py
index f72cca587..8279345a9 100644
--- a/ipalib/cli.py
+++ b/ipalib/cli.py
@@ -123,7 +123,7 @@ class textui(backend.Backend):
def __get_encoding(self, stream):
assert stream in (sys.stdin, sys.stdout)
- if stream.encoding is None:
+ if getattr(stream, 'encoding', None) is None:
return 'UTF-8'
return stream.encoding
@@ -1007,7 +1007,11 @@ class cli(backend.Executioner):
Backend plugin for executing from command line interface.
"""
- def run(self, argv):
+ def get_command(self, argv):
+ """Given CLI arguments, return the Command to use
+
+ On incorrect invocation, prints out a help message and returns None
+ """
if len(argv) == 0:
self.Command.help()
return
@@ -1022,15 +1026,27 @@ class cli(backend.Executioner):
if name not in self.Command or self.Command[name].NO_CLI:
raise CommandError(name=key)
cmd = self.Command[name]
- if not isinstance(cmd, frontend.Local):
- self.create_context()
+ return cmd
+
+ def argv_to_keyword_arguments(self, cmd, argv):
+ """Get the keyword arguments for a Command"""
kw = self.parse(cmd, argv)
if self.env.interactive:
self.prompt_interactively(cmd, kw)
kw = cmd.split_csv(**kw)
kw['version'] = API_VERSION
self.load_files(cmd, kw)
+ return kw
+
+ def run(self, argv):
+ cmd = self.get_command(argv)
+ if cmd is None:
+ return
+ name = cmd.name
+ if not isinstance(cmd, frontend.Local):
+ self.create_context()
try:
+ kw = self.argv_to_keyword_arguments(cmd, argv[1:])
result = self.execute(name, **kw)
if callable(cmd.output_for_cli):
for param in cmd.params():
diff --git a/tests/test_cmdline/test_cli.py b/tests/test_cmdline/test_cli.py
new file mode 100644
index 000000000..889aae413
--- /dev/null
+++ b/tests/test_cmdline/test_cli.py
@@ -0,0 +1,184 @@
+import shlex
+import sys
+import contextlib
+import StringIO
+
+import nose
+
+from tests import util
+from ipalib import api, errors
+from ipapython.version import API_VERSION
+
+
+class TestCLIParsing(object):
+ """Tests that commandlines are correctly parsed to Command keyword args
+ """
+ def check_command(self, commandline, expected_command_name, **kw_expected):
+ argv = shlex.split(commandline)
+ executioner = api.Backend.cli
+
+ cmd = executioner.get_command(argv)
+ kw_got = executioner.argv_to_keyword_arguments(cmd, argv[1:])
+ util.assert_deepequal(expected_command_name, cmd.name, 'Command name')
+ util.assert_deepequal(kw_expected, kw_got)
+
+ def run_command(self, command_name, **kw):
+ """Run a command on the server"""
+ if not api.Backend.xmlclient.isconnected():
+ api.Backend.xmlclient.connect(fallback=False)
+ api.Command[command_name](**kw)
+
+ @contextlib.contextmanager
+ def fake_stdin(self, string_in):
+ """Context manager that temporarily replaces stdin to read a string"""
+ old_stdin = sys.stdin
+ sys.stdin = StringIO.StringIO(string_in)
+ yield
+ sys.stdin = old_stdin
+
+ def test_ping(self):
+ self.check_command('ping', 'ping',
+ version=API_VERSION)
+
+ def test_user_show(self):
+ self.check_command('user-show admin', 'user_show',
+ uid=u'admin',
+ rights=False,
+ raw=False,
+ all=False,
+ version=API_VERSION)
+
+ def test_user_show_underscore(self):
+ self.check_command('user_show admin', 'user_show',
+ uid=u'admin',
+ rights=False,
+ raw=False,
+ all=False,
+ version=API_VERSION)
+
+ def test_group_add(self):
+ self.check_command('group-add tgroup1 --desc="Test group"',
+ 'group_add',
+ cn=u'tgroup1',
+ description=u'Test group',
+ nonposix=False,
+ raw=False,
+ all=False,
+ version=API_VERSION)
+
+ def test_sudocmdgroup_add_member(self):
+ # Test CSV splitting is done correctly
+ self.check_command(
+ # The following is as it would appear on the command line:
+ r'sudocmdgroup-add-member tcmdgroup1 --sudocmds=abc,\"de,f\",\\,g',
+ 'sudocmdgroup_add_member',
+ cn=u'tcmdgroup1',
+ sudocmd=[u'abc', u'de,f', u'\\', u'g'],
+ raw=False,
+ all=False,
+ version=API_VERSION)
+
+ def test_group_add_nonposix(self):
+ self.check_command('group-add tgroup1 --desc="Test group" --nonposix',
+ 'group_add',
+ cn=u'tgroup1',
+ description=u'Test group',
+ nonposix=True,
+ raw=False,
+ all=False,
+ version=API_VERSION)
+
+ def test_group_add_gid(self):
+ self.check_command('group-add tgroup1 --desc="Test group" --gid=1234',
+ 'group_add',
+ cn=u'tgroup1',
+ description=u'Test group',
+ gidnumber=u'1234',
+ nonposix=False,
+ raw=False,
+ all=False,
+ version=API_VERSION)
+
+ def test_group_add_interactive(self):
+ with self.fake_stdin('Test group\n'):
+ self.check_command('group-add tgroup1', 'group_add',
+ cn=u'tgroup1',
+ description=u'Test group',
+ nonposix=False,
+ raw=False,
+ all=False,
+ version=API_VERSION)
+
+ def test_dnsrecord_add(self):
+ self.check_command('dnsrecord-add test-example.com ns --a-rec=1.2.3.4',
+ 'dnsrecord_add',
+ dnszoneidnsname=u'test-example.com',
+ idnsname=u'ns',
+ arecord=[u'1.2.3.4'],
+ structured=False,
+ force=False,
+ raw=False,
+ all=False,
+ version=API_VERSION)
+
+ def test_dnsrecord_del_all(self):
+ try:
+ self.run_command('dnszone_add', idnsname=u'test-example.com',
+ idnssoamname=u'ns.test-example.com',
+ admin_email=u'devnull@test-example.com', force=True)
+ except errors.NotFound:
+ raise nose.SkipTest('DNS is not configured')
+ try:
+ self.run_command('dnsrecord_add',
+ dnszoneidnsname=u'test-example.com',
+ idnsname=u'ns', arecord=u'1.2.3.4')
+ with self.fake_stdin('yes\n'):
+ self.check_command('dnsrecord_del test-example.com ns',
+ 'dnsrecord_del',
+ dnszoneidnsname=u'test-example.com',
+ idnsname=u'ns',
+ del_all=True,
+ structured=False,
+ raw=False,
+ all=False,
+ version=API_VERSION)
+ with self.fake_stdin('YeS\n'):
+ self.check_command('dnsrecord_del test-example.com ns',
+ 'dnsrecord_del',
+ dnszoneidnsname=u'test-example.com',
+ idnsname=u'ns',
+ del_all=True,
+ structured=False,
+ raw=False,
+ all=False,
+ version=API_VERSION)
+ finally:
+ self.run_command('dnszone_del', idnsname=u'test-example.com')
+
+ def test_dnsrecord_del_one_by_one(self):
+ try:
+ self.run_command('dnszone_add', idnsname=u'test-example.com',
+ idnssoamname=u'ns.test-example.com',
+ admin_email=u'devnull@test-example.com', force=True)
+ except errors.NotFound:
+ raise nose.SkipTest('DNS is not configured')
+ try:
+ records = (u'1 1 E3B72BA346B90570EED94BE9334E34AA795CED23',
+ u'2 1 FD2693C1EFFC11A8D2BE57229212A04B45663791')
+ for record in records:
+ self.run_command('dnsrecord_add',
+ dnszoneidnsname=u'test-example.com', idnsname=u'ns',
+ sshfprecord=record)
+ with self.fake_stdin('no\nyes\nyes\n'):
+ self.check_command('dnsrecord_del test-example.com ns',
+ 'dnsrecord_del',
+ dnszoneidnsname=u'test-example.com',
+ idnsname=u'ns',
+ del_all=False,
+ sshfprecord=records,
+ structured=False,
+ raw=False,
+ all=False,
+ version=API_VERSION)
+ finally:
+ self.run_command('dnszone_del', idnsname=u'test-example.com')