Search code examples
listmultithreadinglinked-listlinux-kernel

Understanding list_del_rcu() and list_for_each_entry() working


I am learning RCU link-list kernel APIs. My program is having 2 writers and 1 readers, after every second both writers will come and either add new node or delete the entire list if number of nodes are divisible by 5 completely. However despite having spin lock protection, I am getting NOPTI, I believe it is a case of accessing deleted node again. But how? I am not releasing the lock before list_del_rcu() call.

#include <linux/init.h>
#include <linux/spinlock.h>
#include <linux/rcupdate.h>
#include <linux/kthread.h>
#include <linux/string.h>
#include <linux/slab.h>
#include <linux/delay.h>

#define RD_THREAD               1
#define WR_THREAD               2
#define THREAD_NAME             16

static DEFINE_SPINLOCK(spl);
static struct task_struct       *task_read[RD_THREAD], *task_write[WR_THREAD];
typedef struct {
        struct list_head        node;
        int                     value;
} node_t;

static LIST_HEAD(head);
static atomic_t         total;

static int read_func(void *arg)
{
        node_t                  *curNode;

        while (!kthread_should_stop()) {
                ssleep(10);
                rcu_read_lock();
                list_for_each_entry_rcu(curNode, &head, node) {
                        printk(KERN_CONT "-> %d ", curNode->value);
                }
                printk(KERN_INFO "");
                rcu_read_unlock();
        }
        return 0;
}

static int write_func(void *arg)
{
        int                     counter = 0;
        node_t                  *newNode, *oldNode, *temp;
        int                     choice = 0;

        while (!kthread_should_stop()) {
                ssleep(1);
                switch (choice) {

                        case 0:
                                newNode = kmalloc(sizeof (*newNode), GFP_KERNEL);
                                if (!newNode) {
                                        printk(KERN_ERR "No memory left.\n");
                                        break;
                                }
                                newNode->value = counter++;
                                spin_lock(&spl);
                                list_add_rcu(&newNode->node, &head);
                                atomic_inc(&total);
                                spin_unlock(&spl);
                                break;

                        case 1:
                                if (counter % 5)
                                {
                                        break;
                                }
                                spin_lock(&spl);
                                list_for_each_entry_safe(oldNode, temp, &head, node) {

                                        /*
                                         * Print node content.
                                         */

                                        printk(KERN_INFO "Add %p %d %s %d\n",
                                               (void *)oldNode, oldNode->value,
                                               current->comm, atomic_read(&total));

                                        list_del_rcu(&oldNode->node);
                                        atomic_dec(&total);
                                        spin_unlock(&spl);

                                        synchronize_rcu();

                                        kfree(oldNode);
                                        spin_lock(&spl);

                                }
                                spin_unlock(&spl);
                                break;
                        default:
                                choice = -1;
                                break;
                }
                choice++;
        }

        spin_lock(&spl);
        list_for_each_entry_safe(oldNode, temp, &head, node) {
                list_del_rcu(&oldNode->node);
                spin_unlock(&spl);
                synchronize_rcu();
                kfree(oldNode);
                spin_lock(&spl);
        }
        spin_unlock(&spl);
        return 0;
}

static void end(void)
{
        int                     rc;
        unsigned int            counter;

        for (counter = 0; counter < RD_THREAD; ++counter) {
                if (task_read[counter] && !IS_ERR(task_read[counter])) {
                        rc = kthread_stop(task_read[counter]);
                        printk(KERN_INFO "read_func_%u stopped with rc (%d)\n", counter, rc);
                }
        }
        for (counter = 0; counter < WR_THREAD; ++counter) {
                if (task_write[counter] && !IS_ERR(task_write[counter])) {
                        rc = kthread_stop(task_write[counter]);
                        printk(KERN_INFO "write_func_%u stopped with rc (%d)\n", counter, rc);
                }
        }
        printk(KERN_INFO "Module unloaded.\n");
        return;
}

static int __init start(void)
{
        unsigned int            counter;
        char                    thread_name[THREAD_NAME] = { 0 };

        for (counter = 0; counter < WR_THREAD; ++counter) {
                snprintf(thread_name, THREAD_NAME, "write_func_%d", counter);
                task_write[counter] = kthread_create(write_func, NULL, thread_name);
                if (IS_ERR(task_write[counter])) {
                        end();
                        printk(KERN_ERR "Failed to create %s (%ld)\n", thread_name, PTR_ERR(task_write[counter]));
                        return PTR_ERR(task_write[counter]);
                } else {
                        wake_up_process(task_write[counter]);
                }
        }
        for (counter = 0; counter < RD_THREAD; ++counter) {
                snprintf(thread_name, THREAD_NAME, "read_func_%d", counter);
                task_read[counter] = kthread_create(read_func, NULL, thread_name);
                if (IS_ERR(task_read[counter])) {
                        end();
                        printk(KERN_ERR "Failed to create %s (%ld)\n", thread_name, PTR_ERR(task_read[counter]));
                        return PTR_ERR(task_read[counter]);
                } else {
                        wake_up_process(task_read[counter]);
                }
        }
        printk(KERN_INFO "Module started.\n");
        return 0;
}

module_init(start);
module_exit(end);
MODULE_LICENSE("GPL");
MODULE_AUTHOR("MP");

dmesg:

[ 6124.396574] Module started.
[ 6134.590443] -> 2 -> 2 -> 1 -> 1 -> 0 -> 0
[ 6134.590451]
[ 6138.750431] Add 00000000f8506a7d 4 write_func_1 10
[ 6138.750435] Add 0000000089a4bcde 4 write_func_0 9
[ 6138.756419] Add 0000000089a4bcde 4 write_func_1 8
[ 6138.756423] list_del corruption, ffff905962895a40->prev is LIST_POISON2 (dead000000000200)
[ 6138.756464] ------------[ cut here ]------------ [ 6138.756464] kernel BUG at lib/list_debug.c:50!
[ 6138.756484] invalid opcode: 0000 [#1] SMP NOPTI
[ 6138.756501] CPU: 3 PID: 19796 Comm: write_func_1 Kdump: loaded Tainted: P OE --------- - - 4.18.0-372.9.1.el8.x86_64 #1


Solution

  • When iterate by list_for_each_entry_safe() over the list, which may be modified by other threads, all iterations should be protected by a single critical section.

    But in the cycle

    list_for_each_entry_safe(oldNode, temp, &head, node)
    {
        /*
         * Print node content.
         */
        printk(KERN_INFO "Add %p %d %s %d\n",
               (void *)oldNode, oldNode->value,
               current->comm, atomic_read(&total));
    
        list_del_rcu(&oldNode->node);
        atomic_dec(&total);
        spin_unlock(&spl); // <- release the spinlock
    
        synchronize_rcu();
    
        kfree(oldNode);
        spin_lock(&spl); // <- acquire the spinlock again, but the protection is already broken
    }
    

    you release the spin lock between the iterations.

    The function list_for_each_entry_safe() uses a pointer to the next element as the iterator. So, if a concurrent writer removes that next element, the iterator becomes invalid.

    It is possible to avoid releasing the spinlock at every iteration by using kfree_rcu() function instead of pair of synchronize_rcu() and kfree():

    list_for_each_entry_safe(oldNode, temp, &head, node)
    {
        /*
         * Print node content.
         */
        printk(KERN_INFO "Add %p %d %s %d\n",
               (void *)oldNode, oldNode->value,
               current->comm, atomic_read(&total));
    
        list_del_rcu(&oldNode->node);
        atomic_dec(&total);
        // The element will be removed via kfree after a grace period.
        kfree_rcu(oldNode, rcu);
    }
    

    The above code implies that you add rcu field to your element's structure:

    typedef struct {
            struct list_head        node;
            int                     value;
            struct rcu_head         rcu;
    } node_t;