summaryrefslogtreecommitdiffstats
path: root/stap-client-connect.c
diff options
context:
space:
mode:
authorDave Brolley <brolley@redhat.com>2008-12-24 13:18:50 -0500
committerDave Brolley <brolley@redhat.com>2008-12-24 13:18:50 -0500
commit1cecb3c506475a0e0b0ee4180a91e1a9433d346b (patch)
treefc093dc7c74968a86a20ddbe2d9e11564e070339 /stap-client-connect.c
parente5976ba0af9b828dcc76b3937b5a98fe9c0f6cb8 (diff)
downloadsystemtap-steved-1cecb3c506475a0e0b0ee4180a91e1a9433d346b.tar.gz
systemtap-steved-1cecb3c506475a0e0b0ee4180a91e1a9433d346b.tar.xz
systemtap-steved-1cecb3c506475a0e0b0ee4180a91e1a9433d346b.zip
Systemtap compile server phase 2 (ssl) -- first cut.
Diffstat (limited to 'stap-client-connect.c')
-rw-r--r--stap-client-connect.c449
1 files changed, 449 insertions, 0 deletions
diff --git a/stap-client-connect.c b/stap-client-connect.c
new file mode 100644
index 00000000..29a8e18d
--- /dev/null
+++ b/stap-client-connect.c
@@ -0,0 +1,449 @@
+/*
+ SSL client program that sets up a connection to a SSL server, transmits
+ the given input file and then writes the reply to the given output file.
+
+ Copyright (C) 2008 Red Hat Inc.
+
+ This file is part of systemtap, and is free software. You can
+ redistribute it and/or modify it under the terms of the GNU General Public
+ License as published by the Free Software Foundation; either version 2 of the
+ License, or (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program; if not, write to the Free Software
+ Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
+*/
+
+#include <stdio.h>
+
+#include <ssl.h>
+#include <nspr.h>
+#include <plgetopt.h>
+#include <nss.h>
+
+#define READ_BUFFER_SIZE (60 * 1024)
+
+static char *hostName = NULL;
+static unsigned short port = 0;
+static const char *infileName = NULL;
+static const char *outfileName = NULL;
+
+static void
+Usage(const char *progName)
+{
+ fprintf(stderr, "Usage: %s -h hostname -p port -d dbdir -i infile -o outfile\n",
+ progName);
+ exit(1);
+}
+
+static void
+errWarn(char *function)
+{
+ PRErrorCode errorNumber = PR_GetError();
+
+ printf("Error in function %s: %d\n\n", function, errorNumber);
+}
+
+static void
+exitErr(char *function)
+{
+ errWarn(function);
+ /* Exit gracefully. */
+ /* ignoring return value of NSS_Shutdown as code exits with 1*/
+ (void) NSS_Shutdown();
+ PR_Cleanup();
+ exit(1);
+}
+
+static PRFileDesc *
+setupSSLSocket(void)
+{
+ PRFileDesc *tcpSocket;
+ PRFileDesc *sslSocket;
+ PRSocketOptionData socketOption;
+ PRStatus prStatus;
+ SECStatus secStatus;
+
+ tcpSocket = PR_NewTCPSocket();
+ if (tcpSocket == NULL)
+ {
+ errWarn("PR_NewTCPSocket");
+ }
+
+ /* Make the socket blocking. */
+ socketOption.option = PR_SockOpt_Nonblocking;
+ socketOption.value.non_blocking = PR_FALSE;
+
+ prStatus = PR_SetSocketOption(tcpSocket, &socketOption);
+ if (prStatus != PR_SUCCESS)
+ {
+ errWarn("PR_SetSocketOption");
+ goto loser;
+ }
+
+ /* Import the socket into the SSL layer. */
+ sslSocket = SSL_ImportFD(NULL, tcpSocket);
+ if (!sslSocket)
+ {
+ errWarn("SSL_ImportFD");
+ goto loser;
+ }
+
+ /* Set configuration options. */
+ secStatus = SSL_OptionSet(sslSocket, SSL_SECURITY, PR_TRUE);
+ if (secStatus != SECSuccess)
+ {
+ errWarn("SSL_OptionSet:SSL_SECURITY");
+ goto loser;
+ }
+
+ secStatus = SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_CLIENT, PR_TRUE);
+ if (secStatus != SECSuccess)
+ {
+ errWarn("SSL_OptionSet:SSL_HANDSHAKE_AS_CLIENT");
+ goto loser;
+ }
+
+ /* Set SSL callback routines. */
+#if 0 /* no client authentication */
+ secStatus = SSL_GetClientAuthDataHook(sslSocket,
+ (SSLGetClientAuthData)myGetClientAuthData,
+ (void *)certNickname);
+ if (secStatus != SECSuccess)
+ {
+ errWarn("SSL_GetClientAuthDataHook");
+ goto loser;
+ }
+#endif
+#if 0 /* Use the default */
+ secStatus = SSL_AuthCertificateHook(sslSocket,
+ (SSLAuthCertificate)myAuthCertificate,
+ (void *)CERT_GetDefaultCertDB());
+ if (secStatus != SECSuccess)
+ {
+ errWarn("SSL_AuthCertificateHook");
+ goto loser;
+ }
+#endif
+#if 0 /* Use the default */
+ secStatus = SSL_BadCertHook(sslSocket,
+ (SSLBadCertHandler)myBadCertHandler, NULL);
+ if (secStatus != SECSuccess)
+ {
+ errWarn("SSL_BadCertHook");
+ goto loser;
+ }
+#endif
+#if 0 /* No handshake callback */
+ secStatus = SSL_HandshakeCallback(sslSocket, myHandshakeCallback, NULL);
+ if (secStatus != SECSuccess)
+ {
+ errWarn("SSL_HandshakeCallback");
+ goto loser;
+ }
+#endif
+
+ return sslSocket;
+
+ loser:
+ PR_Close(tcpSocket);
+ return NULL;
+}
+
+
+static SECStatus
+handle_connection(PRFileDesc *sslSocket)
+{
+#if DEBUG
+ int countRead = 0;
+#endif
+ PRInt32 numBytes;
+ char *readBuffer;
+ PRFileInfo info;
+ PRFileDesc *local_file_fd;
+ PRStatus prStatus;
+
+ /* read and send the data. */
+ /* Try to open the local file named.
+ * If successful, then write it to the server
+ */
+ prStatus = PR_GetFileInfo(infileName, &info);
+ if (prStatus != PR_SUCCESS ||
+ info.type != PR_FILE_FILE ||
+ info.size < 0)
+ {
+ fprintf (stderr, "could not find input file %s\n", infileName);
+ return SECFailure;
+ }
+
+ local_file_fd = PR_Open(infileName, PR_RDONLY, 0);
+ if (local_file_fd == NULL)
+ {
+ fprintf (stderr, "could not open input file %s\n", infileName);
+ return SECFailure;
+ }
+
+ /* Send the file size first, so the server knows when it has the entire file. */
+ numBytes = PR_Write(sslSocket, & info.size, sizeof (info.size));
+ /* Error in transmission? */
+ if (numBytes < 0)
+ {
+ errWarn("PR_TransmitFile");
+ return SECFailure;
+ }
+
+ /* Transmit the local file across the socket. */
+ numBytes = PR_TransmitFile(sslSocket, local_file_fd,
+ NULL, 0,
+ PR_TRANSMITFILE_KEEP_OPEN,
+ PR_INTERVAL_NO_TIMEOUT);
+ /* Error in transmission? */
+ if (numBytes < 0)
+ {
+ errWarn("PR_TransmitFile");
+ return SECFailure;
+ }
+
+#if DEBUG
+ /* Transmitted bytes successfully. */
+ fprintf(stderr, "PR_TransmitFile wrote %d bytes from %s\n",
+ numBytes, "stdin");
+#endif
+
+ PR_Close(local_file_fd);
+
+ /* read until EOF */
+ readBuffer = PORT_Alloc(READ_BUFFER_SIZE);
+ if (! readBuffer)
+ exitErr("PORT_Alloc");
+
+ local_file_fd = PR_Open(outfileName, PR_WRONLY | PR_CREATE_FILE | PR_TRUNCATE,
+ PR_IRUSR | PR_IWUSR | PR_IRGRP | PR_IWGRP | PR_IROTH);
+ if (local_file_fd == NULL)
+ {
+ fprintf (stderr, "could not open output file %s\n", outfileName);
+ return SECFailure;
+ }
+ while (PR_TRUE)
+ {
+ numBytes = PR_Read(sslSocket, readBuffer, READ_BUFFER_SIZE);
+ if (numBytes == 0)
+ break; /* EOF */
+
+ if (numBytes < 0)
+ {
+ errWarn("PR_Read");
+ break;
+ }
+#if DEBUG
+ countRead += numBytes;
+#endif
+ /* Write to output file */
+ numBytes = PR_Write(local_file_fd, readBuffer, numBytes);
+ if (numBytes < 0)
+ {
+ fprintf (stderr, "could not write to %s\n", outfileName);
+#if DEBUG
+ fprintf(stderr, "***** Connection read %d bytes (%d total).\n",
+ numBytes, countRead );
+ readBuffer[numBytes] = '\0';
+ fprintf(stderr, "************\n%s\n************\n", readBuffer);
+#endif
+ }
+ }
+
+ PR_Free(readBuffer);
+ PR_Close(local_file_fd);
+
+ /* Caller closes the socket. */
+#if DEBUG
+ fprintf(stderr, "***** Connection read %d bytes total.\n", countRead);
+#endif
+
+ return SECSuccess;
+}
+
+/* make the connection.
+*/
+static SECStatus
+do_connect(PRNetAddr *addr)
+{
+ PRFileDesc *sslSocket;
+ PRHostEnt hostEntry;
+ char buffer[PR_NETDB_BUF_SIZE];
+ PRStatus prStatus;
+ PRIntn hostenum;
+ SECStatus secStatus;
+
+ /* Set up SSL secure socket. */
+ sslSocket = setupSSLSocket();
+ if (sslSocket == NULL)
+ {
+ errWarn("setupSSLSocket");
+ return SECFailure;
+ }
+
+#if 0 /* no client authentication */
+ secStatus = SSL_SetPKCS11PinArg(sslSocket, password);
+ if (secStatus != SECSuccess)
+ {
+ errWarn("SSL_SetPKCS11PinArg");
+ return secStatus;
+ }
+#endif
+
+ secStatus = SSL_SetURL(sslSocket, hostName);
+ if (secStatus != SECSuccess)
+ {
+ errWarn("SSL_SetURL");
+ return secStatus;
+ }
+
+ /* Prepare and setup network connection. */
+ prStatus = PR_GetHostByName(hostName, buffer, sizeof(buffer), &hostEntry);
+ if (prStatus != PR_SUCCESS)
+ {
+ errWarn("PR_GetHostByName");
+ return SECFailure;
+ }
+
+ hostenum = PR_EnumerateHostEnt(0, &hostEntry, port, addr);
+ if (hostenum == -1)
+ {
+ errWarn("PR_EnumerateHostEnt");
+ return SECFailure;
+ }
+
+ prStatus = PR_Connect(sslSocket, addr, PR_INTERVAL_NO_TIMEOUT);
+ if (prStatus != PR_SUCCESS)
+ {
+ errWarn("PR_Connect");
+ return SECFailure;
+ }
+
+ /* Established SSL connection, ready to send data. */
+#if 0 /* Not necessary? */
+ secStatus = SSL_ForceHandshake(sslSocket);
+ if (secStatus != SECSuccess)
+ {
+ errWarn("SSL_ForceHandshake");
+ return secStatus;
+ }
+#endif
+
+ secStatus = SSL_ResetHandshake(sslSocket, /* asServer */ PR_FALSE);
+ if (secStatus != SECSuccess)
+ {
+ errWarn("SSL_ResetHandshake");
+ prStatus = PR_Close(sslSocket);
+ if (prStatus != PR_SUCCESS)
+ errWarn("PR_Close");
+ return secStatus;
+ }
+
+ secStatus = handle_connection(sslSocket);
+ if (secStatus != SECSuccess)
+ {
+ errWarn("handle_connection");
+ return secStatus;
+ }
+
+ PR_Close(sslSocket);
+ return SECSuccess;
+}
+
+static void
+client_main(unsigned short port, const char *hostName)
+{
+ SECStatus secStatus;
+ PRStatus prStatus;
+ PRInt32 rv;
+ PRNetAddr addr;
+ PRHostEnt hostEntry;
+ char buffer[PR_NETDB_BUF_SIZE];
+
+ /* Setup network connection. */
+ prStatus = PR_GetHostByName(hostName, buffer, sizeof (buffer), &hostEntry);
+ if (prStatus != PR_SUCCESS)
+ exitErr("PR_GetHostByName");
+
+ rv = PR_EnumerateHostEnt(0, &hostEntry, port, &addr);
+ if (rv < 0)
+ exitErr("PR_EnumerateHostEnt");
+
+ secStatus = do_connect (&addr);
+ if (secStatus != SECSuccess)
+ exitErr("do_connect");
+}
+
+#if 0 /* No client authorization */
+static char *
+myPasswd(PK11SlotInfo *info, PRBool retry, void *arg)
+{
+ char * passwd = NULL;
+
+ if ( (!retry) && arg )
+ passwd = PORT_Strdup((char *)arg);
+
+ return passwd;
+}
+#endif
+
+int
+main(int argc, char **argv)
+{
+ char * certDir = NULL;
+ char * progName = NULL;
+ SECStatus secStatus;
+ PLOptState *optstate;
+ PLOptStatus status;
+
+ /* Call the NSPR initialization routines */
+ PR_Init( PR_SYSTEM_THREAD, PR_PRIORITY_NORMAL, 1);
+
+ progName = PL_strdup(argv[0]);
+
+ hostName = NULL;
+ optstate = PL_CreateOptState(argc, argv, "d:h:i:o:p:");
+ while ((status = PL_GetNextOpt(optstate)) == PL_OPT_OK)
+ {
+ switch(optstate->option)
+ {
+ case 'd' : certDir = PL_strdup(optstate->value); break;
+ case 'h' : hostName = PL_strdup(optstate->value); break;
+ case 'i' : infileName = PL_strdup(optstate->value); break;
+ case 'o' : outfileName = PL_strdup(optstate->value); break;
+ case 'p' : port = PORT_Atoi(optstate->value); break;
+ case '?' :
+ default : Usage(progName);
+ }
+ }
+
+ if (port == 0 || hostName == NULL || infileName == NULL || outfileName == NULL || certDir == NULL)
+ Usage(progName);
+
+#if 0 /* no client authentication */
+ /* Set our password function callback. */
+ PK11_SetPasswordFunc(myPasswd);
+#endif
+
+ /* Initialize the NSS libraries. */
+ secStatus = NSS_Init(certDir);
+ if (secStatus != SECSuccess)
+ exitErr("NSS_Init");
+
+ /* All cipher suites except RSA_NULL_MD5 are enabled by Domestic Policy. */
+ NSS_SetDomesticPolicy();
+
+ client_main(port, hostName);
+
+ NSS_Shutdown();
+ PR_Cleanup();
+
+ return 0;
+}