微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Python:优化函数以找到给定候选项集的大小为k的频繁项集 算法 0

如何解决Python:优化函数以找到给定候选项集的大小为k的频繁项集 算法 0

我编写了一个函数来查找给定候选项集的大小为 k 的项集的频率。数据集包含超过 16000 笔交易。有人可以帮助我优化此功能,因为当前表单在 minSupport=1 时执行大约需要 45 分钟。

样本数据集

Dataset

解决方法

算法 0(参见下面的其他算法)

使用 Numba 实现了算法的提升。 Numba 是一个 JIT 编译器,可将 Python 代码转换为高度优化的 C++ 代码,然后编译为机器代码。对于许多算法,Numba 实现了 50-200 倍的速度提升。

要使用 numba,您必须通过 pip install numba 安装它,注意 Numba 仅支持 Python

为了满足 Numba 编译要求,我稍微重写了您的代码,我的代码在行为上应该与您的相同,请做一些测试。

我的 numba 优化代码应该可以给你很好的加速!

我也创建了一些人工的简短示例输入数据,以进行测试。

Try it online!

import numba,numpy as np,pandas as pd

@numba.njit(cache = True)
def selectLkNm(dataSet,Ck,minSupport):
    dict_data = {}
    transactions = dataSet.shape[0]
    for items in Ck:
        count = 0
        while count < transactions:
            if items not in dict_data:
                dict_data[items] = 0
            for item in items:
                for e in dataSet[count,:]:
                    if item == e:
                        break
                else:
                    break
            else:
                dict_data[items] += 1
            count += 1
    Lk = {}
    for k,v in dict_data.items():
        if v >= minSupport:
            Lk[k] = v
    return Lk
    
def selectLk(dataSet,minSupport):
    tCk = numba.typed.List()
    for e in Ck:
        tCk.append(e)
    return selectLkNm(dataSet.values,tCk,minSupport)

dataset = pd.DataFrame([[100,160,100,160],[170,180,190,200],[100,200]])
C1 = set()
C1.add((100,160))
C1.add((170,180))
C1.add((190,200))
Lk = selectLk(dataset,C1,2)
print(Lk)

输出:

{(100,160): 2,(190,200): 2}

算法 1(请参阅下面的其他算法)

我通过对您的数据进行排序来改进算法 0(上面),如果您的 Ck 中有很多值或者 Ck 中的每个元组都很长,它会提供很好的加速。

Try it online!

import numba,minSupport):
    assert dataSet.ndim == 2
    dataSet2 = np.empty_like(dataSet)
    for i in range(dataSet.shape[0]):
        dataSet2[i] = np.sort(dataSet[i])
    dataSet = dataSet2
    dict_data = {}
    transactions = dataSet.shape[0]
    for items in Ck:
        count = 0
        while count < transactions:
            if items not in dict_data:
                dict_data[items] = 0
            for item in items:
                ix = np.searchsorted(dataSet[count,:],item)
                if not (ix < dataSet.shape[1] and dataSet[count,ix] == item):
                    break
            else:
                dict_data[items] += 1
            count += 1
    Lk = {}
    for k,200): 2}

算法 2(请参阅下面的其他算法)

如果您不被允许使用 Numba,那么我建议您对算法进行下一步改进。我对您的数据集进行了预排序,以便不是在 O(N) 时间内而是在 O(Log(N)) 时间内搜索每个项目,这要快得多。

我在你的代码中看到你使用了pandas数据框,这意味着你已经安装了pandas,如果你安装了pandas那么你肯定有Numpy,所以我决定使用它。如果您要处理 Pandas 数据框,就不能没有 Numpy。

Try it online!

import numpy as np,pandas as pd,collections

def selectLk(dataSet,minSupport):
    dataSet = np.sort(dataSet.values,axis = 1)
    dict_data = collections.defaultdict(int)
    transactions = dataSet.shape[0]
    for items in Ck:
        count = 0
        while count < transactions:
            for item in items:
                ix = np.searchsorted(dataSet[count,ix] == item):
                    break
            else:
                dict_data[items] += 1
            count += 1
    Lk = {k : v for k,v in dict_data.items() if v >= minSupport}
    return Lk
    
dataset = pd.DataFrame([[100,200): 2}

算法 3

我只是有一个想法,算法 2 的排序部分可能不是瓶颈,可能事务 while 循环可能是瓶颈。

所以为了改善情况,我决定实现并使用更快的算法和 2D searchsorted 版本(没有内置的 2D 版本,所以它必须单独实现),它没有任何长的纯 python 循环,大部分时间都花在了 Numpy 函数上。

请试试这个 Algo 3 是否会更快,如果排序不是瓶颈而是内部 while 循环,它应该会更快。

Try it online!

import numpy as np,minSupport):
    def searchsorted2d(a,bs):
        s = np.r_[0,(np.maximum(a.max(1) - a.min(1) + 1,bs.ravel().max(0)) + 1).cumsum()[:-1]]
        a_scaled = (a + s[:,None]).ravel()
        def sub(b):
            b_scaled = b + s
            return np.searchsorted(a_scaled,b_scaled) - np.arange(len(s)) * a.shape[1]
        return sub

    assert dataSet.values.ndim == 2,dataSet.values.ndim
    dataSet = np.sort(dataSet.values,axis = 1)
    dict_data = collections.defaultdict(int)
    transactions = dataSet.shape[0]
    Ck = np.array(list(Ck))
    assert Ck.ndim == 2,Ck.ndim
    ss = searchsorted2d(dataSet,Ck)
    for items in Ck:
        cnts = np.zeros((dataSet.shape[0],),dtype = np.int64)
        for item in items:
            bs = item.repeat(dataSet.shape[0])
            ixs = np.minimum(ss(bs),dataSet.shape[1] - 1)
            cnts[...] += (dataSet[(np.arange(dataSet.shape[0]),ixs)] == bs).astype(np.uint8)
        dict_data[tuple(items)] += int((cnts == len(items)).sum())
    return {k : v for k,v in dict_data.items() if v >= minSupport}
    
dataset = pd.DataFrame([[100,200): 2}
,

我已经更改了您代码的执行顺序。但是,由于我无法访问您的实际输入数据,因此很难检查优化后的代码是否产生了预期的输出以及您获得了多少速度。

算法 0

import pandas as pd
import numpy as np
from collections import defaultdict

def selectLk(dataSet,minSupport):
    dict_data = defaultdict(int)
    for _,row in dataSet.iterrows():
        for items in Ck:
            dict_data[items] += all(item in row.values for item in items)
    Lk = { k : v for k,v in dict_data.items() if v > minSupport}
    return Lk

if __name__ == '__main__':
    data = list(range(0,1000,10))
    df_data = {}
    for i in range(26):
        sample = np.random.choice(data,size=16000,replace=True)
        df_data[f"d{i}"] = sample
    dataset = pd.DataFrame(df_data)
    C1 = set()
    C1.add((100,160))
    C1.add((170,180))
    C1.add((190,200))
    Lk1 = selectLk(dataset,1)
    dataset = pd.DataFrame([[100,200]])
    Lk2 = selectLk(dataset,1)
    print(Lk1)
    print(Lk2)

算法 1

算法 1 使用 numpy.equal.outer,它为 Ck 元组中的任何匹配元素创建一个布尔掩码。然后,应用 .all() 操作。

def selectLk(dataSet,minSupport):
    dict_data = defaultdict(int)
    dataSet_np = dataSet.to_numpy(copy=False)
    for items in Ck:
        dict_data[items] = dataSet[np.equal.outer(dataSet_np,items).any(axis=1).all(axis=1)].shape[0]
    Lk = { k : v for k,v in dict_data.items() if v > minSupport}
    return Lk

结果:

{(190,200): 811,(170,180): 797,(100,160): 798}
{(190,200): 2,160): 2}

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?