Search code examples
scalaapache-sparkspark-graphx

Spark GraphX : Filtering by passing a vertex value in triplet


I am using Spark 2.1.0 on Windows 10. Since I am new to Spark, I am following this tutorial

In the tutorial, the author prints all the triplets of the graph using the following code:

graph.triplets.sortBy(_.attr, ascending=false).map(triplet =>
"There were " + triplet.attr.toString + " flights from " + triplet.srcAttr + " to " + triplet.dstAttr + ".").take(10)

Problem : I would like to give an input (ATL for example) and I would like to see all the outbound flights from ATL and their counts as shown below:

res60: Array[String] = Array(There were 1388 flights from ATL to LAX.,
There were 1330 flights from ATL to SFO., There were 1283 flights from ATL to HNL., 
There were 1205 flights from ATL to BOS., There were 1229 flights from ATL to LGA., 
There were 1214 flights from ATL to OGG., There were 1173 flights from ATL to LAS., 
There were 1157 flights from ATL to SAN.)

Solution

  • The following is the code:

    // Selecting the desired airport
    val input = "ATL"
    // filtering the edges of the desired airport (here "ATL") from the `graph`(which is built on the full data)
    val TEMPEdge = graph.edges.filter { case Edge(src, dst, prop) => src == MurmurHash3.stringHash(input) }
    // Creating a new graph with the filtered edges
    val TEMPGraph = Graph(airportVertices, TEMPEdge, defaultAirport)
    // Printing the top 10
    TEMPGraph.triplets.sortBy(_.attr, ascending=false).map(triplet => "There were " + triplet.attr.toString + " flights from " + triplet.srcAttr + " to " + triplet.dstAttr + "\n").take(10)
    

    or, we can use filter

    graph.triplets.sortBy(_.attr, ascending=false).filter {_.dstAttr == input }.map(triplet => "There were " + triplet.attr.toString + " flights from " + triplet.srcAttr + " to " + triplet.dstAttr + "\n").take(3)