Search code examples
cmultithreadingpthreadsthread-local-storagefutex

Is there a way to call library thread-local init/cleanup on thread creation/destruction?


This question is similar to How to call a function on a thread's creation and exit? but more specific. In another multi-process shared memory project I used a combination of an __attribute__((constructor)) labeled library init routine, lazy initialisation for each thread, and robust futexes to make sure resources weren't leaked in the shared memory even if a sys admin chose to SIGKILL one of the processes using it. However futexes within the APIs are way too heavyweight for my current project and even the few instructions to deke around some lazy initialisation is something I'd rather avoid. The library APIs will literally be called several trillion times over a few hundred threads across several processes (each API is only a couple hundred instructions.)

I am guessing the answer is no, but since I spent a couple hours looking for and not finding a definitive answer I thought I'd ask it here, then the next person looking for a simple answer will be able to find it more quickly.

My goal is pretty simple: perform some per-thread initialisation as threads are created in multiple processes asynchronously, and robustly perform some cleanup at some point when threads are destroyed asynchronously. Doesn't have to be immediately, it just has to happen eventually.

Some hypothetical ideas to engage critical thinking: a hypothetical pthread_atclone() called from an __attribute__((constructor)) labeled library init func would satisfy the first condition. And an extension to futex()es to add a semop-like operation with a per-thread futex_adj value that, if non-zero in do_exit(), causes FUTEX_OWNER_DIED to be set for the futex "semaphore" allowing cleanup the next time the futex is touched.


Solution

  • After research and experimentation I've come up with what seems to be current "best practice" as far as I can tell. If anyone knows any better, please comment!

    For the first part, per-thread initialisation, I was not able to come up with any alternative to straightforward lazy initialisation. However, I did decide that it's slightly more efficient to move the branch to the caller so that pipelining in the new stack frame isn't immediately confronted with an effectively unnecessary branch. so instead of this:

    __thread int tInf = 0;
    
    void
    threadDoSomething(void *data)
    {
    
        if (!tInf) {
    
            _threadInitInfo(&tInf);
        }
    
        /*l
         * do Something.
         */
    }
    

    This:

    __thread int tInf = 0;
    
    #define threadDoSomething(data) (((!tInf)?_threadInitInfo(&tInf):0), \
                                     _threadDoSomething((data)))
    void
    _threadDoSomething(void *data)
    {
    
        /*l
         * do Something.
         */
    }
    

    Comments on the (admittedly slight) usefulness of this welcome!

    For the second part, robustly performing some cleanup when threads die no matter how asynchronously, I was not able to find any solution better than to have a reaping process epoll_wait() on a file descriptor for the read end of an open pipe passed to it via an SCM_RIGHTS control message in a sendmsg() call on an abstract UNIX domain socket address. Sounds complex, but it's not that bad, here's the client side:

    /*m
     * Client that registers a thread with a server who will do cleanup of a
     * shared interprocess object even if the thread dies asynchronously.
     */
    #include <sys/socket.h>     // socket(), bind(), recvmsg()
    #include <sys/syscall.h>    // syscall()
    #include <sys/un.h>         // sockaddr_un
    #include <stdint.h>         // uint64_t
    #include <fcntl.h>          // O_CLOEXEC()
    #include <malloc.h>         // malloc()
    #include <stdlib.h>         // random()
    #include <unistd.h>         // close(), usleep()
    #include <pthread.h>        // pthread_create()
    #include <tsteplsrv.h>      // Our API.
    
    char iovBuf[] = "SP1";      // 3 char buf to send client type
    
    __thread pid_t cliTid = 0; // per-thread copy of self's Thread ID
    
    
    /*f
     * initClient() is called when we realise we need to lazily initialise
     * our thread based on cliTid being zero.
     */
    void *
    initClient(void *ptr)
    {
        struct sockaddr_un  svAddr;
        struct msghdr       msg;
        struct iovec        io;
        struct cmsghdr     *ctrMsg;
    
        uint64_t ltid;    // local 8-byte copy of the tid
        int      pfds[2], // two fds of our pipe
                 sfd;     // socket fd
    
        /*s
         * This union is necessary to ensure that the buffer is aligned such that
         * we can read cmsg_{len,level,type} from the cmsghdr without causing an
         * alignment fault (SIGBUS.)
         */
        union {
    
            struct cmsghdr hdr;
            char           buf[CMSG_SPACE(sizeof(int))];
    
        } ctrBuf;
    
        pfds[0] = pfds[1] = sfd = -1;
    
        /*l
         * Get our Thread ID.
         */
        ltid = (uint64_t)(cliTid = syscall(SYS_gettid));
    
        /*l
         * Set up an abstract unix domain socket address.
         */
        svAddr.sun_family  = AF_UNIX;
        svAddr.sun_path[0] = '\0';
    
        strcpy(&svAddr.sun_path[1], EPLS_SRV_ADDR);
    
        /*l
         * Set up a socket datagram send buffer.
         */
        io.iov_base = iovBuf;
        io.iov_len  = sizeof(iovBuf);
    
        msg.msg_iov        = &io;
        msg.msg_iovlen     = 1;
        msg.msg_control    = ctrBuf.buf;
        msg.msg_controllen = sizeof(ctrBuf);
        msg.msg_name       = (struct sockaddr *)&svAddr, 
        msg.msg_namelen    =   (&svAddr.sun_path[0] - (char *)&svAddr)
                             + 1
                             + sizeof(EPLS_SRV_ADDR);
    
        /*l
         * Set up the control message header to indicate we are sharing a file
         * descriptor.
         */
        ctrMsg = CMSG_FIRSTHDR(&msg);
    
        ctrMsg->cmsg_len   = CMSG_LEN(sizeof(int));
        ctrMsg->cmsg_level = SOL_SOCKET;
        ctrMsg->cmsg_type  = SCM_RIGHTS;
    
        /*l
         * Create file descriptors with pipe().
         */
        if (-1 == pipe(pfds)) {
    
            printErrMsg("TID: %d pipe() failed", cliTid);
    
        } else {
    
            /*l
             * Write our tid to the pipe.
             */
            memmove(CMSG_DATA(ctrMsg), &pfds[0], sizeof(int));
    
            if (-1 == write(pfds[1], &ltid, sizeof(uint64_t))) {
    
                printErrMsg("TID: %d write() failed", cliTid);
    
            } if (-1 == (sfd = socket(AF_UNIX, SOCK_DGRAM, 0))) {
    
                printErrMsg("TID: %d socket() failed", cliTid);
    
            } else if (-1 == sendmsg(sfd, &msg, 0)) {
    
                printErrMsg("TID: %d sendmsg() failed", cliTid);
    
            } else {
    
                printVerbMsg("TID: %d sent write fd %d to server kept read fd %d",
                             cliTid,
                             pfds[0],
                             pfds[1]);
    
                /*l
                 * Close the read end of the pipe, the server has it now.
                 */
                close(pfds[0]);
    
                pfds[0] = -1;
            }
        }
    
        if (-1 != pfds[1]) close(pfds[1]);
        if (-1 != pfds[0]) close(pfds[0]);
        if (-1 != sfd) close(sfd);
    
        return (void *)0;
    }
    

    And the reaper's code:

    /*m
     * Abstract datagram socket listening for FD's from clients.
     */
    
    #include <sys/socket.h> // socket(), bind(), recvmsg()
    #include <sys/epoll.h>  // epoll_{create,wait}()
    #include <sys/un.h>     // sockaddr_un
    #include <malloc.h>     // malloc()
    #include <unistd.h>     // close()
    #include <tsteplsrv.h>  // Our API.
    
    /*s
     * socket datagram structs for receiving structured messages used to transfer
     * fds from our clients.
     */
    struct msghdr   msg  = { 0 };
    struct iovec    io   = { 0 };
    
    char iovBuf[EPLS_MSG_LEN];   // 3 char buf to receive client type
    
    /*s
     * This union is necessary to ensure that the buffer is aligned such that
     * we can read cmsg_{len,level,type} from the cmsghdr without causing an
     * alignment fault (SIGBUS.)
     */
    union {
    
        struct cmsghdr hdr;
        char           buf[CMSG_SPACE(sizeof(int))];
    
    } ctrBuf;
    
    typedef struct _tidFd_t {
    
        struct _tidFd_t *next;
    
        pid_t tid;
        int   fd;
    
    } tidFd_t;
    
    tidFd_t *tidFdLst = (tidFd_t *)0;
    
    /*f
     * Perform some handshaking with a new client and add the file descriptor
     * it shared with us to the epoll set.
     */
    static void
    welcomeClient(int efd, int cfd)
    {
        uint64_t     tid;
        tidFd_t *tfd;
    
        struct epoll_event epEv;
    
        tfd = (tidFd_t *)-1;
    
        /*l
         * The fd is a pipe and should be readable, and should contain the
         * tid of the client.
         */
        if (-1 != read(cfd, &tid, sizeof(tid)) && (tfd = malloc(sizeof(*tfd)))) {
    
            tfd->fd   = cfd;
            tfd->tid  = (pid_t)tid;
            tfd->next = tidFdLst;
    
            /*l
             * Single threaded process, no race condition here.
             */
            tidFdLst = tfd;
    
            /*l
             * Add the fd to the epoll() set so that we will be woken up with
             * an error if the thread dies.
             */
            epEv.events  = EPOLLIN;
            epEv.data.fd = cfd;
    
            if (-1 == epoll_ctl(efd, EPOLL_CTL_ADD, cfd, &epEv)) {
    
                printErrMsg("TID: %ld Could not register fd %d with epoll set",
                            tid,
                            cfd);
    
            } else {
    
                printVerbMsg("TID: %ld Registered fd %d with epoll set", tid, cfd);
            }
    
        /*l
         * Couldn't allocate memory for the new client.
         */
        } else if (!tfd) {
    
            printErrMsg("Could not allocate memory for new client");
    
        /*l
         * Could not read from the eventfd() file descriptor.
         */
        } else {
    
            printErrMsg("Could not read from client file descriptor");
        }
    }
    
    
    /*f
     * Perform some handshaking with a new client and add the file descriptor
     * it shared with us to the epoll set.
     */
    static void
    processClientEvent(int efd, struct epoll_event *epEv)
    {
        tidFd_t *tfd, **bLnk;
    
        /*l
         * Walk the list of per-tid fd structs.
         */
        for (bLnk = &tidFdLst; (tfd = *bLnk); bLnk = &tfd->next)
    
            if (tfd->fd == epEv->data.fd)
    
                break;
    
        if (!tfd) {
    
            printErrMsg("client file descriptor %d not found on the tfd list!",
                        epEv->data.fd);
    
    
        /*l
         * If we received an EPOLLHUP on the fd, cleanup.
         */
        } else if (epEv->events & EPOLLHUP) {
    
            /*l
             * Try to remove the tid's pipe fd from the epoll set.
             */
            if (-1 == epoll_ctl(efd, EPOLL_CTL_DEL, epEv->data.fd, epEv)) {
    
                printErrMsg("couldn't delete epoll for tid %d", tfd->tid);
    
            /*l
             * Do tid cleanup here.
             */
            } else {
    
                printVerbMsg("TID: %d closing fd: %d", tfd->tid, epEv->data.fd);
    
                close(epEv->data.fd);
    
                /*l
                 * Remove the per-tid struct from the list and free it.
                 */
                *bLnk = tfd->next;
                free(tfd);
            }
    
        } else {
    
            printVerbMsg("TID: %d Received unexpected epoll event %d",
                          tfd->tid,
                          epEv->events);
        }
    }
    
    
    /*f
     * Create and listen on a datagram socket for eventfd() file descriptors
     * from clients.
     */
    int
    main(int argc, char *argv[])
    {
        struct sockaddr_un  svAddr;
        struct cmsghdr     *ctrMsg;
        struct epoll_event *epEv,
                            epEvs[EPLS_MAX_EPEVS];
    
    
        int        sfd, efd, cfd, nfds;
    
        sfd = efd = -1;
    
        /*l
         * Set up an abstract unix domain socket address.
         */
        svAddr.sun_family  = AF_UNIX;
        svAddr.sun_path[0] = '\0';
    
        strcpy(&svAddr.sun_path[1], EPLS_SRV_ADDR);
    
        /*l
         * Set up a socket datagram receive buffer.
         */
        io.iov_base = iovBuf;               // 3-char buffer to ID client type
        io.iov_len  = sizeof(iovBuf);
    
        msg.msg_name       = (char *)0;     // No need for the client addr
        msg.msg_namelen    = 0;
        msg.msg_iov        = &io;           // single IO vector in the S/G array
        msg.msg_iovlen     = 1;
        msg.msg_control    = ctrBuf.buf;    // Control message buffer
        msg.msg_controllen = sizeof(ctrBuf);
    
        /*l
         * Set up an epoll event.
         */
        epEv         = &epEvs[0];
        epEv->events = EPOLLIN;
    
        /*l
         * Create a socket to receive datagrams on and register the socket
         * with our epoll event.
         */
        if (-1 == (epEv->data.fd = sfd = socket(AF_UNIX, SOCK_DGRAM, 0))) {
    
            printErrMsg("socket creation failed");
    
        /*l
         * Bind to the abstract address.  The pointer math is to portably
         * handle weird structure packing _just_in_case_.
         */
        } else if (-1 == bind(sfd,
                             (struct sockaddr *)&svAddr,
                               (&svAddr.sun_path[0] - (char *)&svAddr)
                             + 1
                             + sizeof(EPLS_SRV_ADDR))) {
    
            printErrMsg("could not bind address: %s", &svAddr.sun_path[1]);
    
        /*l
         * Create an epoll interface. Set CLOEXEC for tidiness in case a thread 
         * in the server fork()s and exec()s.
         */
        } else if (-1 == (efd = epoll_create1(EPOLL_CLOEXEC))) {
    
            printErrMsg("could not create epoll instance");
    
        /*l
         * Add our socket fd to the epoll instance.
         */
        } else if (-1 == epoll_ctl(efd, EPOLL_CTL_ADD, sfd, epEv)) {
    
            printErrMsg("could not add socket to epoll instance");
    
        /*l
         * Loop receiving events on our epoll instance.
         */
        } else {
    
            printVerbMsg("server listening on abstract address: %s",
                         &svAddr.sun_path[1]);
    
            /*l
             * Loop forever listening for events on the fds we are interested
             * in.
             */
            while (-1 != (nfds = epoll_wait(efd,  epEvs, EPLS_MAX_EPEVS, -1))) {
    
                /*l
                 * For each fd with an event, figure out what's up!
                 */
                do {
    
                    /*l
                     * Transform nfds from a count to an index.
                     */
                    --nfds;
    
                    /*l
                     * If the fd with an event is the listening socket a client
                     * is trying to send us their eventfd() file descriptor.
                     */
                    if (sfd == epEvs[nfds].data.fd) {
    
                        if (EPOLLIN != epEvs[nfds].events) {
    
                            printErrMsg("unexpected condition on socket: %d",
                                        epEvs[nfds].events);
    
                            nfds = -1;
                            break;
                        }
    
                        /*l
                         * Reset the sizes of the receive buffers to their
                         * actual value; on return they will be set to the
                         * read value.
                         */
                        io.iov_len         = sizeof(iovBuf);
                        msg.msg_controllen = sizeof(ctrBuf);
    
                        /*l
                         * Receive the waiting message.
                         */
                        if (-1 == recvmsg(sfd, &msg, MSG_CMSG_CLOEXEC)) {
    
                            printVerbMsg("failed datagram read on socket");
    
                        /*l
                         * Verify that the message's control buffer contains
                         * a file descriptor.
                         */
                        } else if (   NULL != (ctrMsg = CMSG_FIRSTHDR(&msg))
                                   && CMSG_LEN(sizeof(int)) == ctrMsg->cmsg_len
                                   && SOL_SOCKET == ctrMsg->cmsg_level
                                   && SCM_RIGHTS == ctrMsg->cmsg_type) {
    
                            /*l
                             * Unpack the file descriptor.
                             */
                            memmove(&cfd, CMSG_DATA(ctrMsg), sizeof(cfd));
    
                            printVerbMsg("Received fd %d from client type %c%c%c",
                                         cfd,
                                         ((char *)msg.msg_iov->iov_base)[0],
                                         ((char *)msg.msg_iov->iov_base)[1],
                                         ((char *)msg.msg_iov->iov_base)[2]);
    
                            /*l
                             * Process the incoming file descriptor and add
                             * it to the epoll() list.
                             */
                            welcomeClient(efd, cfd);
    
                        /*l
                         * Note but ignore incorrectly formed datagrams.
                         */
                        } else {
    
                            printVerbMsg("could not extract file descriptor "
                                         "from client's datagram");
                        }
    
                    /*l
                     * The epoll() event is on one of the file descriptors
                     * shared with a client, process it.
                     */
                    } else {
    
                        processClientEvent(efd, &epEvs[nfds]);
                    }
    
                } while (nfds);
    
                /*l
                 * If something happened to our socket break the epoll_wait()
                 * loop.
                 */
                if (nfds)
    
                    break;
            }
        }
    
        /*l
         * An error occurred, cleanup.
         */
        if (-1 != efd)
    
            close(efd);
    
        if (-1 != sfd)
    
            close(sfd);
    
        return -1;
    }
    

    At first I tried using eventfd() rather than pipe() but eventfd file descriptors represent objects not connections, so closing the fd in the client code did not produce an EPOLLHUP in the reaper. If anyone knows of a better alternative to pipe() for this, let me know!

    For completeness here's the #defines used to construct the abstract address:

    /*d
     * server abstract address.
     */
    #define EPLS_SRV_NAM    "_abssSrv"
    #define EPLS_SRV_VER    "0.0.1"
    #define EPLS_SRV_ADDR   EPLS_SRV_NAM "." EPLS_SRV_NAM
    #define EPLS_MSG_LEN    3
    #define EPLS_MAX_EPEVS  32
    

    That's it, hope this is useful for someone.