summaryrefslogtreecommitdiffstats
path: root/client/red_peer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'client/red_peer.cpp')
-rw-r--r--client/red_peer.cpp426
1 files changed, 426 insertions, 0 deletions
diff --git a/client/red_peer.cpp b/client/red_peer.cpp
new file mode 100644
index 00000000..e20d5ca6
--- /dev/null
+++ b/client/red_peer.cpp
@@ -0,0 +1,426 @@
+/*
+ Copyright (C) 2009 Red Hat, Inc.
+
+ This program is free software; you can redistribute it and/or
+ modify it under the terms of the GNU General Public License as
+ published by the Free Software Foundation; either version 2 of
+ the License, or (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+#include "common.h"
+#ifdef _WIN32
+#include <winsock2.h>
+#include <ws2tcpip.h>
+
+#define SHUT_RDWR SD_BOTH
+#else
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <arpa/inet.h>
+#include <netinet/tcp.h>
+#include <netdb.h>
+
+#define INVALID_SOCKET -1
+#define SOCKET_ERROR -1
+#define closesocket(sock) ::close(sock)
+#endif
+#include "red.h"
+#include "red_peer.h"
+#include "utils.h"
+#include "debug.h"
+#include "platform_utils.h"
+
+#ifdef _WIN32
+
+int inet_aton(const char *ip, struct in_addr *in_addr)
+{
+ unsigned long addr = inet_addr(ip);
+
+ if (addr == INADDR_NONE) {
+ return 0;
+ }
+ in_addr->S_un.S_addr = addr;
+ return 1;
+}
+
+#define SHUTDOWN_ERR WSAESHUTDOWN
+#define INTERRUPTED_ERR WSAEINTR
+#define WOULDBLOCK_ERR WSAEWOULDBLOCK
+#define sock_error() WSAGetLastError()
+#define sock_err_message(err) sys_err_to_str(err)
+#else
+#define SHUTDOWN_ERR EPIPE
+#define INTERRUPTED_ERR EINTR
+#define WOULDBLOCK_ERR EAGAIN
+#define sock_error() errno
+#define sock_err_message(err) strerror(err)
+#endif
+
+static void ssl_error()
+{
+ ERR_print_errors_fp(stderr);
+ THROW_ERR(SPICEC_ERROR_CODE_SSL_ERROR, "SSL Error");
+}
+
+RedPeer::RedPeer()
+ : _peer (INVALID_SOCKET)
+ , _shut (false)
+ , _ctx (NULL)
+ , _ssl (NULL)
+{
+}
+
+RedPeer::~RedPeer()
+{
+ cleanup();
+}
+
+void RedPeer::cleanup()
+{
+ if (_ssl) {
+ SSL_free(_ssl);
+ _ssl = NULL;
+ }
+
+ if (_ctx) {
+ SSL_CTX_free(_ctx);
+ _ctx = NULL;
+ }
+
+ if (_peer != INVALID_SOCKET) {
+ closesocket(_peer);
+ _peer = INVALID_SOCKET;
+ }
+}
+
+uint32_t RedPeer::host_by_name(const char* host)
+{
+ struct addrinfo *result = NULL;
+ struct sockaddr_in *addr;
+ uint32_t return_value;
+ int rc;
+
+ rc = getaddrinfo(host, NULL, NULL, &result);
+ if (rc != 0 || result == NULL) {
+ THROW_ERR(SPICEC_ERROR_CODE_GETHOSTBYNAME_FAILED, "cannot resolve host address %s", host);
+ }
+
+ addr = (sockaddr_in *)result->ai_addr;
+ return_value = addr->sin_addr.s_addr;
+
+ freeaddrinfo(result);
+
+ DBG(0, "%s = %u", host, return_value);
+ return ntohl(return_value);
+}
+
+void RedPeer::connect_unsecure(uint32_t ip, int port)
+{
+ struct sockaddr_in addr;
+ int no_delay;
+
+ ASSERT(_ctx == NULL && _ssl == NULL && _peer == INVALID_SOCKET);
+ try {
+ addr.sin_port = htons(port);
+ addr.sin_family = AF_INET;
+ addr.sin_addr.s_addr = htonl(ip);
+
+ Lock lock(_lock);
+ if ((_peer = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP)) == INVALID_SOCKET) {
+ int err = sock_error();
+ THROW_ERR(SPICEC_ERROR_CODE_SOCKET_FAILED, "failed to create socket: %s (%d)",
+ sock_err_message(err), err);
+ }
+
+ no_delay = 1;
+ if (setsockopt(_peer, IPPROTO_TCP, TCP_NODELAY, (const char*)&no_delay, sizeof(no_delay)) ==
+ SOCKET_ERROR) {
+ LOG_WARN("set TCP_NODELAY failed");
+ }
+
+ LOG_INFO("Connecting %s %d", inet_ntoa(addr.sin_addr), port);
+ lock.unlock();
+ if (::connect(_peer, (struct sockaddr *)&addr, sizeof(sockaddr_in)) == SOCKET_ERROR) {
+ int err = sock_error();
+ closesocket(_peer);
+ THROW_ERR(SPICEC_ERROR_CODE_CONNECT_FAILED, "failed to connect: %s (%d)",
+ sock_err_message(err), err);
+ }
+ _serial = 0;
+ } catch (...) {
+ Lock lock(_lock);
+ cleanup();
+ throw;
+ }
+}
+
+void RedPeer::connect_unsecure(const char* host, int port)
+{
+ connect_unsecure(host_by_name(host), port);
+}
+
+// todo: use SSL_CTX_set_cipher_list, SSL_CTX_load_verify_location etc.
+void RedPeer::connect_secure(const ConnectionOptions& options, uint32_t ip)
+{
+ connect_unsecure(ip, options.secure_port);
+ ASSERT(_ctx == NULL && _ssl == NULL && _peer != INVALID_SOCKET);
+
+ try {
+ SSL_METHOD *ssl_method = TLSv1_method();
+
+ _ctx = SSL_CTX_new(ssl_method);
+ if (_ctx == NULL) {
+ ssl_error();
+ }
+
+ _ssl = SSL_new(_ctx);
+ if (!_ssl) {
+ THROW("create ssl failed");
+ }
+
+ BIO* sbio = BIO_new_socket(_peer, BIO_NOCLOSE);
+ if (!sbio) {
+ THROW("alloc new socket bio failed");
+ }
+
+ SSL_set_bio(_ssl, sbio, sbio);
+
+ int return_code = SSL_connect(_ssl);
+ if (return_code <= 0) {
+ SSL_get_error(_ssl, return_code);
+ ssl_error();
+ }
+ } catch (...) {
+ Lock lock(_lock);
+ cleanup();
+ throw;
+ }
+}
+
+void RedPeer::connect_secure(const ConnectionOptions& options, const char* host)
+{
+ connect_secure(options, host_by_name(host));
+}
+
+void RedPeer::shutdown()
+{
+ if (_peer != INVALID_SOCKET) {
+ if (_ssl) {
+ SSL_shutdown(_ssl);
+ }
+ ::shutdown(_peer, SHUT_RDWR);
+ }
+ _shut = true;
+}
+
+void RedPeer::disconnect()
+{
+ Lock lock(_lock);
+ shutdown();
+}
+
+void RedPeer::close()
+{
+ Lock lock(_lock);
+ if (_peer != INVALID_SOCKET) {
+ if (_ctx) {
+ SSL_free(_ssl);
+ _ssl = NULL;
+ SSL_CTX_free(_ctx);
+ _ctx = NULL;
+ }
+
+ closesocket(_peer);
+ _peer = INVALID_SOCKET;
+ }
+}
+
+void RedPeer::swap(RedPeer* other)
+{
+ Lock lock(_lock);
+ SOCKET temp_peer = _peer;
+ SSL_CTX *temp_ctx = _ctx;
+ SSL *temp_ssl = _ssl;
+
+ _peer = other->_peer;
+ other->_peer = temp_peer;
+
+ if (_ctx) {
+ _ctx = other->_ctx;
+ _ssl = other->_ssl;
+
+ other->_ctx = temp_ctx;
+ other->_ssl = temp_ssl;
+ }
+
+ if (_shut) {
+ shutdown();
+ }
+}
+
+uint32_t RedPeer::recive(uint8_t *buf, uint32_t size)
+{
+ uint8_t *pos = buf;
+ while (size) {
+ int now;
+ if (_ctx == NULL) {
+ if ((now = recv(_peer, (char *)pos, size, 0)) <= 0) {
+ int err = sock_error();
+ if (now == SOCKET_ERROR && err == WOULDBLOCK_ERR) {
+ break;
+ }
+
+ if (now == 0 || err == SHUTDOWN_ERR) {
+ throw RedPeer::DisconnectedException();
+ }
+
+ if (err == INTERRUPTED_ERR) {
+ continue;
+ }
+ THROW_ERR(SPICEC_ERROR_CODE_RECV_FAILED, "%s (%d)", sock_err_message(err), err);
+ }
+ size -= now;
+ pos += now;
+ } else {
+ if ((now = SSL_read(_ssl, pos, size)) <= 0) {
+ int ssl_error = SSL_get_error(_ssl, now);
+
+ if (ssl_error == SSL_ERROR_WANT_READ) {
+ break;
+ }
+
+ if (ssl_error == SSL_ERROR_SYSCALL) {
+ int err = sock_error();
+ if (now == -1) {
+ if (err == WOULDBLOCK_ERR) {
+ break;
+ }
+ if (err == INTERRUPTED_ERR) {
+ continue;
+ }
+ }
+ if (now == 0 || (now == -1 && err == SHUTDOWN_ERR)) {
+ throw RedPeer::DisconnectedException();
+ }
+ THROW_ERR(SPICEC_ERROR_CODE_SEND_FAILED, "%s (%d)", sock_err_message(err), err);
+ } else if (ssl_error == SSL_ERROR_ZERO_RETURN) {
+ throw RedPeer::DisconnectedException();
+ }
+ THROW_ERR(SPICEC_ERROR_CODE_RECV_FAILED, "ssl error %d", ssl_error);
+ }
+ size -= now;
+ pos += now;
+ }
+ }
+ return pos - buf;
+}
+
+RedPeer::CompundInMessage* RedPeer::recive()
+{
+ RedDataHeader header;
+ std::auto_ptr<CompundInMessage> message;
+
+ recive((uint8_t*)&header, sizeof(RedDataHeader));
+ message.reset(new CompundInMessage(header.serial, header.type, header.size, header.sub_list));
+ recive(message->data(), message->compund_size());
+ return message.release();
+}
+
+uint32_t RedPeer::send(uint8_t *buf, uint32_t size)
+{
+ uint8_t *pos = buf;
+ while (size) {
+ int now;
+
+ if (_ctx == NULL) {
+ if ((now = ::send(_peer, (char *)pos, size, 0)) == SOCKET_ERROR) {
+ int err = sock_error();
+ if (err == WOULDBLOCK_ERR) {
+ break;
+ }
+ if (err == SHUTDOWN_ERR) {
+ throw RedPeer::DisconnectedException();
+ }
+ if (err == INTERRUPTED_ERR) {
+ continue;
+ }
+ THROW_ERR(SPICEC_ERROR_CODE_SEND_FAILED, "%s (%d)", sock_err_message(err), err);
+ }
+ size -= now;
+ pos += now;
+ } else {
+ if ((now = SSL_write(_ssl, pos, size)) <= 0) {
+ int ssl_error = SSL_get_error(_ssl, now);
+
+ if (ssl_error == SSL_ERROR_WANT_WRITE) {
+ break;
+ }
+ if (ssl_error == SSL_ERROR_SYSCALL) {
+ int err = sock_error();
+ if (now == -1) {
+ if (err == WOULDBLOCK_ERR) {
+ break;
+ }
+ if (err == INTERRUPTED_ERR) {
+ continue;
+ }
+ }
+ if (now == 0 || (now == -1 && err == SHUTDOWN_ERR)) {
+ throw RedPeer::DisconnectedException();
+ }
+ THROW_ERR(SPICEC_ERROR_CODE_SEND_FAILED, "%s (%d)", sock_err_message(err), err);
+ } else if (ssl_error == SSL_ERROR_ZERO_RETURN) {
+ throw RedPeer::DisconnectedException();
+ }
+ THROW_ERR(SPICEC_ERROR_CODE_SEND_FAILED, "ssl error %d", ssl_error);
+ }
+ size -= now;
+ pos += now;
+ }
+ }
+ return pos - buf;
+}
+
+uint32_t RedPeer::send(RedPeer::OutMessage& message)
+{
+ message.header().serial = ++_serial;
+ return send(message.base(), message.message_size());
+}
+
+RedPeer::OutMessage::OutMessage(uint32_t type, uint32_t size)
+ : _data (new uint8_t[size + sizeof(RedDataHeader)])
+ , _size (size)
+{
+ header().type = type;
+ header().size = size;
+}
+
+RedPeer::OutMessage::~OutMessage()
+{
+ delete[] _data;
+}
+
+void RedPeer::OutMessage::resize(uint32_t size)
+{
+ if (size <= _size) {
+ header().size = size;
+ return;
+ }
+ uint32_t type = header().type;
+ delete[] _data;
+ _data = NULL;
+ _size = 0;
+ _data = new uint8_t[size + sizeof(RedDataHeader)];
+ _size = size;
+ header().type = type;
+ header().size = size;
+}
+