From 97a342faf30bf8cf7aac8d20ef52b066570074ef Mon Sep 17 00:00:00 2001 From: Dave Brolley Date: Wed, 3 Mar 2010 15:40:05 -0500 Subject: PR 10331: Improved certificate management -- client side. stap-client-connect.c: use SSL_BadCertHoook to provide an opportunity for the user to trust and/or import the server's certificate. stap-client: Reorganized so that newly trusted certificates can be used. Also does the actual prompting. --- stap-client-connect.c | 271 +++++++++++++++++++++++++++++++------------------- 1 file changed, 169 insertions(+), 102 deletions(-) (limited to 'stap-client-connect.c') diff --git a/stap-client-connect.c b/stap-client-connect.c index 8d48225d..4a68b437 100644 --- a/stap-client-connect.c +++ b/stap-client-connect.c @@ -2,7 +2,7 @@ SSL client program that sets up a connection to a SSL server, transmits the given input file and then writes the reply to the given output file. - Copyright (C) 2008, 2009 Red Hat Inc. + Copyright (C) 2008-2010 Red Hat Inc. This file is part of systemtap, and is free software. You can redistribute it and/or modify it under the terms of the GNU General Public @@ -20,11 +20,13 @@ */ #include +#include #include #include #include #include +#include #include #include #include @@ -36,6 +38,12 @@ static char *hostName = NULL; static unsigned short port = 0; static const char *infileName = NULL; static const char *outfileName = NULL; +static char *certDir = NULL; +static const char *trustNewServer_p = NULL; + +/* Exit error codes */ +#define GENERAL_ERROR 1 +#define CA_CERT_INVALID_ERROR 2 static void Usage(const char *progName) @@ -46,22 +54,94 @@ Usage(const char *progName) } static void -errWarn(char *function) +exitErr(const char* errorStr, int rc) { - fprintf(stderr, "Error in function %s: ", function); + fprintf (stderr, "%s: ", errorStr); nssError(); + /* Exit gracefully. */ + /* ignoring return value of NSS_Shutdown. */ + (void) NSS_Shutdown(); + PR_Cleanup(); + exit(rc); } +/* Add the server's certificate to our database of trusted servers. */ +static SECStatus +trustNewServer (CERTCertificate *serverCert) +{ + SECStatus secStatus; + CERTCertTrust *trust = NULL; + PK11SlotInfo *slot; -static void -exitErr(char *function) + /* Import the certificate. */ + slot = PK11_GetInternalKeySlot();; + secStatus = PK11_ImportCert(slot, serverCert, CK_INVALID_HANDLE, "stap-server", PR_FALSE); + if (secStatus != SECSuccess) + goto done; + + /* Make it a trusted peer. */ + trust = (CERTCertTrust *)PORT_ZAlloc(sizeof(CERTCertTrust)); + if (! trust) + { + secStatus = SECFailure; + goto done; + } + + secStatus = CERT_DecodeTrustString(trust, "P,P,P"); + if (secStatus != SECSuccess) + goto done; + + secStatus = CERT_ChangeCertTrust(CERT_GetDefaultCertDB(), serverCert, trust); + if (secStatus != SECSuccess) + goto done; + +done: + if (trust) + PORT_Free(trust); + return secStatus; +} + +/* Called when the server certificate verification fails. This gives us + the chance to trust the server anyway and add the certificate to the + local database. */ +static SECStatus +badCertHandler(void *arg, PRFileDesc *sslSocket) { - errWarn(function); - /* Exit gracefully. */ - /* ignoring return value of NSS_Shutdown as code exits with 1*/ - (void) NSS_Shutdown(); - PR_Cleanup(); - exit(1); + SECStatus secStatus; + PRErrorCode errorNumber; + CERTCertificate *serverCert; + + /* By default, don't trust the certificate. */ + secStatus = SECFailure; + + errorNumber = PR_GetError (); + if (errorNumber == SEC_ERROR_CA_CERT_INVALID) + { + /* The server's certificate is not trusted. Should we trust it? */ + if (trustNewServer_p == NULL) + { + /* Don't trust the cert, but print information about it. */ + SEC_PrintCertificateAndTrust(serverCert, "Certificate", + serverCert->trust); + return SECFailure; /* Do not trust this server */ + } + + /* Trust it for this session only? */ + if (strcmp (trustNewServer_p, "session") == 0) + return SECSuccess; + + /* Trust it permanently? */ + if (strcmp (trustNewServer_p, "permanent") == 0) + { + /* The user wants to trust this server. Get the server's certificate so + and add it to our database. */ + serverCert = SSL_PeerCertificate (sslSocket); + if (serverCert != NULL) + secStatus = trustNewServer (serverCert); + } + } + + return secStatus; } static PRFileDesc * @@ -75,9 +155,7 @@ setupSSLSocket(void) tcpSocket = PR_NewTCPSocket(); if (tcpSocket == NULL) - { - errWarn("PR_NewTCPSocket"); - } + goto loser; /* Make the socket blocking. */ socketOption.option = PR_SockOpt_Nonblocking; @@ -85,33 +163,21 @@ setupSSLSocket(void) prStatus = PR_SetSocketOption(tcpSocket, &socketOption); if (prStatus != PR_SUCCESS) - { - errWarn("PR_SetSocketOption"); - goto loser; - } + goto loser; /* Import the socket into the SSL layer. */ sslSocket = SSL_ImportFD(NULL, tcpSocket); if (!sslSocket) - { - errWarn("SSL_ImportFD"); - goto loser; - } + goto loser; /* Set configuration options. */ secStatus = SSL_OptionSet(sslSocket, SSL_SECURITY, PR_TRUE); if (secStatus != SECSuccess) - { - errWarn("SSL_OptionSet:SSL_SECURITY"); - goto loser; - } + goto loser; secStatus = SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_CLIENT, PR_TRUE); if (secStatus != SECSuccess) - { - errWarn("SSL_OptionSet:SSL_HANDSHAKE_AS_CLIENT"); - goto loser; - } + goto loser; /* Set SSL callback routines. */ #if 0 /* no client authentication */ @@ -119,43 +185,31 @@ setupSSLSocket(void) (SSLGetClientAuthData)myGetClientAuthData, (void *)certNickname); if (secStatus != SECSuccess) - { - errWarn("SSL_GetClientAuthDataHook"); - goto loser; - } + goto loser; #endif #if 0 /* Use the default */ secStatus = SSL_AuthCertificateHook(sslSocket, (SSLAuthCertificate)myAuthCertificate, (void *)CERT_GetDefaultCertDB()); if (secStatus != SECSuccess) - { - errWarn("SSL_AuthCertificateHook"); - goto loser; - } + goto loser; #endif -#if 0 /* Use the default */ - secStatus = SSL_BadCertHook(sslSocket, - (SSLBadCertHandler)myBadCertHandler, NULL); + + secStatus = SSL_BadCertHook(sslSocket, (SSLBadCertHandler)badCertHandler, NULL); if (secStatus != SECSuccess) - { - errWarn("SSL_BadCertHook"); - goto loser; - } -#endif + goto loser; + #if 0 /* No handshake callback */ secStatus = SSL_HandshakeCallback(sslSocket, myHandshakeCallback, NULL); if (secStatus != SECSuccess) - { - errWarn("SSL_HandshakeCallback"); - goto loser; - } + goto loser; #endif return sslSocket; loser: - PR_Close(tcpSocket); + if (tcpSocket) + PR_Close(tcpSocket); return NULL; } @@ -171,6 +225,7 @@ handle_connection(PRFileDesc *sslSocket) PRFileInfo info; PRFileDesc *local_file_fd; PRStatus prStatus; + SECStatus secStatus = SECSuccess; /* read and send the data. */ /* Try to open the local file named. @@ -195,10 +250,7 @@ handle_connection(PRFileDesc *sslSocket) /* Send the file size first, so the server knows when it has the entire file. */ numBytes = PR_Write(sslSocket, & info.size, sizeof (info.size)); if (numBytes < 0) - { - errWarn("PR_Write"); - return SECFailure; - } + return SECFailure; /* Transmit the local file across the socket. */ numBytes = PR_TransmitFile(sslSocket, local_file_fd, @@ -206,10 +258,7 @@ handle_connection(PRFileDesc *sslSocket) PR_TRANSMITFILE_KEEP_OPEN, PR_INTERVAL_NO_TIMEOUT); if (numBytes < 0) - { - errWarn("PR_TransmitFile"); - return SECFailure; - } + return SECFailure; #if DEBUG /* Transmitted bytes successfully. */ @@ -222,7 +271,7 @@ handle_connection(PRFileDesc *sslSocket) /* read until EOF */ readBuffer = PORT_Alloc(READ_BUFFER_SIZE); if (! readBuffer) - exitErr("PORT_Alloc"); + exitErr("Out of memory", GENERAL_ERROR); local_file_fd = PR_Open(outfileName, PR_WRONLY | PR_CREATE_FILE | PR_TRUNCATE, PR_IRUSR | PR_IWUSR | PR_IRGRP | PR_IWGRP | PR_IROTH); @@ -239,7 +288,7 @@ handle_connection(PRFileDesc *sslSocket) if (numBytes < 0) { - errWarn("PR_Read"); + secStatus = SECFailure; break; } #if DEBUG @@ -250,6 +299,7 @@ handle_connection(PRFileDesc *sslSocket) if (numBytes < 0) { fprintf (stderr, "could not write to %s\n", outfileName); + secStatus = SECFailure; break; } #if DEBUG @@ -268,7 +318,7 @@ handle_connection(PRFileDesc *sslSocket) fprintf(stderr, "***** Connection read %d bytes total.\n", countRead); #endif - return SECSuccess; + return secStatus; } /* make the connection. @@ -290,32 +340,23 @@ do_connect(PRNetAddr *addr) /* Set up SSL secure socket. */ sslSocket = setupSSLSocket(); if (sslSocket == NULL) - { - errWarn("setupSSLSocket"); - return SECFailure; - } + return SECFailure; #if 0 /* no client authentication */ secStatus = SSL_SetPKCS11PinArg(sslSocket, password); if (secStatus != SECSuccess) - { - errWarn("SSL_SetPKCS11PinArg"); - goto done; - } + goto done; #endif secStatus = SSL_SetURL(sslSocket, hostName); if (secStatus != SECSuccess) - { - errWarn("SSL_SetURL"); - goto done; - } -#if 0 /* Already done? */ + goto done; + +#if 0 /* Already done */ /* Prepare and setup network connection. */ prStatus = PR_GetHostByName(hostName, buffer, sizeof(buffer), &hostEntry); if (prStatus != PR_SUCCESS) { - errWarn("PR_GetHostByName"); secStatus = SECFailure; goto done; } @@ -323,7 +364,6 @@ do_connect(PRNetAddr *addr) hostenum = PR_EnumerateHostEnt(0, &hostEntry, port, addr); if (hostenum == -1) { - errWarn("PR_EnumerateHostEnt"); secStatus = SECFailure; goto done; } @@ -331,7 +371,6 @@ do_connect(PRNetAddr *addr) prStatus = PR_Connect(sslSocket, addr, PR_INTERVAL_NO_TIMEOUT); if (prStatus != PR_SUCCESS) { - errWarn("PR_Connect"); secStatus = SECFailure; goto done; } @@ -339,32 +378,20 @@ do_connect(PRNetAddr *addr) /* Established SSL connection, ready to send data. */ secStatus = SSL_ResetHandshake(sslSocket, /* asServer */ PR_FALSE); if (secStatus != SECSuccess) - { - errWarn("SSL_ResetHandshake"); - goto done; - } + goto done; /* This is normally done automatically on the first I/O operation, but doing it here catches any authentication problems early. */ secStatus = SSL_ForceHandshake(sslSocket); if (secStatus != SECSuccess) - { - errWarn("SSL_ForceHandshake"); - goto done; - } + goto done; secStatus = handle_connection(sslSocket); if (secStatus != SECSuccess) - { - errWarn("handle_connection"); - goto done; - } + goto done; done: prStatus = PR_Close(sslSocket); - if (prStatus != PR_SUCCESS) - errWarn("PR_Close"); - return secStatus; } @@ -376,20 +403,55 @@ client_main(unsigned short port) PRInt32 rv; PRNetAddr addr; PRHostEnt hostEntry; + PRErrorCode errorNumber; char buffer[PR_NETDB_BUF_SIZE]; + int attempt; + int errCode = GENERAL_ERROR; /* Setup network connection. */ prStatus = PR_GetHostByName(hostName, buffer, sizeof (buffer), &hostEntry); if (prStatus != PR_SUCCESS) - exitErr("PR_GetHostByName"); + exitErr("Unable to resolve server host name", GENERAL_ERROR); rv = PR_EnumerateHostEnt(0, &hostEntry, port, &addr); if (rv < 0) - exitErr("PR_EnumerateHostEnt"); + exitErr("Unable to resolve server host address", GENERAL_ERROR); - secStatus = do_connect (&addr); - if (secStatus != SECSuccess) - exitErr("do_connect"); + /* Some errors (see below) represent a situation in which trying again + should succeed. However, don't try forever. */ + for (attempt = 0; attempt < 5; ++attempt) + { + secStatus = do_connect (&addr); + if (secStatus == SECSuccess) + return; + + errorNumber = PR_GetError (); + switch (errorNumber) + { + case PR_CONNECT_RESET_ERROR: + /* Server was not ready. */ + sleep (1); + break; /* Try again */ + case SEC_ERROR_EXPIRED_CERTIFICATE: + /* The server's certificate has expired. It should + generate a new certificate. Give the server a chance to recover + and try again. */ + sleep (2); + break; /* Try again */ + case SEC_ERROR_CA_CERT_INVALID: + /* The server's certificate is not trusted. The exit code must + reflect this. */ + errCode = CA_CERT_INVALID_ERROR; + goto failed; /* break switch and loop */ + default: + /* This error is fatal. */ + goto failed; /* break switch and loop */ + } + } + + failed: + /* Unrecoverable error */ + exitErr("Unable to connect to server", errCode); } #if 0 /* No client authorization */ @@ -408,7 +470,6 @@ myPasswd(PK11SlotInfo *info, PRBool retry, void *arg) int main(int argc, char **argv) { - char * certDir = NULL; char * progName = NULL; SECStatus secStatus; PLOptState *optstate; @@ -420,7 +481,7 @@ main(int argc, char **argv) progName = PL_strdup(argv[0]); hostName = NULL; - optstate = PL_CreateOptState(argc, argv, "d:h:i:o:p:"); + optstate = PL_CreateOptState(argc, argv, "d:h:i:o:p:t:"); while ((status = PL_GetNextOpt(optstate)) == PL_OPT_OK) { switch(optstate->option) @@ -430,6 +491,7 @@ main(int argc, char **argv) case 'i' : infileName = PL_strdup(optstate->value); break; case 'o' : outfileName = PL_strdup(optstate->value); break; case 'p' : port = PORT_Atoi(optstate->value); break; + case 't' : trustNewServer_p = PL_strdup(optstate->value); break; case '?' : default : Usage(progName); } @@ -444,9 +506,14 @@ main(int argc, char **argv) #endif /* Initialize the NSS libraries. */ - secStatus = NSS_Init(certDir); + secStatus = NSS_InitReadWrite(certDir); if (secStatus != SECSuccess) - exitErr("NSS_Init"); + { + /* Try it again, readonly. */ + secStatus = NSS_Init(certDir); + if (secStatus != SECSuccess) + exitErr("Error initializing NSS", GENERAL_ERROR); + } /* All cipher suites except RSA_NULL_MD5 are enabled by Domestic Policy. */ NSS_SetDomesticPolicy(); -- cgit