diff options
author | Aris Adamantiadis <aris@0xbadc0de.be> | 2011-09-02 13:58:37 +0300 |
---|---|---|
committer | Aris Adamantiadis <aris@0xbadc0de.be> | 2011-09-02 13:58:37 +0300 |
commit | 20f8e73e3eaf8d7f69786db3f6385095e244e85c (patch) | |
tree | 1fba85badc6b95baceda20f4871735eac6a2404b /src | |
parent | ef5701a5357e3f5b71aa5c387e4f976fe5df0ab7 (diff) | |
download | libssh-20f8e73e3eaf8d7f69786db3f6385095e244e85c.tar.gz libssh-20f8e73e3eaf8d7f69786db3f6385095e244e85c.tar.xz libssh-20f8e73e3eaf8d7f69786db3f6385095e244e85c.zip |
Update libssh to ssh_handle_packets_termination
cherry-picked from 0cb5248
Should resolve all timeout problems
Conflicts:
src/auth.c
src/channels.c
Diffstat (limited to 'src')
-rw-r--r-- | src/auth.c | 11 | ||||
-rw-r--r-- | src/channels.c | 243 | ||||
-rw-r--r-- | src/client.c | 93 | ||||
-rw-r--r-- | src/messages.c | 28 | ||||
-rw-r--r-- | src/misc.c | 5 | ||||
-rw-r--r-- | src/server.c | 36 | ||||
-rw-r--r-- | src/session.c | 94 |
7 files changed, 293 insertions, 217 deletions
@@ -317,6 +317,17 @@ SSH_PACKET_CALLBACK(ssh_packet_userauth_pk_ok){ return rc; } +static int auth_status_termination(void *user){ + ssh_session session=(ssh_session)user; + switch(session->auth_state){ + case SSH_AUTH_STATE_NONE: + case SSH_AUTH_STATE_KBDINT_SENT: + return 0; + default: + return 1; + } +} + /** * @brief Get available authentication methods from the server. * diff --git a/src/channels.c b/src/channels.c index 6dde48b4..6925978b 100644 --- a/src/channels.c +++ b/src/channels.c @@ -220,6 +220,15 @@ SSH_PACKET_CALLBACK(ssh_packet_channel_open_fail){ return SSH_PACKET_USED; } +static int ssh_channel_open_termination(void *c){ + ssh_channel channel = (ssh_channel) c; + if (channel->state != SSH_CHANNEL_STATE_OPENING || + channel->session->session_state == SSH_SESSION_STATE_ERROR) + return 1; + else + return 0; +} + /** * @internal * @@ -242,7 +251,6 @@ static int channel_open(ssh_channel channel, const char *type_c, int window, int maxpacket, ssh_buffer payload) { ssh_session session = channel->session; ssh_string type = NULL; - int timeout; int err=SSH_ERROR; enter_function(); @@ -304,20 +312,9 @@ static int channel_open(ssh_channel channel, const char *type_c, int window, type_c, channel->local_channel); pending: /* wait until channel is opened by server */ - if(ssh_is_blocking(session)) - timeout=-2; - else - timeout=0; - while(channel->state == SSH_CHANNEL_STATE_OPENING){ - err = ssh_handle_packets(session, timeout); - if (err != SSH_OK) { - break; - } - if (session->session_state == SSH_SESSION_STATE_ERROR) { - err = SSH_ERROR; - break; - } - } + err = ssh_handle_packets_termination(session, SSH_TIMEOUT_USER, ssh_channel_open_termination, channel); + if (err != SSH_OK || session->session_state == SSH_SESSION_STATE_ERROR) + err = SSH_ERROR; end: if(channel->state == SSH_CHANNEL_STATE_OPEN) err=SSH_OK; @@ -1173,13 +1170,35 @@ error: return rc; } +/* this termination function waits for a window growing condition */ +static int ssh_channel_waitwindow_termination(void *c){ + ssh_channel channel = (ssh_channel) c; + if (channel->remote_window > 0 || + channel->session->session_state == SSH_SESSION_STATE_ERROR) + return 1; + else + return 0; +} + +/** + * @internal + * @brief Flushes a channel (and its session) until the output buffer + * is empty, or timeout elapsed. + * @param channel SSH channel + * @returns SSH_OK On success, + * SSH_ERROR on error + * SSH_AGAIN Timeout elapsed (or in nonblocking mode) + */ +int ssh_channel_flush(ssh_channel channel){ + return ssh_blocking_flush(channel->session, SSH_TIMEOUT_USER); +} + int channel_write_common(ssh_channel channel, const void *data, uint32_t len, int is_stderr) { ssh_session session; uint32_t origlen = len; size_t effectivelen; size_t maxpacketlen; - int timeout; int rc; if(channel == NULL) { @@ -1198,10 +1217,7 @@ int channel_write_common(ssh_channel channel, const void *data, } enter_function(); - if(ssh_is_blocking(session)) - timeout = -2; - else - timeout = 0; + /* * Handle the max packet len from remote side, be nice * 10 bytes for the headers @@ -1242,8 +1258,9 @@ int channel_write_common(ssh_channel channel, const void *data, /* nothing can be written */ ssh_log(session, SSH_LOG_PROTOCOL, "Wait for a growing window message..."); - rc = ssh_handle_packets(session, timeout); - if (rc == SSH_ERROR || (channel->remote_window == 0 && timeout==0)) + rc = ssh_handle_packets_termination(session, SSH_TIMEOUT_USER, + ssh_channel_waitwindow_termination,channel); + if (rc == SSH_ERROR || !ssh_channel_waitwindow_termination(channel)) goto out; continue; } @@ -1285,9 +1302,9 @@ int channel_write_common(ssh_channel channel, const void *data, data = ((uint8_t*)data + effectivelen); } /* it's a good idea to flush the socket now */ - do { - rc = ssh_handle_packets(session, timeout); - } while(ssh_socket_buffered_write_bytes(session->socket) > 0 && timeout != 0); + rc = ssh_channel_flush(channel); + if(rc == SSH_ERROR) + goto error; out: leave_function(); return (int)(origlen - len); @@ -1458,11 +1475,19 @@ SSH_PACKET_CALLBACK(ssh_packet_channel_failure){ return SSH_PACKET_USED; } +static int ssh_channel_request_termination(void *c){ + ssh_channel channel = (ssh_channel)c; + if(channel->request_state != SSH_CHANNEL_REQ_STATE_PENDING || + channel->session->session_state == SSH_SESSION_STATE_ERROR) + return 1; + else + return 0; +} + static int channel_request(ssh_channel channel, const char *request, ssh_buffer buffer, int reply) { ssh_session session = channel->session; ssh_string req = NULL; - int timeout; int rc = SSH_ERROR; enter_function(); @@ -1509,20 +1534,9 @@ static int channel_request(ssh_channel channel, const char *request, return SSH_OK; } pending: - if(ssh_is_blocking(session)) - timeout=-2; - else - timeout=0; - while(channel->request_state == SSH_CHANNEL_REQ_STATE_PENDING){ - ssh_handle_packets(session, timeout); - if(channel->request_state == SSH_CHANNEL_REQ_STATE_PENDING && timeout==0){ - leave_function(); - return SSH_AGAIN; - } - if(session->session_state == SSH_SESSION_STATE_ERROR) { + rc = ssh_handle_packets_termination(session,SSH_TIMEOUT_USER, ssh_channel_request_termination, channel); + if(session->session_state == SSH_SESSION_STATE_ERROR) { channel->request_state = SSH_CHANNEL_REQ_STATE_ERROR; - break; - } } /* we received something */ switch (channel->request_state){ @@ -1539,8 +1553,11 @@ pending: "Channel request %s success",request); rc=SSH_OK; break; - case SSH_CHANNEL_REQ_STATE_NONE: case SSH_CHANNEL_REQ_STATE_PENDING: + rc = SSH_AGAIN; + leave_function(); + return rc; + case SSH_CHANNEL_REQ_STATE_NONE: /* Never reached */ ssh_set_error(session, SSH_FATAL, "Invalid state in channel_request()"); rc=SSH_ERROR; @@ -1995,6 +2012,15 @@ SSH_PACKET_CALLBACK(ssh_request_denied){ } +static int ssh_global_request_termination(void *s){ + ssh_session session = (ssh_session) s; + if (session->global_req_state != SSH_CHANNEL_REQ_STATE_PENDING || + session->session_state != SSH_SESSION_STATE_ERROR) + return 1; + else + return 0; +} + /** * @internal * @@ -2018,7 +2044,6 @@ static int global_request(ssh_session session, const char *request, ssh_buffer buffer, int reply) { ssh_string req = NULL; int rc = SSH_ERROR; - int timeout; enter_function(); if(session->global_req_state != SSH_CHANNEL_REQ_STATE_NONE) @@ -2059,19 +2084,10 @@ static int global_request(ssh_session session, const char *request, return SSH_OK; } pending: - if(ssh_is_blocking(session)) - timeout=-2; - else - timeout=0; - while(session->global_req_state == SSH_CHANNEL_REQ_STATE_PENDING){ - rc=ssh_handle_packets(session, timeout); - if(rc==SSH_ERROR){ - session->global_req_state = SSH_CHANNEL_REQ_STATE_ERROR; - break; - } - if(session->global_req_state == SSH_CHANNEL_REQ_STATE_PENDING - && timeout == 0) - break; + rc = ssh_handle_packets_termination(session, SSH_TIMEOUT_USER, + ssh_global_request_termination, session); + if(rc==SSH_ERROR || session->session_state == SSH_SESSION_STATE_ERROR){ + session->global_req_state = SSH_CHANNEL_REQ_STATE_ERROR; } switch(session->global_req_state){ case SSH_CHANNEL_REQ_STATE_ACCEPTED: @@ -2469,6 +2485,7 @@ error: * @return The number of bytes read, 0 on end of file or SSH_ERROR * on error. * @deprecated Please use ssh_channel_read instead + * @warning This function doesn't work in nonblocking/timeout mode * @see ssh_channel_read */ int channel_read_buffer(ssh_channel channel, ssh_buffer buffer, uint32_t count, @@ -2502,9 +2519,9 @@ int channel_read_buffer(ssh_channel channel, ssh_buffer buffer, uint32_t count, return r; } if(buffer_add_data(buffer,buffer_tmp,r) < 0){ - ssh_set_error_oom(session); - r = SSH_ERROR; - } + ssh_set_error_oom(session); + r = SSH_ERROR; + } leave_function(); return r; } @@ -2512,7 +2529,7 @@ int channel_read_buffer(ssh_channel channel, ssh_buffer buffer, uint32_t count, leave_function(); return 0; } - ssh_handle_packets(channel->session, -2); + ssh_handle_packets(channel->session, SSH_TIMEOUT_INFINITE); } while (r == 0); } while(total < count){ @@ -2536,6 +2553,22 @@ int channel_read_buffer(ssh_channel channel, ssh_buffer buffer, uint32_t count, return total; } +struct ssh_channel_read_termination_struct { + ssh_channel channel; + uint32_t count; + ssh_buffer buffer; +}; + +static int ssh_channel_read_termination(void *s){ + struct ssh_channel_read_termination_struct *ctx = s; + if (buffer_get_rest_len(ctx->buffer) >= ctx->count || + ctx->channel->remote_eof || + ctx->channel->session->session_state == SSH_SESSION_STATE_ERROR) + return 1; + else + return 0; +} + /* TODO FIXME Fix the delayed close thing */ /* TODO FIXME Fix the blocking behaviours */ @@ -2563,11 +2596,8 @@ int ssh_channel_read(ssh_channel channel, void *dest, uint32_t count, int is_std ssh_session session; ssh_buffer stdbuf; uint32_t len; -<<<<<<< HEAD - int rc; -======= - int ret; ->>>>>>> 6091147... channel: ssh_channel_read is nonblocking, + docfixes + struct ssh_channel_read_termination_struct ctx; + int ret, rc; if(channel == NULL) { return SSH_ERROR; @@ -2612,31 +2642,22 @@ int ssh_channel_read(ssh_channel channel, void *dest, uint32_t count, int is_std } } - /* block reading until at least one byte is read + /* block reading until all bytes are read * and ignore the trivial case count=0 */ - while (buffer_get_rest_len(stdbuf) == 0 && count > 0) { - if (channel->remote_eof && buffer_get_rest_len(stdbuf) == 0) { - leave_function(); - return 0; - } - - if (channel->remote_eof) { - /* Return the resting bytes in buffer */ - break; - } - - if (buffer_get_rest_len(stdbuf) >= count) { - /* Stop reading when buffer is full enough */ - break; - } - - rc = ssh_handle_packets(session, -2); - if (rc != SSH_OK) { - return rc; - } + ctx.channel = channel; + ctx.buffer = stdbuf; + ctx.count = count; + rc = ssh_handle_packets_termination(session, SSH_TIMEOUT_USER, + ssh_channel_read_termination, &ctx); + if (rc == SSH_ERROR){ + leave_function(); + return rc; + } + if (channel->remote_eof && buffer_get_rest_len(stdbuf) == 0) { + leave_function(); + return 0; } - len = buffer_get_rest_len(stdbuf); /* Read count bytes if len is greater, everything otherwise */ len = (len > count ? count : len); @@ -2739,7 +2760,7 @@ int ssh_channel_poll(ssh_channel channel, int is_stderr){ } if (buffer_get_rest_len(stdbuf) == 0 && channel->remote_eof == 0) { - if (ssh_handle_packets(channel->session, 0)==SSH_ERROR) { + if (ssh_handle_packets(channel->session, SSH_TIMEOUT_NONBLOCKING)==SSH_ERROR) { leave_function(); return SSH_ERROR; } @@ -2829,6 +2850,18 @@ ssh_session ssh_channel_get_session(ssh_channel channel) { return channel->session; } +static int ssh_channel_exit_status_termination(void *c){ + ssh_channel channel = c; + if(channel->exit_status != -1 || + /* When a channel is closed, no exit status message can + * come anymore */ + (channel->flags & SSH_CHANNEL_FLAG_CLOSED_REMOTE) || + channel->session->session_state == SSH_SESSION_STATE_ERROR) + return 1; + else + return 0; +} + /** * @brief Get the exit status of the channel (error code from the executed * instruction). @@ -2836,38 +2869,20 @@ ssh_session ssh_channel_get_session(ssh_channel channel) { * @param[in] channel The channel to get the status from. * * @returns The exit status, -1 if no exit status has been returned - * or eof not sent. + * (yet). + * @warning This function may block until a timeout (or never) + * if the other side is not willing to close the channel. */ int ssh_channel_get_exit_status(ssh_channel channel) { - int timeout; + int rc; if(channel == NULL) { return SSH_ERROR; } - - if (channel->local_eof == 0) { - return channel->exit_status; - } - if(ssh_is_blocking(channel->session)) - timeout = -2; - else - timeout = 0; - while ((channel->remote_eof == 0 || channel->exit_status == -1) && channel->session->alive) { - /* Parse every incoming packet */ - if (ssh_handle_packets(channel->session, timeout) != SSH_OK) { - return -1; - } - /* XXX We should actually wait for a close packet and not a close - * we issued ourselves - */ - if (channel->state != SSH_CHANNEL_STATE_OPEN) { - /* When a channel is closed, no exit status message can - * come anymore */ - break; - } - if (timeout == 0) - break; - } - + rc = ssh_handle_packets_termination(channel->session, SSH_TIMEOUT_USER, + ssh_channel_exit_status_termination, channel); + if (rc == SSH_ERROR || channel->session->session_state == + SSH_SESSION_STATE_ERROR) + return SSH_ERROR; return channel->exit_status; } @@ -2890,7 +2905,7 @@ static int channel_protocol_select(ssh_channel *rchans, ssh_channel *wchans, chan = rchans[i]; while (ssh_channel_is_open(chan) && ssh_socket_data_available(chan->session->socket)) { - ssh_handle_packets(chan->session, -2); + ssh_handle_packets(chan->session, SSH_TIMEOUT_NONBLOCKING); } if ((chan->stdout_buffer && buffer_get_rest_len(chan->stdout_buffer) > 0) || diff --git a/src/client.c b/src/client.c index 209524c3..4a7dea75 100644 --- a/src/client.c +++ b/src/client.c @@ -386,6 +386,15 @@ SSH_PACKET_CALLBACK(ssh_packet_service_accept){ return SSH_PACKET_USED; } +static int ssh_service_request_termination(void *s){ + ssh_session session = (ssh_session)s; + if(session->session_state == SSH_SESSION_STATE_ERROR || + session->auth_service_state != SSH_AUTH_SERVICE_SENT) + return 1; + else + return 0; +} + /** * @internal * @@ -405,48 +414,52 @@ int ssh_service_request(ssh_session session, const char *service) { ssh_string service_s = NULL; int rc=SSH_ERROR; enter_function(); - switch(session->auth_service_state){ - case SSH_AUTH_SERVICE_NONE: - if (buffer_add_u8(session->out_buffer, SSH2_MSG_SERVICE_REQUEST) < 0) { - break; - } - service_s = ssh_string_from_char(service); - if (service_s == NULL) { - break; - } - - if (buffer_add_ssh_string(session->out_buffer,service_s) < 0) { - ssh_string_free(service_s); - break; - } - ssh_string_free(service_s); - - if (packet_send(session) == SSH_ERROR) { - ssh_set_error(session, SSH_FATAL, - "Sending SSH2_MSG_SERVICE_REQUEST failed."); - break; - } - - ssh_log(session, SSH_LOG_PACKET, - "Sent SSH_MSG_SERVICE_REQUEST (service %s)", service); - session->auth_service_state=SSH_AUTH_SERVICE_SENT; - rc=SSH_AGAIN; - break; - case SSH_AUTH_SERVICE_DENIED: - ssh_set_error(session,SSH_FATAL,"ssh_auth_service request denied"); - break; - case SSH_AUTH_SERVICE_ACCEPTED: - rc=SSH_OK; - break; - case SSH_AUTH_SERVICE_SENT: - rc=SSH_AGAIN; - break; - case SSH_AUTH_SERVICE_USER_SENT: - /* Invalid state, SSH1 specific */ - rc=SSH_ERROR; - break; + if(session->auth_service_state != SSH_AUTH_SERVICE_NONE) + goto pending; + if (buffer_add_u8(session->out_buffer, SSH2_MSG_SERVICE_REQUEST) < 0) { + goto error; + } + service_s = ssh_string_from_char(service); + if (service_s == NULL) { + goto error; } + if (buffer_add_ssh_string(session->out_buffer,service_s) < 0) { + ssh_string_free(service_s); + goto error; + } + ssh_string_free(service_s); + session->auth_service_state=SSH_AUTH_SERVICE_SENT; + if (packet_send(session) == SSH_ERROR) { + ssh_set_error(session, SSH_FATAL, + "Sending SSH2_MSG_SERVICE_REQUEST failed."); + goto error; + } + + ssh_log(session, SSH_LOG_PACKET, + "Sent SSH_MSG_SERVICE_REQUEST (service %s)", service); +pending: + rc=ssh_handle_packets_termination(session,SSH_TIMEOUT_USER, + ssh_service_request_termination, session); + if(rc == SSH_ERROR) + goto error; + switch(session->auth_service_state){ + case SSH_AUTH_SERVICE_DENIED: + ssh_set_error(session,SSH_FATAL,"ssh_auth_service request denied"); + break; + case SSH_AUTH_SERVICE_ACCEPTED: + rc=SSH_OK; + break; + case SSH_AUTH_SERVICE_SENT: + rc=SSH_AGAIN; + break; + case SSH_AUTH_SERVICE_NONE: + case SSH_AUTH_SERVICE_USER_SENT: + /* Invalid state, SSH1 specific */ + rc=SSH_ERROR; + break; + } +error: leave_function(); return rc; } diff --git a/src/messages.c b/src/messages.c index a9398bba..217d9101 100644 --- a/src/messages.c +++ b/src/messages.c @@ -160,18 +160,32 @@ ssh_message ssh_message_pop_head(ssh_session session){ return msg; } +/* Returns 1 if there is a message available */ +static int ssh_message_termination(void *s){ + ssh_session session = s; + struct ssh_iterator *it; + if(session->session_state == SSH_SESSION_STATE_ERROR) + return 1; + it = ssh_list_get_iterator(session->ssh_message_list); + if(!it) + return 0; + else + return 1; +} /** * @brief Retrieve a SSH message from a SSH session. * * @param[in] session The SSH session to get the message. * - * @returns The SSH message received, NULL in case of error. + * @returns The SSH message received, NULL in case of error, or timeout + * elapsed. * * @warning This function blocks until a message has been received. Betterset up * a callback if this behavior is unwanted. */ ssh_message ssh_message_get(ssh_session session) { ssh_message msg = NULL; + int rc; enter_function(); msg=ssh_message_pop_head(session); @@ -182,13 +196,11 @@ ssh_message ssh_message_get(ssh_session session) { if(session->ssh_message_list == NULL) { session->ssh_message_list = ssh_list_new(); } - do { - if (ssh_handle_packets(session, -2) == SSH_ERROR) { - leave_function(); - return NULL; - } - msg=ssh_list_pop_head(ssh_message, session->ssh_message_list); - } while(msg==NULL); + rc = ssh_handle_packets_termination(session, SSH_TIMEOUT_USER, + ssh_message_termination, session); + if(rc || session->session_state == SSH_SESSION_STATE_ERROR) + return NULL; + msg=ssh_list_pop_head(ssh_message, session->ssh_message_list); leave_function(); return msg; } @@ -933,7 +933,7 @@ int ssh_timeout_elapsed(struct ssh_timestamp *ts, int timeout) { struct ssh_timestamp now; if(timeout < 0) return 0; // -1 means infinite timeout - if(timeout == 0) + if(timeout == SSH_TIMEOUT_NONBLOCKING) return 1; // 0 means no timeout ssh_timestamp_init(&now); @@ -948,8 +948,7 @@ int ssh_timeout_elapsed(struct ssh_timestamp *ts, int timeout) { * @param[in] ts pointer to an existing timestamp * @param[in] timeout timeout in milliseconds. Negative values mean infinite * timeout - * @returns remaining time in milliseconds, 0 if elapsed, -1 if never, - * -2 if option-set-timeout. + * @returns remaining time in milliseconds, 0 if elapsed, -1 if never. */ int ssh_timeout_update(struct ssh_timestamp *ts, int timeout){ struct ssh_timestamp now; diff --git a/src/server.c b/src/server.c index 5b2ccee8..5ca036f3 100644 --- a/src/server.c +++ b/src/server.c @@ -452,10 +452,22 @@ static int callback_receive_banner(const void *data, size_t len, void *user) { return ret; } +/* returns 0 until the key exchange is not finished */ +static int ssh_server_kex_termination(void *s){ + ssh_session session = s; + if (session->session_state != SSH_SESSION_STATE_ERROR && + session->session_state != SSH_SESSION_STATE_AUTHENTICATING && + session->session_state != SSH_SESSION_STATE_DISCONNECTED) + return 0; + else + return 1; +} + /* Do the banner and key exchange */ int ssh_handle_key_exchange(ssh_session session) { int rc; - + if (session->session_state != SSH_SESSION_STATE_NONE) + goto pending; rc = ssh_send_banner(session, 1); if (rc < 0) { return SSH_ERROR; @@ -474,22 +486,16 @@ int ssh_handle_key_exchange(ssh_session session) { if (rc < 0) { return SSH_ERROR; } - - while (session->session_state != SSH_SESSION_STATE_ERROR && - session->session_state != SSH_SESSION_STATE_AUTHENTICATING && - session->session_state != SSH_SESSION_STATE_DISCONNECTED) { - /* - * loop until SSH_SESSION_STATE_BANNER_RECEIVED or - * SSH_SESSION_STATE_ERROR - */ - ssh_handle_packets(session, -2); - ssh_log(session,SSH_LOG_PACKET, "ssh_handle_key_exchange: Actual state : %d", - session->session_state); - } - + pending: + rc = ssh_handle_packets_termination(session, SSH_TIMEOUT_USER, + ssh_server_kex_termination,session); + ssh_log(session,SSH_LOG_PACKET, "ssh_handle_key_exchange: Actual state : %d", + session->session_state); + if (rc != SSH_OK) + return rc; if (session->session_state == SSH_SESSION_STATE_ERROR || session->session_state == SSH_SESSION_STATE_DISCONNECTED) { - return SSH_ERROR; + return SSH_ERROR; } return SSH_OK; diff --git a/src/session.c b/src/session.c index 751aa537..55ff22a7 100644 --- a/src/session.c +++ b/src/session.c @@ -302,6 +302,16 @@ int ssh_is_blocking(ssh_session session){ return (session->flags&SSH_SESSION_FLAG_BLOCKING) ? 1 : 0; } +/* Waits until the output socket is empty */ +static int ssh_flush_termination(void *c){ + ssh_session session = c; + if (ssh_socket_buffered_write_bytes(session->socket) == 0 || + session->session_state == SSH_SESSION_STATE_ERROR) + return 1; + else + return 0; +} + /** * @brief Blocking flush of the outgoing buffer * @param[in] session The SSH session @@ -314,26 +324,20 @@ int ssh_is_blocking(ssh_session session){ */ int ssh_blocking_flush(ssh_session session, int timeout){ - ssh_socket s; - struct ssh_timestamp ts; - int rc = SSH_OK; - if(session==NULL) - return SSH_ERROR; - - enter_function(); - s=session->socket; - ssh_timestamp_init(&ts); - while (ssh_socket_buffered_write_bytes(s) > 0 && session->alive) { - rc=ssh_handle_packets(session, timeout); - if(ssh_timeout_elapsed(&ts,timeout)){ - rc=SSH_AGAIN; - break; - } - timeout = ssh_timeout_update(&ts, timeout); - } + int rc; + if(!session) + return SSH_ERROR; + enter_function(); - leave_function(); - return rc; + rc = ssh_handle_packets_termination(session, timeout, + ssh_flush_termination, session); + if (rc == SSH_ERROR) + goto end; + if (!ssh_flush_termination(session)) + rc = SSH_AGAIN; +end: + leave_function(); + return rc; } /** @@ -424,18 +428,18 @@ static int ssh_make_milliseconds(long sec, long usec) { * @internal * * @brief Poll the current session for an event and call the appropriate - * callbacks. + * callbacks. This function will not loop until the timeout is expired. * * This will block until one event happens. * * @param[in] session The session handle to use. * * @param[in] timeout Set an upper limit on the time for which this function - * will block, in milliseconds. Specifying -1 - * means an infinite timeout. - * Specifying -2 means to use the timeout specified in - * options. 0 means poll will return immediately. This - * parameter is passed to the poll() function. + * will block, in milliseconds. Specifying SSH_TIMEOUT_INFINITE + * (-1) means an infinite timeout. + * Specifying SSH_TIMEOUT_USER means to use the timeout + * specified in options. 0 means poll will return immediately. + * This parameter is passed to the poll() function. * * @return SSH_OK on success, SSH_ERROR otherwise. */ @@ -465,8 +469,11 @@ int ssh_handle_packets(ssh_session session, int timeout) { } } - if (timeout == -2) { - tm = ssh_make_milliseconds(session->timeout, session->timeout_usec); + if (timeout == SSH_TIMEOUT_USER) { + if (ssh_is_blocking(session)) + tm = ssh_make_milliseconds(session->timeout, session->timeout_usec); + else + tm = 0; } rc = ssh_poll_ctx_dopoll(ctx, tm); if (rc == SSH_ERROR) { @@ -483,14 +490,17 @@ int ssh_handle_packets(ssh_session session, int timeout) { * @brief Poll the current session for an event and call the appropriate * callbacks. * - * This will block until termination fuction returns true, or timeout expired. + * This will block until termination function returns true, or timeout expired. * * @param[in] session The session handle to use. * * @param[in] timeout Set an upper limit on the time for which this function - * will block, in milliseconds. Specifying a negative value - * means an infinite timeout. This parameter is passed to - * the poll() function. + * will block, in milliseconds. Specifying SSH_TIMEOUT_INFINITE + * (-1) means an infinite timeout. + * Specifying SSH_TIMEOUT_USER means to use the timeout + * specified in options. 0 means poll will return immediately. + * This parameter is passed to the poll() function. + * * @param[in] fct Termination function to be used to determine if it is * possible to stop polling. * @param[in] user User parameter to be passed to fct termination function. @@ -499,13 +509,23 @@ int ssh_handle_packets(ssh_session session, int timeout) { int ssh_handle_packets_termination(ssh_session session, int timeout, ssh_termination_function fct, void *user){ int ret = SSH_OK; - + struct ssh_timestamp ts; + int tm; + if (timeout == SSH_TIMEOUT_USER) { + if (ssh_is_blocking(session)) + timeout = ssh_make_milliseconds(session->timeout, session->timeout_usec); + else + timeout = SSH_TIMEOUT_NONBLOCKING; + } + ssh_timestamp_init(&ts); + tm = timeout; while(!fct(user)){ - ret = ssh_handle_packets(session, timeout); - if(ret == SSH_ERROR || ret == SSH_AGAIN) - return ret; - if(fct(user)) - return SSH_OK; + ret = ssh_handle_packets(session, tm); + if(ret == SSH_ERROR) + break; + if(ssh_timeout_elapsed(&ts,timeout)) + break; + tm = ssh_timeout_update(&ts, timeout); } return ret; } |