/* Copyright (C) 2009 Red Hat, Inc. This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . */ #include "common.h" #include "red_channel.h" #include "red_client.h" #include "application.h" #include "debug.h" #include "utils.h" #include "openssl/rsa.h" #include "openssl/evp.h" #include "openssl/x509.h" RedChannelBase::RedChannelBase(uint8_t type, uint8_t id, const ChannelCaps& common_caps, const ChannelCaps& caps) : RedPeer() , _type (type) , _id (id) , _common_caps (common_caps) , _caps (caps) { } RedChannelBase::~RedChannelBase() { } void RedChannelBase::link(uint32_t connection_id, const std::string& password) { RedLinkHeader header; RedLinkMess link_mess; RedLinkReply* reply; uint32_t link_res; uint32_t i; EVP_PKEY *pubkey; int nRSASize; BIO *bioKey; RSA *rsa; header.magic = RED_MAGIC; header.size = sizeof(link_mess); header.major_version = RED_VERSION_MAJOR; header.minor_version = RED_VERSION_MINOR; link_mess.connection_id = connection_id; link_mess.channel_type = _type; link_mess.channel_id = _id; link_mess.num_common_caps = get_common_caps().size(); link_mess.num_channel_caps = get_caps().size(); link_mess.caps_offset = sizeof(link_mess); header.size += (link_mess.num_common_caps + link_mess.num_channel_caps) * sizeof(uint32_t); send((uint8_t*)&header, sizeof(header)); send((uint8_t*)&link_mess, sizeof(link_mess)); for (i = 0; i < _common_caps.size(); i++) { send((uint8_t*)&_common_caps[i], sizeof(uint32_t)); } for (i = 0; i < _caps.size(); i++) { send((uint8_t*)&_caps[i], sizeof(uint32_t)); } recive((uint8_t*)&header, sizeof(header)); if (header.magic != RED_MAGIC) { THROW_ERR(SPICEC_ERROR_CODE_CONNECT_FAILED, "bad magic"); } if (header.major_version != RED_VERSION_MAJOR) { THROW_ERR(SPICEC_ERROR_CODE_VERSION_MISMATCH, "version mismatch: expect %u got %u", RED_VERSION_MAJOR, header.major_version); } _remote_minor = header.minor_version; AutoArray reply_buf(new uint8_t[header.size]); recive(reply_buf.get(), header.size); reply = (RedLinkReply *)reply_buf.get(); if (reply->error != RED_ERR_OK) { THROW_ERR(SPICEC_ERROR_CODE_CONNECT_FAILED, "connect error %u", reply->error); } uint32_t num_caps = reply->num_channel_caps + reply->num_common_caps; if ((uint8_t *)(reply + 1) > reply_buf.get() + header.size || (uint8_t *)reply + reply->caps_offset + num_caps * sizeof(uint32_t) > reply_buf.get() + header.size) { THROW_ERR(SPICEC_ERROR_CODE_CONNECT_FAILED, "access violation"); } uint32_t *caps = (uint32_t *)((uint8_t *)reply + reply->caps_offset); _remote_common_caps.clear(); for (i = 0; i < reply->num_common_caps; i++, caps++) { _remote_common_caps.resize(i + 1); _remote_common_caps[i] = *caps; } _remote_caps.clear(); for (i = 0; i < reply->num_channel_caps; i++, caps++) { _remote_caps.resize(i + 1); _remote_caps[i] = *caps; } bioKey = BIO_new(BIO_s_mem()); if (bioKey != NULL) { BIO_write(bioKey, reply->pub_key, RED_TICKET_PUBKEY_BYTES); pubkey = d2i_PUBKEY_bio(bioKey, NULL); rsa = pubkey->pkey.rsa; nRSASize = RSA_size(rsa); AutoArray bufEncrypted(new unsigned char[nRSASize]); /* The use of RSA encryption limit the potential maximum password length. for RSA_PKCS1_OAEP_PADDING it is RSA_size(rsa) - 41. */ if (RSA_public_encrypt(password.length() + 1, (unsigned char *)password.c_str(), (uint8_t *)bufEncrypted.get(), rsa, RSA_PKCS1_OAEP_PADDING) > 0 ) { send((uint8_t*)bufEncrypted.get(), nRSASize); } else { THROW("could not encrypt password"); } memset(bufEncrypted.get(), 0, nRSASize); } else { THROW("Could not initiate BIO"); } BIO_free(bioKey); recive((uint8_t*)&link_res, sizeof(link_res)); if (link_res != RED_ERR_OK) { int error_code = (link_res == RED_ERR_PERMISSION_DENIED) ? SPICEC_ERROR_CODE_CONNECT_FAILED : SPICEC_ERROR_CODE_CONNECT_FAILED; THROW_ERR(error_code, "connect failed %u", link_res); } } void RedChannelBase::connect(const ConnectionOptions& options, uint32_t connection_id, const char* host, std::string password) { if (options.allow_unsecure()) { try { RedPeer::connect_unsecure(host, options.unsecure_port); link(connection_id, password); return; } catch (...) { if (!options.allow_secure()) { throw; } RedPeer::close(); } } ASSERT(options.allow_secure()); RedPeer::connect_secure(options, host); link(connection_id, password); } void RedChannelBase::set_capability(ChannelCaps& caps, uint32_t cap) { uint32_t word_index = cap / 32; if (caps.size() < word_index + 1) { caps.resize(word_index + 1); } caps[word_index] |= 1 << (cap % 32); } void RedChannelBase::set_common_capability(uint32_t cap) { set_capability(_common_caps, cap); } void RedChannelBase::set_capability(uint32_t cap) { set_capability(_caps, cap); } bool RedChannelBase::test_capability(const ChannelCaps& caps, uint32_t cap) { uint32_t word_index = cap / 32; if (caps.size() < word_index + 1) { return false; } return (caps[word_index] & (1 << (cap % 32))) != 0; } bool RedChannelBase::test_common_capability(uint32_t cap) { return test_capability(_remote_common_caps, cap); } bool RedChannelBase::test_capability(uint32_t cap) { return test_capability(_remote_caps, cap); } SendTrigger::SendTrigger(RedChannel& channel) : _channel (channel) { } void SendTrigger::on_event() { _channel.on_send_trigger(); } void AbortTrigger::on_event() { THROW("abort"); } RedChannel::RedChannel(RedClient& client, uint8_t type, uint8_t id, RedChannel::MessageHandler* handler, Platform::ThreadPriority worker_priority) : RedChannelBase(type, id, ChannelCaps(), ChannelCaps()) , _client (client) , _state (PASSIVE_STATE) , _action (WAIT_ACTION) , _error (SPICEC_ERROR_CODE_SUCCESS) , _wait_for_threads (true) , _socket_in_loop (false) , _worker (NULL) , _worker_priority (worker_priority) , _message_handler (handler) , _outgoing_message (NULL) , _incomming_header_pos (0) , _incomming_message (NULL) , _message_ack_count (0) , _message_ack_window (0) , _loop (this) , _send_trigger (*this) , _disconnect_stamp (0) , _disconnect_reason (RED_ERR_OK) { _loop.add_trigger(_send_trigger); _loop.add_trigger(_abort_trigger); } RedChannel::~RedChannel() { ASSERT(_state == TERMINATED_STATE || _state == PASSIVE_STATE); delete _worker; } void* RedChannel::worker_main(void *data) { try { RedChannel* channel = static_cast(data); channel->set_state(DISCONNECTED_STATE); Platform::set_thread_priority(NULL, channel->get_worker_priority()); channel->run(); } catch (Exception& e) { LOG_ERROR("unhandle exception: %s", e.what()); } catch (std::exception& e) { LOG_ERROR("unhandle exception: %s", e.what()); } catch (...) { LOG_ERROR("unhandled exception"); } return NULL; } void RedChannel::post_message(RedChannel::OutMessage* message) { Lock lock(_outgoing_lock); _outgoing_messages.push_back(message); lock.unlock(); _send_trigger.trigger(); } RedPeer::CompundInMessage *RedChannel::recive() { CompundInMessage *message = RedChannelBase::recive(); on_message_recived(); return message; } RedChannel::OutMessage* RedChannel::get_outgoing_message() { if (_state != CONNECTED_STATE || _outgoing_messages.empty()) { return NULL; } RedChannel::OutMessage* message = _outgoing_messages.front(); _outgoing_messages.pop_front(); return message; } class AutoMessage { public: AutoMessage(RedChannel::OutMessage* message) : _message (message) {} ~AutoMessage() {if (_message) _message->release();} void set(RedChannel::OutMessage* message) { _message = message;} RedChannel::OutMessage* get() { return _message;} RedChannel::OutMessage* release(); private: RedChannel::OutMessage* _message; }; RedChannel::OutMessage* AutoMessage::release() { RedChannel::OutMessage* ret = _message; _message = NULL; return ret; } void RedChannel::start() { ASSERT(!_worker); _worker = new Thread(RedChannel::worker_main, this); Lock lock(_state_lock); while (_state == PASSIVE_STATE) { _state_cond.wait(lock); } } void RedChannel::set_state(int state) { Lock lock(_state_lock); _state = state; _state_cond.notify_all(); } void RedChannel::connect() { Lock lock(_action_lock); if (_state != DISCONNECTED_STATE && _state != PASSIVE_STATE) { return; } _action = CONNECT_ACTION; _action_cond.notify_one(); } void RedChannel::disconnect() { clear_outgoing_messages(); Lock lock(_action_lock); if (_state != CONNECTING_STATE && _state != CONNECTED_STATE) { return; } _action = DISCONNECT_ACTION; RedPeer::disconnect(); _action_cond.notify_one(); } void RedChannel::clear_outgoing_messages() { Lock lock(_outgoing_lock); while (!_outgoing_messages.empty()) { RedChannel::OutMessage* message = _outgoing_messages.front(); _outgoing_messages.pop_front(); message->release(); } } void RedChannel::run() { for (;;) { Lock lock(_action_lock); if (_action == WAIT_ACTION) { _action_cond.wait(lock); } int action = _action; _action = WAIT_ACTION; lock.unlock(); switch (action) { case CONNECT_ACTION: try { get_client().get_sync_info(get_type(), get_id(), _sync_info); on_connecting(); set_state(CONNECTING_STATE); ConnectionOptions con_options(_client.get_connection_options(get_type()), _client.get_port(), _client.get_sport(), _client.get_host_auth_options(), _client.get_connection_ciphers()); RedChannelBase::connect(con_options, _client.get_connection_id(), _client.get_host(), _client.get_password()); on_connect(); set_state(CONNECTED_STATE); _loop.add_socket(*this); _socket_in_loop = true; on_event(); _loop.run(); } catch (RedPeer::DisconnectedException&) { _error = SPICEC_ERROR_CODE_SUCCESS; } catch (Exception& e) { LOG_WARN("%s", e.what()); _error = e.get_error_code(); } catch (std::exception& e) { LOG_WARN("%s", e.what()); _error = SPICEC_ERROR_CODE_ERROR; } if (_socket_in_loop) { _socket_in_loop = false; _loop.remove_socket(*this); } if (_outgoing_message) { _outgoing_message->release(); _outgoing_message = NULL; } _incomming_header_pos = 0; delete _incomming_message; _incomming_message = NULL; case DISCONNECT_ACTION: close(); on_disconnect(); set_state(DISCONNECTED_STATE); _client.on_channel_disconnected(*this); continue; case QUIT_ACTION: set_state(TERMINATED_STATE); return; } } } bool RedChannel::abort() { clear_outgoing_messages(); Lock lock(_action_lock); if (_state == TERMINATED_STATE) { if (_wait_for_threads) { _wait_for_threads = false; _worker->join(); } return true; } _action = QUIT_ACTION; _action_cond.notify_one(); lock.unlock(); RedPeer::disconnect(); _abort_trigger.trigger(); for (;;) { Lock state_lock(_state_lock); if (_state == TERMINATED_STATE) { break; } uint64_t timout = 1000 * 1000 * 100; // 100ms if (!_state_cond.timed_wait(state_lock, timout)) { return false; } } if (_wait_for_threads) { _wait_for_threads = false; _worker->join(); } return true; } void RedChannel::send_messages() { if (_outgoing_message) { return; } for (;;) { Lock lock(_outgoing_lock); AutoMessage message(get_outgoing_message()); if (!message.get()) { return; } RedPeer::OutMessage& peer_message = message.get()->peer_message(); uint32_t n = send(peer_message); if (n != peer_message.message_size()) { _outgoing_message = message.release(); _outgoing_pos = n; return; } } } void RedChannel::on_send_trigger() { send_messages(); } void RedChannel::on_message_recived() { if (_message_ack_count && !--_message_ack_count) { post_message(new Message(REDC_ACK, 0)); _message_ack_count = _message_ack_window; } } void RedChannel::on_message_complition(uint64_t serial) { Lock lock(*_sync_info.lock); *_sync_info.message_serial = serial; _sync_info.condition->notify_all(); } void RedChannel::recive_messages() { for (;;) { uint32_t n = RedPeer::recive((uint8_t*)&_incomming_header, sizeof(RedDataHeader)); if (n != sizeof(RedDataHeader)) { _incomming_header_pos = n; return; } std::auto_ptr message(new CompundInMessage(_incomming_header.serial, _incomming_header.type, _incomming_header.size, _incomming_header.sub_list)); n = RedPeer::recive(message->data(), message->compund_size()); if (n != message->compund_size()) { _incomming_message = message.release(); _incomming_message_pos = n; return; } on_message_recived(); _message_handler->handle_message(*message.get()); on_message_complition(message->serial()); } } void RedChannel::on_event() { if (_outgoing_message) { RedPeer::OutMessage& peer_message = _outgoing_message->peer_message(); _outgoing_pos += send(peer_message.base() + _outgoing_pos, peer_message.message_size() - _outgoing_pos); if (_outgoing_pos == peer_message.message_size()) { _outgoing_message->release(); _outgoing_message = NULL; } } send_messages(); if (_incomming_header_pos) { _incomming_header_pos += RedPeer::recive(((uint8_t*)&_incomming_header) + _incomming_header_pos, sizeof(RedDataHeader) - _incomming_header_pos); if (_incomming_header_pos != sizeof(RedDataHeader)) { return; } _incomming_header_pos = 0; _incomming_message = new CompundInMessage(_incomming_header.serial, _incomming_header.type, _incomming_header.size, _incomming_header.sub_list); _incomming_message_pos = 0; } if (_incomming_message) { _incomming_message_pos += RedPeer::recive(_incomming_message->data() + _incomming_message_pos, _incomming_message->compund_size() - _incomming_message_pos); if (_incomming_message_pos != _incomming_message->compund_size()) { return; } std::auto_ptr message(_incomming_message); _incomming_message = NULL; on_message_recived(); _message_handler->handle_message(*message.get()); on_message_complition(message->serial()); } recive_messages(); } void RedChannel::send_migrate_flush_mark() { if (_outgoing_message) { RedPeer::OutMessage& peer_message = _outgoing_message->peer_message(); send(peer_message.base() + _outgoing_pos, peer_message.message_size() - _outgoing_pos); _outgoing_message->release(); _outgoing_message = NULL; } Lock lock(_outgoing_lock); for (;;) { AutoMessage message(get_outgoing_message()); if (!message.get()) { break; } send(message.get()->peer_message()); } lock.unlock(); std::auto_ptr message(new RedPeer::OutMessage(REDC_MIGRATE_FLUSH_MARK, 0)); send(*message); } void RedChannel::handle_migrate(RedPeer::InMessage* message) { DBG(0, "channel type %u id %u", get_type(), get_id()); _socket_in_loop = false; _loop.remove_socket(*this); RedMigrate* migrate = (RedMigrate*)message->data(); if (migrate->flags & RED_MIGRATE_NEED_FLUSH) { send_migrate_flush_mark(); } std::auto_ptr data_message; if (migrate->flags & RED_MIGRATE_NEED_DATA_TRANSFER) { data_message.reset(recive()); } _client.migrate_channel(*this); if (migrate->flags & RED_MIGRATE_NEED_DATA_TRANSFER) { if (data_message->type() != RED_MIGRATE_DATA) { THROW("expect RED_MIGRATE_DATA"); } std::auto_ptr message(new RedPeer::OutMessage(REDC_MIGRATE_DATA, data_message->size())); memcpy(message->data(), data_message->data(), data_message->size()); send(*message); } _loop.add_socket(*this); _socket_in_loop = true; on_migrate(); set_state(CONNECTED_STATE); on_event(); } void RedChannel::handle_set_ack(RedPeer::InMessage* message) { RedSetAck* ack = (RedSetAck*)message->data(); _message_ack_window = _message_ack_count = ack->window; Message *responce = new Message(REDC_ACK_SYNC, sizeof(uint32_t)); *(uint32_t *)responce->data() = ack->generation; post_message(responce); } void RedChannel::handle_ping(RedPeer::InMessage* message) { RedPing *ping = (RedPing *)message->data(); Message *pong = new Message(REDC_PONG, sizeof(RedPing)); *(RedPing *)pong->data() = *ping; post_message(pong); } void RedChannel::handle_disconnect(RedPeer::InMessage* message) { RedDisconnect *disconnect = (RedDisconnect *)message->data(); _disconnect_stamp = disconnect->time_stamp; _disconnect_reason = disconnect->reason; } void RedChannel::handle_notify(RedPeer::InMessage* message) { RedNotify *notify = (RedNotify *)message->data(); const char *sevirity; const char *visibility; const char *message_str = ""; const char *message_prefix = ""; static const char* sevirity_strings[] = {"info", "warn", "error"}; static const char* visibility_strings[] = {"!", "!!", "!!!"}; if (notify->severty > RED_NOTIFY_SEVERITY_ERROR) { THROW("bad sevirity"); } sevirity = sevirity_strings[notify->severty]; if (notify->visibilty > RED_NOTIFY_VISIBILITY_HIGH) { THROW("bad visibilty"); } visibility = visibility_strings[notify->visibilty]; if (notify->message_len) { if ((message->size() - sizeof(*notify) < notify->message_len + 1)) { THROW("access violation"); } message_str = (char *)(notify + 1); if (message_str[notify->message_len] != 0) { THROW("invalid message"); } message_prefix = ": "; } LOG_INFO("remote channel %u:%u %s%s #%u%s%s", get_type(), get_id(), sevirity, visibility, notify->what, message_prefix, message_str); } void RedChannel::handle_wait_for_channels(RedPeer::InMessage* message) { RedWaitForChannels *wait = (RedWaitForChannels *)message->data(); if (message->size() < sizeof(*wait) + wait->wait_count * sizeof(wait->wait_list[0])) { THROW("access violation"); } _client.wait_for_channels(wait->wait_count, wait->wait_list); }