From 52ac727c0b5ff52ee4b1e55b91e9026c5617c6b5 Mon Sep 17 00:00:00 2001
From: John Romein <romein@astron.nl>
Date: Tue, 15 Apr 2025 15:17:12 +0200
Subject: [PATCH] Added e4m3 benchmark.

---
 test/Benchmark/Benchmark.cc | 9 ++++++++-
 test/Benchmark/Benchmark.h  | 5 +++--
 2 files changed, 11 insertions(+), 3 deletions(-)

diff --git a/test/Benchmark/Benchmark.cc b/test/Benchmark/Benchmark.cc
index 6798dbc..54b3792 100644
--- a/test/Benchmark/Benchmark.cc
+++ b/test/Benchmark/Benchmark.cc
@@ -8,6 +8,8 @@
 #include <cstring>
 #include <iostream>
 
+#include <cuda_fp8.h>
+
 #include <cudawrappers/nvrtc.hpp>
 
 #define GNU_SOURCE
@@ -28,7 +30,7 @@ Benchmark::Benchmark()
 
     using Format = tcc::Format;
 
-    for (Format format : { Format::fp16, Format::e4m3, Format::e5m2, Format::i8, Format::i4 })
+    for (Format format : { Format::fp16, Format::e4m3, Format::i8, Format::i4 }) // e5m2 not tested separately, as it performs equal to e4m3
 #pragma omp for collapse(2) schedule(dynamic) ordered
       for (unsigned nrReceivers = 1; nrReceivers <= 576; nrReceivers ++)
 	for (unsigned nrReceiversPerBlock = 32; nrReceiversPerBlock <= 64; nrReceiversPerBlock += 16)
@@ -43,6 +45,11 @@ Benchmark::Benchmark()
 
 				break;
 
+	    case Format::e4m3 : if (capability >= 90)
+				  doTest<std::complex<__nv_fp8_e4m3>, std::complex<float>>(format, nrReceiversPerBlock, nrReceivers);
+
+				break;
+
 	    case Format::fp16 : if (capability >= 70)
 				  doTest<std::complex<__half>, std::complex<float>>(format, nrReceiversPerBlock, nrReceivers);
 
diff --git a/test/Benchmark/Benchmark.h b/test/Benchmark/Benchmark.h
index caac2a6..50d38f1 100644
--- a/test/Benchmark/Benchmark.h
+++ b/test/Benchmark/Benchmark.h
@@ -37,11 +37,12 @@ template<> std::complex<int8_t> Benchmark::randomValue<std::complex<int8_t>>()
 }
 
 
-template<> std::complex<__half> Benchmark::randomValue<std::complex<__half>>()
+template<typename SampleType> SampleType Benchmark::randomValue()
 {
-  return std::complex<__half>(drand48() - .5, drand48() - .5);
+  return SampleType((typename SampleType::value_type) (drand48() - .5), (typename SampleType::value_type) (drand48() - .5));
 }
 
+
 template <typename VisibilityType> bool Benchmark::approximates(const VisibilityType &a, const VisibilityType &b) const
 {
   return a == b;
-- 
GitLab