微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

使用 Pregel 构建树层次结构

如何解决使用 Pregel 构建树层次结构

我正在尝试使用图形框架和 pregel API 构建一棵树。

我有以下联系:

  src  |  dst  
   0   |   1
   0   |   7
   1   |   2
   1   |   5
   1   |   8
   2   |   4
   7   |   4
   7   |   5
   8   |   5

并且需要以下输出作为数据帧:

sequence
[0,1,2,4]
[0,8,5]
[0,7,5]

实际代码如下:

#DATA CREATION
raw_data = [
  ("0","1"),("0","7"),("1","2"),"5"),"8"),("2","4"),("7",("8","5")]

schema = ["src","dst"]
data = spark.createDataFrame(data=raw_data,schema = schema)

from graphframes import GraphFrame

vertices=(data.select("src").union(data.select("dst")).distinct().withColumnRenamed('src','id'))
vertices=vertices.union(spark.createDataFrame(["10"],"string").toDF("id"))
edges=data
graph = GraphFrame(vertices,edges)

import pyspark.sql.functions as F

indegrees=graph.indegrees
outdegrees=graph.outdegrees

init_vertices=(vertices
  .join(outdegrees,on="id",how="left")
  .join(indegrees,how="left")
.withColumn("nodeType",F.when(F.col("inDegree").isNull(),"root").otherwise(F.when(F.col("outDegree").isNull(),"leaf").otherwise("child"))))

gx = GraphFrame(init_vertices,edges)
# pregel API
vertColSchema = T.ArrayType( T.ArrayType( T.StringType(),True),True)

def sendMsgToDst(src,dst,dst_id):
  if src:
    src_tuple = [tuple(lst) for lst in src]
    dst_tuple = [tuple(lst) for lst in dst]
    if not set(src_tuple).issubset(set(dst_tuple)):
      return [i + [dst_id] for i in src] 
  return None

def vertexProgram(vd,msg):
  if msg:
    return msg
  return vd

sendMsgToDstUdf = F.udf(
    sendMsgToDst,vertColSchema
)

vertexProgramUdf = F.udf(
    vertexProgram,vertColSchema
)

start = ["0"]
tree=(gx.pregel
    .withVertexColumn("sequence",F.when(F.col("id").isin(start),F.array(F.array(F.col("id")))).otherwise(F.lit(F.array())),updateAfteraggMsgsExpr=vertexProgramUdf(
        F.col("sequence"),pregel.msg()
      )
    )
    .sendMsgToDst(
      sendMsgToDstUdf(
        pregel.src("sequence"),pregel.dst("sequence"),pregel.dst("id")
      )
    )
    .aggMsgs(F.collect_list(pregel.msg()))
    .setMaxIter(10)
    .setCheckpointInterval(2)
    .run()
)
# RESULT
df = cycles.withColumn("sequence",F.explode("sequence"))
result = df.filter(F.col("nodeType")=="leaf").select("sequence")

sequence
["[[0,1],5]"]
["[[0,7],5]"]
["[[[0,8],2],4]"]
["[[0,4]"]

我不知道为什么,但是 dst_id 附加没有按预期工作。我尝试了许多不同的方法,但在处理列表列表时仍然出现相同的错误

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。