diff options
-rw-r--r-- | include/libssh/libssh.h | 163 | ||||
-rw-r--r-- | include/libssh/priv.h | 8 | ||||
-rw-r--r-- | libssh/poll.c | 226 |
3 files changed, 389 insertions, 8 deletions
diff --git a/include/libssh/libssh.h b/include/libssh/libssh.h index 18d42c73..4f9ffd10 100644 --- a/include/libssh/libssh.h +++ b/include/libssh/libssh.h @@ -375,6 +375,169 @@ char *ssh_userauth_kbdint_getprompt(SSH_SESSION *session, unsigned int i, char * int ssh_userauth_kbdint_setanswer(SSH_SESSION *session, unsigned int i, const char *answer); +/* poll.c */ +#define POLLIN 0x001 /* There is data to read. */ +#define POLLPRI 0x002 /* There is urgent data to read. */ +#define POLLOUT 0x004 /* Writing now will not block. */ + +#define POLLERR 0x008 /* Error condition. */ +#define POLLHUP 0x010 /* Hung up. */ +#define POLLNVAL 0x020 /* Invalid polling request. */ + +typedef struct ssh_poll_ctx SSH_POLL_CTX; +typedef struct ssh_poll SSH_POLL; + +/** + * @brief SSH poll callback. + * + * @param p Poll object this callback belongs to. + * @param fd The raw socket. + * @param revents The current poll events on the socket. + * @param userdata Userdata to be passed to the callback function. + * + * @return 0 on success, < 0 if you removed the poll object from + * it's poll context. + */ +typedef int (*ssh_poll_callback)(SSH_POLL *p, int fd, int revents, + void *userdata); + +/** + * @brief Allocate a new poll object, which could be used within a poll context. + * + * @param fd Socket that will be polled. + * @param events Poll events that will be monitored for the socket. i.e. + * POLLIN, POLLPRI, POLLOUT, POLLERR, POLLHUP, POLLNVAL + * @param cb Function to be called if any of the events are set. + * @param userdata Userdata to be passed to the callback function. NULL if + * not needed. + * + * @return A new poll object, NULL on error + */ +SSH_POLL *ssh_poll_new(socket_t fd, short events, ssh_poll_callback cb, + void *userdata); + +/** + * @brief Free a poll object. + * + * @param p Pointer to an already allocated poll object. + */ +void ssh_poll_free(SSH_POLL *p); + +/** + * @brief Get the poll context of a poll object. + * + * @param p Pointer to an already allocated poll object. + * + * @return Poll context or NULL if the poll object isn't attached. + */ +SSH_POLL_CTX *ssh_poll_get_ctx(SSH_POLL *p); + +/** + * @brief Get the events of a poll object. + * + * @param p Pointer to an already allocated poll object. + * + * @return Poll events. + */ +short ssh_poll_get_events(SSH_POLL *p); + +/** + * @brief Set the events of a poll object. The events will also be propagated + * to an associated poll context. + * + * @param p Pointer to an already allocated poll object. + * @param events Poll events. + */ +void ssh_poll_set_events(SSH_POLL *p, short events); + +/** + * @brief Add extra events to a poll object. Duplicates are ignored. + * The events will also be propagated to an associated poll context. + * + * @param p Pointer to an already allocated poll object. + * @param events Poll events. + */ +void ssh_poll_add_events(SSH_POLL *p, short events); + +/** + * @brief Remove events from a poll object. Non-existent are ignored. + * The events will also be propagated to an associated poll context. + * + * @param p Pointer to an already allocated poll object. + * @param events Poll events. + */ +void ssh_poll_remove_events(SSH_POLL *p, short events); + +/** + * @brief Get the raw socket of a poll object. + * + * @param p Pointer to an already allocated poll object. + * + * @return Raw socket. + */ +int ssh_poll_get_fd(SSH_POLL *p); + +/** + * @brief Set the callback of a poll object. + * + * @param p Pointer to an already allocated poll object. + * @param cb Function to be called if any of the events are set. + * @param userdata Userdata to be passed to the callback function. NULL if + * not needed. + */ +void ssh_poll_set_callback(SSH_POLL *p, ssh_poll_callback cb, void *userdata); + +/** + * @brief Create a new poll context. It could be associated with many poll object + * which are going to be polled at the same time as the poll context. You + * would need a single poll context per thread. + * + * @param chunk_size The size of the memory chunk that will be allocated, when + * more memory is needed. This is for efficiency reasons, + * i.e. don't allocate memory for each new poll object, but + * for the next 5. Set it to 0 if you want to use the + * library's default value. + */ +SSH_POLL_CTX *ssh_poll_ctx_new(size_t chunk_size); + +/** + * @brief Free a poll context. + * + * @param ctx Pointer to an already allocated poll context. + */ +void ssh_poll_ctx_free(SSH_POLL_CTX *ctx); + +/** + * @brief Add a poll object to a poll context. + * + * @param ctx Pointer to an already allocated poll context. + * @param p Pointer to an already allocated poll object. + * + * @return 0 on success, < 0 on error + */ +int ssh_poll_ctx_add(SSH_POLL_CTX *ctx, SSH_POLL *p); + +/** + * @brief Remove a poll object from a poll context. + * + * @param ctx Pointer to an already allocated poll context. + * @param p Pointer to an already allocated poll object. + */ +void ssh_poll_ctx_remove(SSH_POLL_CTX *ctx, SSH_POLL *p); + +/** + * @brief Poll all the sockets associated through a poll object with a + * poll context. If any of the events are set after the poll, the + * call back function of the socket will be called. + * This function should be called once within the programs main loop. + * + * @param ctx Pointer to an already allocated poll context. + * @param timeout An upper limit on the time for which ssh_poll_ctx() will + * block, in milliseconds. Specifying a negative value + * means an infinite timeout. This parameter is passed to + * the poll() function. + */ +int ssh_poll_ctx(SSH_POLL_CTX *ctx, int timeout); /* init.c */ int ssh_init(void); diff --git a/include/libssh/priv.h b/include/libssh/priv.h index 4c380deb..41ee58b6 100644 --- a/include/libssh/priv.h +++ b/include/libssh/priv.h @@ -144,14 +144,6 @@ typedef struct pollfd_s { short revents; /* returned events */ } pollfd_t; -#define POLLIN 0x001 /* There is data to read. */ -#define POLLPRI 0x002 /* There is urgent data to read. */ -#define POLLOUT 0x004 /* Writing now will not block. */ - -#define POLLERR 0x008 /* Error condition. */ -#define POLLHUP 0x010 /* Hung up. */ -#define POLLNVAL 0x020 /* Invalid polling request. */ - typedef unsigned long int nfds_t; #endif /* HAVE_POLL */ diff --git a/libssh/poll.c b/libssh/poll.c index c20300dd..f2fea520 100644 --- a/libssh/poll.c +++ b/libssh/poll.c @@ -27,6 +27,30 @@ #include "config.h" #include "libssh/priv.h" +#include "libssh/libssh.h" + +#ifndef SSH_POLL_CTX_CHUNK +#define SSH_POLL_CTX_CHUNK 5 +#endif + +struct ssh_poll { + SSH_POLL_CTX *ctx; + union { + socket_t fd; + size_t idx; + }; + short events; + ssh_poll_callback cb; + void *cb_data; +}; + +struct ssh_poll_ctx { + SSH_POLL **pollptrs; + pollfd_t *pollfds; + size_t polls_allocated; + size_t polls_used; + size_t chunk_size; +}; #ifdef HAVE_POLL #include <poll.h> @@ -202,3 +226,205 @@ int ssh_poll(pollfd_t *fds, nfds_t nfds, int timeout) { #endif /* HAVE_POLL */ +SSH_POLL *ssh_poll_new(socket_t fd, short events, ssh_poll_callback cb, + void *userdata) { + SSH_POLL *p; + + p = malloc(sizeof(SSH_POLL)); + if (p != NULL) { + p->ctx = NULL; + p->fd = fd; + p->events = events; + p->cb = cb; + p->cb_data = userdata; + } + + return p; +} + +void ssh_poll_free(SSH_POLL *p) { + SAFE_FREE(p); +} + +SSH_POLL_CTX *ssh_poll_get_ctx(SSH_POLL *p) { + return p->ctx; +} + +short ssh_poll_get_events(SSH_POLL *p) { + return p->events; +} + +void ssh_poll_set_events(SSH_POLL *p, short events) { + p->events = events; + if (p->ctx != NULL) { + p->ctx->pollfds[p->idx].events = events; + } +} + +void ssh_poll_add_events(SSH_POLL *p, short events) { + ssh_poll_set_events(p, ssh_poll_get_events(p) | events); +} + +void ssh_poll_remove_events(SSH_POLL *p, short events) { + ssh_poll_set_events(p, ssh_poll_get_events(p) & ~events); +} + +int ssh_poll_get_fd(SSH_POLL *p) { + if (p->ctx != NULL) { + return p->ctx->pollfds[p->idx].fd; + } + + return p->fd; +} + +void ssh_poll_set_callback(SSH_POLL *p, ssh_poll_callback cb, void *userdata) { + if (cb != NULL) { + p->cb = cb; + p->cb_data = userdata; + } +} + +SSH_POLL_CTX *ssh_poll_ctx_new(size_t chunk_size) { + SSH_POLL_CTX *ctx; + + ctx = malloc(sizeof(SSH_POLL_CTX)); + if (ctx != NULL) { + if (!chunk_size) { + chunk_size = SSH_POLL_CTX_CHUNK; + } + + ctx->chunk_size = chunk_size; + ctx->pollptrs = NULL; + ctx->pollfds = NULL; + ctx->polls_allocated = 0; + ctx->polls_used = 0; + } + + return ctx; +} + +void ssh_poll_ctx_free(SSH_POLL_CTX *ctx) { + if (ctx->polls_allocated > 0) { + register size_t i, used; + + used = ctx->polls_used; + for (i = 0; i < used; ) { + SSH_POLL *p = ctx->pollptrs[i]; + int fd = ctx->pollfds[i].fd; + + /* force poll object removal */ + if (p->cb(p, fd, POLLERR, p->cb_data) < 0) { + used = ctx->polls_used; + } else { + i++; + } + } + + SAFE_FREE(ctx->pollptrs); + SAFE_FREE(ctx->pollfds); + } + + SAFE_FREE(ctx); +} + +static int ssh_poll_ctx_resize(SSH_POLL_CTX *ctx, size_t new_size) { + SSH_POLL **pollptrs; + pollfd_t *pollfds; + + pollptrs = realloc(ctx->pollptrs, sizeof(SSH_POLL *) * new_size); + if (pollptrs == NULL) { + return -1; + } + + pollfds = realloc(ctx->pollfds, sizeof(pollfd_t) * new_size); + if (pollfds == NULL) { + ctx->pollptrs = realloc(pollptrs, sizeof(SSH_POLL *) * ctx->polls_allocated); + return -1; + } + + ctx->pollptrs = pollptrs; + ctx->pollfds = pollfds; + ctx->polls_allocated = new_size; + + return 0; +} + +int ssh_poll_ctx_add(SSH_POLL_CTX *ctx, SSH_POLL *p) { + int fd; + + if (p->ctx != NULL) { + /* already attached to a context */ + return -1; + } + + if (ctx->polls_used == ctx->polls_allocated && + ssh_poll_ctx_resize(ctx, ctx->polls_allocated + ctx->chunk_size) < 0) { + return -1; + } + + fd = p->fd; + p->idx = ctx->polls_used++; + ctx->pollptrs[p->idx] = p; + ctx->pollfds[p->idx].fd = fd; + ctx->pollfds[p->idx].events = p->events; + ctx->pollfds[p->idx].revents = 0; + p->ctx = ctx; + + return 0; +} + +void ssh_poll_ctx_remove(SSH_POLL_CTX *ctx, SSH_POLL *p) { + size_t i; + + i = p->idx; + p->fd = ctx->pollfds[i].fd; + p->ctx = NULL; + + ctx->polls_used--; + + /* fill the empty poll slot with the last one */ + if (ctx->polls_used > 0 && ctx->polls_used != i) { + ctx->pollfds[i] = ctx->pollfds[ctx->polls_used]; + ctx->pollptrs[i] = ctx->pollptrs[ctx->polls_used]; + } + + /* this will always leave at least chunk_size polls allocated */ + if (ctx->polls_allocated - ctx->polls_used > ctx->chunk_size) { + ssh_poll_ctx_resize(ctx, ctx->polls_allocated - ctx->chunk_size); + } +} + +int ssh_poll_ctx(SSH_POLL_CTX *ctx, int timeout) { + int rc; + + if (!ctx->polls_used) + return 0; + + rc = ssh_poll(ctx->pollfds, ctx->polls_used, timeout); + if (rc > 0) { + register size_t i, used; + + used = ctx->polls_used; + for (i = 0; i < used && rc > 0; ) { + if (!ctx->pollfds[i].revents) { + i++; + } else { + SSH_POLL *p = ctx->pollptrs[i]; + int fd = ctx->pollfds[i].fd; + int revents = ctx->pollfds[i].revents; + + if (p->cb(p, fd, revents, p->cb_data) < 0) { + /* the poll was removed, reload the used counter and stall the loop */ + used = ctx->polls_used; + } else { + ctx->pollfds[i].revents = 0; + i++; + } + + rc--; + } + } + } + + return rc; +} |