如何解决使用 AVX 提高浮点减法、除法、截断为 int32 的性能
尝试使用 AVX 来提高以下性能
__declspec(dllexport) void __cdecl calculate_quantized_vertical_values(long length,float min,float step,float* source,unsigned long* destination)
{
for (long i = 0; i < length; i++)
{
destination[i] = (source[i] - min) / step;
}
}
替换为
__declspec(dllexport) void __cdecl calculate_quantized_vertical_values_avx(long length,unsigned long* destination)
{
long multiple8end = ((long)(length / 8)) * 8;
__m256 min256 = _mm256_broadcast_ss((const float*)&min);
__m256 step256 = _mm256_broadcast_ss((const float*)&step);
for (long i = 0; i < multiple8end; i+=8)
{
__m256 value256 = _mm256_load_ps((const float*)(source + i));
__m256 offset256 = _mm256_sub_ps(value256,min256);
__m256 floatres256 = _mm256_div_ps(offset256,step256);
__m256i long256 = _mm256_cvttps_epi32(floatres256);
_mm256_store_si256((__m256i*)(destination + i),long256);
}
for (long i = multiple8end; i < length; i ++)
{
destination[i] = (source[i] - min) / step;
}
}
原始循环需要大约 330 毫秒,我的 55M 元素源数组和循环内容编译为
loc_180001050:
movss xmm0,dword ptr [r10+rcx-4]
subss xmm0,xmm3
divss xmm0,xmm2
cvttss2si rax,xmm0
mov [rcx-4],eax
movss xmm1,dword ptr [r10+rcx]
subss xmm1,xmm3
divss xmm1,xmm1
mov [rcx],eax
movss xmm0,dword ptr [r10+rcx+4]
subss xmm0,xmm0
mov [rcx+4],dword ptr [r10+rcx+8]
subss xmm1,xmm1
mov [rcx+8],eax
add rcx,10h
sub r8,1
jnz short loc_180001050
AVX 循环在相同的 55M 元素源数组上花费大约 170 毫秒,并且(主)循环的内容编译为:
loc_180001160:
vmovups ymm0,ymmword ptr [r8+rdx]
lea rdx,[rdx+20h]
vsubps ymm1,ymm0,ymm6
vdivps ymm2,ymm1,ymm7
vcvttps2dq ymm3,ymm2
vmovdqu ymmword ptr [rdx-20h],ymm3
sub rax,1
jnz short loc_180001160
所以 AVX 有性能改进,但我想知道是否有可能获得更显着的性能改进,或者这大约是此特定计算的限制
编辑:我还应该提到,如果有任何不同,我将从 .NET 应用程序调用这些 DLL 函数。
编辑:理想情况下,我希望 unsigned char
使用 destination
数组,但现在坚持使用 int32
,因为我还没有找到一种方法来执行 { {1}} -> float
与 AVX 的转换
如果能提高性能,那么乘以 unsigned char
而不是除以 1.f/step
对我来说应该没问题
解决方法
如果您按 1/step
进行缩放而不是除以 step
,您应该会明显更快,除非您受到内存吞吐量的限制。如果您计算出 min
的减法,您还可以使用 FMA 指令(如果可用):
void calculate_quantized_vertical_values_avx(size_t length,float min,float step,float* source,uint32_t* destination)
{
size_t multiple8end = ((length / 8)) * 8;
const float scale = 1.f/step;
const float offset = -min * scale;
const __m256 scale256 = _mm256_set1_ps(scale);
const __m256 offset256 = _mm256_set1_ps(offset);
for (size_t i = 0; i < multiple8end; i+=8)
{
__m256 value256 = _mm256_load_ps((const float*)(source + i));
#ifdef __FMA__
__m256 floatres256 = _mm256_fmadd_ps(value256,scale256,offset256);
#else
__m256 floatres256 = _mm256_add_ps(_mm256_mul_ps(value256,scale256),offset256);
#endif
__m256i long256 = _mm256_cvttps_epi32(floatres256);
_mm256_store_si256((__m256i*)(destination + i),long256);
}
for (size_t i = multiple8end; i < length; i ++)
{
destination[i] = (source[i] * scale) + offset;
}
}
如果您想将结果转换为 uint8
,请查看 _mm256_packus_epi32
和 _mm256_packus_epi16
(或者 _mm_packus_epi32
和 _mm_packus_epi16
,如果您不想有 AVX2)。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。