/*
Copyright (C) 2009 Red Hat, Inc.
This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License as
published by the Free Software Foundation; either version 2 of
the License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
*/
#include "common.h"
#include "red_channel.h"
#include "red_client.h"
#include "application.h"
#include "debug.h"
#include "utils.h"
#include "openssl/rsa.h"
#include "openssl/evp.h"
#include "openssl/x509.h"
RedChannelBase::RedChannelBase(uint8_t type, uint8_t id, const ChannelCaps& common_caps,
const ChannelCaps& caps)
: RedPeer()
, _type (type)
, _id (id)
, _common_caps (common_caps)
, _caps (caps)
{
}
RedChannelBase::~RedChannelBase()
{
}
void RedChannelBase::link(uint32_t connection_id, const std::string& password)
{
RedLinkHeader header;
RedLinkMess link_mess;
RedLinkReply* reply;
uint32_t link_res;
uint32_t i;
EVP_PKEY *pubkey;
int nRSASize;
BIO *bioKey;
RSA *rsa;
header.magic = RED_MAGIC;
header.size = sizeof(link_mess);
header.major_version = RED_VERSION_MAJOR;
header.minor_version = RED_VERSION_MINOR;
link_mess.connection_id = connection_id;
link_mess.channel_type = _type;
link_mess.channel_id = _id;
link_mess.num_common_caps = get_common_caps().size();
link_mess.num_channel_caps = get_caps().size();
link_mess.caps_offset = sizeof(link_mess);
header.size += (link_mess.num_common_caps + link_mess.num_channel_caps) * sizeof(uint32_t);
send((uint8_t*)&header, sizeof(header));
send((uint8_t*)&link_mess, sizeof(link_mess));
for (i = 0; i < _common_caps.size(); i++) {
send((uint8_t*)&_common_caps[i], sizeof(uint32_t));
}
for (i = 0; i < _caps.size(); i++) {
send((uint8_t*)&_caps[i], sizeof(uint32_t));
}
recive((uint8_t*)&header, sizeof(header));
if (header.magic != RED_MAGIC) {
THROW_ERR(SPICEC_ERROR_CODE_CONNECT_FAILED, "bad magic");
}
if (header.major_version != RED_VERSION_MAJOR) {
THROW_ERR(SPICEC_ERROR_CODE_VERSION_MISMATCH,
"version mismatch: expect %u got %u",
RED_VERSION_MAJOR,
header.major_version);
}
_remote_minor = header.minor_version;
AutoArray reply_buf(new uint8_t[header.size]);
recive(reply_buf.get(), header.size);
reply = (RedLinkReply *)reply_buf.get();
if (reply->error != RED_ERR_OK) {
THROW_ERR(SPICEC_ERROR_CODE_CONNECT_FAILED, "connect error %u", reply->error);
}
uint32_t num_caps = reply->num_channel_caps + reply->num_common_caps;
if ((uint8_t *)(reply + 1) > reply_buf.get() + header.size ||
(uint8_t *)reply + reply->caps_offset + num_caps * sizeof(uint32_t) >
reply_buf.get() + header.size) {
THROW_ERR(SPICEC_ERROR_CODE_CONNECT_FAILED, "access violation");
}
uint32_t *caps = (uint32_t *)((uint8_t *)reply + reply->caps_offset);
_remote_common_caps.clear();
for (i = 0; i < reply->num_common_caps; i++, caps++) {
_remote_common_caps.resize(i + 1);
_remote_common_caps[i] = *caps;
}
_remote_caps.clear();
for (i = 0; i < reply->num_channel_caps; i++, caps++) {
_remote_caps.resize(i + 1);
_remote_caps[i] = *caps;
}
bioKey = BIO_new(BIO_s_mem());
if (bioKey != NULL) {
BIO_write(bioKey, reply->pub_key, RED_TICKET_PUBKEY_BYTES);
pubkey = d2i_PUBKEY_bio(bioKey, NULL);
rsa = pubkey->pkey.rsa;
nRSASize = RSA_size(rsa);
AutoArray bufEncrypted(new unsigned char[nRSASize]);
/*
The use of RSA encryption limit the potential maximum password length.
for RSA_PKCS1_OAEP_PADDING it is RSA_size(rsa) - 41.
*/
if (RSA_public_encrypt(password.length() + 1, (unsigned char *)password.c_str(),
(uint8_t *)bufEncrypted.get(),
rsa, RSA_PKCS1_OAEP_PADDING) > 0) {
send((uint8_t*)bufEncrypted.get(), nRSASize);
} else {
THROW("could not encrypt password");
}
memset(bufEncrypted.get(), 0, nRSASize);
} else {
THROW("Could not initiate BIO");
}
BIO_free(bioKey);
recive((uint8_t*)&link_res, sizeof(link_res));
if (link_res != RED_ERR_OK) {
int error_code = (link_res == RED_ERR_PERMISSION_DENIED) ?
SPICEC_ERROR_CODE_CONNECT_FAILED : SPICEC_ERROR_CODE_CONNECT_FAILED;
THROW_ERR(error_code, "connect failed %u", link_res);
}
}
void RedChannelBase::connect(const ConnectionOptions& options, uint32_t connection_id,
const char* host, std::string password)
{
if (options.allow_unsecure()) {
try {
RedPeer::connect_unsecure(host, options.unsecure_port);
link(connection_id, password);
return;
} catch (...) {
if (!options.allow_secure()) {
throw;
}
RedPeer::close();
}
}
ASSERT(options.allow_secure());
RedPeer::connect_secure(options, host);
link(connection_id, password);
}
void RedChannelBase::set_capability(ChannelCaps& caps, uint32_t cap)
{
uint32_t word_index = cap / 32;
if (caps.size() < word_index + 1) {
caps.resize(word_index + 1);
}
caps[word_index] |= 1 << (cap % 32);
}
void RedChannelBase::set_common_capability(uint32_t cap)
{
set_capability(_common_caps, cap);
}
void RedChannelBase::set_capability(uint32_t cap)
{
set_capability(_caps, cap);
}
bool RedChannelBase::test_capability(const ChannelCaps& caps, uint32_t cap)
{
uint32_t word_index = cap / 32;
if (caps.size() < word_index + 1) {
return false;
}
return (caps[word_index] & (1 << (cap % 32))) != 0;
}
bool RedChannelBase::test_common_capability(uint32_t cap)
{
return test_capability(_remote_common_caps, cap);
}
bool RedChannelBase::test_capability(uint32_t cap)
{
return test_capability(_remote_caps, cap);
}
SendTrigger::SendTrigger(RedChannel& channel)
: _channel (channel)
{
}
void SendTrigger::on_event()
{
_channel.on_send_trigger();
}
void AbortTrigger::on_event()
{
THROW("abort");
}
RedChannel::RedChannel(RedClient& client, uint8_t type, uint8_t id,
RedChannel::MessageHandler* handler,
Platform::ThreadPriority worker_priority)
: RedChannelBase(type, id, ChannelCaps(), ChannelCaps())
, _client (client)
, _state (PASSIVE_STATE)
, _action (WAIT_ACTION)
, _error (SPICEC_ERROR_CODE_SUCCESS)
, _wait_for_threads (true)
, _socket_in_loop (false)
, _worker (NULL)
, _worker_priority (worker_priority)
, _message_handler (handler)
, _outgoing_message (NULL)
, _incomming_header_pos (0)
, _incomming_message (NULL)
, _message_ack_count (0)
, _message_ack_window (0)
, _loop (this)
, _send_trigger (*this)
, _disconnect_stamp (0)
, _disconnect_reason (RED_ERR_OK)
{
_loop.add_trigger(_send_trigger);
_loop.add_trigger(_abort_trigger);
}
RedChannel::~RedChannel()
{
ASSERT(_state == TERMINATED_STATE || _state == PASSIVE_STATE);
delete _worker;
}
void* RedChannel::worker_main(void *data)
{
try {
RedChannel* channel = static_cast(data);
channel->set_state(DISCONNECTED_STATE);
Platform::set_thread_priority(NULL, channel->get_worker_priority());
channel->run();
} catch (Exception& e) {
LOG_ERROR("unhandle exception: %s", e.what());
} catch (std::exception& e) {
LOG_ERROR("unhandle exception: %s", e.what());
} catch (...) {
LOG_ERROR("unhandled exception");
}
return NULL;
}
void RedChannel::post_message(RedChannel::OutMessage* message)
{
Lock lock(_outgoing_lock);
_outgoing_messages.push_back(message);
lock.unlock();
_send_trigger.trigger();
}
RedPeer::CompundInMessage *RedChannel::recive()
{
CompundInMessage *message = RedChannelBase::recive();
on_message_recived();
return message;
}
RedChannel::OutMessage* RedChannel::get_outgoing_message()
{
if (_state != CONNECTED_STATE || _outgoing_messages.empty()) {
return NULL;
}
RedChannel::OutMessage* message = _outgoing_messages.front();
_outgoing_messages.pop_front();
return message;
}
class AutoMessage {
public:
AutoMessage(RedChannel::OutMessage* message) : _message (message) {}
~AutoMessage() {if (_message) _message->release();}
void set(RedChannel::OutMessage* message) { _message = message;}
RedChannel::OutMessage* get() { return _message;}
RedChannel::OutMessage* release();
private:
RedChannel::OutMessage* _message;
};
RedChannel::OutMessage* AutoMessage::release()
{
RedChannel::OutMessage* ret = _message;
_message = NULL;
return ret;
}
void RedChannel::start()
{
ASSERT(!_worker);
_worker = new Thread(RedChannel::worker_main, this);
Lock lock(_state_lock);
while (_state == PASSIVE_STATE) {
_state_cond.wait(lock);
}
}
void RedChannel::set_state(int state)
{
Lock lock(_state_lock);
_state = state;
_state_cond.notify_all();
}
void RedChannel::connect()
{
Lock lock(_action_lock);
if (_state != DISCONNECTED_STATE && _state != PASSIVE_STATE) {
return;
}
_action = CONNECT_ACTION;
_action_cond.notify_one();
}
void RedChannel::disconnect()
{
clear_outgoing_messages();
Lock lock(_action_lock);
if (_state != CONNECTING_STATE && _state != CONNECTED_STATE) {
return;
}
_action = DISCONNECT_ACTION;
RedPeer::disconnect();
_action_cond.notify_one();
}
void RedChannel::clear_outgoing_messages()
{
Lock lock(_outgoing_lock);
while (!_outgoing_messages.empty()) {
RedChannel::OutMessage* message = _outgoing_messages.front();
_outgoing_messages.pop_front();
message->release();
}
}
void RedChannel::run()
{
for (;;) {
Lock lock(_action_lock);
if (_action == WAIT_ACTION) {
_action_cond.wait(lock);
}
int action = _action;
_action = WAIT_ACTION;
lock.unlock();
switch (action) {
case CONNECT_ACTION:
try {
get_client().get_sync_info(get_type(), get_id(), _sync_info);
on_connecting();
set_state(CONNECTING_STATE);
ConnectionOptions con_options(_client.get_connection_options(get_type()),
_client.get_port(),
_client.get_sport(),
_client.get_host_auth_options());
RedChannelBase::connect(con_options, _client.get_connection_id(),
_client.get_host().c_str(),
_client.get_password().c_str());
on_connect();
set_state(CONNECTED_STATE);
_loop.add_socket(*this);
_socket_in_loop = true;
on_event();
_loop.run();
} catch (RedPeer::DisconnectedException&) {
_error = SPICEC_ERROR_CODE_SUCCESS;
} catch (Exception& e) {
LOG_WARN("%s", e.what());
_error = e.get_error_code();
} catch (std::exception& e) {
LOG_WARN("%s", e.what());
_error = SPICEC_ERROR_CODE_ERROR;
}
if (_socket_in_loop) {
_socket_in_loop = false;
_loop.remove_socket(*this);
}
if (_outgoing_message) {
_outgoing_message->release();
_outgoing_message = NULL;
}
_incomming_header_pos = 0;
if (_incomming_message) {
_incomming_message->unref();
_incomming_message = NULL;
}
case DISCONNECT_ACTION:
close();
on_disconnect();
set_state(DISCONNECTED_STATE);
_client.on_channel_disconnected(*this);
continue;
case QUIT_ACTION:
set_state(TERMINATED_STATE);
return;
}
}
}
bool RedChannel::abort()
{
clear_outgoing_messages();
Lock lock(_action_lock);
if (_state == TERMINATED_STATE) {
if (_wait_for_threads) {
_wait_for_threads = false;
_worker->join();
}
return true;
}
_action = QUIT_ACTION;
_action_cond.notify_one();
lock.unlock();
RedPeer::disconnect();
_abort_trigger.trigger();
for (;;) {
Lock state_lock(_state_lock);
if (_state == TERMINATED_STATE) {
break;
}
uint64_t timout = 1000 * 1000 * 100; // 100ms
if (!_state_cond.timed_wait(state_lock, timout)) {
return false;
}
}
if (_wait_for_threads) {
_wait_for_threads = false;
_worker->join();
}
return true;
}
void RedChannel::send_messages()
{
if (_outgoing_message) {
return;
}
for (;;) {
Lock lock(_outgoing_lock);
AutoMessage message(get_outgoing_message());
if (!message.get()) {
return;
}
RedPeer::OutMessage& peer_message = message.get()->peer_message();
uint32_t n = send(peer_message);
if (n != peer_message.message_size()) {
_outgoing_message = message.release();
_outgoing_pos = n;
return;
}
}
}
void RedChannel::on_send_trigger()
{
send_messages();
}
void RedChannel::on_message_recived()
{
if (_message_ack_count && !--_message_ack_count) {
post_message(new Message(REDC_ACK, 0));
_message_ack_count = _message_ack_window;
}
}
void RedChannel::on_message_complition(uint64_t serial)
{
Lock lock(*_sync_info.lock);
*_sync_info.message_serial = serial;
_sync_info.condition->notify_all();
}
void RedChannel::recive_messages()
{
for (;;) {
uint32_t n = RedPeer::recive((uint8_t*)&_incomming_header, sizeof(RedDataHeader));
if (n != sizeof(RedDataHeader)) {
_incomming_header_pos = n;
return;
}
AutoRef message(new CompundInMessage(_incomming_header.serial,
_incomming_header.type,
_incomming_header.size,
_incomming_header.sub_list));
n = RedPeer::recive((*message)->data(), (*message)->compund_size());
if (n != (*message)->compund_size()) {
_incomming_message = message.release();
_incomming_message_pos = n;
return;
}
on_message_recived();
_message_handler->handle_message(*(*message));
on_message_complition((*message)->serial());
}
}
void RedChannel::on_event()
{
if (_outgoing_message) {
RedPeer::OutMessage& peer_message = _outgoing_message->peer_message();
_outgoing_pos += send(peer_message.base() + _outgoing_pos,
peer_message.message_size() - _outgoing_pos);
if (_outgoing_pos == peer_message.message_size()) {
_outgoing_message->release();
_outgoing_message = NULL;
}
}
send_messages();
if (_incomming_header_pos) {
_incomming_header_pos += RedPeer::recive(((uint8_t*)&_incomming_header) +
_incomming_header_pos,
sizeof(RedDataHeader) - _incomming_header_pos);
if (_incomming_header_pos != sizeof(RedDataHeader)) {
return;
}
_incomming_header_pos = 0;
_incomming_message = new CompundInMessage(_incomming_header.serial,
_incomming_header.type,
_incomming_header.size,
_incomming_header.sub_list);
_incomming_message_pos = 0;
}
if (_incomming_message) {
_incomming_message_pos += RedPeer::recive(_incomming_message->data() +
_incomming_message_pos,
_incomming_message->compund_size() -
_incomming_message_pos);
if (_incomming_message_pos != _incomming_message->compund_size()) {
return;
}
AutoRef message(_incomming_message);
_incomming_message = NULL;
on_message_recived();
_message_handler->handle_message(*(*message));
on_message_complition((*message)->serial());
}
recive_messages();
}
void RedChannel::send_migrate_flush_mark()
{
if (_outgoing_message) {
RedPeer::OutMessage& peer_message = _outgoing_message->peer_message();
send(peer_message.base() + _outgoing_pos, peer_message.message_size() - _outgoing_pos);
_outgoing_message->release();
_outgoing_message = NULL;
}
Lock lock(_outgoing_lock);
for (;;) {
AutoMessage message(get_outgoing_message());
if (!message.get()) {
break;
}
send(message.get()->peer_message());
}
lock.unlock();
std::auto_ptr message(new RedPeer::OutMessage(REDC_MIGRATE_FLUSH_MARK, 0));
send(*message);
}
void RedChannel::handle_migrate(RedPeer::InMessage* message)
{
DBG(0, "channel type %u id %u", get_type(), get_id());
_socket_in_loop = false;
_loop.remove_socket(*this);
RedMigrate* migrate = (RedMigrate*)message->data();
if (migrate->flags & RED_MIGRATE_NEED_FLUSH) {
send_migrate_flush_mark();
}
AutoRef data_message;
if (migrate->flags & RED_MIGRATE_NEED_DATA_TRANSFER) {
data_message.reset(recive());
}
_client.migrate_channel(*this);
if (migrate->flags & RED_MIGRATE_NEED_DATA_TRANSFER) {
if ((*data_message)->type() != RED_MIGRATE_DATA) {
THROW("expect RED_MIGRATE_DATA");
}
std::auto_ptr message(new RedPeer::OutMessage(REDC_MIGRATE_DATA,
(*data_message)->size()));
memcpy(message->data(), (*data_message)->data(), (*data_message)->size());
send(*message);
}
_loop.add_socket(*this);
_socket_in_loop = true;
on_migrate();
set_state(CONNECTED_STATE);
on_event();
}
void RedChannel::handle_set_ack(RedPeer::InMessage* message)
{
RedSetAck* ack = (RedSetAck*)message->data();
_message_ack_window = _message_ack_count = ack->window;
Message *responce = new Message(REDC_ACK_SYNC, sizeof(uint32_t));
*(uint32_t *)responce->data() = ack->generation;
post_message(responce);
}
void RedChannel::handle_ping(RedPeer::InMessage* message)
{
RedPing *ping = (RedPing *)message->data();
Message *pong = new Message(REDC_PONG, sizeof(RedPing));
*(RedPing *)pong->data() = *ping;
post_message(pong);
}
void RedChannel::handle_disconnect(RedPeer::InMessage* message)
{
RedDisconnect *disconnect = (RedDisconnect *)message->data();
_disconnect_stamp = disconnect->time_stamp;
_disconnect_reason = disconnect->reason;
}
void RedChannel::handle_notify(RedPeer::InMessage* message)
{
RedNotify *notify = (RedNotify *)message->data();
const char *sevirity;
const char *visibility;
const char *message_str = "";
const char *message_prefix = "";
static const char* sevirity_strings[] = {"info", "warn", "error"};
static const char* visibility_strings[] = {"!", "!!", "!!!"};
if (notify->severty > RED_NOTIFY_SEVERITY_ERROR) {
THROW("bad sevirity");
}
sevirity = sevirity_strings[notify->severty];
if (notify->visibilty > RED_NOTIFY_VISIBILITY_HIGH) {
THROW("bad visibilty");
}
visibility = visibility_strings[notify->visibilty];
if (notify->message_len) {
if ((message->size() - sizeof(*notify) < notify->message_len + 1)) {
THROW("access violation");
}
message_str = (char *)(notify + 1);
if (message_str[notify->message_len] != 0) {
THROW("invalid message");
}
message_prefix = ": ";
}
LOG_INFO("remote channel %u:%u %s%s #%u%s%s",
get_type(), get_id(),
sevirity, visibility,
notify->what,
message_prefix, message_str);
}
void RedChannel::handle_wait_for_channels(RedPeer::InMessage* message)
{
RedWaitForChannels *wait = (RedWaitForChannels *)message->data();
if (message->size() < sizeof(*wait) + wait->wait_count * sizeof(wait->wait_list[0])) {
THROW("access violation");
}
_client.wait_for_channels(wait->wait_count, wait->wait_list);
}