/* Copyright (C) 2009 Red Hat, Inc. This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . */ #include "common.h" #ifdef _WIN32 #include #include #define SHUT_RDWR SD_BOTH #else #include #include #include #include #include #define INVALID_SOCKET -1 #define SOCKET_ERROR -1 #define closesocket(sock) ::close(sock) #endif #include "red.h" #include "red_peer.h" #include "utils.h" #include "debug.h" #include "platform_utils.h" #ifdef _WIN32 int inet_aton(const char *ip, struct in_addr *in_addr) { unsigned long addr = inet_addr(ip); if (addr == INADDR_NONE) { return 0; } in_addr->S_un.S_addr = addr; return 1; } #define SHUTDOWN_ERR WSAESHUTDOWN #define INTERRUPTED_ERR WSAEINTR #define WOULDBLOCK_ERR WSAEWOULDBLOCK #define sock_error() WSAGetLastError() #define sock_err_message(err) sys_err_to_str(err) #else #define SHUTDOWN_ERR EPIPE #define INTERRUPTED_ERR EINTR #define WOULDBLOCK_ERR EAGAIN #define sock_error() errno #define sock_err_message(err) strerror(err) #endif static void ssl_error() { ERR_print_errors_fp(stderr); THROW_ERR(SPICEC_ERROR_CODE_SSL_ERROR, "SSL Error"); } RedPeer::RedPeer() : _peer (INVALID_SOCKET) , _shut (false) , _ctx (NULL) , _ssl (NULL) { } RedPeer::~RedPeer() { cleanup(); } void RedPeer::cleanup() { if (_ssl) { SSL_free(_ssl); _ssl = NULL; } if (_ctx) { SSL_CTX_free(_ctx); _ctx = NULL; } if (_peer != INVALID_SOCKET) { closesocket(_peer); _peer = INVALID_SOCKET; } } uint32_t RedPeer::host_by_name(const char* host) { struct addrinfo *result = NULL; struct sockaddr_in *addr; uint32_t return_value; int rc; rc = getaddrinfo(host, NULL, NULL, &result); if (rc != 0 || result == NULL) { THROW_ERR(SPICEC_ERROR_CODE_GETHOSTBYNAME_FAILED, "cannot resolve host address %s", host); } addr = (sockaddr_in *)result->ai_addr; return_value = addr->sin_addr.s_addr; freeaddrinfo(result); DBG(0, "%s = %u", host, return_value); return ntohl(return_value); } void RedPeer::connect_unsecure(uint32_t ip, int port) { struct sockaddr_in addr; int no_delay; ASSERT(_ctx == NULL && _ssl == NULL && _peer == INVALID_SOCKET); try { addr.sin_port = htons(port); addr.sin_family = AF_INET; addr.sin_addr.s_addr = htonl(ip); Lock lock(_lock); if ((_peer = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP)) == INVALID_SOCKET) { int err = sock_error(); THROW_ERR(SPICEC_ERROR_CODE_SOCKET_FAILED, "failed to create socket: %s (%d)", sock_err_message(err), err); } no_delay = 1; if (setsockopt(_peer, IPPROTO_TCP, TCP_NODELAY, (const char*)&no_delay, sizeof(no_delay)) == SOCKET_ERROR) { LOG_WARN("set TCP_NODELAY failed"); } LOG_INFO("Connecting %s %d", inet_ntoa(addr.sin_addr), port); lock.unlock(); if (::connect(_peer, (struct sockaddr *)&addr, sizeof(sockaddr_in)) == SOCKET_ERROR) { int err = sock_error(); closesocket(_peer); THROW_ERR(SPICEC_ERROR_CODE_CONNECT_FAILED, "failed to connect: %s (%d)", sock_err_message(err), err); } _serial = 0; } catch (...) { Lock lock(_lock); cleanup(); throw; } } void RedPeer::connect_unsecure(const char* host, int port) { connect_unsecure(host_by_name(host), port); } // todo: use SSL_CTX_set_cipher_list, SSL_CTX_load_verify_location etc. void RedPeer::connect_secure(const ConnectionOptions& options, uint32_t ip) { connect_unsecure(ip, options.secure_port); ASSERT(_ctx == NULL && _ssl == NULL && _peer != INVALID_SOCKET); try { SSL_METHOD *ssl_method = TLSv1_method(); _ctx = SSL_CTX_new(ssl_method); if (_ctx == NULL) { ssl_error(); } _ssl = SSL_new(_ctx); if (!_ssl) { THROW("create ssl failed"); } BIO* sbio = BIO_new_socket(_peer, BIO_NOCLOSE); if (!sbio) { THROW("alloc new socket bio failed"); } SSL_set_bio(_ssl, sbio, sbio); int return_code = SSL_connect(_ssl); if (return_code <= 0) { SSL_get_error(_ssl, return_code); ssl_error(); } } catch (...) { Lock lock(_lock); cleanup(); throw; } } void RedPeer::connect_secure(const ConnectionOptions& options, const char* host) { connect_secure(options, host_by_name(host)); } void RedPeer::shutdown() { if (_peer != INVALID_SOCKET) { if (_ssl) { SSL_shutdown(_ssl); } ::shutdown(_peer, SHUT_RDWR); } _shut = true; } void RedPeer::disconnect() { Lock lock(_lock); shutdown(); } void RedPeer::close() { Lock lock(_lock); if (_peer != INVALID_SOCKET) { if (_ctx) { SSL_free(_ssl); _ssl = NULL; SSL_CTX_free(_ctx); _ctx = NULL; } closesocket(_peer); _peer = INVALID_SOCKET; } } void RedPeer::swap(RedPeer* other) { Lock lock(_lock); SOCKET temp_peer = _peer; SSL_CTX *temp_ctx = _ctx; SSL *temp_ssl = _ssl; _peer = other->_peer; other->_peer = temp_peer; if (_ctx) { _ctx = other->_ctx; _ssl = other->_ssl; other->_ctx = temp_ctx; other->_ssl = temp_ssl; } if (_shut) { shutdown(); } } uint32_t RedPeer::recive(uint8_t *buf, uint32_t size) { uint8_t *pos = buf; while (size) { int now; if (_ctx == NULL) { if ((now = recv(_peer, (char *)pos, size, 0)) <= 0) { int err = sock_error(); if (now == SOCKET_ERROR && err == WOULDBLOCK_ERR) { break; } if (now == 0 || err == SHUTDOWN_ERR) { throw RedPeer::DisconnectedException(); } if (err == INTERRUPTED_ERR) { continue; } THROW_ERR(SPICEC_ERROR_CODE_RECV_FAILED, "%s (%d)", sock_err_message(err), err); } size -= now; pos += now; } else { if ((now = SSL_read(_ssl, pos, size)) <= 0) { int ssl_error = SSL_get_error(_ssl, now); if (ssl_error == SSL_ERROR_WANT_READ) { break; } if (ssl_error == SSL_ERROR_SYSCALL) { int err = sock_error(); if (now == -1) { if (err == WOULDBLOCK_ERR) { break; } if (err == INTERRUPTED_ERR) { continue; } } if (now == 0 || (now == -1 && err == SHUTDOWN_ERR)) { throw RedPeer::DisconnectedException(); } THROW_ERR(SPICEC_ERROR_CODE_SEND_FAILED, "%s (%d)", sock_err_message(err), err); } else if (ssl_error == SSL_ERROR_ZERO_RETURN) { throw RedPeer::DisconnectedException(); } THROW_ERR(SPICEC_ERROR_CODE_RECV_FAILED, "ssl error %d", ssl_error); } size -= now; pos += now; } } return pos - buf; } RedPeer::CompundInMessage* RedPeer::recive() { RedDataHeader header; std::auto_ptr message; recive((uint8_t*)&header, sizeof(RedDataHeader)); message.reset(new CompundInMessage(header.serial, header.type, header.size, header.sub_list)); recive(message->data(), message->compund_size()); return message.release(); } uint32_t RedPeer::send(uint8_t *buf, uint32_t size) { uint8_t *pos = buf; while (size) { int now; if (_ctx == NULL) { if ((now = ::send(_peer, (char *)pos, size, 0)) == SOCKET_ERROR) { int err = sock_error(); if (err == WOULDBLOCK_ERR) { break; } if (err == SHUTDOWN_ERR) { throw RedPeer::DisconnectedException(); } if (err == INTERRUPTED_ERR) { continue; } THROW_ERR(SPICEC_ERROR_CODE_SEND_FAILED, "%s (%d)", sock_err_message(err), err); } size -= now; pos += now; } else { if ((now = SSL_write(_ssl, pos, size)) <= 0) { int ssl_error = SSL_get_error(_ssl, now); if (ssl_error == SSL_ERROR_WANT_WRITE) { break; } if (ssl_error == SSL_ERROR_SYSCALL) { int err = sock_error(); if (now == -1) { if (err == WOULDBLOCK_ERR) { break; } if (err == INTERRUPTED_ERR) { continue; } } if (now == 0 || (now == -1 && err == SHUTDOWN_ERR)) { throw RedPeer::DisconnectedException(); } THROW_ERR(SPICEC_ERROR_CODE_SEND_FAILED, "%s (%d)", sock_err_message(err), err); } else if (ssl_error == SSL_ERROR_ZERO_RETURN) { throw RedPeer::DisconnectedException(); } THROW_ERR(SPICEC_ERROR_CODE_SEND_FAILED, "ssl error %d", ssl_error); } size -= now; pos += now; } } return pos - buf; } uint32_t RedPeer::send(RedPeer::OutMessage& message) { message.header().serial = ++_serial; return send(message.base(), message.message_size()); } RedPeer::OutMessage::OutMessage(uint32_t type, uint32_t size) : _data (new uint8_t[size + sizeof(RedDataHeader)]) , _size (size) { header().type = type; header().size = size; } RedPeer::OutMessage::~OutMessage() { delete[] _data; } void RedPeer::OutMessage::resize(uint32_t size) { if (size <= _size) { header().size = size; return; } uint32_t type = header().type; delete[] _data; _data = NULL; _size = 0; _data = new uint8_t[size + sizeof(RedDataHeader)]; _size = size; header().type = type; header().size = size; }