微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

仅当元素非零时如何进行 AVX-512 整数递增

如何解决仅当元素非零时如何进行 AVX-512 整数递增

当且仅当元素的值不为零时,我必须向 AVX 寄存器的元素添加一个值。下面是我拥有的代码,但似乎我必须遇到很多额外的麻烦,并且应该有更好的方法来做到这一点。注释掉的循环是我想要做的纯 C++ 表达式。

#include <immintrin.h>
#include <array>
#include <vector>
#include <iostream>
#include <bitset>
#include <cmath>
#include <chrono>

uint32_t make_bit_mask(bool T0 = 0,bool T1 = 0,bool T2 = 0,bool T3 = 0,bool T4 = 0,bool T5 = 0,bool T6 = 0,bool T7 = 0,bool T8 = 0,bool T9 = 0,bool T10 = 0,bool T11 = 0,bool T12 = 0,bool T13 = 0,bool T14 = 0,bool T15 = 0,bool T16 = 0,bool T17 = 0,bool T18 = 0,bool T19 = 0,bool T20 = 0,bool T21 = 0,bool T22 = 0,bool T23 = 0,bool T24 = 0,bool T25 = 0,bool T26 = 0,bool T27 = 0,bool T28 = 0,bool T29 = 0,bool T30 = 0,bool T31 = 0)
{
    return  ((T0 << 0) | (T1 << 1) | (T2 << 2) | (T3 << 3) |
        (T4 << 4) | (T5 << 5) | (T6 << 6) | (T7 << 7) |
        (T8 << 8) | (T9 << 9) | (T10 << 10) | (T11 << 11) |
        (T12 << 12) | (T13 << 13) | (T14 << 14) | (T15 << 15) |
        (T16 << 16) | (T17 << 17) | (T18 << 18) | (T19 << 19) |
        (T20 << 20) | (T21 << 21) | (T22 << 22) | (T23 << 23) |
        (T24 << 24) | (T25 << 25) | (T26 << 26) | (T27 << 27) |
        (T28 << 28) | (T29 << 29) | (T30 << 30) | (T31 << 31));
}

int main()
{
    std::vector<uint16_t> testValues{0};
    testValues.resize(65'536);
    for (size_t i{ 0 }; i < testValues.size(); i += 4)
    {
        testValues[i] = static_cast<uint16_t>(i);
    }
    auto start{ std::chrono::high_resolution_clock::Now() };
    auto oneRegister{ _mm512_set1_epi16(1) };
    for (size_t i{ 0 }; i < testValues.size(); i += 32)
    {
        uint32_t loadMask{ make_bit_mask(testValues[i],testValues[i + 1],testValues[i + 2],testValues[i + 3],testValues[i + 4],testValues[i + 5],testValues[i + 6],testValues[i + 7],testValues[i + 8],testValues[i + 9],testValues[i + 10],testValues[i + 11],testValues[i + 12],testValues[i + 13],testValues[i + 14],testValues[i + 15],testValues[i + 16],testValues[i + 17],testValues[i + 18],testValues[i + 19],testValues[i + 20],testValues[i + 21],testValues[i + 22],testValues[i + 23],testValues[i + 24],testValues[i + 25],testValues[i + 26],testValues[i + 27],testValues[i + 28],testValues[i + 29],testValues[i + 30],testValues[i + 31])
        };
        _mm512_storeu_epi16(&testValues[i],_mm512_mask_add_epi16(_mm512_loadu_epi16(&testValues[i]),(__mmask32)loadMask,_mm512_loadu_epi16(&testValues[i]),oneRegister ));
    }
    /*
    for (auto& iter : testValues)
    {
        if (iter)
            iter += 1;
    }
    */
    auto end{ std::chrono::high_resolution_clock::Now() };
    std::cout << "Summation took: " << std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count() << std::endl;
    return 0;
}

解决方法

所以你想增加向量中的每个非零元素? (您实际上并不是在一个向量中对元素求和,只是进行垂直加法)。

听起来您真正的问题是根据非零元素将整数数组转换为掩码。 AVX-512 有这方面的说明,例如比较与 0,或更具体地说 vptestmw k1,zmm,zmm 以根据每个非零元素创建掩码。 (v & v 只是 v,因此您可以将相同的操作数传递两次以绕过 AND 运算)。

  __m512i v = _mm512_loadu_si512(&testValues[i]);
  __mmask32 nonzeros = _mm512_test_epi16_mask(v,v);
  v = _mm512_mask_sub_epi16(v,nonzeros,v,_mm512_set1_epi16(-1));  // set1(-1) is cheaper than 1

或者对于其他值,

 v = _mm512_mask_add_epi16(v,_mm512_set1_epi16( increment ));

在函数中,gcc -O2 -march=skylake-avx512 compiles it like this

foo(unsigned short const*):
        vmovdqu64       zmm0,ZMMWORD PTR [rdi]
        vpternlogd      zmm1,zmm1,0xFF       # set1(-1)
        vptestmw        k1,zmm0,zmm0
        vpsubw  zmm0{k1},zmm1
        ret

set1(-1) 将被编译器提升到循环之外。


有趣的事实:clang 会为您将 add(v,set1(1)) 转换为 sub(v,set1(-1)),但 GCC 错过了该优化。


如果没有 AVX-512,(只有 AVX2 或 SSE2),您可以创建带有比较的 0-1 向量。不幸的是,在 AVX-512 之前,我们只有 cmpeqcmpgt,没有 cmpne,所以我们需要反转 0 / -1

  __m256i v = _mm256_loadu_si256((const __m256i*)&testValues[i]);
  __m256i nonzeros = _mm256_cmpeq_epi16(v,_mm256_setzero_si256());
  nonzeros = _mm256_xor_si256(nonzeros,_mm256_set1_epi32(-1)); 
          // or add 1 to turn 0->1 and -1->0 to use add()
  v = _mm256_sub_epi16(v,nonzeros);

对于任意常量,您可以 _mm256_andnot_si256(nonzeros,_mm256_set1_epi16( increment )) 创建一个向量,该向量对于 v=0 元素为 0,对于非零元素为 increment

,

是的,有一个更简单的方法。简单地说,在 c++ 17 或更高版本中有二进制文字。如果你可以使用它?

尝试这样的事情:

int main()
{
    uint32_t value = 0b1010'0100'0001'1100'0100'0011'1010'0010;
    //do whatever with value...
}

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。