微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

大型DataFrame的高效嵌入计算

如何解决大型DataFrame的高效嵌入计算

给出DataFrame

    id  articleno           target
0   1   [607303]            607295
1   1   [607295]            607303
2   2   [243404,617953]    590448
3   2   [590448,617953]    243404

对于每一行,通过在字典的列表中查找每个项目来计算平均文章嵌入:

embeddings = {"607303": np.array([0.19,0.25,0.45]),"607295": np.array([0.77,0.76,0.55]),"243404": np.array([0.35,0.44,0.32]),"617953": np.array([0.23,0.78,0.24]),"590448": np.array([0.67,0.12,0.10])}

因此,例如,为了澄清,对于第三行(索引2),243404617953文章嵌入分别为[0.35,0.32][0.23,0.24]。平均文章嵌入是将所有元素逐元素相加,然后除以文章数,即([0.35,0.32]+[0.23,0.24])/2=[0.29,0.61,0.28]

预期输出

    id  dim1     dim2     dim3      target
0   1   0.19     0.25     0.45      607295
1   1   0.77     0.76     0.55      607303
2   2   0.29     0.61     0.28      590448
3   2   0.45     0.45     0.17      243404

实际上,我的DataFrame具有数百万行,并且articleno中的列表可以包含更多项目。因此,在行上进行迭代可能太慢,因此可能需要更高效解决方案(也许是矢量化的)。

此外,尺寸(嵌入尺寸)的数量是已知的,但是只有几百个,因此列数也是如此。 dim1dim2dim3... dimN应该动态,具体取决于嵌入的尺寸({{1 }}。

解决方法

在上一个问题中,您付出了更多努力来分离articleno列表中的元素,然后从target列表中删除articleno。现在,如果您要访问articleno列表中的元素,则需要再加倍努力以分隔它们。

为了说明我的意思,这是一种从两个问题生成两个输出,同时添加最少的额外代码的方法:

# construct the embeddings dataframe:
embedding_df = pd.DataFrame(embeddings).T.add_prefix('dim')

# aggregation dictionary
agg_dict = {'countrycode':'first','articleno':list}

# taking mean over embedddings
for i in embedding_df.columns: agg_dict[i] = 'mean'

new_df = df.explode('articleno')

(new_df.join(new_df['articleno'].rename('target'))
    .query('articleno != target')
    .merge(embedding_df,left_on='articleno',right_index=True)  # this line is extra from the previous question
    .groupby(['id','target'],as_index=False)
    .agg(agg_dict)
)

输出:

   id  target countrycode         articleno  dim0  dim1  dim2
0   2  590448          US  [617953,617953]  0.23  0.78  0.24
1   2  617953          US  [590448,590448]  0.67  0.12  0.10

现在,如果您不在乎最终输出中的articleno列,您甚至可以在降低内存/运行时间的同时简化代码,如下所示:

total_embeddings = g[embedding_df.columns].sum()
article_counts = g['id'].transform('size')

new_df[embedding_df.columns] = (total_embeddings.sub(new_df[embedding_df.columns])
                                  .div(article_counts-1,axis=0)
                               )

您将获得相同的输出。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?