Search code examples
swiftmacoscoremlcreateml

CoreML: why is the predictions an array?


consider a csv file like this:

number,weigth,length,depth,diameter
1,100,202,314,455
2,1040,2062,3314,4585
3,1200,2502,3134,4557
4,1500,2052,3143,4655
...

and a code like this

let csvFile = Bundle.main.url(forResource: "myData", withExtension: "csv")!
let  dataTable = try! MLDataTable(contentsOf: csvFile)

// print(dataTable)

let regressorColumns = ["weigth", "length", "depth", "diameter"]
let regressorTable = dataTable[regressorColumns]

let (regressorEvaluationTable, regressorTrainingTable) = regressorTable.randomSplit(by: 0.20, seed: 5)

let regressor = try! MLLinearRegressor(trainingData: regressorTrainingTable,
                                      targetColumn: "weigth")

let prediction = try! regressor.predictions(from: dataTable)
print (prediction)

prediction is an array of floats with the same number of elements of the csv file itself.

Four questions:

  1. why is it an array?
  2. why floats?
  3. why the array has the same number of elements as the input csv?
  4. what exactly this array represents?

Solution

  • What the code you posted does is it trains a machine learning model (specifically – a linear regression model) on some input data (regressorTrainingTable) with the goal to be able to predict some weight value (a "dependent" or "target" value) based on length, depth and diameter ("independent" or "feature" values). Then this model is actually used to calculate a weight value for every row of data (length, depth and diameter) stored in dataTable.

    So prediction is a collection of predictions of what a weight value would be based on the values of length, depth and diameter for each row stored in dataTable. Hopefully this answers question 1, 3 and 4.

    As for the second question, it just has to do with how linear regression approach to building a model works under the hood. When building (training) a model, it treats all input (both dependent and independent) values as continuous numeric (that is, floats) even if they are expressed as integers in a datafile.