Search code examples
pythonscalaapache-spark-mllibxgboost

I have a XGBoost model trained in python, but it will get a different predictions when loaded in scala and used the same features, why?


I have a xgboost model trained in python api named as my_fpd20.model, now I want use it in Scala to execute the prediction operation, but when I do test, there get a different predicted result when use the same features.(the xgb version is 0.82)

The code as below:

import ml.dmlc.xgboost4j.scala.XGBoost
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.LabeledPoint

object MSFScoreCompute {
    def main(args: Array[String]): Unit = {
        val model = XGBoost.loadModel("/home/mixxbox/my_fpd20.model")

        val t1 = (1007988895058497544L, 0.928856,   ",,5436.0,559.0,2169.0,15267.0,360.0,3619.0,3508.0,1412.0,,,,118.0,4.0,,,,,,,511.0,648.0,312.0,57.0,4573.0,3116.0,3530.0,124.0,1774.0,521.0,,,,,625.0,124.0,,246.0,,,,,,,,,,,,,,,,,,,,,,608.0,8020.0,7246.0,6427.0,24328.0,22359.0,14074.0,66638.0,3636.0,,608.0,793.0,9183.0,,,,,2451.0,375.0,1127.0,4621.0,,,2.0,,,126.0,,1610.0,,,,,10469.0,,166.0,,2.0,,,1610.0,4869.0,7150.0,,,,,,9.0,,,4179.0,97.0,147.0,938.0,172.0,228.0,543.0,1226.0,115.0,90.0,163.0,199.0,4073.0,2860.0,598.0,469.0,,,172.0,67.0,18.0,,2725.0,6.0,296.0,435.0,,,,273.0,,,147.0,25.0,1397.0,103.0,,,,,52.0,,,,,,,50.0,159.0,134.0,420.0,143.0,64.0,36.0,228.0,,23.0,837.0,,,1650.0,3781.0,1019.0,40.0,116.0,186.0,13826.0,2783.0,,26.0,,,1394.0,,1056.0,135.0,,632.0,,,,87059.0,3821.0,6121.0,183069.0,284618.0,273332.0,292360.0,1068.0,1547.0,1206.0,32231.0,70824.0,70372.0,84594.0,59782.0,61287.0,64698.0,74435.0,471844.0,2602.0,225.0,14518.0,22737.0,24931.0,16938.0,9972.0,17280.0,14224.0,5.0,,,1379.0,1415.0,,2685.0,2977.0,71.0,16853.0,,,,,141.0,71803.0,88547.0,264464.0,196441.0,280080.0,187640.0,272245.0,259110.0,180342.0,248236.0,297921.0,1755.0,88677.0,90063.0,88413.0,58363.0,96684.0,94109.0,21.0,,13925.0,2829.0,16684.0,16453.0,22001.0,23828.0,13205.0,3314.0,14718.0,38169.0,,,7661.0,13241.0,17582.0,222.0,10880.0,,,5104.407676174497,12.4,12.4,7845.433933358866,4060.872788134977,6.76,,,,,,,,,,,1.0,1.0,8.0,189.0,173.0,146.0,70.0,14.0,25.0,38.0,17.0,10858.0,,,432.0,3698.0,3381.0,942.0,109.0,870.0,8.0,8.0,,,,,,,,2.0,4.0,4.0,3.0,1.0,1.0,5.0,8.0")
        val t2 = (1002897541026550024L, 0.969927,   ",,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0")
        val t3 = (1005094818872823816L, 0.96601504, ",,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,5.0,,6.0,,,,,,,1.0,1.0,1.0,2.0,1.0,1.0,2.0,2.0")
        val t4 = (1005094818872823816L, 0.96601504, ",,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0")
        val t5 = (1005966203849544456L, 0.96515334, "13.0,577.0,371.0,121.0,132.0,3052.0,210.0,,882.0,1308.0,,,,683.0,2079.0,,,,,,,,,37.0,25.0,4035.0,225.0,530.0,15.0,,17.0,,,,,15.0,,,,,,,,,,,,,,555.0,164.0,1243.0,,,,524.0,,,,,,,,1411.0,,,7094.0,44568.0,,,,,,,,,,,,,109.0,,50.0,187.0,130.0,122.0,2950.0,1253.0,3326.0,316.0,232.0,,1407.0,15413.0,6595.0,881.0,25.0,57.0,8.0,34.0,72.0,1327.0,4110.0,4138.0,2346.0,,57.0,,414.0,92.0,,755.0,8451.0,15674.0,18366.0,1120.0,2486.0,17526.0,21293.0,49350.0,38514.0,28455.0,77143.0,50427.0,27724.0,53642.0,59878.0,1500.0,4926.0,809.0,21877.0,5977.0,20509.0,13501.0,1557.0,100799.0,25869.0,23.0,1500.0,4926.0,12115.0,3.0,10918.0,1744.0,2851.0,23947.0,16003.0,1.0,,,427.0,815.0,,,,633.0,188.0,43.0,,,78.0,15.0,52.0,63.0,,50.0,,,3679.0,285.0,912.0,4852.0,763.0,2193.0,3164.0,4977.0,2425.0,20147.0,304.0,872.0,,,52.0,60.0,104.0,225.0,641.0,,,,,,34043.0,11996.0,7274.0,32689.0,61913.0,49065.0,85510.0,6076.0,622.0,5298.0,12458.0,8957.0,16414.0,29491.0,31258.0,23791.0,24628.0,34797.0,162136.0,1881.0,516.0,4343.0,2442.0,5408.0,2173.0,7204.0,883.0,9821.0,58.0,,,885.0,995.0,431.0,581.0,1566.0,736.0,10097.0,,,,,,104805.0,138960.0,151672.0,94589.0,129486.0,149334.0,213810.0,180026.0,130354.0,152028.0,197758.0,422.0,64783.0,75352.0,57723.0,33850.0,83382.0,57379.0,9822.0,1141.0,12815.0,857.0,886.0,15905.0,9014.0,11621.0,14919.0,1661.0,9428.0,3358.0,50.0,3.0,1112.0,2501.0,4630.0,300.0,1413.0,,,4126.941409266409,4.87,4.87,4740.780104712042,3864.2336822246352,4.62,,,,,,,,,,,2.0,5.0,5.0,23.0,15.0,25.0,6.0,,,10.0,3.0,285.0,13.0,9.0,232.0,1504.0,820.0,297.0,296.0,461.0,9.0,8.0,,,1.0,,,2.0,,2.0,4.0,5.0,5.0,1.0,1.0,6.0,10.0")
        val t6 = (1003746555179569416L, 0.97295755, ",90.0,,293.0,148.0,2043.0,56.0,70.0,113.0,482.0,,,,160.0,2460.0,354.0,,,2509.0,29.0,71.0,816.0,255.0,885.0,1072.0,5563.0,4042.0,2431.0,268.0,318.0,653.0,591.0,77.0,151.0,710.0,2675.0,2.0,435.0,157.0,,,,,,,,,,,,,,,,,,,,,,,,1304.0,2147.0,4777.0,5847.0,38116.0,168435.0,,,,513.0,,,,,,,,,1457.0,,1191.0,68.0,122.0,338.0,1893.0,216.0,267.0,972.0,1014.0,119.0,4.0,7607.0,3772.0,,1191.0,,68.0,46.0,189.0,663.0,4619.0,2320.0,668.0,,,,,,85.0,511.0,28238.0,419.0,8897.0,12713.0,35535.0,13884.0,19754.0,16271.0,33518.0,12715.0,34050.0,34536.0,36717.0,27203.0,45865.0,7030.0,10168.0,1305.0,410.0,4968.0,335.0,28300.0,3720.0,13834.0,26539.0,29.0,1967.0,,1503.0,53.0,219.0,1952.0,,9862.0,4314.0,,,81.0,,194.0,,,,,,,,,260.0,416.0,416.0,1086.0,425.0,,,,3082.0,212.0,427.0,2387.0,1776.0,3846.0,9103.0,8933.0,6526.0,13894.0,1150.0,160.0,,,,,,294.0,,15.0,174.0,,,,34210.0,4986.0,1217.0,26738.0,63735.0,62730.0,187933.0,239.0,4590.0,157.0,6212.0,12547.0,23008.0,19881.0,19423.0,17632.0,28744.0,61775.0,172619.0,172.0,,2211.0,3989.0,2641.0,3405.0,7307.0,17613.0,3274.0,1594.0,,,1191.0,1027.0,3030.0,1215.0,653.0,2677.0,6224.0,,,,321.0,326.0,42824.0,61799.0,119313.0,82310.0,150184.0,83475.0,148886.0,131373.0,96711.0,155795.0,189394.0,3944.0,47375.0,50923.0,42097.0,26121.0,71303.0,49468.0,,,9026.0,3180.0,2409.0,4598.0,10544.0,14670.0,6743.0,4249.0,2354.0,14226.0,,,320.0,645.0,1075.0,100.0,38.0,,,4311.241637010675,11.18,11.18,5006.421052631579,2395.94753248642,5.95,,,,,,,5.0,4.0,4.0,4.0,3.0,,382.0,890.0,768.0,147.0,358.0,13.0,101.0,50.0,108.0,15215.0,4.0,3.0,317.0,2486.0,1954.0,400.0,554.0,79.0,6.0,,,,19.0,,,108.0,,3.0,5.0,4.0,4.0,3.0,3.0,6.0,5.0")
        Array(t1, t2, t3, t4, t5, t6)
            .map(row => {
                val features = row._3.split(",").map(row => if ("".equals(row)) Double.NaN else row.toDouble).map(row => if (row.equals(0) || row.equals(-1)) Double.NaN else row)
                val labelPoint = LabeledPoint(0, (0 to 326).toArray, features.map(row => row.toFloat))
                val result = model.predict(new DMatrix(Iterator(labelPoint)), false, 1000)
                Array(row._1.toString, row._2, 1.0 - result(0)(0), row._2 - (1 - result(0)(0))).mkString(",")
        }).foreach(row => println("result:++++" + row))
  }
}

The result as below, the 1st column is key, the 2nd column is python predict result, the 3rd column is Scala predict result, the 4th column is the diff value.

result:++++1007988895058497544,0.928856,0.9201448783278465,0.008711144023895279
result:++++1002897541026550024,0.969927,0.8605302572250366,0.10939674277496336
result:++++1005094818872823816,0.96601504,0.8911455124616623,0.07486951263717656
result:++++1005094818872823816,0.96601504,0.8605302572250366,0.10548478277496343
result:++++1005966203849544456,0.96515334,0.9438987076282501,0.021254662174072236
result:++++1003746555179569416,0.97295755,0.9416626989841461,0.03129488081817622

My question is, what could I do to get the consistent result? (I guess it maybe caused by the difference of NaN value between Scala and Python, but how to deal this problem?)


Solution

  • I have solved this problem by assign "missing" value of DMatrix, this "missing" means the value being assigned represent the missing value, the constructors as below:

      @throws(classOf[XGBoostError])
      def this(data: Array[Float], nrow: Int, ncol: Int, missing: Float) {
        this(new JDMatrix(data, nrow, ncol, missing))
      }
    

    for example,

      val ma = new DMatrix(my_features_array, 1, 327, Float.NaN)
      val result = model.predict(ma, false)
    

    as the code, the missing value is NaN, represent that if there are NaNs in your feature array, these NaN will be regard as missing. the result now is consitent as below:

    result:++++1007988895058497544,0.928856,0.9288560450077057,-1.5205383285810115E-8
    result:++++1010594380924326152,0.96601504,0.9660150445997715,-8.744811541561148E-10
    result:++++1002897541026550024,0.969927,0.9699269998818636,-1.292037965505699E-8
    result:++++1005094818872823816,0.96601504,0.9660150445997715,-8.744811541561148E-10
    result:++++1005966203849544456,0.96515334,0.9651533514261246,3.4750365918156945E-9
    result:++++1003746555179569416,0.97295755,0.9729575514793396,-1.4793396507783996E-9