#include <stdio.h>
#include <stdlib.h>
typedef struct Parent {
int node;
int sum;
} Parent;
typedef struct DSU {
Parent* parent;
int* rank;
} DSU;
void create_dsu(DSU* dsu, int n, int* wts)
{
dsu->parent = malloc(sizeof(Parent) * n);
dsu->rank = malloc(sizeof(int) * n);
for (int i = 0; i < n; i++) {
dsu->parent[i].sum = wts[i];
dsu->parent[i].node = i;
dsu->rank[i] = 0;
}
}
int find_parent(DSU* dsu, int n)
{
if (n == dsu->parent[n].node) {
return n;
}
return dsu->parent[n].node = find_parent(dsu, n);
}
void union_by_rank(DSU* dsu, int u, int v)
{
int up = find_parent(dsu, u);
int vp = find_parent(dsu, v);
if (up == vp) {
return;
}
else if (dsu->rank[up] > dsu->rank[vp]) {
dsu->parent[vp].node = up;
dsu->parent[up].sum += dsu->parent[vp].sum;
}
else if (dsu->rank[vp] > dsu->rank[up]) {
dsu->parent[up].node = vp;
dsu->parent[vp].sum += dsu->parent[up].sum;
}
else {
dsu->parent[up].node = vp;
dsu->rank[vp]++;
dsu->parent[vp].sum += dsu->parent[up].sum;
}
}
int find_sum(DSU* dsu, int u)
{
int up = find_parent(dsu, u); // causes a segfault
//
// printf("%d\n", dsu->parent[3].sum); -> 17
// printf("%d\n", dsu->parent[0].sum); -> 11
return (dsu->parent[up].sum);
}
int main()
{
int arr[] = { 11, 13, 1, 3, 5 };
DSU dsu;
create_dsu(&dsu, 5, arr);
union_by_rank(&dsu, 1, 3);
union_by_rank(&dsu, 2, 3);
union_by_rank(&dsu, 0, 4);
//
printf("%d\n", find_sum(&dsu, 2));
}
Here I am trying to implement a Disjoint Set Union in C. In the program, there are 5 nodes which I connect using the union_by_rank function. The function findSum aims to find the sum of the sets with input one of the nodes. Example: nodes have the value 11, 13, 1, 3, 5 if nodes 1-2-3 are connected and nodes 0-4 are connected
then the sum for node 2 will be 13 + 1 + 3 = 17 (summation of node 1, 2 and 3 weights)
For some reason the find_parent int findSum function segfaults when it works just fine in the union function. As I can print the values in the findSum using printf
find_parent should have worked as it does in union function
Ok, so the logic for the find_parent was wrong. The find_parent should have been
int find_parent(DSU* dsu, int n)
{
if (n == dsu->parent[n].node) {
return n;
}
return dsu->parent[n].node = find_parent(dsu, dsu->parent[n].node);
}