Search code examples
javahuggingface-transformersonnx

Empty result Apache OpenNLP ONNX model


I trying to convert huggingface model to onnx for classifying text in Java app, but i can't undestand why i don't see result(result array is just empty). readme.md has link on the model, and it's working pretty well, but i have to use some another one because it is not support language which i need.

Working sample python code:

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

model_name = "tabularisai/multilingual-sentiment-analysis"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

def predict_sentiment(texts):
    inputs = tokenizer(texts, return_tensors="pt", truncation=True, padding=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
    probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
    sentiment_map = {0: "Very Negative", 1: "Negative", 2: "Neutral", 3: "Positive", 4: "Very Positive"}
    return [sentiment_map[p] for p in torch.argmax(probabilities, dim=-1).tolist()]

print(predict_sentiment(["I absolutely love the new design of this app!", "The customer service was disappointing."]))

I tried to convert a few models with two ways:

python -m optimum.exporters.onnx --model tabularisai/multilingual-sentiment-analysis --task sequence-classification onnx_model

and

from optimum.onnxruntime import ORTModelForFeatureExtraction

model = ORTModelForFeatureExtraction.from_pretrained("tabularisai/multilingual-sentiment-analysis", from_transformers=True)
model.save_pretrained("onnx_model")

Same result - onnx_model folder with model, vocab etc

nlptown_bert-base-multilingual-uncased-sentiment - model from readme and it works as expect even i use converted model vocab file Java code sample:

    public void def() {
        try (final DocumentCategorizerDL documentCategorizerDL =
                     new DocumentCategorizerDL(
                             new File("onnx_model/model.onnx"),
//                             new File("nlptown_bert-base-multilingual-uncased-sentiment.onnx"),
                             new File("onnx_model/vocab.txt"),
                             getCategories(),
                             new AverageClassificationScoringStrategy(),
                             new InferenceOptions())) {

            final double[] result = documentCategorizerDL.categorize(new String[] {"I absolutely love the new design of this app!", "The customer service was disappointing."});
            System.out.println("done");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private Map<Integer, String> getCategories() {
        final Map<Integer, String> categories = new HashMap<>();
        categories.put(0, "Very Negative");
        categories.put(1, "Negative");
        categories.put(2, "Neutral");
        categories.put(3, "Positive");
        categories.put(4, "Very Positive");
        return categories;
    }

I don't know how is important but opset version of working model is 11, but my is 14. My opennlp-dl library version is 2.5.1


Solution

  • setIncludeTokenTypeIds(true) was passed to options and this helped to solve this problem

        public void def() {
        var options = new InferenceOptions();
        options.setIncludeTokenTypeIds(false);
        try (
             final DocumentCategorizerDL documentCategorizerDL =
                     new DocumentCategorizerDL(
                             new File("onnx_model/model.onnx"),
                             new File("onnx_model/vocab.txt"),
                             getCategories(),
                             new AverageClassificationScoringStrategy(),
                             options)) {
    
            final double[] result = documentCategorizerDL.categorize(new String[] {"I absolutely love the new design of this app!", "The customer service was disappointing."});
            System.out.println("s");
        } catch (Exception e) {
            log.error(e.getMessage(), e);
        }
    }
    
    private Map<Integer, String> getCategories() {
        final Map<Integer, String> categories = new HashMap<>();
        categories.put(0, "Very Negative");
        categories.put(1, "Negative");
        categories.put(2, "Neutral");
        categories.put(3, "Positive");
        categories.put(4, "Very Positive");
        return categories;
    }