Search code examples
apache-sparkspark-graphx

How can I print out the shortest path by using Spark Graphx


The following code runs well, it print out the shortest path length between two vertexes. But how can I print out the real path or the detail edges (not only the length) between two vertexes?

val conf = new SparkConf().setAppName("SimpleGraphX").setMaster("local")
val sc = new SparkContext(conf)
val vertexArray = Array(
  (1L, ("Alice1", 28)),
  (2L, ("Bob2", 27)),
  (3L, ("Charlie3", 65)),
  (4L, ("David4", 42)),
  (5L, ("Ed5", 55)),
  (6L, ("Fran6", 50))
)
val edgeArray = Array(
  Edge(2L, 1L, 1),
  Edge(2L, 4L, 1),
  Edge(3L, 2L, 1),
  Edge(3L, 6L, 1),
  Edge(4L, 1L, 1),
  Edge(5L, 2L, 1),
  Edge(5L, 3L, 1),
  Edge(5L, 6L, 1)
)

val sourceId : VertexId = 5L;

val initialGraph = graph.mapVertices(
  (id,_) => if(id==sourceId)0.0 else Double.PositiveInfinity    
)

val sssp = initialGraph.pregel(Double.PositiveInfinity)(
    (id, dist, newDist) => math.min(dist, newDist),
    triplet =>{
      if(triplet.srcAttr + triplet.attr < triplet.dstAttr){
        Iterator((triplet.dstId, triplet.srcAttr+triplet.attr))
      }else{
        Iterator.empty
      }
    },
    (a,b) => math.min(a,b)
)    

println(sssp.vertices.collect.mkString("\n"))

When ran the code, it output the shortest path length from VertexId=5 and other vertexes, as follows:

(4,2.0)
(1,2.0)
(6,1.0)
(3,1.0)
(5,0.0)
(2,1.0)

for example, the result (4,2.0) means the shortest path length between vertex 5 and vertex 4 is 2. But I hope it could print out the detail path, such as: 5->2->4.


Solution

  • You can use graphFrame.

    import org.apache.spark.sql.types.{ArrayType, StringType, StructType, StructField, DoubleType}
    import org.graphframes._
    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.functions.udf
    import org.graphframes.lib.Pregel
    import org.apache.spark.sql.functions.struct
    import scala.collection.mutable.WrappedArray
    import org.apache.spark.sql.Row
    
    
    var vertex_df = spark.read.format("csv").option("header", "true").load("/home/tm/Documents/vertex1.csv")
    val edge_df = spark.read.format("csv").option("header", "true").load("/home/tm/Documents/edge1.csv")
    spark.sparkContext.setCheckpointDir("/home/tm/checkpoints")
    
    var graph = GraphFrame(vertex_df, edge_df)
    
    val inDegrees=graph.inDegrees
    val outDegrees=graph.outDegrees
    
    vertex_df = vertex_df.join(outDegrees,outDegrees("id") ===  vertex_df("id"), "left").join(inDegrees, inDegrees("id") === vertex_df("id"), "left").select(vertex_df("id"), vertex_df("name"), outDegrees("outDegree"), inDegrees("inDegree"))
    
    vertex_df = vertex_df.withColumn("nodeType", when(col("inDegree").isNull,"root").otherwise(when(col("outDegree").isNull,"leaf").otherwise("child")))
    
    graph = GraphFrame(vertex_df, edge_df)
    
    val root_node = 5
    
    val vertColSchema = StructType(
      List(
        StructField("dist", DoubleType, true),
        StructField("name", StringType, true),
        StructField("path", ArrayType(StringType), true)
      )
    )
    
    def vertexProgram(vd: Row, msg:Row): (Double, String, WrappedArray[String]) ={
        if (msg == null || vd(0).asInstanceOf[Double] < msg(0).asInstanceOf[Double])
        {
            (vd(0).asInstanceOf[Double], vd(1).asInstanceOf[String], vd(2).asInstanceOf[WrappedArray[String]])
        }
        else
        {
            (msg(0).asInstanceOf[Double], vd(1).asInstanceOf[String], msg(2).asInstanceOf[WrappedArray[String]])
        }
    }
    
    val vertexProgramUdf = udf(vertexProgram _)
    
    
    def sendMsgToDst(src:Row, dst:Row): (Double, String, WrappedArray[String]) = {
        val srcDist = src(0)
        val dstDist = dst(0)
    
        if (srcDist.asInstanceOf[Double] < (dstDist.asInstanceOf[Double] - 1))
        {
            (srcDist.asInstanceOf[Double] + 1, src(1).asInstanceOf[String], src(2).asInstanceOf[WrappedArray[String]] :+ dst(1).asInstanceOf[String])
        }
        else { 
            null
        }
    }
    
    val sendMsgToDstUdf = udf(sendMsgToDst _)
    
    
    def aggMsgs(agg: WrappedArray[Row]): (Double, String, WrappedArray[String]) = {
    (agg(0)(0).asInstanceOf[Double], agg(0)(1).asInstanceOf[String], agg(0)(2).asInstanceOf[WrappedArray[String]])
    }
    
    val aggMsgsUdf = udf(aggMsgs _)
    
    val dbl:Double = 0.0
    
    val result = graph.pregel.setMaxIter(3).withVertexColumn(colName = "vertCol",
    initialExpr = when(col("id")===(lit(root_node)), struct(lit(dbl), col("id"), array(col("id"))))
    .otherwise(struct(lit(Double.PositiveInfinity), col("id"), array(lit("")))).cast(vertColSchema),
    updateAfterAggMsgsExpr = vertexProgramUdf(col("vertCol"), Pregel.msg)).sendMsgToDst(sendMsgToDstUdf(col("src.vertCol"), Pregel.dst("vertCol"))).aggMsgs(aggMsgsUdf(collect_list(Pregel.msg))).run()
    
    scala> result.show()
    +---+-------+---------+--------+--------+-------------------+
    | id|   name|outDegree|inDegree|nodeType|            vertCol|
    +---+-------+---------+--------+--------+-------------------+
    |  1|  Alice|     null|       2|    leaf|[2.0, 1, [5, 2, 1]]|
    |  2|    Bob|        2|       2|   child|   [1.0, 2, [5, 2]]|
    |  3|Charlie|        2|       1|   child|   [1.0, 3, [5, 3]]|
    |  4|  David|        1|       1|   child|[2.0, 4, [5, 2, 4]]|
    |  5|     Ed|        3|    null|    root|      [0.0, 5, [5]]|
    |  6|   Fran|     null|       2|    leaf|   [1.0, 6, [5, 6]]|
    +---+-------+---------+--------+--------+-------------------+