Search code examples
iosswiftcreateml

Can we train CreateML on iOS device using user's data?


I understand that we can train ML model on macOS with Xcode CreateML GUI, as well as in macOS Playground. The problem I had, is to train a similar model on user's device, using their own data. I'm wondering if it's possible?

Can we train CreateML text classifier on user's device? I had did some research but could not find an answer. Mostly people are talking about deploy a trained model to iOS. But I wanna to train on iOS.

P.s.I also had a look at the updatable CoreML model. Which does not seems to support text classifier. They only supports KNN model as well as shallow neural network.

More Specifically. Can we even use MLTextClassifier this to create Model on iOS? The conflict information is that, on Apple's CreateML main page, it says you need to train on Mac. But this API seems to indicate that it supports iOS, which really confuses me.

init(trainingData: [String : [String]], parameters: MLTextClassifier.ModelParameters) 

Solution

  • The CreateML module does work on iOS (since iOS 15). It just doesn't work on iOS simulator.

    You can surround all your training code with

    #if canImport(CreateML)
    
    ...
    
    #endif
    

    so that it only runs when you are on a real device. Admittedly, this is rather inconvenient...

    As for how to use the CreateML API, you can follow the guide here. The code would look something like this. Note that I've updated some of the deprecated (since iOS 16) code in the guide to use the newest APIs.

    import CoreML
    import CreateML
    import NaturalLanguage
    import TabularData
    
    // training...
    let sentimentClassifier = try MLTextClassifier(trainingData: [
        "positive": [...],
        "negative": [...],
        "neutral": [...],
    ])
    
    // write to file for later use...
    let metadata = MLModelMetadata(author: "John Appleseed",
                                   shortDescription: "A model trained to classify movie review sentiment",
                                   version: "1.0")
    try sentimentClassifier.write(to: URL(fileURLWithPath: "/path/to/save/SentimentClassifier.mlmodel"),
                                  metadata: metadata)
    // or use it immediately:
    print(sentimentClassifier.prediction(from: "foo bar baz"))
    
    //... at some later point
    
    let model = try MLModel(contentsOf: URL(fileURLWithPath: "/path/to/save/SentimentClassifier.mlmodel"))
    let nlModel = try NLModel(mlModel: model)
    print(nlModel.predictedLabel(for: "foo bar baz") ?? "no label")