将大型 sparknlp 管道加载到 Apache Spark 批处理作业中花费的时间太长

如何解决将大型 sparknlp 管道加载到 Apache Spark 批处理作业中花费的时间太长

我使用 johnsNowlabs 的 SparkNLP 从我的文本数据中提取嵌入,下面是管道。保存到hdfs后模型大小为1.8g

embeddings = BertSentenceEmbeddings.pretrained("labse","xx") \
      .setInputCols("sentence") \
      .setoutputCol("sentence_embeddings")
nlp_pipeline = Pipeline(stages=[document_assembler,sentence_detector,embeddings])
pipeline_model = nlp_pipeline.fit(spark.createDataFrame([[""]]).toDF("text"))

我使用 pipeline_modelHDFS 保存到 pipeline_model.save("hdfs:///<path>")

上面只执行了一次

在另一个脚本中,我正在使用 HDFSpipeline_model = PretrainedPipeline.from_disk("hdfs:///<path>") 加载存储的管道。

上面的代码加载了模型,但是占用的太多了。我在 spark 本地模型(无集群)上对其进行了测试,但我拥有 94g RAM、32 核的高资源。

后来,我在yarn上部署了脚本,其中有 12 个 Executor,每个 Executor 有 3 个内核和 7g ram。我分配了 10g 的驱动程序内存。

脚本再次从 HDFS 加载保存的模型需要太多时间。

When the spark reaches at this point,it takes too much time

当火花到达这一点时(见上面的截图),需要太多时间

我想到了一个方法

预加载

我认为的方法是以某种方式将模型一次预加载到内存中,并且当脚本想要在数据帧上应用转换时,我可以以某种方式调用对预训练管道的引用并在旅途中使用它,而无需执行任何磁盘 I/O。我搜索过,但无处可寻。

请告诉我您对此解决方案的看法以及实现这一目标的最佳方法

YARN 资源

节点名称 计数 RAM(每个) 核心(每个)
主节点 1 38g 8
次节点 1 38 g 8
工作节点 4 24 g 4
总计 6 172g 32

谢谢

解决方法

正如评论中所讨论的,这是一个基于 PyTorch 的解决方案,而不是 SparkNLP。简化代码:

# labse_spark.py

LABSE_MODEL,LABSE_TOKENIZER = None


def transform(spark,df,input_col='text',output_col='output'):
    spark.sparkContext.addFile('hdfs:///path/to/labse_model')
    output_schema = T.StructType(df.schema.fields + [T.StructField(output_col,T.ArrayType(T.FloatType()))])

    rdd = df.rdd.mapPartitions(_map_partitions_func(input_col,output_col))
    res = spark.createDataFrame(data=rdd,schema=output_schema)
    return res


def _map_partitions_func(input_col,output_col):
    def executor_func(rows):
        # load everything to memory (partitions should be small,~1k rows per partition):
        pandas_df = pd.DataFrame([r.asDict() for r in rows])
        global LABSE_MODEL,LABSE_TOKENIZER
        if not (LABSE_TOKENIZER or LABSE_MODEL):  # should happen once per executor core
            LABSE_TOKENIZER = AutoTokenizer.from_pretrained(SparkFiles.get('labse_model'))
            LABSE_MODEL = AutoModel.from_pretrained(SparkFiles.get('labse_model'))
        
        # copied from HF model card:
        encoded_input = LABSE_TOKENIZER(
            pandas_df[input_col].tolist(),padding=True,truncation=True,max_length=64,return_tensors='pt')
        with torch.no_grad():
            model_output = LABSE_MODEL(**encoded_input)
        embeddings = model_output.pooler_output
        embeddings = torch.nn.functional.normalize(embeddings)

        pandas_df[output_col] = pd.Series(embeddings.tolist())
        return pandas_df.to_dict('records')

    return executor_func

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?