微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

AVX-512 浮点比较和屏蔽

如何解决AVX-512 浮点比较和屏蔽

我对 SIMD 不太熟悉,但我之前用 AVX 写了一些非常简单的东西。现在我也想用 AVX-512 实现一些旧的 AVX 代码

我打算做什么:

// SIZE,LOW_THRESHOLD,HIGH_THRESHOLD and array are defined
// the code works with float data

for ( int index = 0; index < SIZE; ++index )
{
    array[ index ] = LOW_THRESHOLD < array[ index ] && array[ index ] < HIGH_THRESHOLD ? 1.0f : 0.0f;
}

我用 AVX 做了什么:

const __m256 lowThreshold  = _mm256_set1_ps( LOW_THRESHOLD );
const __m256 highThreshold = _mm256_set1_ps( HIGH_THRESHOLD );
const __m256 trueValue     = _mm256_set1_ps( 1.0f );
const __m256 falseValue    = _mm256_set1_ps( 0.0f );

for ( int index = 0; index < SIZE; index += 8 )
{
    // aligned load
    const __m256 val = _mm256_load_ps( array + index );
    // compare
    const __m256 comp1 = _mm256_cmp_ps( lowThreshold,val,_CMP_LT_OQ );
    const __m256 comp2 = _mm256_cmp_ps( val,highThreshold,_CMP_LT_OQ );
    // AND
    const __m256 mask = _mm256_and_ps( comp1,comp2 );
    // blend
    const __m256 result = _mm256_blendv_ps( falseValue,trueValue,mask );
    // aligned store
    _mm256_store_ps( array + index,result );
}

现在我被困在 AVX-512。

const __m512 lowThreshold  = _mm512_set1_ps( LOW_THRESHOLD );
const __m512 highThreshold = _mm512_set1_ps( HIGH_THRESHOLD );
const __m512 trueValue     = _mm512_set1_ps( 1.0f );
const __m512 falseValue    = _mm512_set1_ps( 0.0f );

for ( int index = 0; index < SIZE; index += 16 )
{
    // aligned load
    const __m512 val = _mm512_load_ps( array + index );

    // the result of the comparison goes into a mask?
    const __mmask16 comp1 = _mm512_cmplt_ps_mask( lowThreshold,val );
    const __mmask16 comp2 = _mm512_cmplt_ps_mask( val,highThreshold );

    // how to use these masks?
}

使用__m512 _mm512_and_ps (__m512 a,__m512 b)会很好,但比较后只有__mask16变量,我没有找到任何_mm512函数,例如_mm256_cmp_ps。对于更有经验的 AVX 用户来说,这可能是一个简单的问题。 谢谢!

解决方法

如果您查看 __mmask16 的类型定义,您会看到:typedef unsigned short __mmask16;。因此,您可以将这种类型视为 uint16_t 并使用“&”。 然后您可以使用 __m512 _mm512_mask_blend_ps (__mmask16 k,__m512 a,__m512 b)

这对我来说非常有效:

#include <immintrin.h>
#include <x86intrin.h>
#include <stdio.h>

#define LOW_THRESHOLD 1
#define HIGH_THRESHOLD 3

int main() {
    __m512 lowThreshold  = _mm512_set1_ps( LOW_THRESHOLD );
    __m512 highThreshold = _mm512_set1_ps( HIGH_THRESHOLD );
    __m512 trueValue     = _mm512_set1_ps( 1.0f );
    __m512 falseValue    = _mm512_set1_ps( 0.0f );

    float array[16] = {-5.0f,6.0f,4.0f,1.5f,0.7f,1.0f,-5.0f,6.0f};

    for (int i = 0; i < 16; i += 1) {
        printf("%5.1f ",array[i]);
    }
    printf("\n");

    for ( int index = 0; index < 16; index += 16 )
    {
        __m512 val = _mm512_loadu_ps( array + index );
        __mmask16 comp1 = _mm512_cmplt_ps_mask( lowThreshold,val );
        __mmask16 comp2 = _mm512_cmplt_ps_mask( val,highThreshold );
        __mmask16 mask = comp1 & comp2;
        __m512 result = _mm512_mask_blend_ps(mask,falseValue,trueValue);
        _mm512_storeu_ps( array + index,result );
    }
    for (int i = 0; i < 16; i += 1) {
        printf("%5.1f ",array[i]);
    }
    printf("\n");
}

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。