Search code examples
cpointersdata-structuressegmentation-faultdisjoint-union

segfault with no reason in sight while implementing a dsu in c (probably a silly mistake)


#include <stdio.h>
#include <stdlib.h>

typedef struct Parent {
    int node;
    int sum;
} Parent;

typedef struct DSU {
    Parent* parent;
    int* rank;
} DSU;

void create_dsu(DSU* dsu, int n, int* wts)
{
    dsu->parent = malloc(sizeof(Parent) * n);
    dsu->rank = malloc(sizeof(int) * n);

    for (int i = 0; i < n; i++) {
        dsu->parent[i].sum = wts[i];
        dsu->parent[i].node = i;
        dsu->rank[i] = 0;
    }
}

int find_parent(DSU* dsu, int n)
{
    if (n == dsu->parent[n].node) {
        return n;
    }

    return dsu->parent[n].node = find_parent(dsu, n);
}

void union_by_rank(DSU* dsu, int u, int v)
{
    int up = find_parent(dsu, u);
    int vp = find_parent(dsu, v);
    if (up == vp) {
        return;
    }

    else if (dsu->rank[up] > dsu->rank[vp]) {
        dsu->parent[vp].node = up;
        dsu->parent[up].sum += dsu->parent[vp].sum;
    }

    else if (dsu->rank[vp] > dsu->rank[up]) {
        dsu->parent[up].node = vp;
        dsu->parent[vp].sum += dsu->parent[up].sum;
    }

    else {
        dsu->parent[up].node = vp;
        dsu->rank[vp]++;
        dsu->parent[vp].sum += dsu->parent[up].sum;
    }
}

int find_sum(DSU* dsu, int u)
{
    int up = find_parent(dsu, u); // causes a segfault
                                  //
    // printf("%d\n", dsu->parent[3].sum); -> 17
    // printf("%d\n", dsu->parent[0].sum); -> 11
    return (dsu->parent[up].sum);
}

int main()
{
    int arr[] = { 11, 13, 1, 3, 5 };
    DSU dsu;
    create_dsu(&dsu, 5, arr);

    union_by_rank(&dsu, 1, 3);
    union_by_rank(&dsu, 2, 3);
    union_by_rank(&dsu, 0, 4);
    //
    printf("%d\n", find_sum(&dsu, 2));
}

Here I am trying to implement a Disjoint Set Union in C. In the program, there are 5 nodes which I connect using the union_by_rank function. The function findSum aims to find the sum of the sets with input one of the nodes. Example: nodes have the value 11, 13, 1, 3, 5 if nodes 1-2-3 are connected and nodes 0-4 are connected

then the sum for node 2 will be 13 + 1 + 3 = 17 (summation of node 1, 2 and 3 weights)

For some reason the find_parent int findSum function segfaults when it works just fine in the union function. As I can print the values in the findSum using printf

find_parent should have worked as it does in union function


Solution

  • Ok, so the logic for the find_parent was wrong. The find_parent should have been

    int find_parent(DSU* dsu, int n)
    {
        if (n == dsu->parent[n].node) {
            return n;
        }
    
        return dsu->parent[n].node = find_parent(dsu, dsu->parent[n].node);
    }