diff --git a/benchmarks/matrix_multiplication.cpp b/benchmarks/matrix_multiplication.cpp index 3da7a6ea074db9eafe34bd65d40d57bfc8f5e976..393b0b389e678386eeeb074f360884e4a88a083e 100644 --- a/benchmarks/matrix_multiplication.cpp +++ b/benchmarks/matrix_multiplication.cpp @@ -56,4 +56,12 @@ BENCHMARK_F(InitializeInput, MatrixMultiplicationAvx2) matrixMultiplyAVX2(A, B, C); } } + +// Using direct avx2b implementation +BENCHMARK_F(InitializeInput, MatrixMultiplicationAvx2b) +(benchmark::State& state) { + for (auto _ : state) { + matrixMultiplyAVX2b(A, B, C); + } +} BENCHMARK_MAIN(); diff --git a/code/matrix_multiplication.h b/code/matrix_multiplication.h index acca62471c9c1608af3e7613fca5ed2db6236175..1fa41dc72d5eb7b24bc4e0fdb28fb43e40963d93 100644 --- a/code/matrix_multiplication.h +++ b/code/matrix_multiplication.h @@ -109,4 +109,30 @@ void matrixMultiplyAVX2(const std::complex<float>* a, __m256 c_p2 = _mm256_fmaddsub_ps(a_2, b_2, _mm256_mul_ps(a_4, b_4)); __m256 c_m = _mm256_add_ps(c_p1, c_p2); _mm256_store_ps(c_ptr, c_m); +} + + +void matrixMultiplyAVX2b(const std::complex<float>* a, + const std::complex<float>* b, std::complex<float>* c) { + float* a_ptr = reinterpret_cast<float*>(const_cast<std::complex<float>*>(a)); + float* b_ptr = reinterpret_cast<float*>(const_cast<std::complex<float>*>(b)); + float* c_ptr = reinterpret_cast<float*>(c); + __m256 a_m = _mm256_load_ps(a_ptr); + __m256 b_m = _mm256_load_ps(b_ptr); + + __m256 a_1 = _mm256_permute_ps(a_m, _MM_SHUFFLE(0, 0, 0, 0)); + __m256 a_2 = _mm256_permute_ps(a_m, _MM_SHUFFLE(2, 2, 2, 2)); + __m256 a_3 = _mm256_permute_ps(a_m, _MM_SHUFFLE(1, 1, 1, 1)); + __m256 a_4 = _mm256_permute_ps(a_m, _MM_SHUFFLE(3, 3, 3, 3)); + + __m256 b_1 = _mm256_permutevar8x32_ps(b_m, _mm256_set_epi32(3, 2, 1, 0, 3, 2, 1, 0)); + __m256 b_2 = _mm256_permutevar8x32_ps(b_m, _mm256_set_epi32(7, 6, 5, 4, 7, 6, 5, 4)); + __m256 b_3 = _mm256_permute_ps(b_1, _MM_SHUFFLE(2, 3, 0, 1)); + __m256 b_4 = _mm256_permute_ps(b_2, _MM_SHUFFLE(2, 3, 0, 1)); + + __m256 c_m = _mm256_mul_ps(a_1, b_1); + c_m = _mm256_fmadd_ps(a_2, b_2, c_m); + c_m = _mm256_addsub_ps(c_m, _mm256_mul_ps(a_3, b_3)); + c_m = _mm256_addsub_ps(c_m, _mm256_mul_ps(a_4, b_4)); + _mm256_store_ps(c_ptr, c_m); } \ No newline at end of file diff --git a/test/test_matrix_multiplication.cpp b/test/test_matrix_multiplication.cpp index 44da47e53fc1cd9caf6895bd06dc11b8b38f2010..f2dd60e6f87ccf87abc4f17c3e3f7c8609256f33 100644 --- a/test/test_matrix_multiplication.cpp +++ b/test/test_matrix_multiplication.cpp @@ -34,4 +34,10 @@ TEST_CASE("test complex matrix multiplication", "[float]") { COMPARE_ARRAYS(C_expected, C, 1.e-6); } + + SECTION("test correctness of avx2b implementation") { + matrixMultiplyAVX2b(A.data(), B.data(), C.data()); + + COMPARE_ARRAYS(C_expected, C, 1.e-6); + } } \ No newline at end of file