Search code examples
pythonmatplotlibpygraphviz

How to visualize a Bayesian network model constructed with pomegranate


I want to visualize a Bayesian network created with pomegranate with the following code.

import math
from pomegranate import *
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd

SOmeone has an idea of how I can to this using matplotlib or pygraphvis?

df = pd.DataFrame({'A':[0,0,0,1,0], 'B':[0,0,1,0,0], 'C':[1,1,0,0,1], 'D':[0,1,0,1,1]})
print(df)
df.head()

model = BayesianNetwork.from_samples(df.to_numpy(), state_names=df.columns.values, algorithm='exact')

print(model)

Solution

  • I do not know for Pomegranate, but if I may, using pyAgrum,

    import pyAgrum as gum
    import pandas as pd
    import pyAgrum.lib.notebook as gnb 
    
    df = pd.DataFrame({'A':[0,0,0,1,0], 'B':[0,0,1,0,0], 'C':[1,1,0,0,1], 'D':[0,1,0,1,1]})
    gum.BNLearner(df).useAprioriSmoothing(1e-5).useScoreLog2Likelihood().learnBN()
    

    which returns in a jupyter notebook :

    enter image description here

    For more information on the learned BN,

    bn=gum.BNLearner(df).useAprioriSmoothing(1e-).useScoreLog2Likelihood().learnBN()
    gnb.sideBySide(bn,gnb.getInference(bn))
    gnb.sideBySide(*[bn.cpt(i) for i in bn.nodes()])
    

    enter image description here