Search code examples
scalaapache-sparkmatrix-factorizationrecommendation-engine

Spark ALS algorithm gives too many zero factors


We are using ALS in spark 2.2.1 to calculate the user embeddings and the item embeddings. Our experiment samples contain 12 billion instances, the clicked instances are marked as positive ones and negative otherwise.

When evaluating the AUC with the product of the user embeddings and the item embeddings, we find that the AUC metrics are not stable on the same train data and the same parameters. After checking the embeddings, we find that ALS may give zero factors even though the user has clicked some items which is abnormal.

Is there any ideas? Thanks for your help. Here is our code

val hivedata = sc.sql(sqltext).select(id,dpid,score).coalesce(numPartitions)
val predataItem =  hivedata.rdd.map(r=>(r._1._1,(r._1._2,r._2.sum)))
  .groupByKey().zipWithIndex()
  .persist(StorageLevel.MEMORY_AND_DISK_SER)
val predataUser = predataItem.flatMap(r=>r._1._2.map(y=>(y._1,(r._2.toInt,y._2))))
  .aggregateByKey(zeroValueArr,numPartitions)((a,b)=> a += b,(a,b)=>a ++ b).map(r=>(r._1,r._2.toIterable))
  .zipWithIndex().persist(StorageLevel.MEMORY_AND_DISK_SER)
val trainData = predataUser.flatMap(x => x._1._2.map(y => (x._2.toInt, y._1, y._2.toFloat)))
  .setName(trainDataName).persist(StorageLevel.MEMORY_AND_DISK_SER)

case class ALSData(user:Int, item:Int, rating:Float) extends Serializable
val ratingData = trainData.map(x => ALSData(x._1, x._2, x._3)).toDF()
    val als = new ALS
    val paramMap = ParamMap(als.alpha -> 25000).
      put(als.checkpointInterval, 5).
      put(als.implicitPrefs, true).
      put(als.itemCol, "item").
      put(als.maxIter, 60).
      put(als.nonnegative, false).
      put(als.numItemBlocks, 600).
      put(als.numUserBlocks, 600).
      put(als.regParam, 4.5).
      put(als.rank, 25).
      put(als.userCol, "user")
    als.fit(ratingData, paramMap)

Solution

  • There are two reasons: 1. The item vectors are all zeros when there is no positive sample for them. 2. The input data may be nondeterministic. related github issues