diff options
Diffstat (limited to 'client/red_peer.cpp')
-rw-r--r-- | client/red_peer.cpp | 426 |
1 files changed, 426 insertions, 0 deletions
diff --git a/client/red_peer.cpp b/client/red_peer.cpp new file mode 100644 index 00000000..e20d5ca6 --- /dev/null +++ b/client/red_peer.cpp @@ -0,0 +1,426 @@ +/* + Copyright (C) 2009 Red Hat, Inc. + + This program is free software; you can redistribute it and/or + modify it under the terms of the GNU General Public License as + published by the Free Software Foundation; either version 2 of + the License, or (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +#include "common.h" +#ifdef _WIN32 +#include <winsock2.h> +#include <ws2tcpip.h> + +#define SHUT_RDWR SD_BOTH +#else +#include <sys/socket.h> +#include <netinet/in.h> +#include <arpa/inet.h> +#include <netinet/tcp.h> +#include <netdb.h> + +#define INVALID_SOCKET -1 +#define SOCKET_ERROR -1 +#define closesocket(sock) ::close(sock) +#endif +#include "red.h" +#include "red_peer.h" +#include "utils.h" +#include "debug.h" +#include "platform_utils.h" + +#ifdef _WIN32 + +int inet_aton(const char *ip, struct in_addr *in_addr) +{ + unsigned long addr = inet_addr(ip); + + if (addr == INADDR_NONE) { + return 0; + } + in_addr->S_un.S_addr = addr; + return 1; +} + +#define SHUTDOWN_ERR WSAESHUTDOWN +#define INTERRUPTED_ERR WSAEINTR +#define WOULDBLOCK_ERR WSAEWOULDBLOCK +#define sock_error() WSAGetLastError() +#define sock_err_message(err) sys_err_to_str(err) +#else +#define SHUTDOWN_ERR EPIPE +#define INTERRUPTED_ERR EINTR +#define WOULDBLOCK_ERR EAGAIN +#define sock_error() errno +#define sock_err_message(err) strerror(err) +#endif + +static void ssl_error() +{ + ERR_print_errors_fp(stderr); + THROW_ERR(SPICEC_ERROR_CODE_SSL_ERROR, "SSL Error"); +} + +RedPeer::RedPeer() + : _peer (INVALID_SOCKET) + , _shut (false) + , _ctx (NULL) + , _ssl (NULL) +{ +} + +RedPeer::~RedPeer() +{ + cleanup(); +} + +void RedPeer::cleanup() +{ + if (_ssl) { + SSL_free(_ssl); + _ssl = NULL; + } + + if (_ctx) { + SSL_CTX_free(_ctx); + _ctx = NULL; + } + + if (_peer != INVALID_SOCKET) { + closesocket(_peer); + _peer = INVALID_SOCKET; + } +} + +uint32_t RedPeer::host_by_name(const char* host) +{ + struct addrinfo *result = NULL; + struct sockaddr_in *addr; + uint32_t return_value; + int rc; + + rc = getaddrinfo(host, NULL, NULL, &result); + if (rc != 0 || result == NULL) { + THROW_ERR(SPICEC_ERROR_CODE_GETHOSTBYNAME_FAILED, "cannot resolve host address %s", host); + } + + addr = (sockaddr_in *)result->ai_addr; + return_value = addr->sin_addr.s_addr; + + freeaddrinfo(result); + + DBG(0, "%s = %u", host, return_value); + return ntohl(return_value); +} + +void RedPeer::connect_unsecure(uint32_t ip, int port) +{ + struct sockaddr_in addr; + int no_delay; + + ASSERT(_ctx == NULL && _ssl == NULL && _peer == INVALID_SOCKET); + try { + addr.sin_port = htons(port); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = htonl(ip); + + Lock lock(_lock); + if ((_peer = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP)) == INVALID_SOCKET) { + int err = sock_error(); + THROW_ERR(SPICEC_ERROR_CODE_SOCKET_FAILED, "failed to create socket: %s (%d)", + sock_err_message(err), err); + } + + no_delay = 1; + if (setsockopt(_peer, IPPROTO_TCP, TCP_NODELAY, (const char*)&no_delay, sizeof(no_delay)) == + SOCKET_ERROR) { + LOG_WARN("set TCP_NODELAY failed"); + } + + LOG_INFO("Connecting %s %d", inet_ntoa(addr.sin_addr), port); + lock.unlock(); + if (::connect(_peer, (struct sockaddr *)&addr, sizeof(sockaddr_in)) == SOCKET_ERROR) { + int err = sock_error(); + closesocket(_peer); + THROW_ERR(SPICEC_ERROR_CODE_CONNECT_FAILED, "failed to connect: %s (%d)", + sock_err_message(err), err); + } + _serial = 0; + } catch (...) { + Lock lock(_lock); + cleanup(); + throw; + } +} + +void RedPeer::connect_unsecure(const char* host, int port) +{ + connect_unsecure(host_by_name(host), port); +} + +// todo: use SSL_CTX_set_cipher_list, SSL_CTX_load_verify_location etc. +void RedPeer::connect_secure(const ConnectionOptions& options, uint32_t ip) +{ + connect_unsecure(ip, options.secure_port); + ASSERT(_ctx == NULL && _ssl == NULL && _peer != INVALID_SOCKET); + + try { + SSL_METHOD *ssl_method = TLSv1_method(); + + _ctx = SSL_CTX_new(ssl_method); + if (_ctx == NULL) { + ssl_error(); + } + + _ssl = SSL_new(_ctx); + if (!_ssl) { + THROW("create ssl failed"); + } + + BIO* sbio = BIO_new_socket(_peer, BIO_NOCLOSE); + if (!sbio) { + THROW("alloc new socket bio failed"); + } + + SSL_set_bio(_ssl, sbio, sbio); + + int return_code = SSL_connect(_ssl); + if (return_code <= 0) { + SSL_get_error(_ssl, return_code); + ssl_error(); + } + } catch (...) { + Lock lock(_lock); + cleanup(); + throw; + } +} + +void RedPeer::connect_secure(const ConnectionOptions& options, const char* host) +{ + connect_secure(options, host_by_name(host)); +} + +void RedPeer::shutdown() +{ + if (_peer != INVALID_SOCKET) { + if (_ssl) { + SSL_shutdown(_ssl); + } + ::shutdown(_peer, SHUT_RDWR); + } + _shut = true; +} + +void RedPeer::disconnect() +{ + Lock lock(_lock); + shutdown(); +} + +void RedPeer::close() +{ + Lock lock(_lock); + if (_peer != INVALID_SOCKET) { + if (_ctx) { + SSL_free(_ssl); + _ssl = NULL; + SSL_CTX_free(_ctx); + _ctx = NULL; + } + + closesocket(_peer); + _peer = INVALID_SOCKET; + } +} + +void RedPeer::swap(RedPeer* other) +{ + Lock lock(_lock); + SOCKET temp_peer = _peer; + SSL_CTX *temp_ctx = _ctx; + SSL *temp_ssl = _ssl; + + _peer = other->_peer; + other->_peer = temp_peer; + + if (_ctx) { + _ctx = other->_ctx; + _ssl = other->_ssl; + + other->_ctx = temp_ctx; + other->_ssl = temp_ssl; + } + + if (_shut) { + shutdown(); + } +} + +uint32_t RedPeer::recive(uint8_t *buf, uint32_t size) +{ + uint8_t *pos = buf; + while (size) { + int now; + if (_ctx == NULL) { + if ((now = recv(_peer, (char *)pos, size, 0)) <= 0) { + int err = sock_error(); + if (now == SOCKET_ERROR && err == WOULDBLOCK_ERR) { + break; + } + + if (now == 0 || err == SHUTDOWN_ERR) { + throw RedPeer::DisconnectedException(); + } + + if (err == INTERRUPTED_ERR) { + continue; + } + THROW_ERR(SPICEC_ERROR_CODE_RECV_FAILED, "%s (%d)", sock_err_message(err), err); + } + size -= now; + pos += now; + } else { + if ((now = SSL_read(_ssl, pos, size)) <= 0) { + int ssl_error = SSL_get_error(_ssl, now); + + if (ssl_error == SSL_ERROR_WANT_READ) { + break; + } + + if (ssl_error == SSL_ERROR_SYSCALL) { + int err = sock_error(); + if (now == -1) { + if (err == WOULDBLOCK_ERR) { + break; + } + if (err == INTERRUPTED_ERR) { + continue; + } + } + if (now == 0 || (now == -1 && err == SHUTDOWN_ERR)) { + throw RedPeer::DisconnectedException(); + } + THROW_ERR(SPICEC_ERROR_CODE_SEND_FAILED, "%s (%d)", sock_err_message(err), err); + } else if (ssl_error == SSL_ERROR_ZERO_RETURN) { + throw RedPeer::DisconnectedException(); + } + THROW_ERR(SPICEC_ERROR_CODE_RECV_FAILED, "ssl error %d", ssl_error); + } + size -= now; + pos += now; + } + } + return pos - buf; +} + +RedPeer::CompundInMessage* RedPeer::recive() +{ + RedDataHeader header; + std::auto_ptr<CompundInMessage> message; + + recive((uint8_t*)&header, sizeof(RedDataHeader)); + message.reset(new CompundInMessage(header.serial, header.type, header.size, header.sub_list)); + recive(message->data(), message->compund_size()); + return message.release(); +} + +uint32_t RedPeer::send(uint8_t *buf, uint32_t size) +{ + uint8_t *pos = buf; + while (size) { + int now; + + if (_ctx == NULL) { + if ((now = ::send(_peer, (char *)pos, size, 0)) == SOCKET_ERROR) { + int err = sock_error(); + if (err == WOULDBLOCK_ERR) { + break; + } + if (err == SHUTDOWN_ERR) { + throw RedPeer::DisconnectedException(); + } + if (err == INTERRUPTED_ERR) { + continue; + } + THROW_ERR(SPICEC_ERROR_CODE_SEND_FAILED, "%s (%d)", sock_err_message(err), err); + } + size -= now; + pos += now; + } else { + if ((now = SSL_write(_ssl, pos, size)) <= 0) { + int ssl_error = SSL_get_error(_ssl, now); + + if (ssl_error == SSL_ERROR_WANT_WRITE) { + break; + } + if (ssl_error == SSL_ERROR_SYSCALL) { + int err = sock_error(); + if (now == -1) { + if (err == WOULDBLOCK_ERR) { + break; + } + if (err == INTERRUPTED_ERR) { + continue; + } + } + if (now == 0 || (now == -1 && err == SHUTDOWN_ERR)) { + throw RedPeer::DisconnectedException(); + } + THROW_ERR(SPICEC_ERROR_CODE_SEND_FAILED, "%s (%d)", sock_err_message(err), err); + } else if (ssl_error == SSL_ERROR_ZERO_RETURN) { + throw RedPeer::DisconnectedException(); + } + THROW_ERR(SPICEC_ERROR_CODE_SEND_FAILED, "ssl error %d", ssl_error); + } + size -= now; + pos += now; + } + } + return pos - buf; +} + +uint32_t RedPeer::send(RedPeer::OutMessage& message) +{ + message.header().serial = ++_serial; + return send(message.base(), message.message_size()); +} + +RedPeer::OutMessage::OutMessage(uint32_t type, uint32_t size) + : _data (new uint8_t[size + sizeof(RedDataHeader)]) + , _size (size) +{ + header().type = type; + header().size = size; +} + +RedPeer::OutMessage::~OutMessage() +{ + delete[] _data; +} + +void RedPeer::OutMessage::resize(uint32_t size) +{ + if (size <= _size) { + header().size = size; + return; + } + uint32_t type = header().type; + delete[] _data; + _data = NULL; + _size = 0; + _data = new uint8_t[size + sizeof(RedDataHeader)]; + _size = size; + header().type = type; + header().size = size; +} + |