summaryrefslogtreecommitdiffstats
path: root/client/red_channel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'client/red_channel.cpp')
-rw-r--r--client/red_channel.cpp714
1 files changed, 714 insertions, 0 deletions
diff --git a/client/red_channel.cpp b/client/red_channel.cpp
new file mode 100644
index 00000000..4c6f1f8f
--- /dev/null
+++ b/client/red_channel.cpp
@@ -0,0 +1,714 @@
+/*
+ Copyright (C) 2009 Red Hat, Inc.
+
+ This program is free software; you can redistribute it and/or
+ modify it under the terms of the GNU General Public License as
+ published by the Free Software Foundation; either version 2 of
+ the License, or (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+#include "common.h"
+#include "red_channel.h"
+#include "red_client.h"
+#include "application.h"
+#include "debug.h"
+#include "utils.h"
+
+#include "openssl/rsa.h"
+#include "openssl/evp.h"
+#include "openssl/x509.h"
+
+
+RedChannelBase::RedChannelBase(uint8_t type, uint8_t id, const ChannelCaps& common_caps,
+ const ChannelCaps& caps)
+ : RedPeer()
+ , _type (type)
+ , _id (id)
+ , _common_caps (common_caps)
+ , _caps (caps)
+{
+}
+
+RedChannelBase::~RedChannelBase()
+{
+}
+
+void RedChannelBase::link(uint32_t connection_id, const std::string& password)
+{
+ RedLinkHeader header;
+ RedLinkMess link_mess;
+ RedLinkReply* reply;
+ uint32_t link_res;
+ uint32_t i;
+
+ EVP_PKEY *pubkey;
+ int nRSASize;
+ BIO *bioKey;
+ RSA *rsa;
+
+ header.magic = RED_MAGIC;
+ header.size = sizeof(link_mess);
+ header.major_version = RED_VERSION_MAJOR;
+ header.minor_version = RED_VERSION_MINOR;
+ link_mess.connection_id = connection_id;
+ link_mess.channel_type = _type;
+ link_mess.channel_id = _id;
+ link_mess.num_common_caps = get_common_caps().size();
+ link_mess.num_channel_caps = get_caps().size();
+ link_mess.caps_offset = sizeof(link_mess);
+ header.size += (link_mess.num_common_caps + link_mess.num_channel_caps) * sizeof(uint32_t);
+ send((uint8_t*)&header, sizeof(header));
+ send((uint8_t*)&link_mess, sizeof(link_mess));
+
+ for (i = 0; i < _common_caps.size(); i++) {
+ send((uint8_t*)&_common_caps[i], sizeof(uint32_t));
+ }
+
+ for (i = 0; i < _caps.size(); i++) {
+ send((uint8_t*)&_caps[i], sizeof(uint32_t));
+ }
+
+ recive((uint8_t*)&header, sizeof(header));
+
+ if (header.magic != RED_MAGIC) {
+ THROW_ERR(SPICEC_ERROR_CODE_CONNECT_FAILED, "bad magic");
+ }
+
+ if (header.major_version != RED_VERSION_MAJOR) {
+ THROW_ERR(SPICEC_ERROR_CODE_VERSION_MISMATCH,
+ "version mismatch: expect %u got %u",
+ RED_VERSION_MAJOR,
+ header.major_version);
+ }
+
+ AutoArray<uint8_t> reply_buf(new uint8_t[header.size]);
+ recive(reply_buf.get(), header.size);
+
+ reply = (RedLinkReply *)reply_buf.get();
+
+ if (reply->error != RED_ERR_OK) {
+ THROW_ERR(SPICEC_ERROR_CODE_CONNECT_FAILED, "connect error %u", reply->error);
+ }
+
+ uint32_t num_caps = reply->num_channel_caps + reply->num_common_caps;
+ if ((uint8_t *)(reply + 1) > reply_buf.get() + header.size ||
+ (uint8_t *)reply + reply->caps_offset + num_caps * sizeof(uint32_t) >
+ reply_buf.get() + header.size) {
+ THROW_ERR(SPICEC_ERROR_CODE_CONNECT_FAILED, "access violation");
+ }
+
+ uint32_t *caps = (uint32_t *)((uint8_t *)reply + reply->caps_offset);
+
+ _remote_common_caps.clear();
+ for (i = 0; i < reply->num_common_caps; i++, caps++) {
+ _remote_common_caps.resize(i + 1);
+ _remote_common_caps[i] = *caps;
+ }
+
+ _remote_caps.clear();
+ for (i = 0; i < reply->num_channel_caps; i++, caps++) {
+ _remote_caps.resize(i + 1);
+ _remote_caps[i] = *caps;
+ }
+
+ bioKey = BIO_new(BIO_s_mem());
+ if (bioKey != NULL) {
+ BIO_write(bioKey, reply->pub_key, RED_TICKET_PUBKEY_BYTES);
+ pubkey = d2i_PUBKEY_bio(bioKey, NULL);
+ rsa = pubkey->pkey.rsa;
+ nRSASize = RSA_size(rsa);
+ AutoArray<unsigned char> bufEncrypted(new unsigned char[nRSASize]);
+
+ /*
+ The use of RSA encryption limit the potential maximum password length.
+ for RSA_PKCS1_OAEP_PADDING it is RSA_size(rsa) - 41.
+ */
+ if (RSA_public_encrypt(password.length() + 1, (unsigned char *)password.c_str(),
+ (uint8_t *)bufEncrypted.get(),
+ rsa, RSA_PKCS1_OAEP_PADDING) > 0 ) {
+ send((uint8_t*)bufEncrypted.get(), nRSASize);
+ } else {
+ THROW("could not encrypt password");
+ }
+
+ memset(bufEncrypted.get(), 0, nRSASize);
+ } else {
+ THROW("Could not initiate BIO");
+ }
+
+ BIO_free(bioKey);
+
+ recive((uint8_t*)&link_res, sizeof(link_res));
+ if (link_res != RED_ERR_OK) {
+ int error_code = (link_res == RED_ERR_PERMISSION_DENIED) ?
+ SPICEC_ERROR_CODE_CONNECT_FAILED : SPICEC_ERROR_CODE_CONNECT_FAILED;
+ THROW_ERR(error_code, "connect failed %u", link_res);
+ }
+}
+
+void RedChannelBase::connect(const ConnectionOptions& options, uint32_t connection_id,
+ uint32_t ip, std::string password)
+{
+ if (options.allow_unsecure()) {
+ try {
+ RedPeer::connect_unsecure(ip, options.unsecure_port);
+ link(connection_id, password);
+ return;
+ } catch (...) {
+ if (!options.allow_secure()) {
+ throw;
+ }
+ RedPeer::close();
+ }
+ }
+ ASSERT(options.allow_secure());
+ RedPeer::connect_secure(options, ip);
+ link(connection_id, password);
+}
+
+void RedChannelBase::connect(const ConnectionOptions& options, uint32_t connection_id,
+ const char* host, std::string password)
+{
+ connect(options, connection_id, host_by_name(host), password);
+}
+
+void RedChannelBase::set_capability(ChannelCaps& caps, uint32_t cap)
+{
+ uint32_t word_index = cap / 32;
+
+ if (caps.size() < word_index + 1) {
+ caps.resize(word_index + 1);
+ }
+ caps[word_index] |= 1 << (cap % 32);
+}
+
+void RedChannelBase::set_common_capability(uint32_t cap)
+{
+ set_capability(_common_caps, cap);
+}
+
+void RedChannelBase::set_capability(uint32_t cap)
+{
+ set_capability(_caps, cap);
+}
+
+bool RedChannelBase::test_capability(const ChannelCaps& caps, uint32_t cap)
+{
+ uint32_t word_index = cap / 32;
+
+ if (caps.size() < word_index + 1) {
+ return false;
+ }
+
+ return (caps[word_index] & (1 << (cap % 32))) != 0;
+}
+
+bool RedChannelBase::test_common_capability(uint32_t cap)
+{
+ return test_capability(_remote_common_caps, cap);
+}
+
+bool RedChannelBase::test_capability(uint32_t cap)
+{
+ return test_capability(_remote_caps, cap);
+}
+
+SendTrigger::SendTrigger(RedChannel& channel)
+ : _channel (channel)
+{
+}
+
+void SendTrigger::on_event()
+{
+ _channel.on_send_trigger();
+}
+
+void AbortTrigger::on_event()
+{
+ THROW("abort");
+}
+
+RedChannel::RedChannel(RedClient& client, uint8_t type, uint8_t id,
+ RedChannel::MessageHandler* handler,
+ Platform::ThreadPriority worker_priority)
+ : RedChannelBase(type, id, ChannelCaps(), ChannelCaps())
+ , _client (client)
+ , _state (PASSIVE_STATE)
+ , _action (WAIT_ACTION)
+ , _error (SPICEC_ERROR_CODE_SUCCESS)
+ , _wait_for_threads (true)
+ , _socket_in_loop (false)
+ , _worker (NULL)
+ , _worker_priority (worker_priority)
+ , _message_handler (handler)
+ , _outgoing_message (NULL)
+ , _incomming_header_pos (0)
+ , _incomming_message (NULL)
+ , _message_ack_count (0)
+ , _message_ack_window (0)
+ , _send_trigger (*this)
+ , _disconnect_stamp (0)
+ , _disconnect_reason (RED_ERR_OK)
+{
+ _loop.add_trigger(_send_trigger);
+ _loop.add_trigger(_abort_trigger);
+}
+
+RedChannel::~RedChannel()
+{
+ ASSERT(_state == TERMINATED_STATE || _state == PASSIVE_STATE);
+ delete _worker;
+}
+
+void* RedChannel::worker_main(void *data)
+{
+ try {
+ RedChannel* channel = static_cast<RedChannel*>(data);
+ channel->set_state(DISCONNECTED_STATE);
+ Platform::set_thread_priority(NULL, channel->get_worker_priority());
+ channel->run();
+ } catch (Exception& e) {
+ LOG_ERROR("unhandle exception: %s", e.what());
+ } catch (std::exception& e) {
+ LOG_ERROR("unhandle exception: %s", e.what());
+ } catch (...) {
+ LOG_ERROR("unhandled exception");
+ }
+ return NULL;
+}
+
+void RedChannel::post_message(RedChannel::OutMessage* message)
+{
+ Lock lock(_outgoing_lock);
+ _outgoing_messages.push_back(message);
+ lock.unlock();
+ _send_trigger.trigger();
+}
+
+RedPeer::CompundInMessage *RedChannel::recive()
+{
+ CompundInMessage *message = RedChannelBase::recive();
+ on_message_recived();
+ return message;
+}
+
+RedChannel::OutMessage* RedChannel::get_outgoing_message()
+{
+ if (_state != CONNECTED_STATE || _outgoing_messages.empty()) {
+ return NULL;
+ }
+ RedChannel::OutMessage* message = _outgoing_messages.front();
+ _outgoing_messages.pop_front();
+ return message;
+}
+
+class AutoMessage {
+public:
+ AutoMessage(RedChannel::OutMessage* message) : _message (message) {}
+ ~AutoMessage() {if (_message) _message->release();}
+ void set(RedChannel::OutMessage* message) { _message = message;}
+ RedChannel::OutMessage* get() { return _message;}
+ RedChannel::OutMessage* release();
+
+private:
+ RedChannel::OutMessage* _message;
+};
+
+RedChannel::OutMessage* AutoMessage::release()
+{
+ RedChannel::OutMessage* ret = _message;
+ _message = NULL;
+ return ret;
+}
+
+void RedChannel::start()
+{
+ ASSERT(!_worker);
+ _worker = new Thread(RedChannel::worker_main, this);
+ Lock lock(_state_lock);
+ while (_state == PASSIVE_STATE) {
+ _state_cond.wait(lock);
+ }
+}
+
+void RedChannel::set_state(int state)
+{
+ Lock lock(_state_lock);
+ _state = state;
+ _state_cond.notify_all();
+}
+
+void RedChannel::connect()
+{
+ Lock lock(_action_lock);
+
+ if (_state != DISCONNECTED_STATE && _state != PASSIVE_STATE) {
+ return;
+ }
+ _action = CONNECT_ACTION;
+ _action_cond.notify_one();
+}
+
+void RedChannel::disconnect()
+{
+ clear_outgoing_messages();
+
+ Lock lock(_action_lock);
+ if (_state != CONNECTING_STATE && _state != CONNECTED_STATE) {
+ return;
+ }
+ _action = DISCONNECT_ACTION;
+ RedPeer::disconnect();
+ _action_cond.notify_one();
+}
+
+void RedChannel::clear_outgoing_messages()
+{
+ Lock lock(_outgoing_lock);
+ while (!_outgoing_messages.empty()) {
+ RedChannel::OutMessage* message = _outgoing_messages.front();
+ _outgoing_messages.pop_front();
+ message->release();
+ }
+}
+
+void RedChannel::run()
+{
+ for (;;) {
+ Lock lock(_action_lock);
+ if (_action == WAIT_ACTION) {
+ _action_cond.wait(lock);
+ }
+ int action = _action;
+ _action = WAIT_ACTION;
+ lock.unlock();
+ switch (action) {
+ case CONNECT_ACTION:
+ try {
+ get_client().get_sync_info(get_type(), get_id(), _sync_info);
+ on_connecting();
+ set_state(CONNECTING_STATE);
+ ConnectionOptions con_options(_client.get_connection_options(get_type()),
+ _client.get_port(),
+ _client.get_sport());
+ RedChannelBase::connect(con_options, _client.get_connection_id(),
+ _client.get_host(), _client.get_password());
+ on_connect();
+ set_state(CONNECTED_STATE);
+ _loop.add_socket(*this);
+ _socket_in_loop = true;
+ on_event();
+ _loop.run();
+ } catch (RedPeer::DisconnectedException&) {
+ _error = SPICEC_ERROR_CODE_SUCCESS;
+ } catch (Exception& e) {
+ LOG_WARN("%s", e.what());
+ _error = e.get_error_code();
+ } catch (std::exception& e) {
+ LOG_WARN("%s", e.what());
+ _error = SPICEC_ERROR_CODE_ERROR;
+ }
+ if (_socket_in_loop) {
+ _socket_in_loop = false;
+ _loop.remove_socket(*this);
+ }
+ if (_outgoing_message) {
+ _outgoing_message->release();
+ _outgoing_message = NULL;
+ }
+ _incomming_header_pos = 0;
+ delete _incomming_message;
+ _incomming_message = NULL;
+ case DISCONNECT_ACTION:
+ close();
+ on_disconnect();
+ set_state(DISCONNECTED_STATE);
+ _client.on_channel_disconnected(*this);
+ continue;
+ case QUIT_ACTION:
+ set_state(TERMINATED_STATE);
+ return;
+ }
+ }
+}
+
+bool RedChannel::abort()
+{
+ clear_outgoing_messages();
+ Lock lock(_action_lock);
+ if (_state == TERMINATED_STATE) {
+ if (_wait_for_threads) {
+ _wait_for_threads = false;
+ _worker->join();
+ }
+ return true;
+ }
+
+ _action = QUIT_ACTION;
+ _action_cond.notify_one();
+ lock.unlock();
+ RedPeer::disconnect();
+ _abort_trigger.trigger();
+
+ for (;;) {
+ Lock state_lock(_state_lock);
+ if (_state == TERMINATED_STATE) {
+ break;
+ }
+ uint64_t timout = 1000 * 1000 * 100; // 100ms
+ if (!_state_cond.timed_wait(state_lock, timout)) {
+ return false;
+ }
+ }
+ if (_wait_for_threads) {
+ _wait_for_threads = false;
+ _worker->join();
+ }
+ return true;
+}
+
+void RedChannel::send_messages()
+{
+ if (_outgoing_message) {
+ return;
+ }
+
+ for (;;) {
+ Lock lock(_outgoing_lock);
+ AutoMessage message(get_outgoing_message());
+ if (!message.get()) {
+ return;
+ }
+ RedPeer::OutMessage& peer_message = message.get()->peer_message();
+ uint32_t n = send(peer_message);
+ if (n != peer_message.message_size()) {
+ _outgoing_message = message.release();
+ _outgoing_pos = n;
+ return;
+ }
+ }
+}
+
+void RedChannel::on_send_trigger()
+{
+ send_messages();
+}
+
+void RedChannel::on_message_recived()
+{
+ if (_message_ack_count && !--_message_ack_count) {
+ post_message(new Message(REDC_ACK, 0));
+ _message_ack_count = _message_ack_window;
+ }
+}
+
+void RedChannel::on_message_complition(uint64_t serial)
+{
+ Lock lock(*_sync_info.lock);
+ *_sync_info.message_serial = serial;
+ _sync_info.condition->notify_all();
+}
+
+void RedChannel::recive_messages()
+{
+ for (;;) {
+ uint32_t n = RedPeer::recive((uint8_t*)&_incomming_header, sizeof(RedDataHeader));
+ if (n != sizeof(RedDataHeader)) {
+ _incomming_header_pos = n;
+ return;
+ }
+ std::auto_ptr<CompundInMessage> message(new CompundInMessage(_incomming_header.serial,
+ _incomming_header.type,
+ _incomming_header.size,
+ _incomming_header.sub_list));
+ n = RedPeer::recive(message->data(), message->compund_size());
+ if (n != message->compund_size()) {
+ _incomming_message = message.release();
+ _incomming_message_pos = n;
+ return;
+ }
+ on_message_recived();
+ _message_handler->handle_message(*message.get());
+ on_message_complition(message->serial());
+ }
+}
+
+void RedChannel::on_event()
+{
+ if (_outgoing_message) {
+ RedPeer::OutMessage& peer_message = _outgoing_message->peer_message();
+ _outgoing_pos += send(peer_message.base() + _outgoing_pos,
+ peer_message.message_size() - _outgoing_pos);
+ if (_outgoing_pos == peer_message.message_size()) {
+ _outgoing_message->release();
+ _outgoing_message = NULL;
+ }
+ }
+ send_messages();
+
+ if (_incomming_header_pos) {
+ _incomming_header_pos += RedPeer::recive(((uint8_t*)&_incomming_header) +
+ _incomming_header_pos,
+ sizeof(RedDataHeader) - _incomming_header_pos);
+ if (_incomming_header_pos != sizeof(RedDataHeader)) {
+ return;
+ }
+ _incomming_header_pos = 0;
+ _incomming_message = new CompundInMessage(_incomming_header.serial,
+ _incomming_header.type,
+ _incomming_header.size,
+ _incomming_header.sub_list);
+ _incomming_message_pos = 0;
+ }
+
+ if (_incomming_message) {
+ _incomming_message_pos += RedPeer::recive(_incomming_message->data() +
+ _incomming_message_pos,
+ _incomming_message->compund_size() -
+ _incomming_message_pos);
+ if (_incomming_message_pos != _incomming_message->compund_size()) {
+ return;
+ }
+ std::auto_ptr<CompundInMessage> message(_incomming_message);
+ _incomming_message = NULL;
+ on_message_recived();
+ _message_handler->handle_message(*message.get());
+ on_message_complition(message->serial());
+ }
+ recive_messages();
+}
+
+void RedChannel::send_migrate_flush_mark()
+{
+ if (_outgoing_message) {
+ RedPeer::OutMessage& peer_message = _outgoing_message->peer_message();
+ send(peer_message.base() + _outgoing_pos, peer_message.message_size() - _outgoing_pos);
+ _outgoing_message->release();
+ _outgoing_message = NULL;
+ }
+ Lock lock(_outgoing_lock);
+ for (;;) {
+ AutoMessage message(get_outgoing_message());
+ if (!message.get()) {
+ break;
+ }
+ send(message.get()->peer_message());
+ }
+ lock.unlock();
+ std::auto_ptr<RedPeer::OutMessage> message(new RedPeer::OutMessage(REDC_MIGRATE_FLUSH_MARK, 0));
+ send(*message);
+}
+
+void RedChannel::handle_migrate(RedPeer::InMessage* message)
+{
+ DBG(0, "channel type %u id %u", get_type(), get_id());
+ _socket_in_loop = false;
+ _loop.remove_socket(*this);
+ RedMigrate* migrate = (RedMigrate*)message->data();
+ if (migrate->flags & RED_MIGRATE_NEED_FLUSH) {
+ send_migrate_flush_mark();
+ }
+ std::auto_ptr<RedPeer::CompundInMessage> data_message;
+ if (migrate->flags & RED_MIGRATE_NEED_DATA_TRANSFER) {
+ data_message.reset(recive());
+ }
+ _client.migrate_channel(*this);
+ if (migrate->flags & RED_MIGRATE_NEED_DATA_TRANSFER) {
+ if (data_message->type() != RED_MIGRATE_DATA) {
+ THROW("expect RED_MIGRATE_DATA");
+ }
+ std::auto_ptr<RedPeer::OutMessage> message(new RedPeer::OutMessage(REDC_MIGRATE_DATA,
+ data_message->size()));
+ memcpy(message->data(), data_message->data(), data_message->size());
+ send(*message);
+ }
+ _loop.add_socket(*this);
+ _socket_in_loop = true;
+ on_migrate();
+ set_state(CONNECTED_STATE);
+ on_event();
+}
+
+void RedChannel::handle_set_ack(RedPeer::InMessage* message)
+{
+ RedSetAck* ack = (RedSetAck*)message->data();
+ _message_ack_window = _message_ack_count = ack->window;
+ Message *responce = new Message(REDC_ACK_SYNC, sizeof(uint32_t));
+ *(uint32_t *)responce->data() = ack->generation;
+ post_message(responce);
+}
+
+void RedChannel::handle_ping(RedPeer::InMessage* message)
+{
+ RedPing *ping = (RedPing *)message->data();
+ Message *pong = new Message(REDC_PONG, sizeof(RedPing));
+ *(RedPing *)pong->data() = *ping;
+ post_message(pong);
+}
+
+void RedChannel::handle_disconnect(RedPeer::InMessage* message)
+{
+ RedDisconnect *disconnect = (RedDisconnect *)message->data();
+ _disconnect_stamp = disconnect->time_stamp;
+ _disconnect_reason = disconnect->reason;
+}
+
+void RedChannel::handle_notify(RedPeer::InMessage* message)
+{
+ RedNotify *notify = (RedNotify *)message->data();
+ const char *sevirity;
+ const char *visibility;
+ const char *message_str = "";
+ const char *message_prefix = "";
+
+ static const char* sevirity_strings[] = {"info", "warn", "error"};
+ static const char* visibility_strings[] = {"!", "!!", "!!!"};
+
+
+ if (notify->severty > RED_NOTIFY_SEVERITY_ERROR) {
+ THROW("bad sevirity");
+ }
+ sevirity = sevirity_strings[notify->severty];
+
+ if (notify->visibilty > RED_NOTIFY_VISIBILITY_HIGH) {
+ THROW("bad visibilty");
+ }
+ visibility = visibility_strings[notify->visibilty];
+
+
+ if (notify->message_len) {
+ if ((message->size() - sizeof(*notify) < notify->message_len + 1)) {
+ THROW("access violation");
+ }
+ message_str = (char *)(notify + 1);
+ if (message_str[notify->message_len] != 0) {
+ THROW("invalid message");
+ }
+ message_prefix = ": ";
+ }
+
+
+ LOG_INFO("remote channel %u:%u %s%s #%u%s%s",
+ get_type(), get_id(),
+ sevirity, visibility,
+ notify->what,
+ message_prefix, message_str);
+}
+
+void RedChannel::handle_wait_for_channels(RedPeer::InMessage* message)
+{
+ RedWaitForChannels *wait = (RedWaitForChannels *)message->data();
+ if (message->size() < sizeof(*wait) + wait->wait_count * sizeof(wait->wait_list[0])) {
+ THROW("access violation");
+ }
+ _client.wait_for_channels(wait->wait_count, wait->wait_list);
+}
+