Search code examples
neo4jneo4j-java-apineo4j-traversal-api

Coding a type of random walk in Neo4j using the Traversal Framework


I'm currently working on a graph where nodes are connected via probabilistic edges. The weight on each edge defines the probability of existence of the edge.

Here is an example graph to get you started

(A)-[0.5]->(B)
(A)-[0.5]->(C)
(B)-[0.5]->(C)
(B)-[0.3]->(D)
(C)-[1.0]->(E)
(C)-[0.3]->(D)
(E)-[0.3]->(D)

I would like to use the Neo4j Traversal Framework to traverse this graph starting from (A) and return the number of nodes that have been reached based on the probability of the edges found along the way.

Important:

  • Each node that is reached can only be counted once. -> If (A) reaches (B) and (C), then (C) need not reach (B). On the other hand if (A) fails to reach (B) but reaches (C) then (C) will attempt to reach (B).
  • The same goes if (B) reaches (C), (C) will not try and reach (B) again.
  • This is a discrete time step function, a node will only attempt to reach a neighboring node once.
  • To test the existence of an edge (whether we traverse it) we can generate a random number and verify if it's smaller than the edge weight.

I have already coded part of the traversal description as follows. (Here it is possible to start from multiple nodes but that is not necessary to solve the problem.)

TraversalDescription traversal = db.traversalDescription()
            .breadthFirst()
            .relationships( Rels.INFLUENCES, Direction.OUTGOING )
            .uniqueness( Uniqueness.NODE_PATH )
            .uniqueness( Uniqueness.RELATIONSHIP_GLOBAL )
            .evaluator(new Evaluator() {

              @Override
              public Evaluation evaluate(Path path) {

                // Get current
                Node curNode = path.endNode();

                // If current node is the start node, it doesn't have previous relationship,
                // Just add it to result and keep traversing
                if (startNodes.contains(curNode)) {
                    return Evaluation.INCLUDE_AND_CONTINUE;
                }
                // Otherwise...
                else {
                  // Get current relationhsip
                  Relationship curRel = path.lastRelationship();

                  // Instantiate random number generator
                  Random rnd = new  Random();

                  // Get a random number (between 0 and 1)
                  double rndNum = rnd.nextDouble();


                  // relationship wc is greater than the random number
                  if (rndNum < (double)curRel.getProperty("wc")) {


                    String info = "";
                    if (curRel != null) {
                        Node prevNode = curRel.getOtherNode(curNode);
                        info += "(" + prevNode.getProperty("name") + ")-[" + curRel.getProperty("wc") + "]->";
                    }
                    info += "(" + curNode.getProperty("name") + ")";
                    info += " :" + rndNum;
                    System.out.println(info);

                    // Keep node and keep traversing
                    return Evaluation.INCLUDE_AND_CONTINUE;
                  } else {

                    // Don't save node in result and stop traversing
                    return Evaluation.EXCLUDE_AND_PRUNE;
                  }
                }
              }
            });

I keep track of the number of nodes reached like so:

long score = 0;
for (Node currentNode : traversal.traverse( nodeList ).nodes())
{
    System.out.print(" <" + currentNode.getProperty("name") + "> ");
    score += 1;
}

The problem with this code is that although NODE_PATH is defined there may be cycles which I don't want.

Therefore, I would like to know:

  • Is there is a solution to avoid cycles and count exactly the number of nodes reached?
  • And ideally, is it possible (or better) to do the same thing using PathExpander, and if yes how can I go about coding that?

Thanks


Solution

  • This certainly isn't the best answer.

    Instead of iterating on nodes() I iterate on the paths, and add the endNode() to a set and then simply get the size of the set as the number of unique nodes.

    HashSet<String> nodes = new HashSet<>();
    for (Path path : traversal.traverse(nodeList))
        {
            Node currNode = path.endNode();
            String val = String.valueOf(currNode.getProperty("name"));
            nodes.add(val);
            System.out.println(path);
            System.out.println("");
        }
        score = nodes.size();
    

    Hopefully someone can suggest a more optimal solution.

    I'm still surprised though that NODE_PATH didn't not prevent cycles from forming.