Search code examples
tensorflowjohnsnowlabs-spark-nlp

Cannot use SparkNLP pre-trained T5Transformer, executor fails with error "No Operation named [encoder_input_ids] in the Graph"


Downloaded T5-small model from SparkNLP website, and using this code (almost entirely from the examples):

    import com.johnsnowlabs.nlp.SparkNLP
    import com.johnsnowlabs.nlp.annotators.seq2seq.T5Transformer
    import org.apache.spark.sql.SparkSession

    val spark = SparkSession.builder()
      .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .config("spark.kryoserializer.buffer.max", "500M")
      .master("local").getOrCreate()
    SparkNLP.start()

    val testData = spark.createDataFrame(Seq(
      (1, "Google has announced the release of a beta version of the popular TensorFlow machine learning library"),
      (2, "The Paris metro will soon enter the 21st century, ditching single-use paper tickets for rechargeable electronic cards.")
    )).toDF("id", "text")

    val documentAssembler = new DocumentAssembler()
      .setInputCol("text")
      .setOutputCol("documents")

    val t5 = T5Transformer.load("/tmp/t5-small")
      .setTask("summarize:")
      .setInputCols(Array("documents"))
      .setOutputCol("summaries")

    new Pipeline().setStages(Array(documentAssembler, t5))
      .fit(testData)
      .transform(testData)
      .select("summaries.result").show(truncate = false)

I get this error from the executor:

Caused by: java.lang.IllegalArgumentException: No Operation named [encoder_input_ids] in the Graph
    at org.tensorflow.Session$Runner.operationByName(Session.java:384)
    at org.tensorflow.Session$Runner.parseOutput(Session.java:398)
    at org.tensorflow.Session$Runner.feed(Session.java:132)
    at com.johnsnowlabs.ml.tensorflow.TensorflowT5.process(TensorflowT5.scala:76)

Initially run with Spark-2.3.0, but the issue also reproduced with spark-2.4.4. Other SparkNLP features work well, only this T5 model fails. The model on disk:

$ ll /tmp/t5-small
drwxr-xr-x@ 6 XXX  XXX        192 Dec 25 12:36 metadata
-rw-r--r--@ 1 XXX  XXX     791656 Dec 22 18:32 t5_spp
-rw-r--r--@ 1 XXX  XXX  175686374 Dec 22 18:32 t5_tensorflow

$ cat /tmp/t5-small/metadata/part-00000 
{"class":"com.johnsnowlabs.nlp.annotators.seq2seq.T5Transformer","timestamp":1608475002145,
 "sparkVersion":"2.4.4","uid":"T5Transformer_1e0a16435680","paramMap":{},
 "defaultParamMap":{"task":"","lazyAnnotator":false,"maxOutputLength":200}}

I'm new to SparkNLP, so I'm not sure if this is an actual issue or am I doing something wrong. Will appreciate any help.


Solution

  • The offline model of T5 - t5_base_en_2.7.1_2.4_1610133506835 - was trained on SparkNLP 2.7.1, and there was a breaking change in 2.7.2.

    Solved by downloading and re-saving the new version with

    # dev:
    T5Transformer().pretrained("t5_small").save(...)
    
    # prod:
    T5Transformer.load(path)