summaryrefslogtreecommitdiffstats
path: root/stap-client-connect.c
diff options
context:
space:
mode:
authorDave Brolley <brolley@redhat.com>2010-03-03 15:40:05 -0500
committerDave Brolley <brolley@redhat.com>2010-03-03 15:44:28 -0500
commit97a342faf30bf8cf7aac8d20ef52b066570074ef (patch)
tree0340d224f5fbf0e22690eb76a3cceb194b159ef7 /stap-client-connect.c
parenta2f05a98d7ed062f293f40f88fc1237448438b15 (diff)
downloadsystemtap-steved-97a342faf30bf8cf7aac8d20ef52b066570074ef.tar.gz
systemtap-steved-97a342faf30bf8cf7aac8d20ef52b066570074ef.tar.xz
systemtap-steved-97a342faf30bf8cf7aac8d20ef52b066570074ef.zip
PR 10331: Improved certificate management -- client side.
stap-client-connect.c: use SSL_BadCertHoook to provide an opportunity for the user to trust and/or import the server's certificate. stap-client: Reorganized so that newly trusted certificates can be used. Also does the actual prompting.
Diffstat (limited to 'stap-client-connect.c')
-rw-r--r--stap-client-connect.c271
1 files changed, 169 insertions, 102 deletions
diff --git a/stap-client-connect.c b/stap-client-connect.c
index 8d48225d..4a68b437 100644
--- a/stap-client-connect.c
+++ b/stap-client-connect.c
@@ -2,7 +2,7 @@
SSL client program that sets up a connection to a SSL server, transmits
the given input file and then writes the reply to the given output file.
- Copyright (C) 2008, 2009 Red Hat Inc.
+ Copyright (C) 2008-2010 Red Hat Inc.
This file is part of systemtap, and is free software. You can
redistribute it and/or modify it under the terms of the GNU General Public
@@ -20,11 +20,13 @@
*/
#include <stdio.h>
+#include <unistd.h>
#include <ssl.h>
#include <nspr.h>
#include <plgetopt.h>
#include <nss.h>
+#include <pk11pub.h>
#include <prerror.h>
#include <secerr.h>
#include <sslerr.h>
@@ -36,6 +38,12 @@ static char *hostName = NULL;
static unsigned short port = 0;
static const char *infileName = NULL;
static const char *outfileName = NULL;
+static char *certDir = NULL;
+static const char *trustNewServer_p = NULL;
+
+/* Exit error codes */
+#define GENERAL_ERROR 1
+#define CA_CERT_INVALID_ERROR 2
static void
Usage(const char *progName)
@@ -46,22 +54,94 @@ Usage(const char *progName)
}
static void
-errWarn(char *function)
+exitErr(const char* errorStr, int rc)
{
- fprintf(stderr, "Error in function %s: ", function);
+ fprintf (stderr, "%s: ", errorStr);
nssError();
+ /* Exit gracefully. */
+ /* ignoring return value of NSS_Shutdown. */
+ (void) NSS_Shutdown();
+ PR_Cleanup();
+ exit(rc);
}
+/* Add the server's certificate to our database of trusted servers. */
+static SECStatus
+trustNewServer (CERTCertificate *serverCert)
+{
+ SECStatus secStatus;
+ CERTCertTrust *trust = NULL;
+ PK11SlotInfo *slot;
-static void
-exitErr(char *function)
+ /* Import the certificate. */
+ slot = PK11_GetInternalKeySlot();;
+ secStatus = PK11_ImportCert(slot, serverCert, CK_INVALID_HANDLE, "stap-server", PR_FALSE);
+ if (secStatus != SECSuccess)
+ goto done;
+
+ /* Make it a trusted peer. */
+ trust = (CERTCertTrust *)PORT_ZAlloc(sizeof(CERTCertTrust));
+ if (! trust)
+ {
+ secStatus = SECFailure;
+ goto done;
+ }
+
+ secStatus = CERT_DecodeTrustString(trust, "P,P,P");
+ if (secStatus != SECSuccess)
+ goto done;
+
+ secStatus = CERT_ChangeCertTrust(CERT_GetDefaultCertDB(), serverCert, trust);
+ if (secStatus != SECSuccess)
+ goto done;
+
+done:
+ if (trust)
+ PORT_Free(trust);
+ return secStatus;
+}
+
+/* Called when the server certificate verification fails. This gives us
+ the chance to trust the server anyway and add the certificate to the
+ local database. */
+static SECStatus
+badCertHandler(void *arg, PRFileDesc *sslSocket)
{
- errWarn(function);
- /* Exit gracefully. */
- /* ignoring return value of NSS_Shutdown as code exits with 1*/
- (void) NSS_Shutdown();
- PR_Cleanup();
- exit(1);
+ SECStatus secStatus;
+ PRErrorCode errorNumber;
+ CERTCertificate *serverCert;
+
+ /* By default, don't trust the certificate. */
+ secStatus = SECFailure;
+
+ errorNumber = PR_GetError ();
+ if (errorNumber == SEC_ERROR_CA_CERT_INVALID)
+ {
+ /* The server's certificate is not trusted. Should we trust it? */
+ if (trustNewServer_p == NULL)
+ {
+ /* Don't trust the cert, but print information about it. */
+ SEC_PrintCertificateAndTrust(serverCert, "Certificate",
+ serverCert->trust);
+ return SECFailure; /* Do not trust this server */
+ }
+
+ /* Trust it for this session only? */
+ if (strcmp (trustNewServer_p, "session") == 0)
+ return SECSuccess;
+
+ /* Trust it permanently? */
+ if (strcmp (trustNewServer_p, "permanent") == 0)
+ {
+ /* The user wants to trust this server. Get the server's certificate so
+ and add it to our database. */
+ serverCert = SSL_PeerCertificate (sslSocket);
+ if (serverCert != NULL)
+ secStatus = trustNewServer (serverCert);
+ }
+ }
+
+ return secStatus;
}
static PRFileDesc *
@@ -75,9 +155,7 @@ setupSSLSocket(void)
tcpSocket = PR_NewTCPSocket();
if (tcpSocket == NULL)
- {
- errWarn("PR_NewTCPSocket");
- }
+ goto loser;
/* Make the socket blocking. */
socketOption.option = PR_SockOpt_Nonblocking;
@@ -85,33 +163,21 @@ setupSSLSocket(void)
prStatus = PR_SetSocketOption(tcpSocket, &socketOption);
if (prStatus != PR_SUCCESS)
- {
- errWarn("PR_SetSocketOption");
- goto loser;
- }
+ goto loser;
/* Import the socket into the SSL layer. */
sslSocket = SSL_ImportFD(NULL, tcpSocket);
if (!sslSocket)
- {
- errWarn("SSL_ImportFD");
- goto loser;
- }
+ goto loser;
/* Set configuration options. */
secStatus = SSL_OptionSet(sslSocket, SSL_SECURITY, PR_TRUE);
if (secStatus != SECSuccess)
- {
- errWarn("SSL_OptionSet:SSL_SECURITY");
- goto loser;
- }
+ goto loser;
secStatus = SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_CLIENT, PR_TRUE);
if (secStatus != SECSuccess)
- {
- errWarn("SSL_OptionSet:SSL_HANDSHAKE_AS_CLIENT");
- goto loser;
- }
+ goto loser;
/* Set SSL callback routines. */
#if 0 /* no client authentication */
@@ -119,43 +185,31 @@ setupSSLSocket(void)
(SSLGetClientAuthData)myGetClientAuthData,
(void *)certNickname);
if (secStatus != SECSuccess)
- {
- errWarn("SSL_GetClientAuthDataHook");
- goto loser;
- }
+ goto loser;
#endif
#if 0 /* Use the default */
secStatus = SSL_AuthCertificateHook(sslSocket,
(SSLAuthCertificate)myAuthCertificate,
(void *)CERT_GetDefaultCertDB());
if (secStatus != SECSuccess)
- {
- errWarn("SSL_AuthCertificateHook");
- goto loser;
- }
+ goto loser;
#endif
-#if 0 /* Use the default */
- secStatus = SSL_BadCertHook(sslSocket,
- (SSLBadCertHandler)myBadCertHandler, NULL);
+
+ secStatus = SSL_BadCertHook(sslSocket, (SSLBadCertHandler)badCertHandler, NULL);
if (secStatus != SECSuccess)
- {
- errWarn("SSL_BadCertHook");
- goto loser;
- }
-#endif
+ goto loser;
+
#if 0 /* No handshake callback */
secStatus = SSL_HandshakeCallback(sslSocket, myHandshakeCallback, NULL);
if (secStatus != SECSuccess)
- {
- errWarn("SSL_HandshakeCallback");
- goto loser;
- }
+ goto loser;
#endif
return sslSocket;
loser:
- PR_Close(tcpSocket);
+ if (tcpSocket)
+ PR_Close(tcpSocket);
return NULL;
}
@@ -171,6 +225,7 @@ handle_connection(PRFileDesc *sslSocket)
PRFileInfo info;
PRFileDesc *local_file_fd;
PRStatus prStatus;
+ SECStatus secStatus = SECSuccess;
/* read and send the data. */
/* Try to open the local file named.
@@ -195,10 +250,7 @@ handle_connection(PRFileDesc *sslSocket)
/* Send the file size first, so the server knows when it has the entire file. */
numBytes = PR_Write(sslSocket, & info.size, sizeof (info.size));
if (numBytes < 0)
- {
- errWarn("PR_Write");
- return SECFailure;
- }
+ return SECFailure;
/* Transmit the local file across the socket. */
numBytes = PR_TransmitFile(sslSocket, local_file_fd,
@@ -206,10 +258,7 @@ handle_connection(PRFileDesc *sslSocket)
PR_TRANSMITFILE_KEEP_OPEN,
PR_INTERVAL_NO_TIMEOUT);
if (numBytes < 0)
- {
- errWarn("PR_TransmitFile");
- return SECFailure;
- }
+ return SECFailure;
#if DEBUG
/* Transmitted bytes successfully. */
@@ -222,7 +271,7 @@ handle_connection(PRFileDesc *sslSocket)
/* read until EOF */
readBuffer = PORT_Alloc(READ_BUFFER_SIZE);
if (! readBuffer)
- exitErr("PORT_Alloc");
+ exitErr("Out of memory", GENERAL_ERROR);
local_file_fd = PR_Open(outfileName, PR_WRONLY | PR_CREATE_FILE | PR_TRUNCATE,
PR_IRUSR | PR_IWUSR | PR_IRGRP | PR_IWGRP | PR_IROTH);
@@ -239,7 +288,7 @@ handle_connection(PRFileDesc *sslSocket)
if (numBytes < 0)
{
- errWarn("PR_Read");
+ secStatus = SECFailure;
break;
}
#if DEBUG
@@ -250,6 +299,7 @@ handle_connection(PRFileDesc *sslSocket)
if (numBytes < 0)
{
fprintf (stderr, "could not write to %s\n", outfileName);
+ secStatus = SECFailure;
break;
}
#if DEBUG
@@ -268,7 +318,7 @@ handle_connection(PRFileDesc *sslSocket)
fprintf(stderr, "***** Connection read %d bytes total.\n", countRead);
#endif
- return SECSuccess;
+ return secStatus;
}
/* make the connection.
@@ -290,32 +340,23 @@ do_connect(PRNetAddr *addr)
/* Set up SSL secure socket. */
sslSocket = setupSSLSocket();
if (sslSocket == NULL)
- {
- errWarn("setupSSLSocket");
- return SECFailure;
- }
+ return SECFailure;
#if 0 /* no client authentication */
secStatus = SSL_SetPKCS11PinArg(sslSocket, password);
if (secStatus != SECSuccess)
- {
- errWarn("SSL_SetPKCS11PinArg");
- goto done;
- }
+ goto done;
#endif
secStatus = SSL_SetURL(sslSocket, hostName);
if (secStatus != SECSuccess)
- {
- errWarn("SSL_SetURL");
- goto done;
- }
-#if 0 /* Already done? */
+ goto done;
+
+#if 0 /* Already done */
/* Prepare and setup network connection. */
prStatus = PR_GetHostByName(hostName, buffer, sizeof(buffer), &hostEntry);
if (prStatus != PR_SUCCESS)
{
- errWarn("PR_GetHostByName");
secStatus = SECFailure;
goto done;
}
@@ -323,7 +364,6 @@ do_connect(PRNetAddr *addr)
hostenum = PR_EnumerateHostEnt(0, &hostEntry, port, addr);
if (hostenum == -1)
{
- errWarn("PR_EnumerateHostEnt");
secStatus = SECFailure;
goto done;
}
@@ -331,7 +371,6 @@ do_connect(PRNetAddr *addr)
prStatus = PR_Connect(sslSocket, addr, PR_INTERVAL_NO_TIMEOUT);
if (prStatus != PR_SUCCESS)
{
- errWarn("PR_Connect");
secStatus = SECFailure;
goto done;
}
@@ -339,32 +378,20 @@ do_connect(PRNetAddr *addr)
/* Established SSL connection, ready to send data. */
secStatus = SSL_ResetHandshake(sslSocket, /* asServer */ PR_FALSE);
if (secStatus != SECSuccess)
- {
- errWarn("SSL_ResetHandshake");
- goto done;
- }
+ goto done;
/* This is normally done automatically on the first I/O operation,
but doing it here catches any authentication problems early. */
secStatus = SSL_ForceHandshake(sslSocket);
if (secStatus != SECSuccess)
- {
- errWarn("SSL_ForceHandshake");
- goto done;
- }
+ goto done;
secStatus = handle_connection(sslSocket);
if (secStatus != SECSuccess)
- {
- errWarn("handle_connection");
- goto done;
- }
+ goto done;
done:
prStatus = PR_Close(sslSocket);
- if (prStatus != PR_SUCCESS)
- errWarn("PR_Close");
-
return secStatus;
}
@@ -376,20 +403,55 @@ client_main(unsigned short port)
PRInt32 rv;
PRNetAddr addr;
PRHostEnt hostEntry;
+ PRErrorCode errorNumber;
char buffer[PR_NETDB_BUF_SIZE];
+ int attempt;
+ int errCode = GENERAL_ERROR;
/* Setup network connection. */
prStatus = PR_GetHostByName(hostName, buffer, sizeof (buffer), &hostEntry);
if (prStatus != PR_SUCCESS)
- exitErr("PR_GetHostByName");
+ exitErr("Unable to resolve server host name", GENERAL_ERROR);
rv = PR_EnumerateHostEnt(0, &hostEntry, port, &addr);
if (rv < 0)
- exitErr("PR_EnumerateHostEnt");
+ exitErr("Unable to resolve server host address", GENERAL_ERROR);
- secStatus = do_connect (&addr);
- if (secStatus != SECSuccess)
- exitErr("do_connect");
+ /* Some errors (see below) represent a situation in which trying again
+ should succeed. However, don't try forever. */
+ for (attempt = 0; attempt < 5; ++attempt)
+ {
+ secStatus = do_connect (&addr);
+ if (secStatus == SECSuccess)
+ return;
+
+ errorNumber = PR_GetError ();
+ switch (errorNumber)
+ {
+ case PR_CONNECT_RESET_ERROR:
+ /* Server was not ready. */
+ sleep (1);
+ break; /* Try again */
+ case SEC_ERROR_EXPIRED_CERTIFICATE:
+ /* The server's certificate has expired. It should
+ generate a new certificate. Give the server a chance to recover
+ and try again. */
+ sleep (2);
+ break; /* Try again */
+ case SEC_ERROR_CA_CERT_INVALID:
+ /* The server's certificate is not trusted. The exit code must
+ reflect this. */
+ errCode = CA_CERT_INVALID_ERROR;
+ goto failed; /* break switch and loop */
+ default:
+ /* This error is fatal. */
+ goto failed; /* break switch and loop */
+ }
+ }
+
+ failed:
+ /* Unrecoverable error */
+ exitErr("Unable to connect to server", errCode);
}
#if 0 /* No client authorization */
@@ -408,7 +470,6 @@ myPasswd(PK11SlotInfo *info, PRBool retry, void *arg)
int
main(int argc, char **argv)
{
- char * certDir = NULL;
char * progName = NULL;
SECStatus secStatus;
PLOptState *optstate;
@@ -420,7 +481,7 @@ main(int argc, char **argv)
progName = PL_strdup(argv[0]);
hostName = NULL;
- optstate = PL_CreateOptState(argc, argv, "d:h:i:o:p:");
+ optstate = PL_CreateOptState(argc, argv, "d:h:i:o:p:t:");
while ((status = PL_GetNextOpt(optstate)) == PL_OPT_OK)
{
switch(optstate->option)
@@ -430,6 +491,7 @@ main(int argc, char **argv)
case 'i' : infileName = PL_strdup(optstate->value); break;
case 'o' : outfileName = PL_strdup(optstate->value); break;
case 'p' : port = PORT_Atoi(optstate->value); break;
+ case 't' : trustNewServer_p = PL_strdup(optstate->value); break;
case '?' :
default : Usage(progName);
}
@@ -444,9 +506,14 @@ main(int argc, char **argv)
#endif
/* Initialize the NSS libraries. */
- secStatus = NSS_Init(certDir);
+ secStatus = NSS_InitReadWrite(certDir);
if (secStatus != SECSuccess)
- exitErr("NSS_Init");
+ {
+ /* Try it again, readonly. */
+ secStatus = NSS_Init(certDir);
+ if (secStatus != SECSuccess)
+ exitErr("Error initializing NSS", GENERAL_ERROR);
+ }
/* All cipher suites except RSA_NULL_MD5 are enabled by Domestic Policy. */
NSS_SetDomesticPolicy();