Skip to content
Snippets Groups Projects
Commit 8f4c79f2 authored by Mattia Mancini's avatar Mattia Mancini
Browse files

Add testing of AVX2 and implementation

parent 025ad86c
No related branches found
No related tags found
1 merge request!3Add test for matrix
Pipeline #81724 passed
...@@ -48,4 +48,12 @@ BENCHMARK_F(InitializeInput, MatrixMultiplicationAOAvx) ...@@ -48,4 +48,12 @@ BENCHMARK_F(InitializeInput, MatrixMultiplicationAOAvx)
matrixMultiplyAoCommon(A, B, C); matrixMultiplyAoCommon(A, B, C);
} }
} }
// Using direct avx2 implementation
BENCHMARK_F(InitializeInput, MatrixMultiplicationAvx2)
(benchmark::State& state) {
for (auto _ : state) {
matrixMultiplyAVX2(A, B, C);
}
}
BENCHMARK_MAIN(); BENCHMARK_MAIN();
#include <aocommon/avx/MatrixComplexFloat2x2.h> #include <aocommon/avx/MatrixComplexFloat2x2.h>
#include <complex> #include <complex>
#include <iomanip>
#include <iostream>
#include <new> // For std::align_val_t #include <new> // For std::align_val_t
#include <random> #include <random>
...@@ -69,3 +71,45 @@ void matrixMultiplyAoCommon(const std::complex<float>* a, ...@@ -69,3 +71,45 @@ void matrixMultiplyAoCommon(const std::complex<float>* a,
c[2] = C[2]; c[2] = C[2];
c[3] = C[3]; c[3] = C[3];
} }
void matrixMultiplyAVX2(const std::complex<float>* a,
const std::complex<float>* b, std::complex<float>* c) {
float* a_ptr = reinterpret_cast<float*>(const_cast<std::complex<float>*>(a));
float* b_ptr = reinterpret_cast<float*>(const_cast<std::complex<float>*>(b));
float* c_ptr = reinterpret_cast<float*>(c);
__m256 a_m = _mm256_load_ps(a_ptr);
__m256 b_m = _mm256_load_ps(b_ptr);
__m256i a_1_ind = _mm256_set_epi32(4, 4, 4, 4, 0, 0, 0, 0);
__m256i b_1_ind = _mm256_set_epi32(3, 2, 1, 0, 3, 2, 1, 0);
__m256i a_2_ind = _mm256_set_epi32(6, 6, 6, 6, 2, 2, 2, 2);
__m256i b_2_ind = _mm256_set_epi32(7, 6, 5, 4, 7, 6, 5, 4);
__m256i a_3_ind = _mm256_set_epi32(5, 5, 5, 5, 1, 1, 1, 1);
__m256i b_3_ind = _mm256_set_epi32(2, 3, 0, 1, 2, 3, 0, 1);
__m256i a_4_ind = _mm256_set_epi32(7, 7, 7, 7, 3, 3, 3, 3);
__m256i b_4_ind = _mm256_set_epi32(6, 7, 4, 5, 6, 7, 4, 5);
__m256 inv = _mm256_set_ps(1., -1., 1., -1., 1., -1., 1., -1.);
__m256 a_1 = _mm256_permutevar8x32_ps(a_m, a_1_ind);
__m256 b_1 = _mm256_permutevar8x32_ps(b_m, b_1_ind);
__m256 a_2 = _mm256_permutevar8x32_ps(a_m, a_2_ind);
__m256 b_2 = _mm256_permutevar8x32_ps(b_m, b_2_ind);
__m256 a_3 = _mm256_permutevar8x32_ps(a_m, a_3_ind);
__m256 a_3_inv = _mm256_mul_ps(a_3, inv);
__m256 b_3 = _mm256_permutevar8x32_ps(b_m, b_3_ind);
__m256 a_4 = _mm256_permutevar8x32_ps(a_m, a_4_ind);
__m256 a_4_inv = _mm256_mul_ps(a_4, inv);
__m256 b_4 = _mm256_permutevar8x32_ps(b_m, b_4_ind);
__m256 c_m = _mm256_mul_ps(a_1, b_1);
c_m = _mm256_fmadd_ps(a_2, b_2, c_m);
c_m = _mm256_fmadd_ps(a_3_inv, b_3, c_m);
c_m = _mm256_fmadd_ps(a_4_inv, b_4, c_m);
_mm256_store_ps(c_ptr, c_m);
}
\ No newline at end of file
#ifndef HELPERS
#define HELPERS
#include <array> #include <array>
#include <catch2/matchers/catch_matchers_floating_point.hpp> #include <catch2/matchers/catch_matchers_floating_point.hpp>
...@@ -27,6 +30,16 @@ void compareSingle(const std::vector<std::complex<float>>& lv, ...@@ -27,6 +30,16 @@ void compareSingle(const std::vector<std::complex<float>>& lv,
} }
} }
template <typename T>
std::string valueToString(const T& value) {
return std::to_string(value);
}
std::string valueToString(const std::complex<float>& value) {
return std::to_string(value.real()) + ", " + std::to_string(value.imag()) +
"j";
}
template <typename T, size_t N> template <typename T, size_t N>
void compareArrays(const std::string& test, unsigned line, std::array<T, N> lhs, void compareArrays(const std::string& test, unsigned line, std::array<T, N> lhs,
std::array<T, N> rhs, float precision) { std::array<T, N> rhs, float precision) {
...@@ -34,5 +47,20 @@ void compareArrays(const std::string& test, unsigned line, std::array<T, N> lhs, ...@@ -34,5 +47,20 @@ void compareArrays(const std::string& test, unsigned line, std::array<T, N> lhs,
std::vector<T> rv(rhs.begin(), rhs.end()); std::vector<T> rv(rhs.begin(), rhs.end());
INFO("Test case [" << test << "] failed at line " INFO("Test case [" << test << "] failed at line "
<< line); // Reported only if REQUIRE fails << line); // Reported only if REQUIRE fails
std::stringstream ss;
ss << "Expected : \n";
for (size_t idx = 0; idx < N; idx++) {
ss << valueToString(lhs[idx]) << "\t";
}
ss << "\nObtained : \n";
for (size_t idx = 0; idx < N; idx++) {
ss << valueToString(rhs[idx]) << "\t";
}
ss << "\n";
INFO("Reason: \n" << ss.str());
compareSingle(lv, rv, precision); compareSingle(lv, rv, precision);
} }
#endif
\ No newline at end of file
...@@ -28,4 +28,10 @@ TEST_CASE("test complex matrix multiplication", "[float]") { ...@@ -28,4 +28,10 @@ TEST_CASE("test complex matrix multiplication", "[float]") {
COMPARE_ARRAYS(C_expected, C, 1.e-6); COMPARE_ARRAYS(C_expected, C, 1.e-6);
} }
SECTION("test correctness of avx2 implementation") {
matrixMultiplyAVX2(A.data(), B.data(), C.data());
COMPARE_ARRAYS(C_expected, C, 1.e-6);
}
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment