From 45d45eabcd0b71791f81cb6c34c0af92a7da9732 Mon Sep 17 00:00:00 2001 From: Petr Viktorin Date: Tue, 17 Sep 2013 12:24:40 +0200 Subject: test_integration.host: Move transport-related functionality to a new module This will make it possible to use a different mechanism for cases like - Paramiko is not available - Hosts without SSH servers (e.g. Windows) Add BaseHost, Transport & Command base classes that define the interface and common functionality, and Host, ParamikoTransport & SSHCommand with specific details. The {get,put}_file_contents methods are left on Host for convenience; all other Transport methods must be now accessed through the transport. Part of the work for https://fedorahosted.org/freeipa/ticket/3890 --- ipatests/test_integration/host.py | 263 ++++++--------------------- ipatests/test_integration/tasks.py | 8 +- ipatests/test_integration/test_caless.py | 11 +- ipatests/test_integration/transport.py | 303 +++++++++++++++++++++++++++++++ 4 files changed, 371 insertions(+), 214 deletions(-) create mode 100644 ipatests/test_integration/transport.py (limited to 'ipatests') diff --git a/ipatests/test_integration/host.py b/ipatests/test_integration/host.py index 1614b5fd..38d9ae6f 100644 --- a/ipatests/test_integration/host.py +++ b/ipatests/test_integration/host.py @@ -21,121 +21,17 @@ import os import socket -import threading -import subprocess -from contextlib import contextmanager -import errno - -import paramiko from ipapython.ipaldap import IPAdmin from ipapython import ipautil from ipapython.ipa_log_manager import log_mgr +from ipatests.test_integration import transport -class RemoteCommand(object): - """A Popen-style object representing a remote command - - Unlike subprocess.Popen, this does not run the given command; instead - it only starts a shell. The command must be written to stdin manually. - - The standard error and output are handled by this class. They're not - available for file-like reading. They are logged by default. - To make sure reading doesn't stall after one buffer fills up, they are read - in parallel using threads. - - After calling wait(), stdout_text and stderr_text attributes will be - strings containing the output, and returncode will contain the - exit code. - - :param host: The Host on which the command is run - :param argv: The command that will be run (for logging only) - :param index: An identification number added to the logs - :param log_stdout: If false, stdout will not be logged - """ - def __init__(self, host, argv, index, log_stdout=True): - self.returncode = None - self.host = host - self.argv = argv - self._stdout_lines = [] - self._stderr_lines = [] - self.running_threads = set() - - self.logger_name = '%s.cmd%s' % (self.host.logger_name, index) - self.log = log_mgr.get_logger(self.logger_name) - - self.log.info('RUN %s', argv) - - self._ssh = host.transport.open_channel('session') - - self._ssh.invoke_shell() - stdin = self.stdin = self._ssh.makefile('wb') - stdout = self._ssh.makefile('rb') - stderr = self._ssh.makefile_stderr('rb') - - self._start_pipe_thread(self._stdout_lines, stdout, 'out', log_stdout) - self._start_pipe_thread(self._stderr_lines, stderr, 'err', True) - - self._done = False - - def wait(self, raiseonerr=True): - """Wait for the remote process to exit - - Raises an excption if the exit code is not 0. - """ - if self._done: - return self.returncode - - self._ssh.shutdown_write() - while self.running_threads: - self.running_threads.pop().join() - - self.stdout_text = ''.join(self._stdout_lines) - self.stderr_text = ''.join(self._stderr_lines) - self.returncode = self._ssh.recv_exit_status() - self._ssh.close() - - self._done = True - - if raiseonerr and self.returncode: - self.log.error('Exit code: %s', self.returncode) - raise subprocess.CalledProcessError(self.returncode, self.argv) - else: - self.log.debug('Exit code: %s', self.returncode) - return self.returncode - - def _start_pipe_thread(self, result_list, stream, name, do_log=True): - log = log_mgr.get_logger('%s.%s' % (self.logger_name, name)) - - def read_stream(): - for line in stream: - if do_log: - log.debug(line.rstrip('\n')) - result_list.append(line) - - thread = threading.Thread(target=read_stream) - self.running_threads.add(thread) - thread.start() - return thread - - -@contextmanager -def sftp_open(sftp, filename, mode='r'): - """Context manager that provides a file-like object over a SFTP channel - - This provides compatibility with older Paramiko versions. - (In Paramiko 1.10+, file objects from `sftp.open` are directly usable as - context managers). - """ - file = sftp.open(filename, mode) - try: - yield file - finally: - file.close() - - -class Host(object): +class BaseHost(object): """Representation of a remote IPA host""" + transport_class = None + def __init__(self, domain, hostname, role, index, ip=None): self.domain = domain self.role = role @@ -175,8 +71,6 @@ class Host(object): self.env_sh_path = os.path.join(domain.config.test_dir, 'env.sh') - self._command_index = 0 - self.log_collectors = [] def __str__(self): @@ -224,13 +118,48 @@ class Host(object): return env + @property + def transport(self): + try: + return self._transport + except AttributeError: + cls = self.transport_class + if cls: + # transport_class is None in the base class and must be + # set in subclasses. + # Pylint reports that calling None will fail + self._transport = cls(self) # pylint: disable=E1102 + else: + raise NotImplementedError('transport class not available') + return self._transport + + def get_file_contents(self, filename): + """Shortcut for transport.get_file_contents""" + return self.transport.get_file_contents(filename) + + def put_file_contents(self, filename, contents): + """Shortcut for transport.put_file_contents""" + self.transport.put_file_contents(filename, contents) + + def ldap_connect(self): + """Return an LDAPClient authenticated to this host as directory manager + """ + ldap = IPAdmin(self.external_hostname) + ldap.do_simple_bind(self.config.dirman_dn, + self.config.dirman_password) + return ldap + + def collect_log(self, filename): + for collector in self.log_collectors: + collector(self, filename) + def run_command(self, argv, set_env=True, stdin_text=None, log_stdout=True, raiseonerr=True, cwd=None): """Run the given command on this host - Returns a RemoteCommand instance. The command will have already run - when this method returns, so its stdout_text, stderr_text, and + Returns a Shell instance. The command will have already run in the + shell when this method returns, so its stdout_text, stderr_text, and returncode attributes will be available. :param argv: Command to run, as either a Popen-style list, or a string @@ -242,30 +171,43 @@ class Host(object): (but will still be available as cmd.stdout_text) :param raiseonerr: If true, an exception will be raised if the command does not exit with return code 0 + :param cwd: The working directory for the command """ - assert self.transport + raise NotImplementedError() + + +class Host(BaseHost): + """A Unix host""" + transport_class = transport.ParamikoTransport - self._command_index += 1 - command = RemoteCommand(self, argv, index=self._command_index, - log_stdout=log_stdout) + def run_command(self, argv, set_env=True, stdin_text=None, + log_stdout=True, raiseonerr=True, + cwd=None): + # This will give us a Bash shell + command = self.transport.start_shell(argv, log_stdout=log_stdout) + # Set working directory if cwd is None: cwd = self.config.test_dir command.stdin.write('cd %s\n' % ipautil.shell_quote(cwd)) + # Set the environment if set_env: command.stdin.write('. %s\n' % ipautil.shell_quote(self.env_sh_path)) command.stdin.write('set -e\n') if isinstance(argv, basestring): + # Run a shell command given as a string command.stdin.write('(') command.stdin.write(argv) command.stdin.write(')') else: + # Run a command given as a popen-style list (no shell expansion) for arg in argv: command.stdin.write(ipautil.shell_quote(arg)) command.stdin.write(' ') + command.stdin.write(';exit\n') if stdin_text: command.stdin.write(stdin_text) @@ -273,92 +215,3 @@ class Host(object): command.wait(raiseonerr=raiseonerr) return command - - @property - def transport(self): - """Paramiko Transport connected to this host""" - try: - return self._transport - except AttributeError: - sock = socket.create_connection((self.external_hostname, - self.ssh_port)) - self._transport = transport = paramiko.Transport(sock) - transport.connect(hostkey=self.host_key) - if self.root_ssh_key_filename: - self.log.debug('Authenticating with private RSA key') - filename = os.path.expanduser(self.root_ssh_key_filename) - key = paramiko.RSAKey.from_private_key_file(filename) - transport.auth_publickey(username='root', key=key) - elif self.root_password: - self.log.debug('Authenticating with password') - transport.auth_password(username='root', - password=self.root_password) - else: - self.log.critical('No SSH credentials configured') - raise RuntimeError('No SSH credentials configured') - return transport - - @property - def sftp(self): - """Paramiko SFTPClient connected to this host""" - try: - return self._sftp - except AttributeError: - transport = self.transport - self._sftp = paramiko.SFTPClient.from_transport(transport) - return self._sftp - - def ldap_connect(self): - """Return an LDAPClient authenticated to this host as directory manager - """ - ldap = IPAdmin(self.external_hostname) - ldap.do_simple_bind(self.config.dirman_dn, - self.config.dirman_password) - return ldap - - def mkdir_recursive(self, path): - """`mkdir -p` on the remote host""" - try: - self.sftp.chdir(path or '/') - except IOError as e: - if not path or path == '/': - raise - self.mkdir_recursive(os.path.dirname(path)) - self.sftp.mkdir(path) - self.sftp.chdir(path) - - def get_file_contents(self, filename): - """Read the named remote file and return the contents as a string""" - self.log.debug('READ %s', filename) - with sftp_open(self.sftp, filename) as f: - return f.read() - - def put_file_contents(self, filename, contents): - """Write the given string to the named remote file""" - self.log.info('WRITE %s', filename) - with sftp_open(self.sftp, filename, 'w') as f: - f.write(contents) - - def file_exists(self, filename): - """Return true if the named remote file exists""" - self.log.debug('STAT %s', filename) - try: - self.sftp.stat(filename) - except IOError, e: - if e.errno == errno.ENOENT: - return False - else: - raise - return True - - def get_file(self, remotepath, localpath): - self.log.debug('GET %s', remotepath) - self.sftp.get(remotepath, localpath) - - def put_file(self, localpath, remotepath): - self.log.info('PUT %s', remotepath) - self.sftp.put(localpath, remotepath) - - def collect_log(self, filename): - for collector in self.log_collectors: - collector(self, filename) diff --git a/ipatests/test_integration/tasks.py b/ipatests/test_integration/tasks.py index 69a34a28..7ea0ce4f 100644 --- a/ipatests/test_integration/tasks.py +++ b/ipatests/test_integration/tasks.py @@ -40,7 +40,7 @@ log = log_mgr.get_logger(__name__) def prepare_host(host): env_filename = os.path.join(host.config.test_dir, 'env.sh') host.collect_log(env_filename) - host.mkdir_recursive(host.config.test_dir) + host.transport.mkdir_recursive(host.config.test_dir) host.put_file_contents(env_filename, env_to_script(host.to_env())) @@ -51,10 +51,10 @@ def apply_common_fixes(host): def backup_file(host, filename): - if host.file_exists(filename): + if host.transport.file_exists(filename): backupname = os.path.join(host.config.test_dir, 'file_backup', filename.lstrip('/')) - host.mkdir_recursive(os.path.dirname(backupname)) + host.transport.mkdir_recursive(os.path.dirname(backupname)) host.run_command(['cp', '-af', filename, backupname]) return True else: @@ -63,7 +63,7 @@ def backup_file(host, filename): ipautil.shell_quote(filename), ipautil.shell_quote(rmname))) contents = host.get_file_contents(rmname) - host.mkdir_recursive(os.path.dirname(rmname)) + host.transport.mkdir_recursive(os.path.dirname(rmname)) return False diff --git a/ipatests/test_integration/test_caless.py b/ipatests/test_integration/test_caless.py index 643c9c48..12faf296 100644 --- a/ipatests/test_integration/test_caless.py +++ b/ipatests/test_integration/test_caless.py @@ -103,7 +103,7 @@ class CALessBase(IntegrationTest): host.mkdir_recursive(cls.crl_path) for source in glob.glob(os.path.join(base, '*.crl')): dest = os.path.join(cls.crl_path, os.path.basename(source)) - host.put_file(source, dest) + host.transport.put_file(source, dest) @classmethod def uninstall(cls): @@ -174,8 +174,8 @@ class CALessBase(IntegrationTest): @classmethod def copy_cert(cls, host, filename): - host.put_file(os.path.join(cls.cert_dir, filename), - os.path.join(host.config.test_dir, filename)) + host.transport.put_file(os.path.join(cls.cert_dir, filename), + os.path.join(host.config.test_dir, filename)) @classmethod def uninstall_server(self, host=None): @@ -211,8 +211,9 @@ class CALessBase(IntegrationTest): if dirsrv_pkcs12_exists: files_to_copy.append(dirsrv_pkcs12) for filename in set(files_to_copy): - master.put_file(os.path.join(self.cert_dir, filename), - os.path.join(master.config.test_dir, filename)) + master.transport.put_file( + os.path.join(self.cert_dir, filename), + os.path.join(master.config.test_dir, filename)) self.collect_log(replica, '/var/log/ipareplica-install.log') self.collect_log(replica, '/var/log/ipaclient-install.log') diff --git a/ipatests/test_integration/transport.py b/ipatests/test_integration/transport.py new file mode 100644 index 00000000..52b689a1 --- /dev/null +++ b/ipatests/test_integration/transport.py @@ -0,0 +1,303 @@ +# Authors: +# Petr Viktorin +# +# Copyright (C) 2013 Red Hat +# see file 'COPYING' for use and warranty information +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Objects for communicating with remote hosts""" + +import os +import socket +import threading +import subprocess +from contextlib import contextmanager +import errno + +from ipapython.ipa_log_manager import log_mgr + +try: + import paramiko + have_paramiko = True +except ImportError: + have_paramiko = False + + +class Transport(object): + """Mechanism for communicating with remote hosts + + The Transport can manipulate files on a remote host, and open a Command. + + The base class defines an interface that specific subclasses implement. + """ + def __init__(self, host): + self.host = host + self.logger_name = '%s.%s' % (host.logger_name, type(self).__name__) + self.log = log_mgr.get_logger(self.logger_name) + self._command_index = 0 + + def get_file_contents(self, filename): + """Read the named remote file and return the contents as a string""" + raise NotImplementedError('Transport.get_file_contents') + + def put_file_contents(self, filename, contents): + """Write the given string to the named remote file""" + raise NotImplementedError('Transport.put_file_contents') + + def file_exists(self, filename): + """Return true if the named remote file exists""" + raise NotImplementedError('Transport.file_exists') + + def mkdir(self, path): + """Make the named directory""" + raise NotImplementedError('Transport.mkdir') + + def start_shell(self, argv, log_stdout=True): + """Start a Shell + + :param argv: The command this shell is intended to run (used for + logging only) + :param log_stdout: If false, the stdout will not be logged (useful when + binary output is expected) + + Given a `shell` from this method, the caller can then use + ``shell.stdin.write()`` to input any command(s), call ``shell.wait()`` + to let the command run, and then inspect ``returncode``, + ``stdout_text`` or ``stderr_text``. + """ + raise NotImplementedError('Transport.start_shell') + + def mkdir_recursive(self, path): + """`mkdir -p` on the remote host""" + if not path or path == '/': + raise ValueError('Invalid path') + if not self.file_exists(path or '/'): + self.mkdir_recursive(os.path.dirname(path)) + self.mkdir(path) + + def get_file(self, remotepath, localpath): + """Copy a file from the remote host to a local file""" + contents = self.get_file_contents(remotepath) + with open(localpath, 'wb') as local_file: + local_file.write(contents) + + def put_file(self, localpath, remotepath): + """Copy a local file to the remote host""" + with open(localpath, 'rb') as local_file: + contents = local_file.read() + self.put_file_contents(remotepath, contents) + + def get_next_command_logger_name(self): + self._command_index += 1 + return '%s.cmd%s' % (self.host.logger_name, self._command_index) + + +class Command(object): + """A Popen-style object representing a remote command + + Instances of this class should only be created via method of a concrete + Transport, such as start_shell. + + The standard error and output are handled by this class. They're not + available for file-like reading, and are logged by default. + To make sure reading doesn't stall after one buffer fills up, they are read + in parallel using threads. + + After calling wait(), ``stdout_text`` and ``stderr_text`` attributes will + be strings containing the output, and ``returncode`` will contain the + exit code. + """ + def __init__(self, argv, logger_name=None, log_stdout=True): + self.returncode = None + self.argv = argv + self._done = False + + if logger_name: + self.logger_name = logger_name + else: + self.logger_name = '%s.%s' % (self.__module__, type(self).__name__) + self.log = log_mgr.get_logger(self.logger_name) + + def wait(self, raiseonerr=True): + """Wait for the remote process to exit + + Raises an excption if the exit code is not 0, unless raiseonerr is + true. + """ + if self._done: + return self.returncode + + self._end_process() + + self._done = True + + if raiseonerr and self.returncode: + self.log.error('Exit code: %s', self.returncode) + raise subprocess.CalledProcessError(self.returncode, self.argv) + else: + self.log.debug('Exit code: %s', self.returncode) + return self.returncode + + def _end_process(self): + """Wait until the process exits and output is received, close channel + + Called from wait() + """ + raise NotImplementedError() + + +class ParamikoTransport(Transport): + """Transport that uses the Paramiko SSH2 library""" + def __init__(self, host): + super(ParamikoTransport, self).__init__(host) + sock = socket.create_connection((host.external_hostname, + host.ssh_port)) + self._transport = transport = paramiko.Transport(sock) + transport.connect(hostkey=host.host_key) + if host.root_ssh_key_filename: + self.log.debug('Authenticating with private RSA key') + filename = os.path.expanduser(host.root_ssh_key_filename) + key = paramiko.RSAKey.from_private_key_file(filename) + transport.auth_publickey(username='root', key=key) + elif host.root_password: + self.log.debug('Authenticating with password') + transport.auth_password(username='root', + password=host.root_password) + else: + self.log.critical('No SSH credentials configured') + raise RuntimeError('No SSH credentials configured') + + @contextmanager + def sftp_open(self, filename, mode='r'): + """Context manager that provides a file-like object over a SFTP channel + + This provides compatibility with older Paramiko versions. + (In Paramiko 1.10+, file objects from `sftp.open` are directly usable + as context managers). + """ + file = self.sftp.open(filename, mode) + try: + yield file + finally: + file.close() + + @property + def sftp(self): + """Paramiko SFTPClient connected to this host""" + try: + return self._sftp + except AttributeError: + transport = self._transport + self._sftp = paramiko.SFTPClient.from_transport(transport) + return self._sftp + + def get_file_contents(self, filename): + """Read the named remote file and return the contents as a string""" + self.log.debug('READ %s', filename) + with self.sftp_open(filename) as f: + return f.read() + + def put_file_contents(self, filename, contents): + """Write the given string to the named remote file""" + self.log.info('WRITE %s', filename) + with self.sftp_open(filename, 'w') as f: + f.write(contents) + + def file_exists(self, filename): + """Return true if the named remote file exists""" + self.log.debug('STAT %s', filename) + try: + self.sftp.stat(filename) + except IOError, e: + if e.errno == errno.ENOENT: + return False + else: + raise + return True + + def mkdir(self, path): + self.log.info('MKDIR %s', path) + self.sftp.mkdir(path) + + def start_shell(self, argv, log_stdout=True): + logger_name = self.get_next_command_logger_name() + ssh = self._transport.open_channel('session') + self.log.info('RUN %s', argv) + return SSHCommand(ssh, argv, logger_name=logger_name, + log_stdout=log_stdout) + + def get_file(self, remotepath, localpath): + self.log.debug('GET %s', remotepath) + self.sftp.get(remotepath, localpath) + + def put_file(self, localpath, remotepath): + self.log.info('PUT %s', remotepath) + self.sftp.put(localpath, remotepath) + + +class SSHCommand(Command): + """Command implementation for ParamikoTransport and OpenSSHTranspport""" + def __init__(self, ssh, argv, logger_name, log_stdout=True, + collect_output=True): + super(SSHCommand, self).__init__(argv, logger_name, + log_stdout=log_stdout) + self._stdout_lines = [] + self._stderr_lines = [] + self.running_threads = set() + + self._ssh = ssh + + self.log.debug('RUN %s', argv) + + self._ssh.invoke_shell() + stdin = self.stdin = self._ssh.makefile('wb') + stdout = self._ssh.makefile('rb') + stderr = self._ssh.makefile_stderr('rb') + + if collect_output: + self._start_pipe_thread(self._stdout_lines, stdout, 'out', + log_stdout) + self._start_pipe_thread(self._stderr_lines, stderr, 'err', True) + + def _end_process(self, raiseonerr=True): + self._ssh.shutdown_write() + + while self.running_threads: + self.running_threads.pop().join() + + self.stdout_text = ''.join(self._stdout_lines) + self.stderr_text = ''.join(self._stderr_lines) + self.returncode = self._ssh.recv_exit_status() + self._ssh.close() + + def _start_pipe_thread(self, result_list, stream, name, do_log=True): + """Start a thread that copies lines from ``stream`` to ``result_list`` + + If do_log is true, also logs the lines under ``name`` + + The thread is added to ``self.running_threads``. + """ + log = log_mgr.get_logger('%s.%s' % (self.logger_name, name)) + + def read_stream(): + for line in stream: + if do_log: + log.debug(line.rstrip('\n')) + result_list.append(line) + + thread = threading.Thread(target=read_stream) + self.running_threads.add(thread) + thread.start() + return thread -- cgit