From c486faba64b7147c44b3e2befe77fcaa525b4ee2 Mon Sep 17 00:00:00 2001
From: Mattia Mancini <mancini@astron.nl>
Date: Thu, 6 Feb 2025 14:04:26 +0000
Subject: [PATCH] Test hermitian implementation

---
 .gitlab-ci.yml                          |   6 +-
 CMakeLists.txt                          |   8 +-
 benchmarks/hermitian_square.cpp         |  47 +++++++++
 benchmarks/kronecker_square.cpp         |  45 ++++++++
 benchmarks/main.cpp                     |   3 +
 benchmarks/matrix_multiplication.cpp    |   2 -
 ci/summarize-results.py                 |   2 +-
 cmake/precompile.cmake                  |   3 +-
 code/hermitian_naive_simd.cpp           | 114 ++++++++++++++++++++
 code/hermitian_naive_simd_v2.cpp        | 134 ++++++++++++++++++++++++
 code/hermitian_square.h                 |  26 +++++
 code/hermitian_square_reference.cpp     |  75 +++++++++++++
 code/kronecker_square.h                 |  27 +++++
 code/kronecker_square_fused.cpp         |  40 +++++++
 code/kroneker_square_reference.cpp      |   7 ++
 code/kroneker_square_reference_simd.cpp |   0
 code/matrix_multiplication.h            |   4 +-
 test/helpers.cpp                        |  98 +++++++++++++++++
 test/helpers.h                          |  59 ++++-------
 test/test_hermitian_square.cpp          |  28 +++++
 test/test_kronecker_square.cpp          |  26 +++++
 21 files changed, 705 insertions(+), 49 deletions(-)
 create mode 100644 benchmarks/hermitian_square.cpp
 create mode 100644 benchmarks/kronecker_square.cpp
 create mode 100644 benchmarks/main.cpp
 create mode 100644 code/hermitian_naive_simd.cpp
 create mode 100644 code/hermitian_naive_simd_v2.cpp
 create mode 100644 code/hermitian_square.h
 create mode 100644 code/hermitian_square_reference.cpp
 create mode 100644 code/kronecker_square.h
 create mode 100644 code/kronecker_square_fused.cpp
 create mode 100644 code/kroneker_square_reference.cpp
 create mode 100644 code/kroneker_square_reference_simd.cpp
 create mode 100644 test/helpers.cpp
 create mode 100644 test/test_hermitian_square.cpp
 create mode 100644 test/test_kronecker_square.cpp

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 42f5bb2..d11a266 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -84,6 +84,7 @@ performance-jetson:   # This job runs in the test stage.
     - sbatch --wait -p jetsonq -w orin02 -o output.txt -e error.txt ci/das6/compile_and_run_jetson.sh jetson_arm64 11.4.0
     - cat output.txt >&1
     - cat error.txt >&2
+  allow_failure: true
     
   artifacts:
     paths:
@@ -127,4 +128,7 @@ collect-performance:
       - ./*.tar
   script:
   - ls -la 
-  - python3 ci/summarize-results.py --filter MatrixMultiplication results*.json result-summary
+  - python3 ci/summarize-results.py --filter MatrixMultiplication results*.json result-summary-matrix-multiplication
+  - python3 ci/summarize-results.py --filter HermitianSquare results*.json result-summary-hermitian-square
+  - python3 ci/summarize-results.py --filter KroneckerSquare results*.json result-summary-kronecker-square
+  
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 20257d1..73d5509 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -50,6 +50,7 @@ FetchContent_Declare(
 FetchContent_Populate(aocommon)
 
 set(COMPILER_FLAGS "-O3;-march=native;-ggdb;")
+
 # List all kernel code
 file(GLOB KERNEL_SOURCES "code/*.cpp")
 # Add the benchmark executable
@@ -59,12 +60,17 @@ add_executable(microbenchmarks ${BENCHMARK_SOURCES} ${KERNEL_SOURCES})
 file(GLOB TEST_SOURCES "test/*.cpp")
 add_executable(unittests ${TEST_SOURCES} ${KERNEL_SOURCES})
 
+find_package(OpenMP)
 
 # Link against Google Benchmark
-target_link_libraries(microbenchmarks benchmark::benchmark)
+target_link_libraries(microbenchmarks PRIVATE benchmark::benchmark)
 target_include_directories(microbenchmarks PRIVATE ${aocommon_SOURCE_DIR}/include)
 target_include_directories(microbenchmarks PRIVATE code)
 target_compile_options(microbenchmarks PUBLIC ${COMPILER_FLAGS})
+if(OpenMP_CXX_FOUND)
+    target_link_libraries(microbenchmarks PRIVATE OpenMP::OpenMP_CXX)
+    target_link_libraries(unittests PRIVATE OpenMP::OpenMP_CXX)
+endif()
 
 include(precompile)
 list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
diff --git a/benchmarks/hermitian_square.cpp b/benchmarks/hermitian_square.cpp
new file mode 100644
index 0000000..ffffa74
--- /dev/null
+++ b/benchmarks/hermitian_square.cpp
@@ -0,0 +1,47 @@
+#include <benchmark/benchmark.h>
+#include <hermitian_square.h>
+
+#include <iostream>
+
+namespace {
+
+class InitializeInput : public benchmark::Fixture {
+ public:
+  void SetUp(::benchmark::State& state) {
+    A = std::make_unique<aocommon::Matrix4x4>();
+    InitializeInput(*A);
+  }
+  void TearDown(::benchmark::State& state) { A.reset(); }
+
+  std::unique_ptr<aocommon::Matrix4x4> A;
+};
+}  // namespace
+
+// Reference standard
+BENCHMARK_F(InitializeInput, HermitianSquare)
+(benchmark::State& state) {
+  for (auto _ : state) {
+    HermitianSquare(*A);
+  }
+}
+
+BENCHMARK_F(InitializeInput, HermitianSquareRefactored)
+(benchmark::State& state) {
+  for (auto _ : state) {
+    HermitianSquareRefactored(*A);
+  }
+}
+
+BENCHMARK_F(InitializeInput, HermitianSquareNaiveSIMD)
+(benchmark::State& state) {
+  for (auto _ : state) {
+    HermitianSquareNaiveSIMD(*A);
+  }
+}
+
+BENCHMARK_F(InitializeInput, HermitianSquareNaiveSIMDv2)
+(benchmark::State& state) {
+  for (auto _ : state) {
+    HermitianSquareNaiveSIMDv2(*A);
+  }
+}
diff --git a/benchmarks/kronecker_square.cpp b/benchmarks/kronecker_square.cpp
new file mode 100644
index 0000000..f7e95a9
--- /dev/null
+++ b/benchmarks/kronecker_square.cpp
@@ -0,0 +1,45 @@
+#include <benchmark/benchmark.h>
+#include <kronecker_square.h>
+
+#include <iostream>
+
+namespace {
+
+class InitializeInput : public benchmark::Fixture {
+ public:
+  void SetUp(::benchmark::State& state) {
+    A = std::make_unique<aocommon::MC2x2>();
+    B = std::make_unique<aocommon::MC2x2>();
+    InitializeInput(*A);
+    InitializeInput(*B);
+  }
+  void TearDown(::benchmark::State& state) {
+    A.reset();
+    B.reset();
+  }
+
+  std::unique_ptr<aocommon::MC2x2> A, B;
+};
+}  // namespace
+
+// Reference implementation
+BENCHMARK_F(InitializeInput, KroneckerSquareReference)
+(benchmark::State& state) {
+  for (auto _ : state) {
+    aocommon::HMC4x4 r = aocommon::HMC4x4::Zero();
+    for (size_t i = 0; i < 100000000; i++) {
+      r += KroneckerSquareReference(*A, *B);
+    }
+  }
+}
+
+// Fused implementation
+BENCHMARK_F(InitializeInput, KroneckerSquareFused)
+(benchmark::State& state) {
+  for (auto _ : state) {
+    aocommon::HMC4x4 r = aocommon::HMC4x4::Zero();
+    for (size_t i = 0; i < 100000000; i++) {
+      r += KroneckerSquareFused(*A, *B);
+        }
+  }
+}
diff --git a/benchmarks/main.cpp b/benchmarks/main.cpp
new file mode 100644
index 0000000..71fefa0
--- /dev/null
+++ b/benchmarks/main.cpp
@@ -0,0 +1,3 @@
+#include <benchmark/benchmark.h>
+
+BENCHMARK_MAIN();
diff --git a/benchmarks/matrix_multiplication.cpp b/benchmarks/matrix_multiplication.cpp
index fb0f0f7..75a058d 100644
--- a/benchmarks/matrix_multiplication.cpp
+++ b/benchmarks/matrix_multiplication.cpp
@@ -109,5 +109,3 @@ BENCHMARK_F(InitializeInput, MatrixMultiplicationSSE)
   }
 }
 #endif
-
-BENCHMARK_MAIN();
diff --git a/ci/summarize-results.py b/ci/summarize-results.py
index ab99375..549852a 100644
--- a/ci/summarize-results.py
+++ b/ci/summarize-results.py
@@ -5,7 +5,7 @@ import seaborn
 import pandas
 import matplotlib.pyplot as plt
 
-seaborn.set_theme(palette="flare", style="whitegrid")
+seaborn.set_theme(palette="bright", style="whitegrid")
 
 def parse_args():
     parser = ArgumentParser(description="Combine benchmark metrics from Google Benchmarks Framework")
diff --git a/cmake/precompile.cmake b/cmake/precompile.cmake
index 097e36d..cfc0ce6 100644
--- a/cmake/precompile.cmake
+++ b/cmake/precompile.cmake
@@ -8,9 +8,10 @@ macro (add_precompile_target)
     list(TRANSFORM apt_INCLUDE_DIRS PREPEND -I)
 
     add_custom_command(OUTPUT ${apt_TARGET_NAME}.s
-                    COMMAND ${CMAKE_CXX_COMPILER} -std=c++${CMAKE_CXX_STANDARD} ${CMAKE_CXX_COMPLERS_ARG} ${apt_COMPILER_FLAGS} ${apt_SOURCES} ${apt_INCLUDE_DIRS} -S -o ${apt_TARGET_NAME}.s
+                    COMMAND ${CMAKE_CXX_COMPILER} -std=c++${CMAKE_CXX_STANDARD} ${CMAKE_CXX_COMPLERS_ARG} ${apt_COMPILER_FLAGS} ${apt_SOURCES} ${apt_INCLUDE_DIRS} -S -fverbose-asm -Wa,aslh -o ${apt_TARGET_NAME}.s
                     DEPENDS ${apt_SOURCES}
                     )
+    
     add_custom_target(precompile_${apt_TARGET_NAME} ALL
                   DEPENDS ${apt_TARGET_NAME}.s
                   COMMENT "Adding precompile target ${KERNEL_NAME} with source ${KERNEL_SOURCE}")
diff --git a/code/hermitian_naive_simd.cpp b/code/hermitian_naive_simd.cpp
new file mode 100644
index 0000000..48d52d1
--- /dev/null
+++ b/code/hermitian_naive_simd.cpp
@@ -0,0 +1,114 @@
+#include "hermitian_square.h"
+
+aocommon::HMatrix4x4 HermitianSquareNaiveSIMD(const aocommon::Matrix4x4& mat) {
+  const double* mat_pointer = reinterpret_cast<const double*>(&mat[0]);
+  double tmp_real_col0[4];
+  double tmp_real_col1[4];
+  double tmp_real_col2[4];
+  double tmp_real_col3[4];
+
+  double tmp_imag_col0[4];
+  double tmp_imag_col1[4];
+  double tmp_imag_col2[4];
+  double tmp_imag_col3[4];
+
+  double c1c0_r[4], c2c1_r[4], c2c0_r[4], c3c2_r[4], c3c1_r[4], c3c0_r[4];
+  double c1c0_i[4], c2c1_i[4], c2c0_i[4], c3c2_i[4], c3c1_i[4], c3c0_i[4];
+  double tmp_real_uf[6] = {0, 0, 0, 0, 0, 0};
+  double tmp_imag_uf[6] = {0, 0, 0, 0, 0, 0};
+  double tmp_diag[4];
+
+  memset(tmp_diag, 0, 4 * sizeof(double));
+
+#pragma omp simd
+  for (int k = 0; k < 4; k++) {
+    tmp_real_col0[k] = mat_pointer[2 * (4 * k + 0)];
+    tmp_real_col1[k] = mat_pointer[2 * (4 * k + 1)];
+    tmp_real_col2[k] = mat_pointer[2 * (4 * k + 2)];
+    tmp_real_col3[k] = mat_pointer[2 * (4 * k + 3)];
+
+    tmp_imag_col0[k] = mat_pointer[2 * (4 * k + 0) + 1];
+    tmp_imag_col1[k] = mat_pointer[2 * (4 * k + 1) + 1];
+    tmp_imag_col2[k] = mat_pointer[2 * (4 * k + 2) + 1];
+    tmp_imag_col3[k] = mat_pointer[2 * (4 * k + 3) + 1];
+  }
+
+#pragma omp simd
+  for (int k = 0; k < 4; k++) {
+    tmp_diag[0] += ((tmp_real_col0[k] * tmp_real_col0[k]) +
+                    (tmp_imag_col0[k] * tmp_imag_col0[k]));
+
+    tmp_diag[1] += ((tmp_real_col1[k] * tmp_real_col1[k]) +
+                    (tmp_imag_col1[k] * tmp_imag_col1[k]));
+
+    tmp_diag[2] += ((tmp_real_col2[k] * tmp_real_col2[k]) +
+                    (tmp_imag_col2[k] * tmp_imag_col2[k]));
+
+    tmp_diag[3] += ((tmp_real_col3[k] * tmp_real_col3[k]) +
+                    (tmp_imag_col3[k] * tmp_imag_col3[k]));
+  }
+
+#pragma omp simd
+  for (int k = 0; k < 4; k++) {
+    c1c0_r[k] = tmp_real_col1[k] * tmp_real_col0[k] +
+                tmp_imag_col1[k] * tmp_imag_col0[k];
+    c2c0_r[k] = tmp_real_col2[k] * tmp_real_col0[k] +
+                tmp_imag_col2[k] * tmp_imag_col0[k];
+    c2c1_r[k] = tmp_real_col2[k] * tmp_real_col1[k] +
+                tmp_imag_col2[k] * tmp_imag_col1[k];
+    c3c0_r[k] = tmp_real_col3[k] * tmp_real_col0[k] +
+                tmp_imag_col3[k] * tmp_imag_col0[k];
+    c3c1_r[k] = tmp_real_col3[k] * tmp_real_col1[k] +
+                tmp_imag_col3[k] * tmp_imag_col1[k];
+    c3c2_r[k] = tmp_real_col3[k] * tmp_real_col2[k] +
+                tmp_imag_col3[k] * tmp_imag_col2[k];
+
+    c1c0_i[k] = tmp_real_col1[k] * tmp_imag_col0[k] -
+                tmp_imag_col1[k] * tmp_real_col0[k];
+    c2c0_i[k] = tmp_real_col2[k] * tmp_imag_col0[k] -
+                tmp_imag_col2[k] * tmp_real_col0[k];
+    c2c1_i[k] = tmp_real_col2[k] * tmp_imag_col1[k] -
+                tmp_imag_col2[k] * tmp_real_col1[k];
+    c3c0_i[k] = tmp_real_col3[k] * tmp_imag_col0[k] -
+                tmp_imag_col3[k] * tmp_real_col0[k];
+    c3c1_i[k] = tmp_real_col3[k] * tmp_imag_col1[k] -
+                tmp_imag_col3[k] * tmp_real_col1[k];
+    c3c2_i[k] = tmp_real_col3[k] * tmp_imag_col2[k] -
+                tmp_imag_col3[k] * tmp_real_col2[k];
+  }
+
+  for (int k = 0; k < 4; k++) {
+    tmp_real_uf[0] += c1c0_r[k];
+    tmp_real_uf[1] += c2c0_r[k];
+    tmp_real_uf[2] += c2c1_r[k];
+    tmp_real_uf[3] += c3c0_r[k];
+    tmp_real_uf[4] += c3c1_r[k];
+    tmp_real_uf[5] += c3c2_r[k];
+
+    tmp_imag_uf[0] += c1c0_i[k];
+    tmp_imag_uf[1] += c2c0_i[k];
+    tmp_imag_uf[2] += c2c1_i[k];
+    tmp_imag_uf[3] += c3c0_i[k];
+    tmp_imag_uf[4] += c3c1_i[k];
+    tmp_imag_uf[5] += c3c2_i[k];
+  }
+
+  return {
+      tmp_diag[0],
+      0.0,
+      0.0,
+      0.0,
+      {tmp_real_uf[0], tmp_imag_uf[0]},
+      tmp_diag[1],
+      0.0,
+      0.0,
+      {tmp_real_uf[1], tmp_imag_uf[1]},
+      {tmp_real_uf[2], tmp_imag_uf[2]},
+      tmp_diag[2],
+      0.0,
+      {tmp_real_uf[3], tmp_imag_uf[3]},
+      {tmp_real_uf[4], tmp_imag_uf[4]},
+      {tmp_real_uf[5], tmp_imag_uf[5]},
+      tmp_diag[3],
+  };
+}
diff --git a/code/hermitian_naive_simd_v2.cpp b/code/hermitian_naive_simd_v2.cpp
new file mode 100644
index 0000000..46031dc
--- /dev/null
+++ b/code/hermitian_naive_simd_v2.cpp
@@ -0,0 +1,134 @@
+
+#include "hermitian_square.h"
+
+#if __AVX__
+#include <immintrin.h>
+inline double reduce_single(const double col[4]) {
+  __m256d vec = _mm256_load_pd(col);
+  __m128d low = _mm256_extractf128_pd(vec, 0);
+  __m128d high = _mm256_extractf128_pd(vec, 1);
+  __m128d sum = _mm_add_pd(low, high);
+  sum = _mm_hadd_pd(sum, sum);
+  return _mm_cvtsd_f64(sum);
+}
+#else
+inline double reduce_single(const double col[4]) {
+  double sum = col[0];
+#pragma unroll
+  for (size_t i = 0; i < 4; i++) {
+    sum += col[i];
+  }
+  return sum;
+}
+#endif
+
+aocommon::HMatrix4x4 HermitianSquareNaiveSIMDv2(
+    const aocommon::Matrix4x4& mat) {
+  const double* mat_pointer = reinterpret_cast<const double*>(&mat[0]);
+  double tmp_real_col0[4];
+  double tmp_real_col1[4];
+  double tmp_real_col2[4];
+  double tmp_real_col3[4];
+
+  double tmp_imag_col0[4];
+  double tmp_imag_col1[4];
+  double tmp_imag_col2[4];
+  double tmp_imag_col3[4];
+
+  double c1c0_r[4], c2c1_r[4], c2c0_r[4], c3c2_r[4], c3c1_r[4], c3c0_r[4];
+  double c1c0_i[4], c2c1_i[4], c2c0_i[4], c3c2_i[4], c3c1_i[4], c3c0_i[4];
+  double tmp_real_uf[6] = {0., 0., 0., 0., 0., 0.};
+  double tmp_imag_uf[6] = {0., 0., 0., 0., 0., 0.};
+  double tmp_diag[4] = {0., 0., 0., 0.};
+
+  memset(tmp_diag, 0, 4 * sizeof(double));
+
+#pragma omp simd
+  for (int k = 0; k < 4; k++) {
+    tmp_real_col0[k] = mat_pointer[2 * (4 * k + 0)];
+    tmp_real_col1[k] = mat_pointer[2 * (4 * k + 1)];
+    tmp_real_col2[k] = mat_pointer[2 * (4 * k + 2)];
+    tmp_real_col3[k] = mat_pointer[2 * (4 * k + 3)];
+
+    tmp_imag_col0[k] = mat_pointer[2 * (4 * k + 0) + 1];
+    tmp_imag_col1[k] = mat_pointer[2 * (4 * k + 1) + 1];
+    tmp_imag_col2[k] = mat_pointer[2 * (4 * k + 2) + 1];
+    tmp_imag_col3[k] = mat_pointer[2 * (4 * k + 3) + 1];
+  }
+
+#pragma omp simd
+  for (int k = 0; k < 4; k++) {
+    tmp_diag[0] += ((tmp_real_col0[k] * tmp_real_col0[k]) +
+                    (tmp_imag_col0[k] * tmp_imag_col0[k]));
+
+    tmp_diag[1] += ((tmp_real_col1[k] * tmp_real_col1[k]) +
+                    (tmp_imag_col1[k] * tmp_imag_col1[k]));
+
+    tmp_diag[2] += ((tmp_real_col2[k] * tmp_real_col2[k]) +
+                    (tmp_imag_col2[k] * tmp_imag_col2[k]));
+
+    tmp_diag[3] += ((tmp_real_col3[k] * tmp_real_col3[k]) +
+                    (tmp_imag_col3[k] * tmp_imag_col3[k]));
+  }
+#pragma omp simd
+  for (int k = 0; k < 4; k++) {
+    c1c0_r[k] = tmp_real_col1[k] * tmp_real_col0[k] +
+                tmp_imag_col1[k] * tmp_imag_col0[k];
+    c2c0_r[k] = tmp_real_col2[k] * tmp_real_col0[k] +
+                tmp_imag_col2[k] * tmp_imag_col0[k];
+    c2c1_r[k] = tmp_real_col2[k] * tmp_real_col1[k] +
+                tmp_imag_col2[k] * tmp_imag_col1[k];
+    c3c0_r[k] = tmp_real_col3[k] * tmp_real_col0[k] +
+                tmp_imag_col3[k] * tmp_imag_col0[k];
+    c3c1_r[k] = tmp_real_col3[k] * tmp_real_col1[k] +
+                tmp_imag_col3[k] * tmp_imag_col1[k];
+    c3c2_r[k] = tmp_real_col3[k] * tmp_real_col2[k] +
+                tmp_imag_col3[k] * tmp_imag_col2[k];
+
+    c1c0_i[k] = tmp_real_col1[k] * tmp_imag_col0[k] -
+                tmp_imag_col1[k] * tmp_real_col0[k];
+    c2c0_i[k] = tmp_real_col2[k] * tmp_imag_col0[k] -
+                tmp_imag_col2[k] * tmp_real_col0[k];
+    c2c1_i[k] = tmp_real_col2[k] * tmp_imag_col1[k] -
+                tmp_imag_col2[k] * tmp_real_col1[k];
+    c3c0_i[k] = tmp_real_col3[k] * tmp_imag_col0[k] -
+                tmp_imag_col3[k] * tmp_real_col0[k];
+    c3c1_i[k] = tmp_real_col3[k] * tmp_imag_col1[k] -
+                tmp_imag_col3[k] * tmp_real_col1[k];
+    c3c2_i[k] = tmp_real_col3[k] * tmp_imag_col2[k] -
+                tmp_imag_col3[k] * tmp_real_col2[k];
+  }
+
+  tmp_real_uf[0] = reduce_single(c1c0_r);
+  tmp_real_uf[1] = reduce_single(c2c0_r);
+  tmp_real_uf[2] = reduce_single(c2c1_r);
+  tmp_real_uf[3] = reduce_single(c3c0_r);
+  tmp_real_uf[4] = reduce_single(c3c1_r);
+  tmp_real_uf[5] = reduce_single(c3c2_r);
+
+  tmp_imag_uf[0] = reduce_single(c1c0_i);
+  tmp_imag_uf[1] = reduce_single(c2c0_i);
+  tmp_imag_uf[2] = reduce_single(c2c1_i);
+  tmp_imag_uf[3] = reduce_single(c3c0_i);
+  tmp_imag_uf[4] = reduce_single(c3c1_i);
+  tmp_imag_uf[5] = reduce_single(c3c2_i);
+
+  return {
+      tmp_diag[0],
+      0.0,
+      0.0,
+      0.0,
+      {tmp_real_uf[0], tmp_imag_uf[0]},
+      tmp_diag[1],
+      0.0,
+      0.0,
+      {tmp_real_uf[1], tmp_imag_uf[1]},
+      {tmp_real_uf[2], tmp_imag_uf[2]},
+      tmp_diag[2],
+      0.0,
+      {tmp_real_uf[3], tmp_imag_uf[3]},
+      {tmp_real_uf[4], tmp_imag_uf[4]},
+      {tmp_real_uf[5], tmp_imag_uf[5]},
+      tmp_diag[3],
+  };
+}
\ No newline at end of file
diff --git a/code/hermitian_square.h b/code/hermitian_square.h
new file mode 100644
index 0000000..48e2cee
--- /dev/null
+++ b/code/hermitian_square.h
@@ -0,0 +1,26 @@
+#ifndef HERMITIAN_SQUARE_H
+#define HERMITIAN_SQUARE_H
+
+#include <aocommon/hmatrix4x4.h>
+#include <aocommon/matrix4x4.h>
+
+#include <iostream>
+#include <random>
+
+inline void Initialize(aocommon::Matrix4x4& mat) {
+  // Initialize matrices with random complex values
+  std::seed_seq seed({42});
+  std::mt19937 gen(seed);
+  std::uniform_real_distribution<double> dis(-1.0, 1.0);
+
+  for (int i = 0; i < 16; i++) {
+    mat[i] = std::complex<double>(dis(gen), dis(gen));
+  }
+}
+
+aocommon::HMatrix4x4 HermitianSquare(const aocommon::Matrix4x4& mat);
+aocommon::HMatrix4x4 HermitianSquareRefactored(const aocommon::Matrix4x4& mat);
+aocommon::HMatrix4x4 HermitianSquareNaiveSIMD(const aocommon::Matrix4x4& mat);
+aocommon::HMatrix4x4 HermitianSquareNaiveSIMDv2(const aocommon::Matrix4x4& mat);
+
+#endif
\ No newline at end of file
diff --git a/code/hermitian_square_reference.cpp b/code/hermitian_square_reference.cpp
new file mode 100644
index 0000000..50f5140
--- /dev/null
+++ b/code/hermitian_square_reference.cpp
@@ -0,0 +1,75 @@
+#include "hermitian_square.h"
+
+aocommon::HMatrix4x4 HermitianSquare(const aocommon::Matrix4x4& mat) {
+  return mat.HermitianSquare();
+}
+
+template <typename T>
+using cmplx = std::complex<T>;
+
+aocommon::HMatrix4x4 HermitianSquareRefactored(const aocommon::Matrix4x4& mat) {
+  auto N = [](std::complex<double> z) -> double { return std::norm(z); };
+
+  /*
+    Mat is a complex matrix indexed as
+     0  1  2  3
+     4  5  6  7
+     8  9 10 11
+    12 13 14 15
+
+    norm = r*r + i*i #ops = 3
+    conj = r +ji -> r - ji #ops = 0
+    complex multiplication (CM) = (a + ib ) * (c + i d )
+    -> a*c - b*d + i (a d +  bc) -> 4 mult 1 sum 1 sub -> #ops 6
+    complex conjugate multiplication (CCM) = (a - ib ) * (c + i d )
+    -> a*c + b*d + i (-bc + ad) -> 4 mult 1 sum 1 sub -> #ops 6
+
+
+    On the diagonal there are 4 norms + 3 addition -> #ops 15
+    there are 4 values in the diagonal so 60 operations
+
+    Off diagonal 4 complex multiplication -> #ops 24
+    24 * 6 off diagonal element #ops 144
+
+    144 + 60 -> 204 operations per matrix
+  */
+
+  const cmplx<double> R00 = std::norm(mat[0]) + std::norm(mat[4]) +
+                            std::norm(mat[8]) + std::norm(mat[12]);
+  const cmplx<double> R01 = 0.0;
+  const cmplx<double> R02 = 0.0;
+  const cmplx<double> R03 = 0.0;
+
+  const cmplx<double> R10 =
+      std::conj(mat[1]) * mat[0] + std::conj(mat[5]) * mat[4] +
+      std::conj(mat[9]) * mat[8] + std::conj(mat[13]) * mat[12];
+  const cmplx<double> R11 = std::norm(mat[1]) + std::norm(mat[5]) +
+                            std::norm(mat[9]) + std::norm(mat[13]);
+  const cmplx<double> R12 = 0.0;
+  const cmplx<double> R13 = 0.0;
+
+  const cmplx<double> R20 =
+      std::conj(mat[2]) * mat[0] + std::conj(mat[6]) * mat[4] +
+      std::conj(mat[10]) * mat[8] + std::conj(mat[14]) * mat[12];
+  const cmplx<double> R21 =
+      std::conj(mat[2]) * mat[1] + std::conj(mat[6]) * mat[5] +
+      std::conj(mat[10]) * mat[9] + std::conj(mat[14]) * mat[13];
+  const cmplx<double> R22 = std::norm(mat[2]) + std::norm(mat[6]) +
+                            std::norm(mat[10]) + std::norm(mat[14]);
+  const cmplx<double> R23 = 0.0;
+
+  const cmplx<double> R30 =
+      std::conj(mat[3]) * mat[0] + std::conj(mat[7]) * mat[4] +
+      std::conj(mat[11]) * mat[8] + std::conj(mat[15]) * mat[12];
+  const cmplx<double> R31 =
+      std::conj(mat[3]) * mat[1] + std::conj(mat[7]) * mat[5] +
+      std::conj(mat[11]) * mat[9] + std::conj(mat[15]) * mat[13];
+  const cmplx<double> R32 =
+      std::conj(mat[3]) * mat[2] + std::conj(mat[7]) * mat[6] +
+      std::conj(mat[11]) * mat[10] + std::conj(mat[15]) * mat[14];
+  const cmplx<double> R33 = std::norm(mat[3]) + std::norm(mat[7]) +
+                            std::norm(mat[11]) + std::norm(mat[15]);
+
+  return {R00, R01, R02, R03, R10, R11, R12, R13,
+          R20, R21, R22, R23, R30, R31, R32, R33};
+}
diff --git a/code/kronecker_square.h b/code/kronecker_square.h
new file mode 100644
index 0000000..6bae773
--- /dev/null
+++ b/code/kronecker_square.h
@@ -0,0 +1,27 @@
+#ifndef KRONECKER_SQUARE_H
+#define KRONECKER_SQUARE_H
+
+#include <aocommon/hmatrix4x4.h>
+#include <aocommon/matrix4x4.h>
+
+#include <iostream>
+#include <random>
+
+inline void Initialize(aocommon::MC2x2& a) {
+  // Initialize matrices with random complex values
+  std::seed_seq seed({42});
+  std::mt19937 gen(seed);
+  std::uniform_real_distribution<double> dis(-1.0, 1.0);
+
+  for (size_t i = 0; i < 4; i++) {
+    a.Set(i, std::complex<double>(dis(gen), dis(gen)));
+  }
+}
+
+aocommon::HMC4x4 KroneckerSquareReference(aocommon::MC2x2& left,
+                                          aocommon::MC2x2& right);
+
+aocommon::HMC4x4 KroneckerSquareFused(aocommon::MC2x2& left,
+                                      aocommon::MC2x2& right);
+
+#endif  // KRONECKER_SQUARE_H
\ No newline at end of file
diff --git a/code/kronecker_square_fused.cpp b/code/kronecker_square_fused.cpp
new file mode 100644
index 0000000..7acfa1f
--- /dev/null
+++ b/code/kronecker_square_fused.cpp
@@ -0,0 +1,40 @@
+
+#include "kronecker_square.h"
+
+// This code is taken from
+// https://gitlab.com/aroffringa/wsclean/-/merge_requests/772
+
+aocommon::HMC4x4 KroneckerSquareFused(aocommon::MC2x2& a, aocommon::MC2x2& b) {
+  using T = const std::complex<double>;
+  using RT = const double;
+  using std::conj;
+  using std::norm;
+
+  // Calculate a^H a. Because the result is Hermitian, some shortcuts
+  // can be made.
+  RT p00 = norm(a.Get(0)) + norm(a.Get(2));
+  T p01 = conj(a.Get(0)) * a.Get(1) + conj(a.Get(2)) * a.Get(3);
+  RT p11 = norm(a.Get(1)) + norm(a.Get(3));
+
+  // Calculate b^H b.
+  RT q00 = norm(b.Get(0)) + norm(b.Get(2));
+  T q10 = conj(b.Get(1)) * b.Get(0) + conj(b.Get(3)) * b.Get(2);  // = conj(q01)
+  RT q11 = norm(b.Get(1)) + norm(b.Get(3));
+
+  // Calculate the Kronecker product of p^T and q
+  RT m00 = p00 * q00;
+  T m10 = p00 * q10;
+  RT m11 = p00 * q11;
+  T m20 = p01 * q00;
+  T m21 = p01 * conj(q10);
+  RT m22 = p11 * q00;
+  T m30 = p01 * q10;
+  T m31 = p01 * q11;
+  T m32 = p11 * q10;
+  RT m33 = p11 * q11;
+
+  return aocommon::HMC4x4::FromData(
+      {m00, m10.real(), m10.imag(), m11, m20.real(), m20.imag(), m21.real(),
+       m21.imag(), m22, m30.real(), m30.imag(), m31.real(), m31.imag(),
+       m32.real(), m32.imag(), m33});
+}
diff --git a/code/kroneker_square_reference.cpp b/code/kroneker_square_reference.cpp
new file mode 100644
index 0000000..f108022
--- /dev/null
+++ b/code/kroneker_square_reference.cpp
@@ -0,0 +1,7 @@
+#include "kronecker_square.h"
+
+aocommon::HMC4x4 KroneckerSquareReference(aocommon::MC2x2& left,
+                                          aocommon::MC2x2& right) {
+  return aocommon::HMC4x4::KroneckerProduct(left.HermitianSquare().Transpose(),
+                                            right.HermitianSquare());
+}
\ No newline at end of file
diff --git a/code/kroneker_square_reference_simd.cpp b/code/kroneker_square_reference_simd.cpp
new file mode 100644
index 0000000..e69de29
diff --git a/code/matrix_multiplication.h b/code/matrix_multiplication.h
index 3867383..4f51072 100644
--- a/code/matrix_multiplication.h
+++ b/code/matrix_multiplication.h
@@ -10,8 +10,8 @@
 
 inline void Initialize(std::complex<float>* a) {
   // Initialize matrices with random complex values
-  std::random_device rd;
-  std::mt19937 gen(rd());
+  std::seed_seq seed({42});
+  std::mt19937 gen(seed);
   std::uniform_real_distribution<float> dis(-1.0, 1.0);
 
   for (int i = 0; i < 4; i++) {
diff --git a/test/helpers.cpp b/test/helpers.cpp
new file mode 100644
index 0000000..60cabef
--- /dev/null
+++ b/test/helpers.cpp
@@ -0,0 +1,98 @@
+#include "helpers.h"
+
+#include <aocommon/matrix4x4.h>
+
+#include <catch2/catch_test_macros.hpp>
+#include <catch2/matchers/catch_matchers_floating_point.hpp>
+#include <iostream>
+
+#define COMPARE_ARRAYS(lhs, rhs, precision)                                    \
+  compareArrays(Catch::getResultCapture().getCurrentTestName(), __LINE__, lhs, \
+                rhs, precision)
+
+template <typename T>
+void compareSingle(const std::vector<T>& lv, const std::vector<T>& rv,
+                   float precision) {
+  REQUIRE_THAT(lv, Catch::Matchers::WithinAbs(rv, precision));
+}
+
+template <>
+void compareSingle(const std::vector<std::complex<float>>& lv,
+                   const std::vector<std::complex<float>>& rv,
+                   float precision) {
+  for (size_t idx = 0; idx < lv.size(); idx++) {
+    const auto le = lv[idx];
+    const auto re = rv[idx];
+
+    REQUIRE_THAT(le.real(), Catch::Matchers::WithinAbs(re.real(), precision));
+    REQUIRE_THAT(le.imag(), Catch::Matchers::WithinAbs(re.imag(), precision));
+  }
+}
+
+template <typename T>
+std::string valueToString(const T& value) {
+  return std::to_string(value);
+}
+
+std::string valueToString(const std::complex<float>& value) {
+  return std::to_string(value.real()) + " " + std::to_string(value.imag()) +
+         "j";
+}
+
+template <typename T, size_t N>
+void compareArrays(const std::string& test, unsigned line, std::array<T, N> lhs,
+                   std::array<T, N> rhs, float precision) {
+  std::vector<T> lv(lhs.begin(), lhs.end());
+  std::vector<T> rv(rhs.begin(), rhs.end());
+  INFO("Test case [" << test << "] failed at line "
+                     << line);  // Reported only if REQUIRE fails
+
+  std::stringstream ss;
+  ss << "Expected : \n";
+  for (size_t idx = 0; idx < N; idx++) {
+    ss << valueToString(lhs[idx]) << "\t";
+  }
+
+  ss << "\nObtained : \n";
+  for (size_t idx = 0; idx < N; idx++) {
+    ss << valueToString(rhs[idx]) << "\t";
+  }
+  ss << "\n";
+  INFO("Reason: \n" << ss.str());
+  compareSingle(lv, rv, precision);
+}
+
+template void compareArrays(const std::string& test, unsigned line,
+                            std::array<std::complex<float>, 4ul> lhs,
+                            std::array<std::complex<float>, 4ul> rhs,
+                            float precision);
+
+void AssertEqual(const aocommon::Matrix4x4& a, const aocommon::Matrix4x4& b,
+                 float precision) {
+  for (size_t i = 0; i < 16; i++) {
+    REQUIRE_THAT(a[i].real(),
+                 Catch::Matchers::WithinAbs(b[i].real(), precision));
+    REQUIRE_THAT(a[i].imag(),
+                 Catch::Matchers::WithinAbs(b[i].imag(), precision));
+  }
+}
+
+void AssertEqual(const aocommon::HMatrix4x4& a, const aocommon::HMatrix4x4& b,
+                 float precision) {
+  for (size_t i = 0; i < 16; i++) {
+    REQUIRE_THAT(a[i].real(),
+                 Catch::Matchers::WithinAbs(b[i].real(), precision));
+    REQUIRE_THAT(a[i].imag(),
+                 Catch::Matchers::WithinAbs(b[i].imag(), precision));
+  }
+}
+void PrintMatrix(const aocommon::Matrix4x4& mat) {
+  for (size_t j = 0; j < 4; j++) {
+    for (size_t i = 0; i < 3; i++) std::cout << mat[i + j * 4] << "\t";
+    std::cout << mat[3 + j * 4] << std::endl;
+  }
+}
+
+void PrintMatrix(const aocommon::HMatrix4x4& mat) {
+  PrintMatrix(mat.ToMatrix());
+}
\ No newline at end of file
diff --git a/test/helpers.h b/test/helpers.h
index 1939b9b..83f306f 100644
--- a/test/helpers.h
+++ b/test/helpers.h
@@ -1,9 +1,13 @@
-#ifndef HELPERS
+#ifndef HELPERS_H
 
-#define HELPERS
+#define HELPERS_H
+
+#include <aocommon/hmatrix4x4.h>
+#include <aocommon/matrix4x4.h>
 
 #include <array>
 #include <catch2/matchers/catch_matchers_floating_point.hpp>
+#include <complex>
 #include <string>
 #include <vector>
 
@@ -13,54 +17,27 @@
 
 template <typename T>
 void compareSingle(const std::vector<T>& lv, const std::vector<T>& rv,
-                   float precision) {
-  REQUIRE_THAT(lv, Catch::Matchers::WithinAbs(rv, precision));
-}
+                   float precision);
 
 template <>
 void compareSingle(const std::vector<std::complex<float>>& lv,
-                   const std::vector<std::complex<float>>& rv,
-                   float precision) {
-  for (size_t idx = 0; idx < lv.size(); idx++) {
-    const auto le = lv[idx];
-    const auto re = rv[idx];
-
-    REQUIRE_THAT(le.real(), Catch::Matchers::WithinAbs(re.real(), precision));
-    REQUIRE_THAT(le.imag(), Catch::Matchers::WithinAbs(re.imag(), precision));
-  }
-}
+                   const std::vector<std::complex<float>>& rv, float precision);
 
 template <typename T>
-std::string valueToString(const T& value) {
-  return std::to_string(value);
-}
+std::string valueToString(const T& value);
 
-std::string valueToString(const std::complex<float>& value) {
-  return std::to_string(value.real()) + " " + std::to_string(value.imag()) +
-         "j";
-}
+std::string valueToString(const std::complex<float>& value);
 
 template <typename T, size_t N>
 void compareArrays(const std::string& test, unsigned line, std::array<T, N> lhs,
-                   std::array<T, N> rhs, float precision) {
-  std::vector<T> lv(lhs.begin(), lhs.end());
-  std::vector<T> rv(rhs.begin(), rhs.end());
-  INFO("Test case [" << test << "] failed at line "
-                     << line);  // Reported only if REQUIRE fails
+                   std::array<T, N> rhs, float precision);
+
+void AssertEqual(const aocommon::Matrix4x4& a, const aocommon::Matrix4x4& b,
+                 float precision);
 
-  std::stringstream ss;
-  ss << "Expected : \n";
-  for (size_t idx = 0; idx < N; idx++) {
-    ss << valueToString(lhs[idx]) << "\t";
-  }
+void AssertEqual(const aocommon::HMatrix4x4& a, const aocommon::HMatrix4x4& b,
+                 float precision);
 
-  ss << "\nObtained : \n";
-  for (size_t idx = 0; idx < N; idx++) {
-    ss << valueToString(rhs[idx]) << "\t";
-  }
-  ss << "\n";
-  INFO("Reason: \n" << ss.str());
-  compareSingle(lv, rv, precision);
-}
+void PrintMatrix(const aocommon::Matrix4x4& mat);
 
-#endif
\ No newline at end of file
+#endif  // HELPERS_H
\ No newline at end of file
diff --git a/test/test_hermitian_square.cpp b/test/test_hermitian_square.cpp
new file mode 100644
index 0000000..459b616
--- /dev/null
+++ b/test/test_hermitian_square.cpp
@@ -0,0 +1,28 @@
+#include <hermitian_square.h>
+
+#include <catch2/catch_test_macros.hpp>
+#include <catch2/matchers/catch_matchers_floating_point.hpp>
+
+#include "helpers.h"
+
+TEST_CASE("test hermitian square", "[double]") {
+  // This setup will be done 4 times in total, once for each section
+  aocommon::Matrix4x4 A;
+  aocommon::HMatrix4x4 HA;
+  aocommon::HMatrix4x4 HA_expected;
+
+  Initialize(A);
+
+  HA_expected = HermitianSquare(A);
+
+  SECTION("test correctness of refactored implementation") {
+    HA = HermitianSquareRefactored(A);
+    CHECK(HA_expected == HA);
+  }
+
+  SECTION("test correctness of nasty implementation") {
+    HA = HermitianSquareNaiveSIMD(A);
+
+    AssertEqual(HA_expected, HA, 1.e-6);
+  }
+}
\ No newline at end of file
diff --git a/test/test_kronecker_square.cpp b/test/test_kronecker_square.cpp
new file mode 100644
index 0000000..5ef6838
--- /dev/null
+++ b/test/test_kronecker_square.cpp
@@ -0,0 +1,26 @@
+#include <kronecker_square.h>
+
+#include <catch2/catch_test_macros.hpp>
+#include <catch2/matchers/catch_matchers_floating_point.hpp>
+
+#include "helpers.h"
+
+TEST_CASE("test kronecker square matrix multiplication", "[double]") {
+  // This setup will be done 4 times in total, once for each section
+  aocommon::MC2x2 A;
+  aocommon::MC2x2 B;
+  aocommon::Matrix4x4 C;
+
+  aocommon::Matrix4x4 C_expected;
+
+  Initialize(A);
+  Initialize(B);
+
+  C_expected = KroneckerSquareReference(A, B).ToMatrix();
+
+  SECTION("test correctness of fused implementation") {
+    C = KroneckerSquareFused(A, B).ToMatrix();
+
+    AssertEqual(C_expected, C, 1.e-6);
+  }
+}
\ No newline at end of file
-- 
GitLab