diff options
author | Yaniv Kamay <ykamay@redhat.com> | 2009-09-19 21:25:46 +0300 |
---|---|---|
committer | Yaniv Kamay <ykamay@redhat.com> | 2009-10-14 15:06:41 +0200 |
commit | c1b79eb035fa158fb2ac3bc8e559809611070016 (patch) | |
tree | 3348dd749a700dedf87c9b16fe8be77c62928df8 /client/red_channel.cpp | |
download | spice-c1b79eb035fa158fb2ac3bc8e559809611070016.tar.gz spice-c1b79eb035fa158fb2ac3bc8e559809611070016.tar.xz spice-c1b79eb035fa158fb2ac3bc8e559809611070016.zip |
fresh start
Diffstat (limited to 'client/red_channel.cpp')
-rw-r--r-- | client/red_channel.cpp | 714 |
1 files changed, 714 insertions, 0 deletions
diff --git a/client/red_channel.cpp b/client/red_channel.cpp new file mode 100644 index 00000000..4c6f1f8f --- /dev/null +++ b/client/red_channel.cpp @@ -0,0 +1,714 @@ +/* + Copyright (C) 2009 Red Hat, Inc. + + This program is free software; you can redistribute it and/or + modify it under the terms of the GNU General Public License as + published by the Free Software Foundation; either version 2 of + the License, or (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +#include "common.h" +#include "red_channel.h" +#include "red_client.h" +#include "application.h" +#include "debug.h" +#include "utils.h" + +#include "openssl/rsa.h" +#include "openssl/evp.h" +#include "openssl/x509.h" + + +RedChannelBase::RedChannelBase(uint8_t type, uint8_t id, const ChannelCaps& common_caps, + const ChannelCaps& caps) + : RedPeer() + , _type (type) + , _id (id) + , _common_caps (common_caps) + , _caps (caps) +{ +} + +RedChannelBase::~RedChannelBase() +{ +} + +void RedChannelBase::link(uint32_t connection_id, const std::string& password) +{ + RedLinkHeader header; + RedLinkMess link_mess; + RedLinkReply* reply; + uint32_t link_res; + uint32_t i; + + EVP_PKEY *pubkey; + int nRSASize; + BIO *bioKey; + RSA *rsa; + + header.magic = RED_MAGIC; + header.size = sizeof(link_mess); + header.major_version = RED_VERSION_MAJOR; + header.minor_version = RED_VERSION_MINOR; + link_mess.connection_id = connection_id; + link_mess.channel_type = _type; + link_mess.channel_id = _id; + link_mess.num_common_caps = get_common_caps().size(); + link_mess.num_channel_caps = get_caps().size(); + link_mess.caps_offset = sizeof(link_mess); + header.size += (link_mess.num_common_caps + link_mess.num_channel_caps) * sizeof(uint32_t); + send((uint8_t*)&header, sizeof(header)); + send((uint8_t*)&link_mess, sizeof(link_mess)); + + for (i = 0; i < _common_caps.size(); i++) { + send((uint8_t*)&_common_caps[i], sizeof(uint32_t)); + } + + for (i = 0; i < _caps.size(); i++) { + send((uint8_t*)&_caps[i], sizeof(uint32_t)); + } + + recive((uint8_t*)&header, sizeof(header)); + + if (header.magic != RED_MAGIC) { + THROW_ERR(SPICEC_ERROR_CODE_CONNECT_FAILED, "bad magic"); + } + + if (header.major_version != RED_VERSION_MAJOR) { + THROW_ERR(SPICEC_ERROR_CODE_VERSION_MISMATCH, + "version mismatch: expect %u got %u", + RED_VERSION_MAJOR, + header.major_version); + } + + AutoArray<uint8_t> reply_buf(new uint8_t[header.size]); + recive(reply_buf.get(), header.size); + + reply = (RedLinkReply *)reply_buf.get(); + + if (reply->error != RED_ERR_OK) { + THROW_ERR(SPICEC_ERROR_CODE_CONNECT_FAILED, "connect error %u", reply->error); + } + + uint32_t num_caps = reply->num_channel_caps + reply->num_common_caps; + if ((uint8_t *)(reply + 1) > reply_buf.get() + header.size || + (uint8_t *)reply + reply->caps_offset + num_caps * sizeof(uint32_t) > + reply_buf.get() + header.size) { + THROW_ERR(SPICEC_ERROR_CODE_CONNECT_FAILED, "access violation"); + } + + uint32_t *caps = (uint32_t *)((uint8_t *)reply + reply->caps_offset); + + _remote_common_caps.clear(); + for (i = 0; i < reply->num_common_caps; i++, caps++) { + _remote_common_caps.resize(i + 1); + _remote_common_caps[i] = *caps; + } + + _remote_caps.clear(); + for (i = 0; i < reply->num_channel_caps; i++, caps++) { + _remote_caps.resize(i + 1); + _remote_caps[i] = *caps; + } + + bioKey = BIO_new(BIO_s_mem()); + if (bioKey != NULL) { + BIO_write(bioKey, reply->pub_key, RED_TICKET_PUBKEY_BYTES); + pubkey = d2i_PUBKEY_bio(bioKey, NULL); + rsa = pubkey->pkey.rsa; + nRSASize = RSA_size(rsa); + AutoArray<unsigned char> bufEncrypted(new unsigned char[nRSASize]); + + /* + The use of RSA encryption limit the potential maximum password length. + for RSA_PKCS1_OAEP_PADDING it is RSA_size(rsa) - 41. + */ + if (RSA_public_encrypt(password.length() + 1, (unsigned char *)password.c_str(), + (uint8_t *)bufEncrypted.get(), + rsa, RSA_PKCS1_OAEP_PADDING) > 0 ) { + send((uint8_t*)bufEncrypted.get(), nRSASize); + } else { + THROW("could not encrypt password"); + } + + memset(bufEncrypted.get(), 0, nRSASize); + } else { + THROW("Could not initiate BIO"); + } + + BIO_free(bioKey); + + recive((uint8_t*)&link_res, sizeof(link_res)); + if (link_res != RED_ERR_OK) { + int error_code = (link_res == RED_ERR_PERMISSION_DENIED) ? + SPICEC_ERROR_CODE_CONNECT_FAILED : SPICEC_ERROR_CODE_CONNECT_FAILED; + THROW_ERR(error_code, "connect failed %u", link_res); + } +} + +void RedChannelBase::connect(const ConnectionOptions& options, uint32_t connection_id, + uint32_t ip, std::string password) +{ + if (options.allow_unsecure()) { + try { + RedPeer::connect_unsecure(ip, options.unsecure_port); + link(connection_id, password); + return; + } catch (...) { + if (!options.allow_secure()) { + throw; + } + RedPeer::close(); + } + } + ASSERT(options.allow_secure()); + RedPeer::connect_secure(options, ip); + link(connection_id, password); +} + +void RedChannelBase::connect(const ConnectionOptions& options, uint32_t connection_id, + const char* host, std::string password) +{ + connect(options, connection_id, host_by_name(host), password); +} + +void RedChannelBase::set_capability(ChannelCaps& caps, uint32_t cap) +{ + uint32_t word_index = cap / 32; + + if (caps.size() < word_index + 1) { + caps.resize(word_index + 1); + } + caps[word_index] |= 1 << (cap % 32); +} + +void RedChannelBase::set_common_capability(uint32_t cap) +{ + set_capability(_common_caps, cap); +} + +void RedChannelBase::set_capability(uint32_t cap) +{ + set_capability(_caps, cap); +} + +bool RedChannelBase::test_capability(const ChannelCaps& caps, uint32_t cap) +{ + uint32_t word_index = cap / 32; + + if (caps.size() < word_index + 1) { + return false; + } + + return (caps[word_index] & (1 << (cap % 32))) != 0; +} + +bool RedChannelBase::test_common_capability(uint32_t cap) +{ + return test_capability(_remote_common_caps, cap); +} + +bool RedChannelBase::test_capability(uint32_t cap) +{ + return test_capability(_remote_caps, cap); +} + +SendTrigger::SendTrigger(RedChannel& channel) + : _channel (channel) +{ +} + +void SendTrigger::on_event() +{ + _channel.on_send_trigger(); +} + +void AbortTrigger::on_event() +{ + THROW("abort"); +} + +RedChannel::RedChannel(RedClient& client, uint8_t type, uint8_t id, + RedChannel::MessageHandler* handler, + Platform::ThreadPriority worker_priority) + : RedChannelBase(type, id, ChannelCaps(), ChannelCaps()) + , _client (client) + , _state (PASSIVE_STATE) + , _action (WAIT_ACTION) + , _error (SPICEC_ERROR_CODE_SUCCESS) + , _wait_for_threads (true) + , _socket_in_loop (false) + , _worker (NULL) + , _worker_priority (worker_priority) + , _message_handler (handler) + , _outgoing_message (NULL) + , _incomming_header_pos (0) + , _incomming_message (NULL) + , _message_ack_count (0) + , _message_ack_window (0) + , _send_trigger (*this) + , _disconnect_stamp (0) + , _disconnect_reason (RED_ERR_OK) +{ + _loop.add_trigger(_send_trigger); + _loop.add_trigger(_abort_trigger); +} + +RedChannel::~RedChannel() +{ + ASSERT(_state == TERMINATED_STATE || _state == PASSIVE_STATE); + delete _worker; +} + +void* RedChannel::worker_main(void *data) +{ + try { + RedChannel* channel = static_cast<RedChannel*>(data); + channel->set_state(DISCONNECTED_STATE); + Platform::set_thread_priority(NULL, channel->get_worker_priority()); + channel->run(); + } catch (Exception& e) { + LOG_ERROR("unhandle exception: %s", e.what()); + } catch (std::exception& e) { + LOG_ERROR("unhandle exception: %s", e.what()); + } catch (...) { + LOG_ERROR("unhandled exception"); + } + return NULL; +} + +void RedChannel::post_message(RedChannel::OutMessage* message) +{ + Lock lock(_outgoing_lock); + _outgoing_messages.push_back(message); + lock.unlock(); + _send_trigger.trigger(); +} + +RedPeer::CompundInMessage *RedChannel::recive() +{ + CompundInMessage *message = RedChannelBase::recive(); + on_message_recived(); + return message; +} + +RedChannel::OutMessage* RedChannel::get_outgoing_message() +{ + if (_state != CONNECTED_STATE || _outgoing_messages.empty()) { + return NULL; + } + RedChannel::OutMessage* message = _outgoing_messages.front(); + _outgoing_messages.pop_front(); + return message; +} + +class AutoMessage { +public: + AutoMessage(RedChannel::OutMessage* message) : _message (message) {} + ~AutoMessage() {if (_message) _message->release();} + void set(RedChannel::OutMessage* message) { _message = message;} + RedChannel::OutMessage* get() { return _message;} + RedChannel::OutMessage* release(); + +private: + RedChannel::OutMessage* _message; +}; + +RedChannel::OutMessage* AutoMessage::release() +{ + RedChannel::OutMessage* ret = _message; + _message = NULL; + return ret; +} + +void RedChannel::start() +{ + ASSERT(!_worker); + _worker = new Thread(RedChannel::worker_main, this); + Lock lock(_state_lock); + while (_state == PASSIVE_STATE) { + _state_cond.wait(lock); + } +} + +void RedChannel::set_state(int state) +{ + Lock lock(_state_lock); + _state = state; + _state_cond.notify_all(); +} + +void RedChannel::connect() +{ + Lock lock(_action_lock); + + if (_state != DISCONNECTED_STATE && _state != PASSIVE_STATE) { + return; + } + _action = CONNECT_ACTION; + _action_cond.notify_one(); +} + +void RedChannel::disconnect() +{ + clear_outgoing_messages(); + + Lock lock(_action_lock); + if (_state != CONNECTING_STATE && _state != CONNECTED_STATE) { + return; + } + _action = DISCONNECT_ACTION; + RedPeer::disconnect(); + _action_cond.notify_one(); +} + +void RedChannel::clear_outgoing_messages() +{ + Lock lock(_outgoing_lock); + while (!_outgoing_messages.empty()) { + RedChannel::OutMessage* message = _outgoing_messages.front(); + _outgoing_messages.pop_front(); + message->release(); + } +} + +void RedChannel::run() +{ + for (;;) { + Lock lock(_action_lock); + if (_action == WAIT_ACTION) { + _action_cond.wait(lock); + } + int action = _action; + _action = WAIT_ACTION; + lock.unlock(); + switch (action) { + case CONNECT_ACTION: + try { + get_client().get_sync_info(get_type(), get_id(), _sync_info); + on_connecting(); + set_state(CONNECTING_STATE); + ConnectionOptions con_options(_client.get_connection_options(get_type()), + _client.get_port(), + _client.get_sport()); + RedChannelBase::connect(con_options, _client.get_connection_id(), + _client.get_host(), _client.get_password()); + on_connect(); + set_state(CONNECTED_STATE); + _loop.add_socket(*this); + _socket_in_loop = true; + on_event(); + _loop.run(); + } catch (RedPeer::DisconnectedException&) { + _error = SPICEC_ERROR_CODE_SUCCESS; + } catch (Exception& e) { + LOG_WARN("%s", e.what()); + _error = e.get_error_code(); + } catch (std::exception& e) { + LOG_WARN("%s", e.what()); + _error = SPICEC_ERROR_CODE_ERROR; + } + if (_socket_in_loop) { + _socket_in_loop = false; + _loop.remove_socket(*this); + } + if (_outgoing_message) { + _outgoing_message->release(); + _outgoing_message = NULL; + } + _incomming_header_pos = 0; + delete _incomming_message; + _incomming_message = NULL; + case DISCONNECT_ACTION: + close(); + on_disconnect(); + set_state(DISCONNECTED_STATE); + _client.on_channel_disconnected(*this); + continue; + case QUIT_ACTION: + set_state(TERMINATED_STATE); + return; + } + } +} + +bool RedChannel::abort() +{ + clear_outgoing_messages(); + Lock lock(_action_lock); + if (_state == TERMINATED_STATE) { + if (_wait_for_threads) { + _wait_for_threads = false; + _worker->join(); + } + return true; + } + + _action = QUIT_ACTION; + _action_cond.notify_one(); + lock.unlock(); + RedPeer::disconnect(); + _abort_trigger.trigger(); + + for (;;) { + Lock state_lock(_state_lock); + if (_state == TERMINATED_STATE) { + break; + } + uint64_t timout = 1000 * 1000 * 100; // 100ms + if (!_state_cond.timed_wait(state_lock, timout)) { + return false; + } + } + if (_wait_for_threads) { + _wait_for_threads = false; + _worker->join(); + } + return true; +} + +void RedChannel::send_messages() +{ + if (_outgoing_message) { + return; + } + + for (;;) { + Lock lock(_outgoing_lock); + AutoMessage message(get_outgoing_message()); + if (!message.get()) { + return; + } + RedPeer::OutMessage& peer_message = message.get()->peer_message(); + uint32_t n = send(peer_message); + if (n != peer_message.message_size()) { + _outgoing_message = message.release(); + _outgoing_pos = n; + return; + } + } +} + +void RedChannel::on_send_trigger() +{ + send_messages(); +} + +void RedChannel::on_message_recived() +{ + if (_message_ack_count && !--_message_ack_count) { + post_message(new Message(REDC_ACK, 0)); + _message_ack_count = _message_ack_window; + } +} + +void RedChannel::on_message_complition(uint64_t serial) +{ + Lock lock(*_sync_info.lock); + *_sync_info.message_serial = serial; + _sync_info.condition->notify_all(); +} + +void RedChannel::recive_messages() +{ + for (;;) { + uint32_t n = RedPeer::recive((uint8_t*)&_incomming_header, sizeof(RedDataHeader)); + if (n != sizeof(RedDataHeader)) { + _incomming_header_pos = n; + return; + } + std::auto_ptr<CompundInMessage> message(new CompundInMessage(_incomming_header.serial, + _incomming_header.type, + _incomming_header.size, + _incomming_header.sub_list)); + n = RedPeer::recive(message->data(), message->compund_size()); + if (n != message->compund_size()) { + _incomming_message = message.release(); + _incomming_message_pos = n; + return; + } + on_message_recived(); + _message_handler->handle_message(*message.get()); + on_message_complition(message->serial()); + } +} + +void RedChannel::on_event() +{ + if (_outgoing_message) { + RedPeer::OutMessage& peer_message = _outgoing_message->peer_message(); + _outgoing_pos += send(peer_message.base() + _outgoing_pos, + peer_message.message_size() - _outgoing_pos); + if (_outgoing_pos == peer_message.message_size()) { + _outgoing_message->release(); + _outgoing_message = NULL; + } + } + send_messages(); + + if (_incomming_header_pos) { + _incomming_header_pos += RedPeer::recive(((uint8_t*)&_incomming_header) + + _incomming_header_pos, + sizeof(RedDataHeader) - _incomming_header_pos); + if (_incomming_header_pos != sizeof(RedDataHeader)) { + return; + } + _incomming_header_pos = 0; + _incomming_message = new CompundInMessage(_incomming_header.serial, + _incomming_header.type, + _incomming_header.size, + _incomming_header.sub_list); + _incomming_message_pos = 0; + } + + if (_incomming_message) { + _incomming_message_pos += RedPeer::recive(_incomming_message->data() + + _incomming_message_pos, + _incomming_message->compund_size() - + _incomming_message_pos); + if (_incomming_message_pos != _incomming_message->compund_size()) { + return; + } + std::auto_ptr<CompundInMessage> message(_incomming_message); + _incomming_message = NULL; + on_message_recived(); + _message_handler->handle_message(*message.get()); + on_message_complition(message->serial()); + } + recive_messages(); +} + +void RedChannel::send_migrate_flush_mark() +{ + if (_outgoing_message) { + RedPeer::OutMessage& peer_message = _outgoing_message->peer_message(); + send(peer_message.base() + _outgoing_pos, peer_message.message_size() - _outgoing_pos); + _outgoing_message->release(); + _outgoing_message = NULL; + } + Lock lock(_outgoing_lock); + for (;;) { + AutoMessage message(get_outgoing_message()); + if (!message.get()) { + break; + } + send(message.get()->peer_message()); + } + lock.unlock(); + std::auto_ptr<RedPeer::OutMessage> message(new RedPeer::OutMessage(REDC_MIGRATE_FLUSH_MARK, 0)); + send(*message); +} + +void RedChannel::handle_migrate(RedPeer::InMessage* message) +{ + DBG(0, "channel type %u id %u", get_type(), get_id()); + _socket_in_loop = false; + _loop.remove_socket(*this); + RedMigrate* migrate = (RedMigrate*)message->data(); + if (migrate->flags & RED_MIGRATE_NEED_FLUSH) { + send_migrate_flush_mark(); + } + std::auto_ptr<RedPeer::CompundInMessage> data_message; + if (migrate->flags & RED_MIGRATE_NEED_DATA_TRANSFER) { + data_message.reset(recive()); + } + _client.migrate_channel(*this); + if (migrate->flags & RED_MIGRATE_NEED_DATA_TRANSFER) { + if (data_message->type() != RED_MIGRATE_DATA) { + THROW("expect RED_MIGRATE_DATA"); + } + std::auto_ptr<RedPeer::OutMessage> message(new RedPeer::OutMessage(REDC_MIGRATE_DATA, + data_message->size())); + memcpy(message->data(), data_message->data(), data_message->size()); + send(*message); + } + _loop.add_socket(*this); + _socket_in_loop = true; + on_migrate(); + set_state(CONNECTED_STATE); + on_event(); +} + +void RedChannel::handle_set_ack(RedPeer::InMessage* message) +{ + RedSetAck* ack = (RedSetAck*)message->data(); + _message_ack_window = _message_ack_count = ack->window; + Message *responce = new Message(REDC_ACK_SYNC, sizeof(uint32_t)); + *(uint32_t *)responce->data() = ack->generation; + post_message(responce); +} + +void RedChannel::handle_ping(RedPeer::InMessage* message) +{ + RedPing *ping = (RedPing *)message->data(); + Message *pong = new Message(REDC_PONG, sizeof(RedPing)); + *(RedPing *)pong->data() = *ping; + post_message(pong); +} + +void RedChannel::handle_disconnect(RedPeer::InMessage* message) +{ + RedDisconnect *disconnect = (RedDisconnect *)message->data(); + _disconnect_stamp = disconnect->time_stamp; + _disconnect_reason = disconnect->reason; +} + +void RedChannel::handle_notify(RedPeer::InMessage* message) +{ + RedNotify *notify = (RedNotify *)message->data(); + const char *sevirity; + const char *visibility; + const char *message_str = ""; + const char *message_prefix = ""; + + static const char* sevirity_strings[] = {"info", "warn", "error"}; + static const char* visibility_strings[] = {"!", "!!", "!!!"}; + + + if (notify->severty > RED_NOTIFY_SEVERITY_ERROR) { + THROW("bad sevirity"); + } + sevirity = sevirity_strings[notify->severty]; + + if (notify->visibilty > RED_NOTIFY_VISIBILITY_HIGH) { + THROW("bad visibilty"); + } + visibility = visibility_strings[notify->visibilty]; + + + if (notify->message_len) { + if ((message->size() - sizeof(*notify) < notify->message_len + 1)) { + THROW("access violation"); + } + message_str = (char *)(notify + 1); + if (message_str[notify->message_len] != 0) { + THROW("invalid message"); + } + message_prefix = ": "; + } + + + LOG_INFO("remote channel %u:%u %s%s #%u%s%s", + get_type(), get_id(), + sevirity, visibility, + notify->what, + message_prefix, message_str); +} + +void RedChannel::handle_wait_for_channels(RedPeer::InMessage* message) +{ + RedWaitForChannels *wait = (RedWaitForChannels *)message->data(); + if (message->size() < sizeof(*wait) + wait->wait_count * sizeof(wait->wait_list[0])) { + THROW("access violation"); + } + _client.wait_for_channels(wait->wait_count, wait->wait_list); +} + |