Skip to content
Snippets Groups Projects
Commit 79c7e3f3 authored by Mattia Mancini's avatar Mattia Mancini
Browse files

Merge branch 'add_sse_implementation' into 'main'

Add SSE implementation

See merge request !12
parents fd4ee762 ba41dd2f
No related branches found
No related tags found
1 merge request!12Add SSE implementation
Pipeline #97912 passed
...@@ -100,4 +100,14 @@ BENCHMARK_F(InitializeInput, MatrixMultiplicationNEONb) ...@@ -100,4 +100,14 @@ BENCHMARK_F(InitializeInput, MatrixMultiplicationNEONb)
} }
#endif #endif
#if defined(__SSE__)
// Using direct NEON implementation
BENCHMARK_F(InitializeInput, MatrixMultiplicationSSE)
(benchmark::State& state) {
for (auto _ : state) {
matrixMultiplySSE(A, B, C);
}
}
#endif
BENCHMARK_MAIN(); BENCHMARK_MAIN();
...@@ -14,17 +14,9 @@ inline void Initialize(std::complex<float>* a) { ...@@ -14,17 +14,9 @@ inline void Initialize(std::complex<float>* a) {
std::mt19937 gen(rd()); std::mt19937 gen(rd());
std::uniform_real_distribution<float> dis(-1.0, 1.0); std::uniform_real_distribution<float> dis(-1.0, 1.0);
a[0] = std::complex<float>(-0.129469, 0.407216);
a[1] = std::complex<float>(0.237004, 0.379525);
a[2] = std::complex<float>(-0.506045, -0.692210);
a[3] = std::complex<float>(0.259458, 0.150587);
/*
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
a[i] = std::complex<float>(dis(gen), dis(gen)); a[i] = std::complex<float>(dis(gen), dis(gen));
} }
*/
} }
// Function to perform matrix multiplication for 2x2 complex matrices // Function to perform matrix multiplication for 2x2 complex matrices
...@@ -55,4 +47,7 @@ void matrixMultiplyNEONa(const std::complex<float>* a, ...@@ -55,4 +47,7 @@ void matrixMultiplyNEONa(const std::complex<float>* a,
void matrixMultiplyNEONb(const std::complex<float>* a, void matrixMultiplyNEONb(const std::complex<float>* a,
const std::complex<float>* b, std::complex<float>* c); const std::complex<float>* b, std::complex<float>* c);
void matrixMultiplySSE(const std::complex<float>* a,
const std::complex<float>* b, std::complex<float>* c);
#endif #endif
...@@ -7,8 +7,8 @@ void matrixMultiplyAoCommon(const std::complex<float>* a, ...@@ -7,8 +7,8 @@ void matrixMultiplyAoCommon(const std::complex<float>* a,
const aocommon::MC2x2F B(b[0], b[1], b[2], b[3]); const aocommon::MC2x2F B(b[0], b[1], b[2], b[3]);
const aocommon::MC2x2F C = A * B; const aocommon::MC2x2F C = A * B;
c[0] = C[0]; c[0] = C.Get(0);
c[1] = C[1]; c[1] = C.Get(1);
c[2] = C[2]; c[2] = C.Get(2);
c[3] = C[3]; c[3] = C.Get(3);
} }
#ifdef __SSE__
#include <immintrin.h>
#include "matrix_multiplication.h"
struct Matrix2x2 {
float m[4];
};
void sse_2x2_multiply_add_v1(const Matrix2x2 a, const Matrix2x2 b,
Matrix2x2& c) {
// Load matrices into SSE registers
__m128 ma = _mm_loadu_ps(a.m);
__m128 mb = _mm_loadu_ps(b.m);
__m128 mc = _mm_loadu_ps(c.m);
__m128 col_a_0 = _mm_moveldup_ps(ma);
__m128 col_a_1 = _mm_movehdup_ps(ma);
__m128 row_b_0 = _mm_shuffle_ps(mb, mb, _MM_SHUFFLE(1, 0, 1, 0));
__m128 row_b_1 = _mm_shuffle_ps(mb, mb, _MM_SHUFFLE(3, 2, 3, 2));
__m128 col0 = _mm_mul_ps(col_a_0, row_b_0);
__m128 col1 = _mm_mul_ps(col_a_1, row_b_1);
// Horizontal add to get final results
__m128 sum = _mm_add_ps(col0, col1);
__m128 result = _mm_add_ps(mc, sum);
// Store the result
_mm_storeu_ps(c.m, result);
}
void sse_2x2_multiply_sub_v1(const Matrix2x2 a, const Matrix2x2 b,
Matrix2x2& c) {
// Load matrices into SSE registers
__m128 ma = _mm_loadu_ps(a.m);
__m128 mb = _mm_loadu_ps(b.m);
__m128 mc = _mm_loadu_ps(c.m);
__m128 col_a_0 = _mm_moveldup_ps(ma);
__m128 col_a_1 = _mm_movehdup_ps(ma);
__m128 row_b_0 = _mm_shuffle_ps(mb, mb, _MM_SHUFFLE(1, 0, 1, 0));
__m128 row_b_1 = _mm_shuffle_ps(mb, mb, _MM_SHUFFLE(3, 2, 3, 2));
__m128 col0 = _mm_mul_ps(col_a_0, row_b_0);
__m128 col1 = _mm_mul_ps(col_a_1, row_b_1);
// Horizontal add to get final results
__m128 sum = _mm_add_ps(col0, col1);
__m128 result = _mm_sub_ps(mc, sum);
// Store the result
_mm_storeu_ps(c.m, result);
}
void matrixMultiplySSE(const std::complex<float>* a,
const std::complex<float>* b, std::complex<float>* c) {
const Matrix2x2 a_real{a[0].real(), a[1].real(), a[2].real(), a[3].real()};
const Matrix2x2 b_real{b[0].real(), b[1].real(), b[2].real(), b[3].real()};
const Matrix2x2 a_imag{a[0].imag(), a[1].imag(), a[2].imag(), a[3].imag()};
const Matrix2x2 b_imag{b[0].imag(), b[1].imag(), b[2].imag(), b[3].imag()};
Matrix2x2 c_real{0.f, 0.f, 0.f, 0.f};
Matrix2x2 c_imag{0.f, 0.f, 0.f, 0.f};
sse_2x2_multiply_add_v1(a_real, b_real, c_real);
sse_2x2_multiply_sub_v1(a_imag, b_imag, c_real);
sse_2x2_multiply_add_v1(a_real, b_imag, c_imag);
sse_2x2_multiply_add_v1(a_imag, b_real, c_imag);
c[0] = {c_real.m[0], c_imag.m[0]};
c[1] = {c_real.m[1], c_imag.m[1]};
c[2] = {c_real.m[2], c_imag.m[2]};
c[3] = {c_real.m[3], c_imag.m[3]};
}
#endif
\ No newline at end of file
...@@ -29,6 +29,14 @@ TEST_CASE("test complex matrix multiplication", "[float]") { ...@@ -29,6 +29,14 @@ TEST_CASE("test complex matrix multiplication", "[float]") {
COMPARE_ARRAYS(C_expected, C, 1.e-6); COMPARE_ARRAYS(C_expected, C, 1.e-6);
} }
#if defined(__SSE__)
SECTION("test correctness of SSE implementation") {
matrixMultiplySSE(A.data(), B.data(), C.data());
COMPARE_ARRAYS(C_expected, C, 1.e-6);
}
#endif
#if defined(__AVX__) #if defined(__AVX__)
SECTION("test correctness of avx implementation") { SECTION("test correctness of avx implementation") {
matrixMultiplyAVX(A.data(), B.data(), C.data()); matrixMultiplyAVX(A.data(), B.data(), C.data());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment