如何解决AVX512向量乘法速度
我具有这样的功能:
#define SPLIT(zmm,ymmA,ymmB) \
ymmA = _mm512_castsi512_si256(zmm); \
ymmB = _mm512_extracti32x8_epi32(zmm,1)
#define PAIR_AND_BLEND(src1,src2,dst1,dst2) \
dst1 = _mm256_blend_epi32(src1,0b11110000); \
dst2 = _mm256_permute2x128_si256(src1,0b00100001);
#define OPERATE_ROW_2(i,ymmB) \
zmm##i = _mm512_maddubs_epi16(zmm30,zmm##i); \
zmm##i = _mm512_madd_epi16(zmm##i,zmm31); \
SPLIT(zmm##i,ymmB);
/*
* multiply query to each code in codes.
* @param n: number of codes
* @param query: 64 x uint8_t array data
* @param codes: 64 x n x uint8_t array data
* @param output: n x int32_t array data,to store output data.
*/
void avx_IP_distance_64_2(size_t n,const uint8_t *query,const uint8_t *codes,int32_t *output){
__m512i zmm0,zmm1,zmm2,zmm3,zmm4,zmm5,zmm6,zmm7,zmm8,zmm9,zmm10,zmm11,zmm12,zmm13,zmm14,zmm15,zmm16,zmm17,zmm18,zmm19,zmm20,zmm21,zmm22,zmm23,zmm24,zmm25,zmm26,zmm27,zmm28,zmm29,zmm30,zmm31;
__m256i ymm0,ymm1,ymm2,ymm3,ymm4,ymm5,ymm6,ymm7,ymm8,ymm9,ymm10,ymm11,ymm12,ymm13,ymm14,ymm15;
zmm30 = _mm512_loadu_si512(query);
zmm31 = _mm512_set1_epi16(1);
int k_8 = n / 8;
int left = n % 8;
for (int i = 0; i < k_8; ++i){
zmm0 = _mm512_loadu_si512(codes);
zmm1 = _mm512_loadu_si512(codes + 64 * 1);
zmm2 = _mm512_loadu_si512(codes + 64 * 2);
zmm3 = _mm512_loadu_si512(codes + 64 * 3);
zmm4 = _mm512_loadu_si512(codes + 64 * 4);
zmm5 = _mm512_loadu_si512(codes + 64 * 5);
zmm6 = _mm512_loadu_si512(codes + 64 * 6);
zmm7 = _mm512_loadu_si512(codes + 64 * 7);
OPERATE_ROW_2(0,ymm0,ymm1);
OPERATE_ROW_2(1,ymm3);
OPERATE_ROW_2(2,ymm5);
OPERATE_ROW_2(3,ymm7);
OPERATE_ROW_2(4,ymm9);
OPERATE_ROW_2(5,ymm11);
OPERATE_ROW_2(6,ymm13);
OPERATE_ROW_2(7,ymm15);
ymm0 = _mm256_add_epi32(ymm0,ymm1);
ymm2 = _mm256_add_epi32(ymm2,ymm3);
ymm4 = _mm256_add_epi32(ymm4,ymm5);
ymm6 = _mm256_add_epi32(ymm6,ymm7);
ymm8 = _mm256_add_epi32(ymm8,ymm9);
ymm10 = _mm256_add_epi32(ymm10,ymm11);
ymm12 = _mm256_add_epi32(ymm12,ymm13);
ymm14 = _mm256_add_epi32(ymm14,ymm15);
PAIR_AND_BLEND(ymm0,ymm9);
PAIR_AND_BLEND(ymm2,ymm11);
PAIR_AND_BLEND(ymm4,ymm13);
PAIR_AND_BLEND(ymm6,ymm15);
ymm1 = _mm256_add_epi32(ymm1,ymm9);
ymm3 = _mm256_add_epi32(ymm3,ymm11);
ymm5 = _mm256_add_epi32(ymm5,ymm13);
ymm7 = _mm256_add_epi32(ymm7,ymm15);
ymm1 = _mm256_hadd_epi32(ymm1,ymm3);
ymm5 = _mm256_hadd_epi32(ymm5,ymm7);
ymm1 = _mm256_hadd_epi32(ymm1,ymm5);
_mm256_storeu_si256((__m256i *)(output),ymm1);
codes += 8 * 64;
output += 8;
}
for (int i = 0; i < left; ++i){
OPERATE_ROW_1(0);
}
}
#define LOOP 10
int main(){
int d = 64;
int q = 1;
int n = 100000;
std::mt19937 rng;
std::uniform_real_distribution<> distrib;
uint8_t *codes = new uint8_t[d * n];
uint8_t *query = new uint8_t[d * q];
int32_t *output = new int32_t[n];
for (int i = 0; i < n; ++i){
for (int j = 0; j < d; ++j){
// codes[d*i+j] = j;
codes[d * i + j] = int(distrib(rng)) * 127;
}
}
for (int i = 0; i < q; ++i){
for (int j = 0; j < d; ++j){
// query[d*i+j] = j;
query[d * i + j] = int(distrib(rng)) * 127 - 64;
}
}
Timer timer;
timer.start();
for (int i = 0; i < LOOP; ++i){
avx_IP_distance_64_2(n,query,codes,output);
}
timer.end("Second type");
return 0;
}
当n = 10k时,持续时间为:0.143917 ms
当n = 100k时,持续时间为:3.2002毫秒
当N小于10k时,时间消耗基本上呈线性增加。
我怀疑这是一个缓存问题,但我不确定。
我想知道为什么时间消耗不随n线性增加?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。