微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

AVX512向量乘法速度

如何解决AVX512向量乘法速度

我具有这样的功能

#define SPLIT(zmm,ymmA,ymmB) \
ymmA = _mm512_castsi512_si256(zmm); \
ymmB = _mm512_extracti32x8_epi32(zmm,1)

#define PAIR_AND_BLEND(src1,src2,dst1,dst2) \
dst1 = _mm256_blend_epi32(src1,0b11110000); \
dst2 = _mm256_permute2x128_si256(src1,0b00100001);

#define OPERATE_ROW_2(i,ymmB)        \
zmm##i = _mm512_maddubs_epi16(zmm30,zmm##i);     \
zmm##i = _mm512_madd_epi16(zmm##i,zmm31);        \
SPLIT(zmm##i,ymmB);

/*
 * multiply query to each code in codes.
 * @param n: number of codes
 * @param query: 64 x uint8_t array data
 * @param codes: 64 x n x uint8_t array data
 * @param output: n x int32_t array data,to store output data.
 */
void avx_IP_distance_64_2(size_t n,const uint8_t *query,const uint8_t *codes,int32_t *output){
    __m512i zmm0,zmm1,zmm2,zmm3,zmm4,zmm5,zmm6,zmm7,zmm8,zmm9,zmm10,zmm11,zmm12,zmm13,zmm14,zmm15,zmm16,zmm17,zmm18,zmm19,zmm20,zmm21,zmm22,zmm23,zmm24,zmm25,zmm26,zmm27,zmm28,zmm29,zmm30,zmm31;

    __m256i ymm0,ymm1,ymm2,ymm3,ymm4,ymm5,ymm6,ymm7,ymm8,ymm9,ymm10,ymm11,ymm12,ymm13,ymm14,ymm15;

    zmm30 = _mm512_loadu_si512(query);
    zmm31 = _mm512_set1_epi16(1);

    int k_8 = n / 8;
    int left = n % 8;
    for (int i = 0; i < k_8; ++i){
        zmm0 = _mm512_loadu_si512(codes);
        zmm1 = _mm512_loadu_si512(codes + 64 * 1);
        zmm2 = _mm512_loadu_si512(codes + 64 * 2);
        zmm3 = _mm512_loadu_si512(codes + 64 * 3);
        zmm4 = _mm512_loadu_si512(codes + 64 * 4);
        zmm5 = _mm512_loadu_si512(codes + 64 * 5);
        zmm6 = _mm512_loadu_si512(codes + 64 * 6);
        zmm7 = _mm512_loadu_si512(codes + 64 * 7);

        OPERATE_ROW_2(0,ymm0,ymm1);
        OPERATE_ROW_2(1,ymm3);
        OPERATE_ROW_2(2,ymm5);
        OPERATE_ROW_2(3,ymm7);
        OPERATE_ROW_2(4,ymm9);
        OPERATE_ROW_2(5,ymm11);
        OPERATE_ROW_2(6,ymm13);
        OPERATE_ROW_2(7,ymm15);

        ymm0 = _mm256_add_epi32(ymm0,ymm1);
        ymm2 = _mm256_add_epi32(ymm2,ymm3);
        ymm4 = _mm256_add_epi32(ymm4,ymm5);
        ymm6 = _mm256_add_epi32(ymm6,ymm7);
        ymm8 = _mm256_add_epi32(ymm8,ymm9);
        ymm10 = _mm256_add_epi32(ymm10,ymm11);
        ymm12 = _mm256_add_epi32(ymm12,ymm13);
        ymm14 = _mm256_add_epi32(ymm14,ymm15);

        PAIR_AND_BLEND(ymm0,ymm9);
        PAIR_AND_BLEND(ymm2,ymm11);
        PAIR_AND_BLEND(ymm4,ymm13);
        PAIR_AND_BLEND(ymm6,ymm15);

        ymm1 = _mm256_add_epi32(ymm1,ymm9);
        ymm3 = _mm256_add_epi32(ymm3,ymm11);
        ymm5 = _mm256_add_epi32(ymm5,ymm13);
        ymm7 = _mm256_add_epi32(ymm7,ymm15);

        ymm1 = _mm256_hadd_epi32(ymm1,ymm3);
        ymm5 = _mm256_hadd_epi32(ymm5,ymm7);

        ymm1 = _mm256_hadd_epi32(ymm1,ymm5);
        _mm256_storeu_si256((__m256i *)(output),ymm1);

        codes += 8 * 64;
        output += 8;
    }

    for (int i = 0; i < left; ++i){
        OPERATE_ROW_1(0);
    }
}


#define LOOP 10

int main(){
    int d = 64; 
    int q = 1;
    int n = 100000;

    std::mt19937 rng;
    std::uniform_real_distribution<> distrib;

    uint8_t *codes = new uint8_t[d * n]; 
    uint8_t *query = new uint8_t[d * q]; 

    int32_t *output = new int32_t[n];

    for (int i = 0; i < n; ++i){
        for (int j = 0; j < d; ++j){
            // codes[d*i+j] = j;
            codes[d * i + j] = int(distrib(rng)) * 127;
        }   
    }   

    for (int i = 0; i < q; ++i){
        for (int j = 0; j < d; ++j){
            // query[d*i+j] = j;
            query[d * i + j] = int(distrib(rng)) * 127 - 64; 
        }   
    }

    Timer timer;
    timer.start();
    for (int i = 0; i < LOOP; ++i){
        avx_IP_distance_64_2(n,query,codes,output);
    }
    timer.end("Second type");
    return 0;
}

当n = 10k时,持续时间为:0.143917 ms

当n = 100k时,持续时间为:3.2002毫秒

当N小于10k时,时间消耗基本上呈线性增加

我怀疑这是一个缓存问题,但我不确定。

我想知道为什么时间消耗不随n线性增加

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。