Skip to content
Snippets Groups Projects
Commit 815ae398 authored by Mattia Mancini's avatar Mattia Mancini
Browse files

Merge branch 'second-avx2-implementation' into 'main'

Add matrixMultiplyAVX2b

See merge request !4
parents a418699e e1a55db4
No related branches found
No related tags found
1 merge request!4Add matrixMultiplyAVX2b
Pipeline #81804 passed
......@@ -56,4 +56,12 @@ BENCHMARK_F(InitializeInput, MatrixMultiplicationAvx2)
matrixMultiplyAVX2(A, B, C);
}
}
// Using direct avx2b implementation
BENCHMARK_F(InitializeInput, MatrixMultiplicationAvx2b)
(benchmark::State& state) {
for (auto _ : state) {
matrixMultiplyAVX2b(A, B, C);
}
}
BENCHMARK_MAIN();
......@@ -110,3 +110,29 @@ void matrixMultiplyAVX2(const std::complex<float>* a,
__m256 c_m = _mm256_add_ps(c_p1, c_p2);
_mm256_store_ps(c_ptr, c_m);
}
void matrixMultiplyAVX2b(const std::complex<float>* a,
const std::complex<float>* b, std::complex<float>* c) {
float* a_ptr = reinterpret_cast<float*>(const_cast<std::complex<float>*>(a));
float* b_ptr = reinterpret_cast<float*>(const_cast<std::complex<float>*>(b));
float* c_ptr = reinterpret_cast<float*>(c);
__m256 a_m = _mm256_load_ps(a_ptr);
__m256 b_m = _mm256_load_ps(b_ptr);
__m256 a_1 = _mm256_permute_ps(a_m, _MM_SHUFFLE(0, 0, 0, 0));
__m256 a_2 = _mm256_permute_ps(a_m, _MM_SHUFFLE(2, 2, 2, 2));
__m256 a_3 = _mm256_permute_ps(a_m, _MM_SHUFFLE(1, 1, 1, 1));
__m256 a_4 = _mm256_permute_ps(a_m, _MM_SHUFFLE(3, 3, 3, 3));
__m256 b_1 = _mm256_permutevar8x32_ps(b_m, _mm256_set_epi32(3, 2, 1, 0, 3, 2, 1, 0));
__m256 b_2 = _mm256_permutevar8x32_ps(b_m, _mm256_set_epi32(7, 6, 5, 4, 7, 6, 5, 4));
__m256 b_3 = _mm256_permute_ps(b_1, _MM_SHUFFLE(2, 3, 0, 1));
__m256 b_4 = _mm256_permute_ps(b_2, _MM_SHUFFLE(2, 3, 0, 1));
__m256 c_m = _mm256_mul_ps(a_1, b_1);
c_m = _mm256_fmadd_ps(a_2, b_2, c_m);
c_m = _mm256_addsub_ps(c_m, _mm256_mul_ps(a_3, b_3));
c_m = _mm256_addsub_ps(c_m, _mm256_mul_ps(a_4, b_4));
_mm256_store_ps(c_ptr, c_m);
}
\ No newline at end of file
......@@ -34,4 +34,10 @@ TEST_CASE("test complex matrix multiplication", "[float]") {
COMPARE_ARRAYS(C_expected, C, 1.e-6);
}
SECTION("test correctness of avx2b implementation") {
matrixMultiplyAVX2b(A.data(), B.data(), C.data());
COMPARE_ARRAYS(C_expected, C, 1.e-6);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment