Tweaked a bit to apply to 1.12: * krb5_int64 hadn't been replaced by int64_t yet. commit 346883c48f1b9e09b1af2cf73e3b96ee8f934072 Author: Greg Hudson Date: Wed Mar 26 13:21:45 2014 -0400 Refactor cm functions in sendto_kdc.c Move get_curtime_ms and the cm functions near the top of the file right after structure definitions. Except for cm_select_or_poll, define each cm function separately for poll and for select, since the implementations don't share much in common. Instead of cm_unset_write, define cm_read and cm_write functions to put an fd in read-only or write-only state. Remove the ssflags argument from cm_add_fd and just expect the caller to make a subsequent call to cm_read or cm_write. Always select for exceptions when using select. (Polling for exceptions is implicit with poll). With these changes, we no longer select/poll for reading on a TCP connection until we are done writing to it. So in service_tcp_fd, remove the check for unexpected read events. diff --git a/src/lib/krb5/os/sendto_kdc.c b/src/lib/krb5/os/sendto_kdc.c index e60a375..e773a0a 100644 --- a/src/lib/krb5/os/sendto_kdc.c +++ b/src/lib/krb5/os/sendto_kdc.c @@ -59,8 +59,7 @@ typedef krb5_int64 time_ms; -/* Since fd_set is large on some platforms (8K on AIX 5.2), this probably - * shouldn't be allocated in automatic storage. */ +/* This can be pretty large, so should not be stack-allocated. */ struct select_state { #ifdef USE_POLL struct pollfd fds[MAX_POLLFDS]; @@ -107,6 +106,183 @@ struct conn_state { time_ms endtime; }; +/* Get current time in milliseconds. */ +static krb5_error_code +get_curtime_ms(time_ms *time_out) +{ + struct timeval tv; + + if (gettimeofday(&tv, 0)) + return errno; + *time_out = (time_ms)tv.tv_sec * 1000 + tv.tv_usec / 1000; + return 0; +} + +#ifdef USE_POLL + +/* Find a pollfd in selstate by fd, or abort if we can't find it. */ +static inline struct pollfd * +find_pollfd(struct select_state *selstate, int fd) +{ + int i; + + for (i = 0; i < selstate->nfds; i++) { + if (selstate->fds[i].fd == fd) + return &selstate->fds[i]; + } + abort(); +} + +static void +cm_init_selstate(struct select_state *selstate) +{ + selstate->nfds = 0; +} + +static krb5_boolean +cm_add_fd(struct select_state *selstate, int fd) +{ + if (selstate->nfds >= MAX_POLLFDS) + return FALSE; + selstate->fds[selstate->nfds].fd = fd; + selstate->fds[selstate->nfds].events = 0; + selstate->nfds++; + return TRUE; +} + +static void +cm_remove_fd(struct select_state *selstate, int fd) +{ + struct pollfd *pfd = find_pollfd(selstate, fd); + + *pfd = selstate->fds[selstate->nfds - 1]; + selstate->nfds--; +} + +/* Poll for reading (and not writing) on fd the next time we poll. */ +static void +cm_read(struct select_state *selstate, int fd) +{ + find_pollfd(selstate, fd)->events = POLLIN; +} + +/* Poll for writing (and not reading) on fd the next time we poll. */ +static void +cm_write(struct select_state *selstate, int fd) +{ + find_pollfd(selstate, fd)->events = POLLOUT; +} + +/* Get the output events for fd in the form of ssflags. */ +static unsigned int +cm_get_ssflags(struct select_state *selstate, int fd) +{ + struct pollfd *pfd = find_pollfd(selstate, fd); + + return ((pfd->revents & POLLIN) ? SSF_READ : 0) | + ((pfd->revents & POLLOUT) ? SSF_WRITE : 0) | + ((pfd->revents & POLLERR) ? SSF_EXCEPTION : 0); +} + +#else /* not USE_POLL */ + +static void +cm_init_selstate(struct select_state *selstate) +{ + selstate->nfds = 0; + selstate->max = 0; + FD_ZERO(&selstate->rfds); + FD_ZERO(&selstate->wfds); + FD_ZERO(&selstate->xfds); +} + +static krb5_boolean +cm_add_fd(struct select_state *selstate, int fd) +{ +#ifndef _WIN32 /* On Windows FD_SETSIZE is a count, not a max value. */ + if (fd >= FD_SETSIZE) + return FALSE; +#endif + FD_SET(fd, &selstate->xfds); + if (selstate->max <= fd) + selstate->max = fd + 1; + selstate->nfds++; + return TRUE; +} + +static void +cm_remove_fd(struct select_state *selstate, int fd) +{ + FD_CLR(fd, &selstate->rfds); + FD_CLR(fd, &selstate->wfds); + FD_CLR(fd, &selstate->xfds); + if (selstate->max == fd + 1) { + while (selstate->max > 0 && + !FD_ISSET(selstate->max - 1, &selstate->rfds) && + !FD_ISSET(selstate->max - 1, &selstate->wfds) && + !FD_ISSET(selstate->max - 1, &selstate->xfds)) + selstate->max--; + } + selstate->nfds--; +} + +/* Select for reading (and not writing) on fd the next time we select. */ +static void +cm_read(struct select_state *selstate, int fd) +{ + FD_SET(fd, &selstate->rfds); + FD_CLR(fd, &selstate->wfds); +} + +/* Select for writing (and not reading) on fd the next time we select. */ +static void +cm_write(struct select_state *selstate, int fd) +{ + FD_CLR(fd, &selstate->rfds); + FD_SET(fd, &selstate->wfds); +} + +/* Get the events for fd from selstate after a select. */ +static unsigned int +cm_get_ssflags(struct select_state *selstate, int fd) +{ + return (FD_ISSET(fd, &selstate->rfds) ? SSF_READ : 0) | + (FD_ISSET(fd, &selstate->wfds) ? SSF_WRITE : 0) | + (FD_ISSET(fd, &selstate->xfds) ? SSF_EXCEPTION : 0); +} + +#endif /* not USE_POLL */ + +static krb5_error_code +cm_select_or_poll(const struct select_state *in, time_ms endtime, + struct select_state *out, int *sret) +{ +#ifndef USE_POLL + struct timeval tv; +#endif + krb5_error_code retval; + time_ms curtime, interval; + + retval = get_curtime_ms(&curtime); + if (retval != 0) + return retval; + interval = (curtime < endtime) ? endtime - curtime : 0; + + /* We don't need a separate copy of the selstate for poll, but use one for + * consistency with how we use select. */ + *out = *in; + +#ifdef USE_POLL + *sret = poll(out->fds, out->nfds, interval); +#else + tv.tv_sec = interval / 1000; + tv.tv_usec = interval % 1000 * 1000; + *sret = select(out->max, &out->rfds, &out->wfds, &out->xfds, &tv); +#endif + + return (*sret < 0) ? SOCKET_ERRNO : 0; +} + static int in_addrlist(struct server_entry *entry, struct serverlist *list) { @@ -251,18 +427,6 @@ cleanup: return retval; } -/* Get current time in milliseconds. */ -static krb5_error_code -get_curtime_ms(time_ms *time_out) -{ - struct timeval tv; - - if (gettimeofday(&tv, 0)) - return errno; - *time_out = (time_ms)tv.tv_sec * 1000 + tv.tv_usec / 1000; - return 0; -} - /* * Notes: * @@ -283,144 +447,6 @@ get_curtime_ms(time_ms *time_out) * connections already in progress */ -static void -cm_init_selstate(struct select_state *selstate) -{ - selstate->nfds = 0; -#ifndef USE_POLL - selstate->max = 0; - FD_ZERO(&selstate->rfds); - FD_ZERO(&selstate->wfds); - FD_ZERO(&selstate->xfds); -#endif -} - -static krb5_boolean -cm_add_fd(struct select_state *selstate, int fd, unsigned int ssflags) -{ -#ifdef USE_POLL - if (selstate->nfds >= MAX_POLLFDS) - return FALSE; - selstate->fds[selstate->nfds].fd = fd; - selstate->fds[selstate->nfds].events = 0; - if (ssflags & SSF_READ) - selstate->fds[selstate->nfds].events |= POLLIN; - if (ssflags & SSF_WRITE) - selstate->fds[selstate->nfds].events |= POLLOUT; -#else -#ifndef _WIN32 /* On Windows FD_SETSIZE is a count, not a max value. */ - if (fd >= FD_SETSIZE) - return FALSE; -#endif - if (ssflags & SSF_READ) - FD_SET(fd, &selstate->rfds); - if (ssflags & SSF_WRITE) - FD_SET(fd, &selstate->wfds); - if (ssflags & SSF_EXCEPTION) - FD_SET(fd, &selstate->xfds); - if (selstate->max <= fd) - selstate->max = fd + 1; -#endif - selstate->nfds++; - return TRUE; -} - -static void -cm_remove_fd(struct select_state *selstate, int fd) -{ -#ifdef USE_POLL - int i; - - /* Find the FD in the array and move the last entry to its place. */ - assert(selstate->nfds > 0); - for (i = 0; i < selstate->nfds && selstate->fds[i].fd != fd; i++); - assert(i < selstate->nfds); - selstate->fds[i] = selstate->fds[selstate->nfds - 1]; -#else - FD_CLR(fd, &selstate->rfds); - FD_CLR(fd, &selstate->wfds); - FD_CLR(fd, &selstate->xfds); - if (selstate->max == 1 + fd) { - while (selstate->max > 0 - && ! FD_ISSET(selstate->max-1, &selstate->rfds) - && ! FD_ISSET(selstate->max-1, &selstate->wfds) - && ! FD_ISSET(selstate->max-1, &selstate->xfds)) - selstate->max--; - } -#endif - selstate->nfds--; -} - -static void -cm_unset_write(struct select_state *selstate, int fd) -{ -#ifdef USE_POLL - int i; - - for (i = 0; i < selstate->nfds && selstate->fds[i].fd != fd; i++); - assert(i < selstate->nfds); - selstate->fds[i].events &= ~POLLOUT; -#else - FD_CLR(fd, &selstate->wfds); -#endif -} - -static krb5_error_code -cm_select_or_poll(const struct select_state *in, time_ms endtime, - struct select_state *out, int *sret) -{ -#ifndef USE_POLL - struct timeval tv; -#endif - krb5_error_code retval; - time_ms curtime, interval; - - retval = get_curtime_ms(&curtime); - if (retval != 0) - return retval; - interval = (curtime < endtime) ? endtime - curtime : 0; - - /* We don't need a separate copy of the selstate for poll, but use one for - * consistency with how we use select. */ - *out = *in; - -#ifdef USE_POLL - *sret = poll(out->fds, out->nfds, interval); -#else - tv.tv_sec = interval / 1000; - tv.tv_usec = interval % 1000 * 1000; - *sret = select(out->max, &out->rfds, &out->wfds, &out->xfds, &tv); -#endif - - return (*sret < 0) ? SOCKET_ERRNO : 0; -} - -static unsigned int -cm_get_ssflags(struct select_state *selstate, int fd) -{ - unsigned int ssflags = 0; -#ifdef USE_POLL - int i; - - for (i = 0; i < selstate->nfds && selstate->fds[i].fd != fd; i++); - assert(i < selstate->nfds); - if (selstate->fds[i].revents & POLLIN) - ssflags |= SSF_READ; - if (selstate->fds[i].revents & POLLOUT) - ssflags |= SSF_WRITE; - if (selstate->fds[i].revents & POLLERR) - ssflags |= SSF_EXCEPTION; -#else - if (FD_ISSET(fd, &selstate->rfds)) - ssflags |= SSF_READ; - if (FD_ISSET(fd, &selstate->wfds)) - ssflags |= SSF_WRITE; - if (FD_ISSET(fd, &selstate->xfds)) - ssflags |= SSF_EXCEPTION; -#endif - return ssflags; -} - static int service_tcp_fd(krb5_context context, struct conn_state *conn, struct select_state *selstate, int ssflags); static int service_udp_fd(krb5_context context, struct conn_state *conn, @@ -600,7 +626,6 @@ start_connection(krb5_context context, struct conn_state *state, struct sendto_callback_info *callback_info) { int fd, e; - unsigned int ssflags; static const int one = 1; static const struct linger lopt = { 0, 0 }; @@ -676,15 +701,17 @@ start_connection(krb5_context context, struct conn_state *state, state->state = READING; } } - ssflags = SSF_READ | SSF_EXCEPTION; - if (state->state == CONNECTING || state->state == WRITING) - ssflags |= SSF_WRITE; - if (!cm_add_fd(selstate, state->fd, ssflags)) { + + if (!cm_add_fd(selstate, state->fd)) { (void) closesocket(state->fd); state->fd = INVALID_SOCKET; state->state = FAILED; return -1; } + if (state->state == CONNECTING || state->state == WRITING) + cm_write(selstate, state->fd); + else + cm_read(selstate, state->fd); return 0; } @@ -768,9 +795,8 @@ service_tcp_fd(krb5_context context, struct conn_state *conn, ssize_t nwritten, nread; SOCKET_WRITEV_TEMP tmp; - /* Check for a socket exception or readable data before we expect it. */ - if (ssflags & SSF_EXCEPTION || - ((ssflags & SSF_READ) && conn->state != READING)) + /* Check for a socket exception. */ + if (ssflags & SSF_EXCEPTION) goto kill_conn; switch (conn->state) { @@ -810,7 +836,7 @@ service_tcp_fd(krb5_context context, struct conn_state *conn, } if (conn->x.out.sg_count == 0) { /* Done writing, switch to reading. */ - cm_unset_write(selstate, conn->fd); + cm_read(selstate, conn->fd); conn->state = READING; conn->x.in.bufsizebytes_read = 0; conn->x.in.bufsize = 0;