Skip to content
Snippets Groups Projects

Add batch approach for matrix multiplication

Open Mattia Mancini requested to merge multi-matrix-multiply into main

Files

@@ -26,6 +26,34 @@ class InitializeInput : public benchmark::Fixture {
std::complex<float>* C;
};
class InitializeInputBatch : public benchmark::Fixture {
public:
void SetUp(::benchmark::State& state) {
size_t n_matrices = state.range(0);
A = static_cast<std::complex<float>*>(
std::aligned_alloc(32, n_matrices * 4 * sizeof(std::complex<float>)));
B = static_cast<std::complex<float>*>(
std::aligned_alloc(32, n_matrices * 4 * sizeof(std::complex<float>)));
C = static_cast<std::complex<float>*>(
std::aligned_alloc(32, n_matrices * 4 * sizeof(std::complex<float>)));
// Initialize matrices with random complex values
for (size_t s = 0; s < n_matrices; s++) {
Initialize(A + 4 * s);
Initialize(B + 4 * s);
}
}
void TearDown(::benchmark::State& state) {
// Free the allocated memory
std::free(A);
std::free(B);
std::free(C);
}
std::complex<float>* A;
std::complex<float>* B;
std::complex<float>* C;
};
// Reference standard
BENCHMARK_F(InitializeInput, MatrixMultiplicationReference)
(benchmark::State& state) {
@@ -100,4 +128,38 @@ BENCHMARK_F(InitializeInput, MatrixMultiplicationNEONb)
}
#endif
BENCHMARK_DEFINE_F(InitializeInputBatch, BatchMatrixMultiplicationReference)
(benchmark::State& state) {
for (auto _ : state) {
for (size_t s = 0; s < state.range(0); s++) {
matrixMultiplyReference(A + s * 4, B + s * 4, C + s * 4);
}
}
}
BENCHMARK_DEFINE_F(InitializeInputBatch, BatchMatrixMultiplicationAOCommon)
(benchmark::State& state) {
for (auto _ : state) {
for (size_t s = 0; s < state.range(0); s++) {
matrixMultiplyAoCommon(A + s * 4, B + s * 4, C + s * 4);
}
}
}
BENCHMARK_DEFINE_F(InitializeInputBatch, BatchMatrixMultiplicationRealComplex)
(benchmark::State& state) {
for (auto _ : state) {
matrixMultiplyRealComplex(A, B, C, state.range(0));
}
}
BENCHMARK_REGISTER_F(InitializeInputBatch, BatchMatrixMultiplicationReference)
->Range(8, 512);
BENCHMARK_REGISTER_F(InitializeInputBatch, BatchMatrixMultiplicationAOCommon)
->Range(8, 512);
BENCHMARK_REGISTER_F(InitializeInputBatch, BatchMatrixMultiplicationRealComplex)
->Range(8, 512);
BENCHMARK_MAIN();
Loading