Search code examples
scalaapache-sparkspark-graphx

how to attach properties to vertices in a graphx and retrieve the neighbourhood


I am rather new with Spark and Scala... I have a graph:Graph[Int, String] and I'd like to attach to these vertices some properties I have in a DataFrame.

What I need to do is, for each vertex, to find the average value in the neighbourhood for each property. This is my approach so far, but I don't understand how to correctly map the Row I get from the join of the two data frames:

val res = graph.collectNeighbors(EdgeDirection.Either)
         .toDF("ID", "neighbours")
         .join(aDataFrameWithProperties, "ID")
         .map{x => // this is where I am lost
         }

I don't think my approach is any right, because I join the properties of each vertex with the array of their neighbours, but still I don't know the values of the properties for the neighbours...

EDIT

Some data to help understand what I want to accomplish... say you build the graph as in this answer to how to create EdgeRDD from data frame in Spark

val sqlc : SQLContext = ???

case class Person(id: Long, country: String, age: Int)

val testPeople = Seq(
   Person(1, "Romania"    , 15),
   Person(2, "New Zealand", 30),
   Person(3, "Romania"    , 17),
   Person(4, "Iceland"    , 20),
   Person(5, "Romania"    , 40),
   Person(6, "Romania"    , 44),
   Person(7, "Romania"    , 45),
   Person(8, "Iceland"    , 21),
   Person(9, "Iceland"    , 22)
 )

 val people = sqlc.createDataFrame(testPeople)
 val peopleR = people
   .withColumnRenamed("id"     , "idR")
   .withColumnRenamed("country", "countryR")
   .withColumnRenamed("age"    , "ageR")

 import org.apache.spark.sql.functions._

 val relations = people.join(peopleR,
       (people("id") < peopleR("idR")) &&
         (people("country") === peopleR("countryR")) &&
         (abs(people("age") - peopleR("ageR")) < 5))

 import org.apache.spark.graphx._

 val edges = EdgeRDD.fromEdges(relations.map(row => Edge(
       row.getAs[Long]("id"), row.getAs[Long]("idR"), ())))

 val users = VertexRDD.apply(people.map(row => (row.getAs[Int]("id").toLong, row.getAs[Int]("id").toInt)))

 val graph = Graph(users, edges)

Then you have a data frame like:

case class Person(id:Long, gender:Int, income:Int)
val properties = Seq(
  Person(1, 0, 321),
  Person(2, 1, 212),
  Person(3, 0, 212),
  Person(4, 0, 122),
  Person(5, 1, 898),
  Person(6, 1, 212),
  Person(7, 1, 22),
  Person(8, 0, 8),
  Person(9, 0, 212)
)

val people = sqlc.createDataFrame(properties)

I'd like to compute, for each Vertex, what is the average sex and what is the average income of the neighbours, returned as a DataFrame


Solution

  • Generally speaking you should use graph operators instead of converting everything to a DataFrame but something like this should do the trick:

    import org.apache.spark.sql.functions.{explode, avg}
    
    val statsDF = graph.collectNeighbors(EdgeDirection.Either)
      .toDF("ID", "neighbours")
      // Flatten neighbours column
      .withColumn("neighbour", explode($"neighbours"))
      // and extract neighbour id
      .select($"ID".alias("this_id"), $"neighbour._1".alias("other_id"))
      // join with people 
      .join(people, people("ID") === $"other_id")
      .groupBy($"this_id")
      .agg(avg($"gender"), avg($"income"))
    

    what if instead of an average I'd like to count, say, the number of neighbours with gender = the gender of myself and then find the average over all connections

    To do this you would need two separate joins - one on this_id and one on ohter_id. Next you can simply aggregate with following expression:

    avg((this_gender === other_gender).cast("integer"))
    

    Regarding graph operators there are a few operations you can use. For starters you can use join operation to add properties to vertices:

    val properties: RDD[(VertexId, (Int, Int))] = sc.parallelize(Seq(
      (1L, (0, 321)), (2L, (1, 212)), (3L, (0, 212)),
      (4L, (0, 122)), (5L, (1, 898)), (6L, (1, 212)),
      (7L, (1, 22)), (8L, (0, 8)), (9L, (0, 212))
    ))
    
    val graphWithProperties = graph
      .outerJoinVertices(properties)((_, _, prop) => prop)
      // For simplicity this assumes no missing values 
      .mapVertices((_, props) => props.get) 
    

    Next we can aggregate messages to create new VertexRDD

    val neighboursAggregated = graphWithProperties
      .aggregateMessages[(Int, (Int, Int))](
        triplet => {
          triplet.sendToDst(1, triplet.srcAttr)
          triplet.sendToSrc(1, triplet.dstAttr)
        },
        {case ((cnt1, (age1, inc1)), (cnt2, (age2, inc2))) =>
          (cnt1 + cnt2, (age1 + age2, inc1 + inc2))}
      )
    

    Finally it we can replace existing properties:

    graphWithProperties.outerJoinVertices(neighboursAggregated)(
      (_, oldProps, newProps) => newProps match {
        case Some((cnt, (gender, inc))) => Some(
          if (oldProps._1 == 1) gender.toDouble / cnt
          else  1 - gender.toDouble / cnt,
          inc.toDouble / cnt
        )
        case _ => None
      })
    

    If you're interested only into values you can pass all required values in aggregateMessages and omit the second outerJoinVertices.