diff options
Diffstat (limited to 'ipatests/test_integration/transport.py')
-rw-r--r-- | ipatests/test_integration/transport.py | 303 |
1 files changed, 303 insertions, 0 deletions
diff --git a/ipatests/test_integration/transport.py b/ipatests/test_integration/transport.py new file mode 100644 index 00000000..52b689a1 --- /dev/null +++ b/ipatests/test_integration/transport.py @@ -0,0 +1,303 @@ +# Authors: +# Petr Viktorin <pviktori@redhat.com> +# +# Copyright (C) 2013 Red Hat +# see file 'COPYING' for use and warranty information +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +"""Objects for communicating with remote hosts""" + +import os +import socket +import threading +import subprocess +from contextlib import contextmanager +import errno + +from ipapython.ipa_log_manager import log_mgr + +try: + import paramiko + have_paramiko = True +except ImportError: + have_paramiko = False + + +class Transport(object): + """Mechanism for communicating with remote hosts + + The Transport can manipulate files on a remote host, and open a Command. + + The base class defines an interface that specific subclasses implement. + """ + def __init__(self, host): + self.host = host + self.logger_name = '%s.%s' % (host.logger_name, type(self).__name__) + self.log = log_mgr.get_logger(self.logger_name) + self._command_index = 0 + + def get_file_contents(self, filename): + """Read the named remote file and return the contents as a string""" + raise NotImplementedError('Transport.get_file_contents') + + def put_file_contents(self, filename, contents): + """Write the given string to the named remote file""" + raise NotImplementedError('Transport.put_file_contents') + + def file_exists(self, filename): + """Return true if the named remote file exists""" + raise NotImplementedError('Transport.file_exists') + + def mkdir(self, path): + """Make the named directory""" + raise NotImplementedError('Transport.mkdir') + + def start_shell(self, argv, log_stdout=True): + """Start a Shell + + :param argv: The command this shell is intended to run (used for + logging only) + :param log_stdout: If false, the stdout will not be logged (useful when + binary output is expected) + + Given a `shell` from this method, the caller can then use + ``shell.stdin.write()`` to input any command(s), call ``shell.wait()`` + to let the command run, and then inspect ``returncode``, + ``stdout_text`` or ``stderr_text``. + """ + raise NotImplementedError('Transport.start_shell') + + def mkdir_recursive(self, path): + """`mkdir -p` on the remote host""" + if not path or path == '/': + raise ValueError('Invalid path') + if not self.file_exists(path or '/'): + self.mkdir_recursive(os.path.dirname(path)) + self.mkdir(path) + + def get_file(self, remotepath, localpath): + """Copy a file from the remote host to a local file""" + contents = self.get_file_contents(remotepath) + with open(localpath, 'wb') as local_file: + local_file.write(contents) + + def put_file(self, localpath, remotepath): + """Copy a local file to the remote host""" + with open(localpath, 'rb') as local_file: + contents = local_file.read() + self.put_file_contents(remotepath, contents) + + def get_next_command_logger_name(self): + self._command_index += 1 + return '%s.cmd%s' % (self.host.logger_name, self._command_index) + + +class Command(object): + """A Popen-style object representing a remote command + + Instances of this class should only be created via method of a concrete + Transport, such as start_shell. + + The standard error and output are handled by this class. They're not + available for file-like reading, and are logged by default. + To make sure reading doesn't stall after one buffer fills up, they are read + in parallel using threads. + + After calling wait(), ``stdout_text`` and ``stderr_text`` attributes will + be strings containing the output, and ``returncode`` will contain the + exit code. + """ + def __init__(self, argv, logger_name=None, log_stdout=True): + self.returncode = None + self.argv = argv + self._done = False + + if logger_name: + self.logger_name = logger_name + else: + self.logger_name = '%s.%s' % (self.__module__, type(self).__name__) + self.log = log_mgr.get_logger(self.logger_name) + + def wait(self, raiseonerr=True): + """Wait for the remote process to exit + + Raises an excption if the exit code is not 0, unless raiseonerr is + true. + """ + if self._done: + return self.returncode + + self._end_process() + + self._done = True + + if raiseonerr and self.returncode: + self.log.error('Exit code: %s', self.returncode) + raise subprocess.CalledProcessError(self.returncode, self.argv) + else: + self.log.debug('Exit code: %s', self.returncode) + return self.returncode + + def _end_process(self): + """Wait until the process exits and output is received, close channel + + Called from wait() + """ + raise NotImplementedError() + + +class ParamikoTransport(Transport): + """Transport that uses the Paramiko SSH2 library""" + def __init__(self, host): + super(ParamikoTransport, self).__init__(host) + sock = socket.create_connection((host.external_hostname, + host.ssh_port)) + self._transport = transport = paramiko.Transport(sock) + transport.connect(hostkey=host.host_key) + if host.root_ssh_key_filename: + self.log.debug('Authenticating with private RSA key') + filename = os.path.expanduser(host.root_ssh_key_filename) + key = paramiko.RSAKey.from_private_key_file(filename) + transport.auth_publickey(username='root', key=key) + elif host.root_password: + self.log.debug('Authenticating with password') + transport.auth_password(username='root', + password=host.root_password) + else: + self.log.critical('No SSH credentials configured') + raise RuntimeError('No SSH credentials configured') + + @contextmanager + def sftp_open(self, filename, mode='r'): + """Context manager that provides a file-like object over a SFTP channel + + This provides compatibility with older Paramiko versions. + (In Paramiko 1.10+, file objects from `sftp.open` are directly usable + as context managers). + """ + file = self.sftp.open(filename, mode) + try: + yield file + finally: + file.close() + + @property + def sftp(self): + """Paramiko SFTPClient connected to this host""" + try: + return self._sftp + except AttributeError: + transport = self._transport + self._sftp = paramiko.SFTPClient.from_transport(transport) + return self._sftp + + def get_file_contents(self, filename): + """Read the named remote file and return the contents as a string""" + self.log.debug('READ %s', filename) + with self.sftp_open(filename) as f: + return f.read() + + def put_file_contents(self, filename, contents): + """Write the given string to the named remote file""" + self.log.info('WRITE %s', filename) + with self.sftp_open(filename, 'w') as f: + f.write(contents) + + def file_exists(self, filename): + """Return true if the named remote file exists""" + self.log.debug('STAT %s', filename) + try: + self.sftp.stat(filename) + except IOError, e: + if e.errno == errno.ENOENT: + return False + else: + raise + return True + + def mkdir(self, path): + self.log.info('MKDIR %s', path) + self.sftp.mkdir(path) + + def start_shell(self, argv, log_stdout=True): + logger_name = self.get_next_command_logger_name() + ssh = self._transport.open_channel('session') + self.log.info('RUN %s', argv) + return SSHCommand(ssh, argv, logger_name=logger_name, + log_stdout=log_stdout) + + def get_file(self, remotepath, localpath): + self.log.debug('GET %s', remotepath) + self.sftp.get(remotepath, localpath) + + def put_file(self, localpath, remotepath): + self.log.info('PUT %s', remotepath) + self.sftp.put(localpath, remotepath) + + +class SSHCommand(Command): + """Command implementation for ParamikoTransport and OpenSSHTranspport""" + def __init__(self, ssh, argv, logger_name, log_stdout=True, + collect_output=True): + super(SSHCommand, self).__init__(argv, logger_name, + log_stdout=log_stdout) + self._stdout_lines = [] + self._stderr_lines = [] + self.running_threads = set() + + self._ssh = ssh + + self.log.debug('RUN %s', argv) + + self._ssh.invoke_shell() + stdin = self.stdin = self._ssh.makefile('wb') + stdout = self._ssh.makefile('rb') + stderr = self._ssh.makefile_stderr('rb') + + if collect_output: + self._start_pipe_thread(self._stdout_lines, stdout, 'out', + log_stdout) + self._start_pipe_thread(self._stderr_lines, stderr, 'err', True) + + def _end_process(self, raiseonerr=True): + self._ssh.shutdown_write() + + while self.running_threads: + self.running_threads.pop().join() + + self.stdout_text = ''.join(self._stdout_lines) + self.stderr_text = ''.join(self._stderr_lines) + self.returncode = self._ssh.recv_exit_status() + self._ssh.close() + + def _start_pipe_thread(self, result_list, stream, name, do_log=True): + """Start a thread that copies lines from ``stream`` to ``result_list`` + + If do_log is true, also logs the lines under ``name`` + + The thread is added to ``self.running_threads``. + """ + log = log_mgr.get_logger('%s.%s' % (self.logger_name, name)) + + def read_stream(): + for line in stream: + if do_log: + log.debug(line.rstrip('\n')) + result_list.append(line) + + thread = threading.Thread(target=read_stream) + self.running_threads.add(thread) + thread.start() + return thread |