无法使用 SparkNLP 预训练的 T5Transformer,执行程序失败并显示错误“图中没有名为 [encoder_input_ids] 的操作”

如何解决无法使用 SparkNLP 预训练的 T5Transformer,执行程序失败并显示错误“图中没有名为 [encoder_input_ids] 的操作”

从 SparkNLP 网站下载 T5-small 模型,并使用此代码(几乎完全来自示例):

    import com.johnsNowlabs.nlp.SparkNLP
    import com.johnsNowlabs.nlp.annotators.seq2seq.T5Transformer
    import org.apache.spark.sql.SparkSession

    val spark = SparkSession.builder()
      .config("spark.serializer","org.apache.spark.serializer.KryoSerializer")
      .config("spark.kryoserializer.buffer.max","500M")
      .master("local").getorCreate()
    SparkNLP.start()

    val testData = spark.createDataFrame(Seq(
      (1,"Google has announced the release of a beta version of the popular TensorFlow machine learning library"),(2,"The Paris metro will soon enter the 21st century,ditching single-use paper tickets for rechargeable electronic cards.")
    )).toDF("id","text")

    val documentAssembler = new DocumentAssembler()
      .setInputCol("text")
      .setoutputCol("documents")

    val t5 = T5Transformer.load("/tmp/t5-small")
      .setTask("summarize:")
      .setInputCols(Array("documents"))
      .setoutputCol("summaries")

    new Pipeline().setStages(Array(documentAssembler,t5))
      .fit(testData)
      .transform(testData)
      .select("summaries.result").show(truncate = false)

我从执行程序那里得到这个错误

Caused by: java.lang.IllegalArgumentException: No Operation named [encoder_input_ids] in the Graph
    at org.tensorflow.Session$Runner.operationByName(Session.java:384)
    at org.tensorflow.Session$Runner.parSEOutput(Session.java:398)
    at org.tensorflow.Session$Runner.Feed(Session.java:132)
    at com.johnsNowlabs.ml.tensorflow.TensorflowT5.process(TensorflowT5.scala:76)

最初使用 Spark-2.3.0 运行,但使用 spark-2.4.4 也重现了该问题。其他 SparkNLP 功能运行良好,只有这个 T5 模型失败。磁盘模型:

$ ll /tmp/t5-small
drwxr-xr-x@ 6 XXX  XXX        192 Dec 25 12:36 Metadata
-rw-r--r--@ 1 XXX  XXX     791656 Dec 22 18:32 t5_spp
-rw-r--r--@ 1 XXX  XXX  175686374 Dec 22 18:32 t5_tensorflow

$ cat /tmp/t5-small/Metadata/part-00000 
{"class":"com.johnsNowlabs.nlp.annotators.seq2seq.T5Transformer","timestamp":1608475002145,"sparkVersion":"2.4.4","uid":"T5Transformer_1e0a16435680","paramMap":{},"defaultParamMap":{"task":"","lazyAnnotator":false,"maxOutputLength":200}}

我是 SparkNLP 的新手,所以我不确定这是一个实际问题还是我做错了什么。将不胜感激任何帮助。

解决方法

T5 的离线模型 - t5_base_en_2.7.1_2.4_1610133506835 - 在 SparkNLP 2.7.1 上训练,2.7.2 中有一个 breaking change

通过下载并重新保存新版本解决

# dev:
T5Transformer().pretrained("t5_small").save(...)

# prod:
T5Transformer.load(path)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?