From 64aa100f39dca60999028f83feb31983728ea4d4 Mon Sep 17 00:00:00 2001 From: Dave Brolley Date: Fri, 9 Jan 2009 15:11:04 -0500 Subject: New framework for creating/using certificate databases for client/server. --- stap-client-connect.c | 103 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 71 insertions(+), 32 deletions(-) (limited to 'stap-client-connect.c') diff --git a/stap-client-connect.c b/stap-client-connect.c index 29a8e18d..9466b566 100644 --- a/stap-client-connect.c +++ b/stap-client-connect.c @@ -2,7 +2,7 @@ SSL client program that sets up a connection to a SSL server, transmits the given input file and then writes the reply to the given output file. - Copyright (C) 2008 Red Hat Inc. + Copyright (C) 2008, 2009 Red Hat Inc. This file is part of systemtap, and is free software. You can redistribute it and/or modify it under the terms of the GNU General Public @@ -25,9 +25,10 @@ #include #include #include +#include +#include #define READ_BUFFER_SIZE (60 * 1024) - static char *hostName = NULL; static unsigned short port = 0; static const char *infileName = NULL; @@ -44,9 +45,42 @@ Usage(const char *progName) static void errWarn(char *function) { - PRErrorCode errorNumber = PR_GetError(); + PRErrorCode errorNumber; + PRInt32 errorTextLength; + PRInt32 rc; + char *errorText; + + errorNumber = PR_GetError(); + fprintf(stderr, "Error in function %s: %d: ", function, errorNumber); + + /* See if PR_GetErrorText can tell us what the error is. */ + if (errorNumber >= PR_NSPR_ERROR_BASE && errorNumber <= PR_MAX_ERROR) + { + errorTextLength = PR_GetErrorTextLength (); + if (errorTextLength != 0) { + errorText = PORT_Alloc(errorTextLength); + rc = PR_GetErrorText (errorText); + if (rc != 0) + fprintf (stderr, "%s\n", errorText); + PR_Free (errorText); + if (rc != 0) + return; + } + } - printf("Error in function %s: %d\n\n", function, errorNumber); + /* Otherwise handle common errors ourselves. */ + switch (errorNumber) + { + case SEC_ERROR_CA_CERT_INVALID: + fputs ("The issuer's certificate is invalid\n", stderr); + break; + case PR_CONNECT_RESET_ERROR: + fputs ("Connection reset by peer\n", stderr); + break; + default: + fputs ("Unknown error\n", stderr); + break; + } } static void @@ -190,10 +224,9 @@ handle_connection(PRFileDesc *sslSocket) /* Send the file size first, so the server knows when it has the entire file. */ numBytes = PR_Write(sslSocket, & info.size, sizeof (info.size)); - /* Error in transmission? */ if (numBytes < 0) { - errWarn("PR_TransmitFile"); + errWarn("PR_Write"); return SECFailure; } @@ -202,7 +235,6 @@ handle_connection(PRFileDesc *sslSocket) NULL, 0, PR_TRANSMITFILE_KEEP_OPEN, PR_INTERVAL_NO_TIMEOUT); - /* Error in transmission? */ if (numBytes < 0) { errWarn("PR_TransmitFile"); @@ -212,7 +244,7 @@ handle_connection(PRFileDesc *sslSocket) #if DEBUG /* Transmitted bytes successfully. */ fprintf(stderr, "PR_TransmitFile wrote %d bytes from %s\n", - numBytes, "stdin"); + numBytes, infileName); #endif PR_Close(local_file_fd); @@ -248,13 +280,14 @@ handle_connection(PRFileDesc *sslSocket) if (numBytes < 0) { fprintf (stderr, "could not write to %s\n", outfileName); + break; + } #if DEBUG - fprintf(stderr, "***** Connection read %d bytes (%d total).\n", - numBytes, countRead ); - readBuffer[numBytes] = '\0'; - fprintf(stderr, "************\n%s\n************\n", readBuffer); + fprintf(stderr, "***** Connection read %d bytes (%d total).\n", + numBytes, countRead ); + readBuffer[numBytes] = '\0'; + fprintf(stderr, "************\n%s\n************\n", readBuffer); #endif - } } PR_Free(readBuffer); @@ -280,6 +313,8 @@ do_connect(PRNetAddr *addr) PRIntn hostenum; SECStatus secStatus; + secStatus = SECSuccess; + /* Set up SSL secure socket. */ sslSocket = setupSSLSocket(); if (sslSocket == NULL) @@ -293,7 +328,7 @@ do_connect(PRNetAddr *addr) if (secStatus != SECSuccess) { errWarn("SSL_SetPKCS11PinArg"); - return secStatus; + goto done; } #endif @@ -301,7 +336,7 @@ do_connect(PRNetAddr *addr) if (secStatus != SECSuccess) { errWarn("SSL_SetURL"); - return secStatus; + goto done; } /* Prepare and setup network connection. */ @@ -309,52 +344,56 @@ do_connect(PRNetAddr *addr) if (prStatus != PR_SUCCESS) { errWarn("PR_GetHostByName"); - return SECFailure; + secStatus = SECFailure; + goto done; } hostenum = PR_EnumerateHostEnt(0, &hostEntry, port, addr); if (hostenum == -1) { errWarn("PR_EnumerateHostEnt"); - return SECFailure; + secStatus = SECFailure; + goto done; } prStatus = PR_Connect(sslSocket, addr, PR_INTERVAL_NO_TIMEOUT); if (prStatus != PR_SUCCESS) { errWarn("PR_Connect"); - return SECFailure; + secStatus = SECFailure; + goto done; } /* Established SSL connection, ready to send data. */ -#if 0 /* Not necessary? */ - secStatus = SSL_ForceHandshake(sslSocket); + secStatus = SSL_ResetHandshake(sslSocket, /* asServer */ PR_FALSE); if (secStatus != SECSuccess) { - errWarn("SSL_ForceHandshake"); - return secStatus; + errWarn("SSL_ResetHandshake"); + goto done; } -#endif - secStatus = SSL_ResetHandshake(sslSocket, /* asServer */ PR_FALSE); + /* This is normally done automatically on the first I/O operation, + but doing it here catches any authentication problems early. */ + secStatus = SSL_ForceHandshake(sslSocket); if (secStatus != SECSuccess) { - errWarn("SSL_ResetHandshake"); - prStatus = PR_Close(sslSocket); - if (prStatus != PR_SUCCESS) - errWarn("PR_Close"); - return secStatus; + errWarn("SSL_ForceHandshake"); + goto done; } secStatus = handle_connection(sslSocket); if (secStatus != SECSuccess) { errWarn("handle_connection"); - return secStatus; + goto done; } - PR_Close(sslSocket); - return SECSuccess; + done: + prStatus = PR_Close(sslSocket); + if (prStatus != PR_SUCCESS) + errWarn("PR_Close"); + + return secStatus; } static void -- cgit