diff options
-rw-r--r-- | server/reds.c | 68 |
1 files changed, 64 insertions, 4 deletions
diff --git a/server/reds.c b/server/reds.c index c6c8a66a..4c36f92a 100644 --- a/server/reds.c +++ b/server/reds.c @@ -239,6 +239,7 @@ typedef struct RedLinkInfo { SpiceLinkMess *link_mess; int mess_pos; TicketInfo tiTicketing; + SpiceLinkAuthMechanism auth_mechanism; } RedLinkInfo; typedef struct VDIPortBuf VDIPortBuf; @@ -1338,6 +1339,28 @@ static int sync_write(RedsStream *stream, const void *in_buf, size_t n) return TRUE; } +static void reds_channel_set_common_caps(Channel *channel, int cap, int active) +{ + int nbefore, n; + + nbefore = channel->num_common_caps; + n = cap / 32; + channel->num_common_caps = MAX(channel->num_common_caps, n + 1); + channel->common_caps = spice_renew(uint32_t, channel->common_caps, channel->num_common_caps); + memset(channel->common_caps + nbefore, 0, + (channel->num_common_caps - nbefore) * sizeof(uint32_t)); + if (active) + channel->common_caps[n] |= (1 << cap); + else + channel->common_caps[n] &= ~(1 << cap); +} + +void reds_channel_init_auth_caps(Channel *channel) +{ + reds_channel_set_common_caps(channel, SPICE_COMMON_CAP_AUTH_SPICE, TRUE); + reds_channel_set_common_caps(channel, SPICE_COMMON_CAP_PROTOCOL_AUTH_SELECTION, TRUE); +} + void reds_channel_dispose(Channel *channel) { free(channel->caps); @@ -1371,6 +1394,8 @@ static int reds_send_link_ack(RedLinkInfo *link) channel = ∩︀ } + reds_channel_init_auth_caps(channel); /* make sure common caps are set */ + ack.num_common_caps = channel->num_common_caps; ack.num_channel_caps = channel->num_caps; header.size += (ack.num_common_caps + ack.num_channel_caps) * sizeof(uint32_t); @@ -1687,6 +1712,31 @@ static void async_read_handler(int fd, int event, void *data) } } +static void reds_get_spice_ticket(RedLinkInfo *link) +{ + AsyncRead *obj = &link->asyc_read; + + obj->now = (uint8_t *)&link->tiTicketing.encrypted_ticket.encrypted_data; + obj->end = obj->now + link->tiTicketing.rsa_size; + obj->done = reds_handle_ticket; + async_read_handler(0, 0, &link->asyc_read); +} + +static void reds_handle_auth_mechanism(void *opaque) +{ + RedLinkInfo *link = (RedLinkInfo *)opaque; + + red_printf("Auth method: %d", link->auth_mechanism.auth_mechanism); + + if (link->auth_mechanism.auth_mechanism == SPICE_COMMON_CAP_AUTH_SPICE) { + reds_get_spice_ticket(link); + } else { + red_printf("Unknown auth method, disconnecting"); + reds_send_link_error(link, SPICE_LINK_ERR_INVALID_DATA); + reds_link_free(link); + } +} + static int reds_security_check(RedLinkInfo *link) { ChannelSecurityOptions *security_option = find_channel_security(link->link_mess->channel_type); @@ -1701,6 +1751,8 @@ static void reds_handle_read_link_done(void *opaque) SpiceLinkMess *link_mess = link->link_mess; AsyncRead *obj = &link->asyc_read; uint32_t num_caps = link_mess->num_common_caps + link_mess->num_channel_caps; + uint32_t *caps = (uint32_t *)((uint8_t *)link_mess + link_mess->caps_offset); + int auth_selection; if (num_caps && (num_caps * sizeof(uint32_t) + link_mess->caps_offset > link->link_header.size || @@ -1710,6 +1762,9 @@ static void reds_handle_read_link_done(void *opaque) return; } + auth_selection = link_mess->num_common_caps > 0 && + (caps[0] & (1 << SPICE_COMMON_CAP_PROTOCOL_AUTH_SELECTION));; + if (!reds_security_check(link)) { if (link->stream->ssl) { red_printf("spice channels %d should not be encrypted", link_mess->channel_type); @@ -1727,10 +1782,15 @@ static void reds_handle_read_link_done(void *opaque) return; } - obj->now = (uint8_t *)&link->tiTicketing.encrypted_ticket.encrypted_data; - obj->end = obj->now + link->tiTicketing.rsa_size; - obj->done = reds_handle_ticket; - async_read_handler(0, 0, &link->asyc_read); + if (!auth_selection) { + red_printf("Peer doesn't support AUTH selection"); + reds_get_spice_ticket(link); + } else { + obj->now = (uint8_t *)&link->auth_mechanism; + obj->end = obj->now + sizeof(SpiceLinkAuthMechanism); + obj->done = reds_handle_auth_mechanism; + async_read_handler(0, 0, &link->asyc_read); + } } static void reds_handle_link_error(void *opaque, int err) |