diff options
-rw-r--r-- | src/socket_wrapper.c | 314 |
1 files changed, 122 insertions, 192 deletions
diff --git a/src/socket_wrapper.c b/src/socket_wrapper.c index 6c9ec51..a22f7dc 100644 --- a/src/socket_wrapper.c +++ b/src/socket_wrapper.c @@ -142,6 +142,10 @@ enum swrap_dbglvl_e { } while(0) #endif +#ifndef SAFE_FREE +#define SAFE_FREE(x) do { if ((x) != NULL) {free(x); (x)=NULL;} } while(0) +#endif + #ifndef discard_const #define discard_const(ptr) ((void *)((uintptr_t)(ptr))) #endif @@ -198,68 +202,6 @@ enum swrap_dbglvl_e { pthread_mutex_unlock(&sic->meta.mutex); \ } while(0) -#define DLIST_ADD(list, item) do { \ - if (!(list)) { \ - (item)->prev = NULL; \ - (item)->next = NULL; \ - (list) = (item); \ - } else { \ - (item)->prev = NULL; \ - (item)->next = (list); \ - (list)->prev = (item); \ - (list) = (item); \ - } \ -} while (0) - -#define SWRAP_DLIST_ADD(list, item) do { \ - SWRAP_LOCK(list); \ - DLIST_ADD(list, item); \ - SWRAP_UNLOCK(list); \ -} while (0) - -#define DLIST_REMOVE(list, item) do { \ - if ((list) == (item)) { \ - (list) = (item)->next; \ - if (list) { \ - (list)->prev = NULL; \ - } \ - } else { \ - if ((item)->prev) { \ - (item)->prev->next = (item)->next; \ - } \ - if ((item)->next) { \ - (item)->next->prev = (item)->prev; \ - } \ - } \ - (item)->prev = NULL; \ - (item)->next = NULL; \ -} while (0) - -#define SWRAP_DLIST_REMOVE(list,item) do { \ - SWRAP_LOCK(list); \ - DLIST_REMOVE(list, item); \ - SWRAP_UNLOCK(list); \ -} while (0) - -#define DLIST_ADD_AFTER(list, item, el) do { \ - if ((list) == NULL || (el) == NULL) { \ - DLIST_ADD(list, item); \ - } else { \ - (item)->prev = (el); \ - (item)->next = (el)->next; \ - (el)->next = (item); \ - if ((item)->next != NULL) { \ - (item)->next->prev = (item); \ - } \ - } \ -} while (0) - -#define SWRAP_DLIST_ADD_AFTER(list, item, el) do { \ - SWRAP_LOCK(list); \ - DLIST_ADD_AFTER(list, item, el); \ - SWRAP_UNLOCK(list); \ -} while (0) - #if defined(HAVE_GETTIMEOFDAY_TZ) || defined(HAVE_GETTIMEOFDAY_TZ_VOID) #define swrapGetTimeOfDay(tval) gettimeofday(tval,NULL) #else @@ -288,7 +230,6 @@ enum swrap_dbglvl_e { #define SOCKET_MAX_SOCKETS 1024 - /* * Maximum number of socket_info structures that can * be used. Can be overriden by the environment variable @@ -316,17 +257,6 @@ struct swrap_address { } sa; }; -struct socket_info_fd { - struct socket_info_fd *prev, *next; - int fd; - - /* - * Points to corresponding index in array of - * socket_info structures - */ - int si_index; -}; - int first_free; struct socket_info @@ -376,7 +306,7 @@ static size_t max_sockets = 0; * numerical value gets changed. So its better to store it locally to each * process rather than including it within socket_info which will be shared. */ -static struct socket_info_fd *socket_fds; +static int *socket_fds_idx; /* The mutex for accessing the global libc.symbols */ static pthread_mutex_t libc_symbol_binding_mutex = PTHREAD_MUTEX_INITIALIZER; @@ -390,12 +320,6 @@ static pthread_mutex_t autobind_start_mutex = PTHREAD_MUTEX_INITIALIZER; static pthread_mutex_t sockets_mutex = PTHREAD_MUTEX_INITIALIZER; /* - * Global mutex to protect modification of the socket_fds linked - * list structure by different threads within a process. - */ -static pthread_mutex_t socket_fds_mutex = PTHREAD_MUTEX_INITIALIZER; - -/* * Global mutex to synchronize the query for first free index in array of * socket_info structures by different threads within a process. */ @@ -1371,6 +1295,30 @@ done: return max_sockets; } +static void socket_wrapper_init_fds_idx(void) +{ + int *tmp = NULL; + size_t i; + + if (socket_fds_idx != NULL) { + return; + } + + tmp = (int *)calloc(SOCKET_WRAPPER_MAX_SOCKETS_DEFAULT, sizeof(int)); + if (tmp == NULL) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "Failed to allocate socket fds index array: %s", + strerror(errno)); + exit(-1); + } + + for (i = 0; i < SOCKET_WRAPPER_MAX_SOCKETS_DEFAULT; i++) { + tmp[i] = -1; + } + + socket_fds_idx = tmp; +} + static void socket_wrapper_init_sockets(void) { size_t i; @@ -1382,6 +1330,8 @@ static void socket_wrapper_init_sockets(void) return; } + socket_wrapper_init_fds_idx(); + max_sockets = socket_wrapper_max_sockets(); sockets = (struct socket_info_container *)calloc(max_sockets, @@ -1438,6 +1388,42 @@ static unsigned int socket_wrapper_default_iface(void) return 1;/* 127.0.0.1 */ } +static void set_socket_info_index(int fd, int idx) +{ + socket_fds_idx[fd] = idx; + /* This builtin issues a full memory barrier. */ + __sync_synchronize(); +} + +static void reset_socket_info_index(int fd) +{ + set_socket_info_index(fd, -1); +} + +static int find_socket_info_index(int fd) +{ + if (fd < 0) { + return -1; + } + + if (socket_fds_idx == NULL) { + return -1; + } + + if (fd >= SOCKET_WRAPPER_MAX_SOCKETS_DEFAULT) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "The max socket index limit of %u has been reached, " + "trying to add %d", + SOCKET_WRAPPER_MAX_SOCKETS_DEFAULT, + fd); + return -1; + } + + /* This builtin issues a full memory barrier. */ + __sync_synchronize(); + return socket_fds_idx[fd]; +} + static int swrap_add_socket_info(struct socket_info *si_input) { struct socket_info *si = NULL; @@ -1473,25 +1459,23 @@ out: static int swrap_create_socket(struct socket_info *si, int fd) { - struct socket_info_fd *fi = NULL; int idx; - fi = (struct socket_info_fd *)calloc(1, sizeof(struct socket_info_fd)); - if (fi == NULL) { - errno = ENOMEM; + if (fd >= SOCKET_WRAPPER_MAX_SOCKETS_DEFAULT) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "The max socket index limit of %u has been reached, " + "trying to add %d", + SOCKET_WRAPPER_MAX_SOCKETS_DEFAULT, + fd); return -1; } idx = swrap_add_socket_info(si); if (idx == -1) { - free(fi); return -1; } - fi->fd = fd; - fi->si_index = idx; - - SWRAP_DLIST_ADD(socket_fds, fi); + set_socket_info_index(fd, idx); return idx; } @@ -1863,34 +1847,6 @@ static int convert_in_un_alloc(struct socket_info *si, const struct sockaddr *in return 0; } -static struct socket_info_fd *find_socket_info_fd(int fd) -{ - struct socket_info_fd *f; - - SWRAP_LOCK(socket_fds); - - for (f = socket_fds; f; f = f->next) { - if (f->fd == fd) { - break; - } - } - - SWRAP_UNLOCK(socket_fds); - - return f; -} - -static int find_socket_info_index(int fd) -{ - struct socket_info_fd *fi = find_socket_info_fd(fd); - - if (fi == NULL) { - return -1; - } - - return fi->si_index; -} - static struct socket_info *find_socket_info(int fd) { int idx = find_socket_info_index(fd); @@ -1996,29 +1952,25 @@ static bool check_addr_port_in_use(const struct sockaddr *sa, socklen_t len) static void swrap_remove_stale(int fd) { - struct socket_info_fd *fi = find_socket_info_fd(fd); struct socket_info *si; int si_index; - if (fi == NULL) { - return; - } - SWRAP_LOG(SWRAP_LOG_TRACE, "remove stale wrapper for %d", fd); - si_index = fi->si_index; + si_index = find_socket_info_index(fd); + if (si_index == -1) { + return; + } si = swrap_get_socket_info(si_index); + reset_socket_info_index(fd); + SWRAP_LOCK(first_free); SWRAP_LOCK_SI(si); - SWRAP_DLIST_REMOVE(socket_fds, fi); - swrap_dec_refcount(si); - free(fi); - if (swrap_get_refcount(si) > 0) { goto out; } @@ -5869,29 +5821,26 @@ ssize_t writev(int s, const struct iovec *vector, int count) static int swrap_close(int fd) { - struct socket_info_fd *fi = find_socket_info_fd(fd); struct socket_info *si = NULL; int si_index; int ret; - if (fi == NULL) { + si_index = find_socket_info_index(fd); + if (si_index == -1) { return libc_close(fd); } - si_index = fi->si_index; + reset_socket_info_index(fd); + si = swrap_get_socket_info(si_index); SWRAP_LOCK(first_free); SWRAP_LOCK_SI(si); - SWRAP_DLIST_REMOVE(socket_fds, fi); - ret = libc_close(fd); swrap_dec_refcount(si); - free(fi); - if (swrap_get_refcount(si) > 0) { /* there are still references left */ goto out; @@ -5932,25 +5881,18 @@ int close(int fd) static int swrap_dup(int fd) { struct socket_info *si; - struct socket_info_fd *src_fi, *fi; + int dup_fd, idx; - src_fi = find_socket_info_fd(fd); - if (src_fi == NULL) { + idx = find_socket_info_index(fd); + if (idx == -1) { return libc_dup(fd); } - si = swrap_get_socket_info(src_fi->si_index); - - fi = (struct socket_info_fd *)calloc(1, sizeof(struct socket_info_fd)); - if (fi == NULL) { - errno = ENOMEM; - return -1; - } + si = swrap_get_socket_info(idx); - fi->fd = libc_dup(fd); - if (fi->fd == -1) { + dup_fd = libc_dup(fd); + if (dup_fd == -1) { int saved_errno = errno; - free(fi); errno = saved_errno; return -1; } @@ -5958,15 +5900,15 @@ static int swrap_dup(int fd) SWRAP_LOCK_SI(si); swrap_inc_refcount(si); - fi->si_index = src_fi->si_index; SWRAP_UNLOCK_SI(si); /* Make sure we don't have an entry for the fd */ - swrap_remove_stale(fi->fd); + swrap_remove_stale(dup_fd); + + set_socket_info_index(dup_fd, idx); - SWRAP_DLIST_ADD_AFTER(socket_fds, fi, src_fi); - return fi->fd; + return dup_fd; } int dup(int fd) @@ -5981,14 +5923,14 @@ int dup(int fd) static int swrap_dup2(int fd, int newfd) { struct socket_info *si; - struct socket_info_fd *src_fi, *fi; + int dup_fd, idx; - src_fi = find_socket_info_fd(fd); - if (src_fi == NULL) { + idx = find_socket_info_index(fd); + if (idx == -1) { return libc_dup2(fd, newfd); } - si = swrap_get_socket_info(src_fi->si_index); + si = swrap_get_socket_info(idx); if (fd == newfd) { /* @@ -6006,16 +5948,9 @@ static int swrap_dup2(int fd, int newfd) swrap_close(newfd); } - fi = (struct socket_info_fd *)calloc(1, sizeof(struct socket_info_fd)); - if (fi == NULL) { - errno = ENOMEM; - return -1; - } - - fi->fd = libc_dup2(fd, newfd); - if (fi->fd == -1) { + dup_fd = libc_dup2(fd, newfd); + if (dup_fd == -1) { int saved_errno = errno; - free(fi); errno = saved_errno; return -1; } @@ -6023,15 +5958,15 @@ static int swrap_dup2(int fd, int newfd) SWRAP_LOCK_SI(si); swrap_inc_refcount(si); - fi->si_index = src_fi->si_index; SWRAP_UNLOCK_SI(si); /* Make sure we don't have an entry for the fd */ - swrap_remove_stale(fi->fd); + swrap_remove_stale(dup_fd); - SWRAP_DLIST_ADD_AFTER(socket_fds, fi, src_fi); - return fi->fd; + set_socket_info_index(dup_fd, idx); + + return dup_fd; } int dup2(int fd, int newfd) @@ -6045,29 +5980,21 @@ int dup2(int fd, int newfd) static int swrap_vfcntl(int fd, int cmd, va_list va) { - struct socket_info_fd *src_fi, *fi; struct socket_info *si; - int rc; + int rc, dup_fd, idx; - src_fi = find_socket_info_fd(fd); - if (src_fi == NULL) { + idx = find_socket_info_index(fd); + if (idx == -1) { return libc_vfcntl(fd, cmd, va); } - si = swrap_get_socket_info(src_fi->si_index); + si = swrap_get_socket_info(idx); switch (cmd) { case F_DUPFD: - fi = (struct socket_info_fd *)calloc(1, sizeof(struct socket_info_fd)); - if (fi == NULL) { - errno = ENOMEM; - return -1; - } - - fi->fd = libc_vfcntl(fd, cmd, va); - if (fi->fd == -1) { + dup_fd = libc_vfcntl(fd, cmd, va); + if (dup_fd == -1) { int saved_errno = errno; - free(fi); errno = saved_errno; return -1; } @@ -6075,16 +6002,15 @@ static int swrap_vfcntl(int fd, int cmd, va_list va) SWRAP_LOCK_SI(si); swrap_inc_refcount(si); - fi->si_index = src_fi->si_index; SWRAP_UNLOCK_SI(si); /* Make sure we don't have an entry for the fd */ - swrap_remove_stale(fi->fd); + swrap_remove_stale(dup_fd); - SWRAP_DLIST_ADD_AFTER(socket_fds, fi, src_fi); + set_socket_info_index(dup_fd, idx); - rc = fi->fd; + rc = dup_fd; break; default: rc = libc_vfcntl(fd, cmd, va); @@ -6194,14 +6120,18 @@ void swrap_constructor(void) */ void swrap_destructor(void) { - struct socket_info_fd *s = socket_fds; + size_t i; - while (s != NULL) { - swrap_close(s->fd); - s = socket_fds; + if (socket_fds_idx != NULL) { + for (i = 0; i < SOCKET_WRAPPER_MAX_SOCKETS_DEFAULT; ++i) { + if (socket_fds_idx[i] != -1) { + swrap_close(i); + } + } + SAFE_FREE(socket_fds_idx); } - free(sockets); + SAFE_FREE(sockets); if (swrap.libc.handle != NULL) { dlclose(swrap.libc.handle); |