diff --git a/README.md b/README.md
index a1ee014fe75b6d980bdec002227e6d7fb0660a59..6ecad7da010efb0083ecae0b21bd49235ff89139 100644
--- a/README.md
+++ b/README.md
@@ -32,36 +32,49 @@ how to use the CUDA driver API (wrappers); and
 OpenCL program.  `test/CorrelatorTest/CorrelatorTest.cc` is a much more versatile,
 robust (and complex) example than `test/SimpleExample/SimpleExample.cu`.
 
-Input and output data types are defined as follows:
+The TCC accepts the following input data types:
+- half precision floating point (a.k.a. fp16), starting from Volta (sm\_70)
+- e4m3 and e5m2 (a.k.a. fp8), starting from Hopper (sm\_90)
+- 8-bit integers (i8), starting from the Jetson Xavier (sm\_72)
+- 4-bit integers (i4), only natively supported on Ampere and Ada
 
 ```
-#if NR_BITS == 4
+#if INPUT_FORMAT == FORMAT_I4
+#define NR_TIMES_PER_BLOCK    32
 typedef complex_int4_t        Sample;
 typedef std::complex<int32_t> Visibility;
-#elif NR_BITS == 8
+#elif INPUT_FORMAT == FORMAT_I8
+#define NR_TIMES_PER_BLOCK    16
 typedef std::complex<int8_t>  Sample;
 typedef std::complex<int32_t> Visibility;
-#elif NR_BITS == 16
+#elif INPUT_FORMAT == FORMAT_E4M3
+#define NR_TIMES_PER_BLOCK    16
+typedef std::complex<__nv_fp8_e4m3>  Sample;
+typedef std::complex<float>   Visibility;
+#elif INPUT_FORMAT == FORMAT_E5M2
+#define NR_TIMES_PER_BLOCK    16
+typedef std::complex<__nv_fp8_e5m2>  Sample;
+typedef std::complex<float>   Visibility;
+#elif INPUT_FORMAT == FORMAT_FP16
+#define NR_TIMES_PER_BLOCK    8
 typedef std::complex<__half>  Sample;
 typedef std::complex<float>   Visibility;
 #endif
 
-#define NR_TIMES_PER_BLOCK (128 / NR_BITS)
 
 typedef Sample Samples[NR_CHANNELS][NR_SAMPLES_PER_CHANNEL / NR_TIMES_PER_BLOCK][NR_RECEIVERS][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK];
 typedef Visibility Visibilities[NR_CHANNELS][NR_BASELINES][NR_POLARIZATIONS][NR_POLARIZATIONS];
 ```
 
-Note that in 4-bit and 8-bit mode, the input samples may not contain -8 or -128
+Note that with FORMAT\_I4 and FORMAT\_I8, the input samples may not contain -8 or -128
 respectively, as these values cannot be conjugated properly.
 The input data type (`Samples`) is a weird format, but this seemed to be the only
 format that yields good performance (tensor cores are very unforgiving).
 
 Limitations:
 - `NR_POLARIZATIONS` must be 2
-- `NR_BITS` must be 4, 8, or 16
-- the amount of samples over which is integrated) must be a multiple of 128 / `NR_BITS`
-  (i.e., 32, 16, or 8 for 4-bit, 8-bit, or 16-bit input, respectively).
+- the amount of samples over which is integrated must be a multiple of
+NR\_TIMES\_PER\_BLOCK.
 
 ## Building, testing, and installation
 Clone the repository:
diff --git a/libtcc/Correlator.cc b/libtcc/Correlator.cc
index b95fe813132dc0c55ddf3e09e7888eb1da0df1f6..9d189d22ceb474d82858fe2d13277daec955aecb 100644
--- a/libtcc/Correlator.cc
+++ b/libtcc/Correlator.cc
@@ -33,7 +33,7 @@ std::string Correlator::findNVRTCincludePath() const
 
 
 Correlator::Correlator(const cu::Device &device,
-		       unsigned nrBits,
+		       Format   inputFormat,
 		       unsigned nrReceivers,
 		       unsigned nrChannels,
 		       unsigned nrSamplesPerChannel,
@@ -46,8 +46,37 @@ Correlator::Correlator(const cu::Device &device,
     return 10 * device.getAttribute<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>() + device.getAttribute<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>();
   } ()),
   nrReceiversPerBlock(nrReceiversPerBlock != 0 ? nrReceiversPerBlock : defaultNrReceiversPerBlock(nrReceivers)),
-  correlatorModule(compileModule(nrBits, nrReceivers, nrChannels, nrSamplesPerChannel, nrPolarizations, this->nrReceiversPerBlock, customStoreVisibility)),
-  correlatorKernel(correlatorModule, nrBits, nrReceivers, nrChannels, nrSamplesPerChannel, nrPolarizations, this->nrReceiversPerBlock)
+  correlatorModule(compileModule(inputFormat, nrReceivers, nrChannels, nrSamplesPerChannel, nrPolarizations, this->nrReceiversPerBlock, customStoreVisibility)),
+  correlatorKernel(correlatorModule, nrReceivers, nrChannels, nrSamplesPerChannel, nrPolarizations, this->nrReceiversPerBlock)
+{
+}
+
+
+Correlator::Correlator(const cu::Device &device,
+		       unsigned nrBits,
+		       unsigned nrReceivers,
+		       unsigned nrChannels,
+		       unsigned nrSamplesPerChannel,
+		       unsigned nrPolarizations,
+		       unsigned nrReceiversPerBlock,
+		       const std::string &customStoreVisibility
+		      )
+:
+  Correlator(device, 
+	     [&] () -> Format {
+	       switch (nrBits) {
+		    case  4 : return Format::i4;
+		    case  8 : return Format::i8;
+		    case 16 : return Format::fp16;
+		    default : throw std::invalid_argument("nrBits should be 4, 8, or 16");
+	       }
+	     } (),
+	     nrReceivers,
+	     nrChannels,
+	     nrSamplesPerChannel,
+	     nrPolarizations,
+	     nrReceiversPerBlock,
+	     customStoreVisibility)
 {
 }
 
@@ -64,7 +93,7 @@ unsigned Correlator::defaultNrReceiversPerBlock(unsigned nrReceivers) const
 }
 
 
-cu::Module Correlator::compileModule(unsigned nrBits,
+cu::Module Correlator::compileModule(Format   inputFormat,
 				     unsigned nrReceivers,
 				     unsigned nrChannels,
 				     unsigned nrSamplesPerChannel,
@@ -79,12 +108,17 @@ cu::Module Correlator::compileModule(unsigned nrBits,
     "-std=c++11",
     "-arch=compute_" + std::to_string(capability),
     "-lineinfo",
-    "-DNR_BITS=" + std::to_string(nrBits),
+    "-DINPUT_FORMAT=" + std::to_string(inputFormat),
     "-DNR_RECEIVERS=" + std::to_string(nrReceivers),
     "-DNR_CHANNELS=" + std::to_string(nrChannels),
     "-DNR_SAMPLES_PER_CHANNEL=" + std::to_string(nrSamplesPerChannel),
     "-DNR_POLARIZATIONS=" + std::to_string(nrPolarizations),
     "-DNR_RECEIVERS_PER_BLOCK=" + std::to_string(nrReceiversPerBlock),
+    "-DFORMAT_FP16=" + std::to_string(fp16),
+    "-DFORMAT_E4M3=" + std::to_string(e4m3),
+    "-DFORMAT_E5M2=" + std::to_string(e5m2),
+    "-DFORMAT_I8=" + std::to_string(i8),
+    "-DFORMAT_I4=" + std::to_string(i4),
   };
 
   if (!customStoreVisibility.empty())
diff --git a/libtcc/Correlator.h b/libtcc/Correlator.h
index b618baf73ec7ff3e2f0a8dc381b3177d57ccea65..5f705340e8a0679098f06d01972170c73c70b124 100644
--- a/libtcc/Correlator.h
+++ b/libtcc/Correlator.h
@@ -9,9 +9,21 @@
 #include "libtcc/CorrelatorKernel.h"
 
 namespace tcc {
+  enum Format { fp16, e4m3, e5m2, i8, i4 };
+
   class Correlator {
     public:
       Correlator(const cu::Device &,
+		 Format   inputFormat,
+		 unsigned nrReceivers,
+		 unsigned nrChannels,
+		 unsigned nrSamplesPerChannel,
+		 unsigned nrPolarizations = 2,
+		 unsigned nrReceiversPerBlock = 0, // 0: use a heuristic value that should work well
+		 const std::string &customStoreVisibility = ""
+		); // throw (cu::Error, nvrtc::Error)
+
+      [[deprecated]] Correlator(const cu::Device &,
 		 unsigned nrBits,
 		 unsigned nrReceivers,
 		 unsigned nrChannels,
@@ -29,7 +41,7 @@ namespace tcc {
     private:
       std::string      findNVRTCincludePath() const;
       unsigned         defaultNrReceiversPerBlock(unsigned nrReceivers) const;
-      cu::Module       compileModule(unsigned nrBits,
+      cu::Module       compileModule(Format   inputFormat,
 				     unsigned nrReceivers,
 				     unsigned nrChannels,
 				     unsigned nrSamplesPerChannel,
diff --git a/libtcc/CorrelatorKernel.cc b/libtcc/CorrelatorKernel.cc
index 1571f6e13e46dfcdcfe460f5b4e76667333fc6c7..4229dd449876890186299feb89f24436dc34a267 100644
--- a/libtcc/CorrelatorKernel.cc
+++ b/libtcc/CorrelatorKernel.cc
@@ -4,7 +4,6 @@
 namespace tcc {
 
 CorrelatorKernel::CorrelatorKernel(cu::Module &module,
-				   unsigned nrBits,
 				   unsigned nrReceivers,
 				   unsigned nrChannels,
 				   unsigned nrSamplesPerChannel,
@@ -13,7 +12,6 @@ CorrelatorKernel::CorrelatorKernel(cu::Module &module,
 				  )
 :
   Kernel(module, "correlate"),
-  nrBits(nrBits),
   nrReceivers(nrReceivers),
   nrChannels(nrChannels),
   nrSamplesPerChannel(nrSamplesPerChannel),
diff --git a/libtcc/CorrelatorKernel.h b/libtcc/CorrelatorKernel.h
index d9c0a5425381d83500a898733e1aa9fdd729df7e..3186921f360974c3e346185645f5be67e7e1650c 100644
--- a/libtcc/CorrelatorKernel.h
+++ b/libtcc/CorrelatorKernel.h
@@ -9,7 +9,6 @@ namespace tcc {
   {
     public:
       CorrelatorKernel(cu::Module &module,
-		       unsigned nrBits,
 		       unsigned nrReceivers,
 		       unsigned nrChannels,
 		       unsigned nrSamplesPerChannel,
@@ -22,7 +21,6 @@ namespace tcc {
       virtual uint64_t FLOPS() const;
 
     private:
-      const unsigned nrBits;
       const unsigned nrReceivers;
       const unsigned nrChannels;
       const unsigned nrSamplesPerChannel;
diff --git a/libtcc/kernel/TCCorrelator.cu b/libtcc/kernel/TCCorrelator.cu
index 8caf22e6b80354fc82ae1059ae3e5dc8da063e42..2a88364515ad9187d187e3c540e1c19f4380aa22 100644
--- a/libtcc/kernel/TCCorrelator.cu
+++ b/libtcc/kernel/TCCorrelator.cu
@@ -7,18 +7,41 @@
 #endif
 
 #include <mma.h>
+#include <cuda_fp8.h>
 
 #define NR_BASELINES		 (NR_RECEIVERS * (NR_RECEIVERS + 1) / 2)
 #define ALIGN(A,N)		 (((A)+(N)-1)/(N)*(N))
 
-#define NR_TIMES_PER_BLOCK	 (128 / (NR_BITS))
-#define NR_RECEIVERS_PER_TCM_X	 ((NR_BITS) == 4 ? 2 : 4)
+#if INPUT_FORMAT == FORMAT_FP16
+#define NR_BITS			 16
+#define NR_TIMES_PER_BLOCK	 8
+#define NR_RECEIVERS_PER_TCM_X	 4
+#define MINIMUM_ARCHITECTURE	 700
+#elif INPUT_FORMAT == FORMAT_E4M3 || INPUT_FORMAT == FORMAT_E5M2
+#define NR_BITS			 8
+#define NR_TIMES_PER_BLOCK	 16
+#define NR_RECEIVERS_PER_TCM_X	 2
+#define MINIMUM_ARCHITECTURE	 900
+#elif INPUT_FORMAT == FORMAT_I8
+#define NR_BITS			 8
+#define NR_TIMES_PER_BLOCK	 16
+#define NR_RECEIVERS_PER_TCM_X	 4
+#define MINIMUM_ARCHITECTURE	 720
+#elif INPUT_FORMAT == FORMAT_I4
+#define NR_BITS			 4
+#define NR_TIMES_PER_BLOCK	 32
+#define NR_RECEIVERS_PER_TCM_X	 2
+#define MINIMUM_ARCHITECTURE	 750
+#else
+#error Unsupported input format
+#endif
+
 #define NR_RECEIVERS_PER_TCM_Y	 8
 #define NR_RECEIVERS_PER_BLOCK_X (NR_RECEIVERS_PER_BLOCK == 64 ? 32 : NR_RECEIVERS_PER_BLOCK)
 
 #define COMPLEX			 2
 
-#if __CUDA_ARCH__ < (NR_BITS == 4 ? 730 : NR_BITS == 8 ? 720 : NR_BITS == 16 ? 700 : 0)
+#if __CUDA_ARCH__ < MINIMUM_ARCHITECTURE
 #error this architecture has no suitable tensor cores
 #endif
 
@@ -52,7 +75,7 @@ inline __device__ unsigned laneid()
 
 namespace nvcuda {
   namespace wmma {
-#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 730
+#if INPUT_FORMAT == FORMAT_I4
     template<> class fragment<matrix_a, 16, 8, 64, experimental::precision::s4, row_major> : public __frag_base<experimental::precision::s4, 32, 4> {};
     template<> class fragment<matrix_b, 16, 8, 64, experimental::precision::s4, col_major> : public __frag_base<experimental::precision::s4, 16, 2> {};
     template<> class fragment<accumulator, 16, 8, 64, int> : public __frag_base<int, 4> {};
@@ -91,21 +114,80 @@ namespace nvcuda {
       ((int2 *) p)[ldm / 2 * (laneid() / 4 + 8) + laneid() % 4] = make_int2(d.x[2], d.x[3]);
     }
 #endif
+
+#if INPUT_FORMAT == FORMAT_E4M3 || INPUT_FORMAT == FORMAT_E5M2
+    template<typename T> class fragment<matrix_a, 16, 8, 32, T, row_major> : public __frag_base<int, 4> {};
+    template<typename T> class fragment<matrix_b, 16, 8, 32, T, col_major> : public __frag_base<int, 2> {};
+    template<> class fragment<accumulator, 16, 8, 32, float> : public __frag_base<float, 4> {};
+
+    inline __device__ void mma_sync(fragment<accumulator, 16, 8, 32, float>& d,
+				    const fragment<matrix_a, 16, 8, 32, __nv_fp8_e4m3, row_major>& a,
+				    const fragment<matrix_b, 16, 8, 32, __nv_fp8_e4m3, col_major>& b,
+				    const fragment<accumulator, 16, 8, 32, float>& c)
+    {
+      asm ("mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" :
+	   "=f" (d.x[0]), "=f" (d.x[1]), "=f" (d.x[2]), "=f" (d.x[3]) :
+	   "r" (a.x[0]), "r" (a.x[1]), "r" (a.x[2]), "r" (a.x[3]),
+	   "r" (b.x[0]), "r" (b.x[1]),
+	   "f" (c.x[0]), "f" (c.x[1]), "f" (c.x[2]), "f" (c.x[3])
+	  );
+    }
+
+    inline __device__ void mma_sync(fragment<accumulator, 16, 8, 32, float>& d,
+				    const fragment<matrix_a, 16, 8, 32, __nv_fp8_e5m2, row_major>& a,
+				    const fragment<matrix_b, 16, 8, 32, __nv_fp8_e5m2, col_major>& b,
+				    const fragment<accumulator, 16, 8, 32, float>& c)
+    {
+      asm ("mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" :
+	   "=f" (d.x[0]), "=f" (d.x[1]), "=f" (d.x[2]), "=f" (d.x[3]) :
+	   "r" (a.x[0]), "r" (a.x[1]), "r" (a.x[2]), "r" (a.x[3]),
+	   "r" (b.x[0]), "r" (b.x[1]),
+	   "f" (c.x[0]), "f" (c.x[1]), "f" (c.x[2]), "f" (c.x[3])
+	  );
+    }
+
+    template <typename T> inline __device__ void load_matrix_sync(fragment<matrix_a, 16, 8, 32, T, row_major> &a, const void *p, unsigned ldm)
+    {
+      a.x[0] = ((const int *) p)[ldm / 4 * (laneid() / 4    ) + laneid() % 4    ];
+      a.x[1] = ((const int *) p)[ldm / 4 * (laneid() / 4 + 8) + laneid() % 4    ];
+      a.x[2] = ((const int *) p)[ldm / 4 * (laneid() / 4    ) + laneid() % 4 + 4];
+      a.x[3] = ((const int *) p)[ldm / 4 * (laneid() / 4 + 8) + laneid() % 4 + 4];
+    }
+
+    template <typename T> inline __device__ void load_matrix_sync(fragment<matrix_b, 16, 8, 32, T, col_major> &b, const void *p, unsigned ldm)
+    {
+      b.x[0] = ((const int *) p)[ldm / 4 * (laneid() / 4) + laneid() % 4    ];
+      b.x[1] = ((const int *) p)[ldm / 4 * (laneid() / 4) + laneid() % 4 + 4];
+    }
+
+    inline __device__ void store_matrix_sync(float *p, const fragment<accumulator, 16, 8, 32, float>& d, unsigned ldm, layout_t layout)
+    {
+      // FIXME: only row-major supported 
+      ((float2 *) p)[ldm / 2 * (laneid() / 4    ) + laneid() % 4] = make_float2(d.x[0], d.x[1]);
+      ((float2 *) p)[ldm / 2 * (laneid() / 4 + 8) + laneid() % 4] = make_float2(d.x[2], d.x[3]);
+    }
+#endif
   }
 }
 
 
 using namespace nvcuda::wmma;
 
-#if NR_BITS == 4
-typedef char    Sample;
-typedef int2    Visibility;
-#elif NR_BITS == 8
-typedef char2   Sample;
-typedef int2    Visibility;
-#elif NR_BITS == 16
-typedef __half2 Sample;
-typedef float2  Visibility;
+#if INPUT_FORMAT == FORMAT_I4
+typedef char            Sample;
+typedef int2            Visibility;
+#elif INPUT_FORMAT == FORMAT_I8
+typedef char2           Sample;
+typedef int2            Visibility;
+#elif INPUT_FORMAT == FORMAT_E4M3
+typedef __nv_fp8x2_e4m3 Sample;
+typedef float2          Visibility;
+#elif INPUT_FORMAT == FORMAT_E5M2
+typedef __nv_fp8x2_e5m2 Sample;
+typedef float2          Visibility;
+#elif INPUT_FORMAT == FORMAT_FP16
+typedef __half2         Sample;
+typedef float2          Visibility;
 #endif
 
 
@@ -123,18 +205,26 @@ typedef Visibility Visibilities[NR_CHANNELS][NR_BASELINES][NR_POLARIZATIONS][NR_
 #endif
 
 
-#if NR_BITS == 4
+#if INPUT_FORMAT == FORMAT_I4
 typedef fragment<matrix_a, 16, 8, 64, experimental::precision::s4, row_major> Afrag;
 typedef fragment<matrix_b, 16, 8, 64, experimental::precision::s4, col_major> Bfrag;
 typedef fragment<accumulator, 16, 8, 64, int>                                 Sum;
-#elif NR_BITS == 8
-typedef fragment<matrix_a, 16, 16, 16, signed char, row_major>               Afrag;
-typedef fragment<matrix_b, 16, 16, 16, signed char, col_major>               Bfrag;
-typedef fragment<accumulator, 16, 16, 16, int>                               Sum;
-#elif NR_BITS == 16
-typedef fragment<matrix_a, 16, 16, 16, __half, row_major>                    Afrag;
-typedef fragment<matrix_b, 16, 16, 16, __half, col_major>                    Bfrag;
-typedef fragment<accumulator, 16, 16, 16, float>                             Sum;
+#elif INPUT_FORMAT == FORMAT_I8
+typedef fragment<matrix_a, 16, 16, 16, signed char, row_major>                Afrag;
+typedef fragment<matrix_b, 16, 16, 16, signed char, col_major>                Bfrag;
+typedef fragment<accumulator, 16, 16, 16, int>                                Sum;
+#elif INPUT_FORMAT == FORMAT_E4M3
+typedef fragment<matrix_a, 16, 8, 32, __nv_fp8_e4m3, row_major>               Afrag;
+typedef fragment<matrix_b, 16, 8, 32, __nv_fp8_e4m3, col_major>               Bfrag;
+typedef fragment<accumulator, 16, 8, 32, float>                               Sum;
+#elif INPUT_FORMAT == FORMAT_E5M2
+typedef fragment<matrix_a, 16, 8, 32, __nv_fp8_e5m2, row_major>               Afrag;
+typedef fragment<matrix_b, 16, 8, 32, __nv_fp8_e5m2, col_major>               Bfrag;
+typedef fragment<accumulator, 16, 8, 32, float>                               Sum;
+#elif INPUT_FORMAT == FORMAT_FP16
+typedef fragment<matrix_a, 16, 16, 16, __half, row_major>                     Afrag;
+typedef fragment<matrix_b, 16, 16, 16, __half, col_major>                     Bfrag;
+typedef fragment<accumulator, 16, 16, 16, float>                              Sum;
 #endif
 
 
@@ -143,13 +233,15 @@ typedef Visibility ScratchSpace[NR_RECEIVERS_PER_TCM_Y][NR_POLARIZATIONS][NR_REC
 
 __device__ inline int conj_perm(int v)
 {
-#if NR_BITS == 4
+#if INPUT_FORMAT == FORMAT_I4
   //return ((v & 0x0F0F0F0F) << 4) | (__vnegss4(v >> 4) & 0x0F0F0F0F);
   return ((v & 0x0F0F0F0F) << 4) | ((0xF0F0F0F0 - ((v >> 4) & 0x0F0F0F0F)) & 0x0F0F0F0F);
-#elif NR_BITS == 8
+#elif INPUT_FORMAT == FORMAT_I8
   //return __byte_perm(v, __vnegss4(v), 0x2705);
   return __byte_perm(v, 0x00FF00FF - (v & 0xFF00FF00), 0x2705);
-#elif NR_BITS == 16
+#elif INPUT_FORMAT == FORMAT_E4M3 || INPUT_FORMAT == FORMAT_E5M2
+  return __byte_perm(v ^ 0x80008000, v, 0x2301);
+#elif INPUT_FORMAT == FORMAT_FP16
   return __byte_perm(v ^ 0x80000000, v, 0x1032);
 #endif
 }
@@ -178,15 +270,21 @@ __device__ inline int4 conj_perm(int4 v)
 
 template <unsigned nrReceiversPerBlock = NR_RECEIVERS_PER_BLOCK> struct SharedData
 {
-#if NR_BITS == 4
-  typedef char        Asamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK][1];
-  typedef char        Bsamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][COMPLEX][NR_TIMES_PER_BLOCK + 16][1];
-#elif NR_BITS == 8
-  typedef signed char Asamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK][COMPLEX];
-  typedef signed char Bsamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][COMPLEX][NR_TIMES_PER_BLOCK + 8][COMPLEX];
-#elif NR_BITS == 16
-  typedef __half      Asamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK][COMPLEX];
-  typedef __half      Bsamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][COMPLEX][NR_TIMES_PER_BLOCK + 4][COMPLEX];
+#if INPUT_FORMAT == FORMAT_I4
+  typedef char          Asamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK][1];
+  typedef char          Bsamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][COMPLEX][NR_TIMES_PER_BLOCK + 16][1];
+#elif INPUT_FORMAT == FORMAT_I8
+  typedef signed char   Asamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK][COMPLEX];
+  typedef signed char   Bsamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][COMPLEX][NR_TIMES_PER_BLOCK + 8][COMPLEX];
+#elif INPUT_FORMAT == FORMAT_E4M3
+  typedef __nv_fp8_e4m3 Asamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK][COMPLEX];
+  typedef __nv_fp8_e4m3 Bsamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][COMPLEX][NR_TIMES_PER_BLOCK + 8][COMPLEX];
+#elif INPUT_FORMAT == FORMAT_E5M2
+  typedef __nv_fp8_e5m2 Asamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK][COMPLEX];
+  typedef __nv_fp8_e5m2 Bsamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][COMPLEX][NR_TIMES_PER_BLOCK + 8][COMPLEX];
+#elif INPUT_FORMAT == FORMAT_FP16
+  typedef __half        Asamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK][COMPLEX];
+  typedef __half        Bsamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][COMPLEX][NR_TIMES_PER_BLOCK + 4][COMPLEX];
 #endif
 };
 
@@ -301,11 +399,11 @@ template <bool add>__device__ inline void storeVisibilities(Visibilities visibil
               printf("firstY=%u firstX=%u warp=%u y=%u x=%u _y=%u pol_y=%u _x=%u pol_x=%u val=(%f,%f)\n", firstReceiverY, firstReceiverX, warp, y, x, _y, pol_y, _x, pol_x, (float) scratchSpace[warp][_y][pol_y][_x][pol_x].x, (float) scratchSpace[warp][_y][pol_y][_x][pol_x].y);
 #endif
 
-#if NR_BITS == 4
+#if INPUT_FORMAT == FORMAT_I4 || INPUT_FORMAT == FORMAT_E4M3  || INPUT_FORMAT == FORMAT_E5M2
   unsigned _y       = threadIdx.x >> 2;
   unsigned _x       = (threadIdx.x >> 1) & 1;
   unsigned polY     = threadIdx.x & 1;
-#elif NR_BITS == 8 || NR_BITS == 16
+#elif INPUT_FORMAT == FORMAT_I8 || INPUT_FORMAT == FORMAT_FP16
   unsigned _y       = threadIdx.x >> 2;
   unsigned _x       = threadIdx.x & 3;
 #endif
@@ -315,16 +413,16 @@ template <bool add>__device__ inline void storeVisibilities(Visibilities visibil
   unsigned baseline = (recvX * (recvX + 1) / 2) + recvY;
 
   if ((skipCheckY || recvY <= recvX) && (skipCheckX || recvX < NR_RECEIVERS))
-#if NR_BITS == 4
+#if INPUT_FORMAT == FORMAT_I4 || INPUT_FORMAT == FORMAT_E4M3 || INPUT_FORMAT == FORMAT_E5M2
     for (unsigned polX = 0; polX < NR_POLARIZATIONS; polX ++)
       visibilities[channel][baseline][polY][polX] = scratchSpace[warp][_y][polY][_x][polX];
-#elif NR_BITS == 8 || NR_BITS == 16
+#elif INPUT_FORMAT == FORMAT_I8 || INPUT_FORMAT == FORMAT_FP16
     for (unsigned polY = 0; polY < NR_POLARIZATIONS; polY ++)
       for (unsigned polX = 0; polX < NR_POLARIZATIONS; polX ++)
         visibilities[channel][baseline][polY][polX] = scratchSpace[warp][_y][polY][_x][polX];
 #endif
 #else
-#if __CUDA_ARCH__ == 700 || (__CUDA_ARCH__ == 720 && NR_BITS == 16)
+#if __CUDA_ARCH__ == 700 || (__CUDA_ARCH__ == 720 && INPUT_FORMAT == FORMAT_FP16)
   unsigned polY  = threadIdx.x & 1;
   unsigned polX  = (threadIdx.x >> 1) & 1;
   unsigned recvY = firstReceiverY + NR_RECEIVERS_PER_TCM_Y * y + ((threadIdx.x >> 3) & 2) + (threadIdx.x & 4);
@@ -334,18 +432,18 @@ template <bool add>__device__ inline void storeVisibilities(Visibilities visibil
   storeVisibility<add>(visibilities, channel, recvY + 0, recvX + 1, polY, polX, skipCheckY, skipCheckX, sum.x[4], sum.x[5]);
   storeVisibility<add>(visibilities, channel, recvY + 1, recvX + 0, polY, polX, skipCheckY, skipCheckX, sum.x[2], sum.x[3]);
   storeVisibility<add>(visibilities, channel, recvY + 1, recvX + 1, polY, polX, skipCheckY, skipCheckX, sum.x[6], sum.x[7]);
-#elif (__CUDA_ARCH__ == 720 && NR_BITS == 8) || __CUDA_ARCH__ == 750 || __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1200
+#elif (__CUDA_ARCH__ == 720 && INPUT_FORMAT == FORMAT_I8) || __CUDA_ARCH__ == 750 || __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1200
   unsigned polY  = (threadIdx.x >> 2) & 1;
   unsigned polX  = threadIdx.x & 1;
   unsigned recvY = firstReceiverY + NR_RECEIVERS_PER_TCM_Y * y + ((threadIdx.x >> 3) & 3);
   unsigned recvX = firstReceiverX + NR_RECEIVERS_PER_TCM_X * x + ((threadIdx.x >> 1) & 1);
 
   storeVisibility<add>(visibilities, channel, recvY + 0, recvX + 0, polY, polX, skipCheckY, skipCheckX, sum.x[0], sum.x[1]);
-#if NR_BITS == 8 || NR_BITS == 16
+#if INPUT_FORMAT == FORMAT_I8 || INPUT_FORMAT == FORMAT_FP16
   storeVisibility<add>(visibilities, channel, recvY + 0, recvX + 2, polY, polX, skipCheckY, skipCheckX, sum.x[4], sum.x[5]);
 #endif
   storeVisibility<add>(visibilities, channel, recvY + 4, recvX + 0, polY, polX, skipCheckY, skipCheckX, sum.x[2], sum.x[3]);
-#if NR_BITS == 8 || NR_BITS == 16
+#if INPUT_FORMAT == FORMAT_I8 || INPUT_FORMAT == FORMAT_FP16
   storeVisibility<add>(visibilities, channel, recvY + 4, recvX + 2, polY, polX, skipCheckY, skipCheckX, sum.x[6], sum.x[7]);
 #endif
 #endif
@@ -430,8 +528,10 @@ template <bool add, bool fullTriangle> __device__ void doCorrelateTriangle(Visib
 
     __syncthreads();
 
+    constexpr unsigned minorTimeIncrement = INPUT_FORMAT == FORMAT_I8 ? 8 : NR_TIMES_PER_BLOCK;
+
 #pragma unroll
-    for (unsigned minorTime = 0; minorTime < NR_TIMES_PER_BLOCK; minorTime += ((NR_BITS) == 4 ? 32 : 8)) {
+    for (unsigned minorTime = 0; minorTime < NR_TIMES_PER_BLOCK; minorTime += minorTimeIncrement) {
       Afrag aFrag;
       Bfrag bFrag[nrFragmentsX];
 
@@ -468,7 +568,7 @@ template <bool add, bool fullTriangle> __device__ void doCorrelateTriangle(Visib
   if (warp != 0)
     for (unsigned y = 0, i = 0; y < nrFragmentsY; y ++)
       for (unsigned x = 0; x < nrFragmentsX; x ++, i ++)
-	storeVisibilities<add>(visibilities, channel, firstReceiver + recvYoffset, firstReceiver + recvXoffset, y, x, y < 2 || x > (NR_BITS == 4 ? 4 : 2), fullTriangle, sum[i], scratchSpace, warp);
+	storeVisibilities<add>(visibilities, channel, firstReceiver + recvYoffset, firstReceiver + recvXoffset, y, x, y < 2 || x > (INPUT_FORMAT == FORMAT_I4 || INPUT_FORMAT == FORMAT_E4M3 || INPUT_FORMAT == FORMAT_E5M2 ? 4 : 2), fullTriangle, sum[i], scratchSpace, warp);
   else
     for (unsigned z = 0, i = 0; z < 3; z ++)
       for (unsigned y = 0; y < 2; y ++)
@@ -583,8 +683,10 @@ template <bool add, unsigned nrFragmentsY, unsigned nrFragmentsX, bool skipLoadY
 
     __syncthreads();
 
+    constexpr unsigned minorTimeIncrement = INPUT_FORMAT == FORMAT_I8 ? 8 : NR_TIMES_PER_BLOCK;
+
 #pragma unroll
-    for (unsigned minorTime = 0; minorTime < NR_TIMES_PER_BLOCK; minorTime += ((NR_BITS) == 4 ? 32 : 8)) {
+    for (unsigned minorTime = 0; minorTime < NR_TIMES_PER_BLOCK; minorTime += minorTimeIncrement) {
       Afrag aFrag;
       Bfrag bFrag[nrFragmentsX];
 
@@ -605,7 +707,7 @@ template <bool add, unsigned nrFragmentsY, unsigned nrFragmentsX, bool skipLoadY
     for (unsigned x = 0; x < nrFragmentsX; x ++)
       for (unsigned i = 0; i < sum[0][0].num_storage_elements; i ++)
 	if (sum[y][x].x[i] != 0)
-#if NR_BITS == 4 || NR_BITS == 8
+#if INPUT_FORMAT == FORMAT_I4 || INPUT_FORMAT == FORMAT_I8
 	  printf("blockIdx=(%d,%d,%d) tid=%u y=%u x=%u i=%u v=%d\n", blockIdx.x, blockIdx.y, blockIdx.z, tid, y, x, i, sum[y][x].x[i]);
 #else
 	  printf("blockIdx=(%d,%d,%d) tid=%u y=%u x=%u i=%u v=%f\n", blockIdx.x, blockIdx.y, blockIdx.z, tid, y, x, i, sum[y][x].x[i]);
diff --git a/scripts/load-modules.sh b/scripts/load-modules.sh
index d5e2a8039d5cb25d5e2c4bc85c4267031e298bfc..039e8069e26e8ada54e35ef9e4b8c08f48f1f7ae 100644
--- a/scripts/load-modules.sh
+++ b/scripts/load-modules.sh
@@ -1,2 +1,2 @@
-module load spack/9.4.0
-module load cuda/12.2.1
+module load spack
+module load cuda
diff --git a/test/Benchmark/Benchmark.cc b/test/Benchmark/Benchmark.cc
index a26c470a37f61503918f12450eee93fe781c674b..54b3792effe4e9961432f6aa5828be41d3b455b9 100644
--- a/test/Benchmark/Benchmark.cc
+++ b/test/Benchmark/Benchmark.cc
@@ -8,6 +8,8 @@
 #include <cstring>
 #include <iostream>
 
+#include <cuda_fp8.h>
+
 #include <cudawrappers/nvrtc.hpp>
 
 #define GNU_SOURCE
@@ -26,40 +28,50 @@ Benchmark::Benchmark()
   ep([&] () {
     context.setCurrent();
 
-    for (unsigned nrBits: {8, 16, 4})
+    using Format = tcc::Format;
+
+    for (Format format : { Format::fp16, Format::e4m3, Format::i8, Format::i4 }) // e5m2 not tested separately, as it performs equal to e4m3
 #pragma omp for collapse(2) schedule(dynamic) ordered
       for (unsigned nrReceivers = 1; nrReceivers <= 576; nrReceivers ++)
 	for (unsigned nrReceiversPerBlock = 32; nrReceiversPerBlock <= 64; nrReceiversPerBlock += 16)
-	  if (nrBits ==  4 && (capability >= 73 && capability <= 89) ||
-	      nrBits ==  8 && capability >= 72 ||
-	      nrBits == 16 && capability >= 70)
-	    switch (nrBits) {
-	      case  4 : doTest<complex_int4_t, std::complex<int32_t>>(4, nrReceiversPerBlock, nrReceivers);
-			break;
+	  switch (format) {
+	    case  Format::i4  : if (capability >= 73 && capability <= 89)
+				  doTest<complex_int4_t, std::complex<int32_t>>(format, nrReceiversPerBlock, nrReceivers);
+
+				break;
+
+	    case Format::i8   : if (capability >= 72)
+				  doTest<std::complex<int8_t>, std::complex<int32_t>>(format, nrReceiversPerBlock, nrReceivers);
+
+				break;
+
+	    case Format::e4m3 : if (capability >= 90)
+				  doTest<std::complex<__nv_fp8_e4m3>, std::complex<float>>(format, nrReceiversPerBlock, nrReceivers);
+
+				break;
 
-	      case  8 : doTest<std::complex<int8_t>, std::complex<int32_t>>(8, nrReceiversPerBlock, nrReceivers);
-			break;
+	    case Format::fp16 : if (capability >= 70)
+				  doTest<std::complex<__half>, std::complex<float>>(format, nrReceiversPerBlock, nrReceivers);
 
-	      case 16 : doTest<std::complex<__half>, std::complex<float>>(16, nrReceiversPerBlock, nrReceivers);
-			break;
-	    }
+				break;
+	  }
 
     stream.synchronize();
   });
 }
 
 
-template <typename SampleType, typename VisibilityType> void Benchmark::doTest(unsigned nrBits, unsigned nrReceiversPerBlock, unsigned nrReceivers)
+template <typename SampleType, typename VisibilityType> void Benchmark::doTest(tcc::Format inputFormat, unsigned nrReceiversPerBlock, unsigned nrReceivers)
 {
   constexpr double   measureTime         = 3; // seconds
 	    unsigned nrChannels          = 4 * device.getAttribute<CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT>(); // provide enough parallelism
   constexpr unsigned nrPolarizations     = 2;
   constexpr unsigned nrSamplesPerChannel = 3072;
   constexpr bool     addVisibilities     = false;
-            unsigned nrTimesPerBlock     = 128 / nrBits;
+            unsigned nrTimesPerBlock     = (unsigned []) {8, 16, 16, 16, 32}[inputFormat];
 	    unsigned nrBaselines         = nrReceivers * (nrReceivers + 1) / 2;
 
-  tcc::Correlator correlator(device, nrBits, nrReceivers, nrChannels, nrSamplesPerChannel, nrPolarizations, nrReceiversPerBlock);
+  tcc::Correlator correlator(device, inputFormat, nrReceivers, nrChannels, nrSamplesPerChannel, nrPolarizations, nrReceiversPerBlock);
   unsigned repeatCount;
 
   multi_array::extent<5> samplesExtent(multi_array::extents[nrChannels][nrSamplesPerChannel / nrTimesPerBlock][nrReceivers][nrPolarizations][nrTimesPerBlock]);
@@ -107,7 +119,7 @@ template <typename SampleType, typename VisibilityType> void Benchmark::doTest(u
   stream.synchronize();
 
   char msg[64];
-  sprintf(msg, "bits=%u recv/blk=%u recv=%u cnt=%u", nrBits, nrReceiversPerBlock, nrReceivers, repeatCount);
+  sprintf(msg, "%s recv/blk=%u recv=%u cnt=%u", (const char * []) {"fp16", "e4m3", "e5m2", "i8", "i4"}[inputFormat], nrReceiversPerBlock, nrReceivers, repeatCount);
 #pragma omp ordered
   report(msg, computeRecordStart, computeRecordStop, repeatCount * correlator.FLOPS());
 }
diff --git a/test/Benchmark/Benchmark.h b/test/Benchmark/Benchmark.h
index 32bb0aae4308e8a87c2d395218866132fd5cfcca..50d38f1b48da0673d4d2e67e4f9e7cfca5c397e5 100644
--- a/test/Benchmark/Benchmark.h
+++ b/test/Benchmark/Benchmark.h
@@ -14,7 +14,7 @@ class Benchmark : public UnitTest
     Benchmark();
 
   private:
-    template<typename SampleType, typename VisibilityType> void doTest(unsigned nrBits, unsigned nrReceiversPerBlock, unsigned nrReceivers);
+    template<typename SampleType, typename VisibilityType> void doTest(tcc::Format inputFormat, unsigned nrReceiversPerBlock, unsigned nrReceivers);
     template<typename SampleType>                          void setTestPattern(const multi_array::array_ref<SampleType, 5> &samples);
 
     template<typename SampleType> static SampleType randomValue();
@@ -37,11 +37,12 @@ template<> std::complex<int8_t> Benchmark::randomValue<std::complex<int8_t>>()
 }
 
 
-template<> std::complex<__half> Benchmark::randomValue<std::complex<__half>>()
+template<typename SampleType> SampleType Benchmark::randomValue()
 {
-  return std::complex<__half>(drand48() - .5, drand48() - .5);
+  return SampleType((typename SampleType::value_type) (drand48() - .5), (typename SampleType::value_type) (drand48() - .5));
 }
 
+
 template <typename VisibilityType> bool Benchmark::approximates(const VisibilityType &a, const VisibilityType &b) const
 {
   return a == b;
diff --git a/test/Common/ComplexInt4.h b/test/Common/ComplexInt4.h
index 2cb9c50fb3b56a8fcb4bcf83947d166672d77ad6..b3355d70c1edcbe6607cd8c3b7da09d3694afd26 100644
--- a/test/Common/ComplexInt4.h
+++ b/test/Common/ComplexInt4.h
@@ -7,6 +7,8 @@
 class complex_int4_t
 {
   public:
+    typedef int value_type;
+
     complex_int4_t() {}
     complex_int4_t(int real, int imag) { value = (imag << 4) | (real & 0xF); }
     complex_int4_t operator = (const complex_int4_t &other) { value = other.value; return *this; }
diff --git a/test/CorrelatorTest/CMakeLists.txt b/test/CorrelatorTest/CMakeLists.txt
index 737418b9eae240d9aad6d8d6ffc1afc1fe38b3e7..17d9c30012b4b1062dd95183996f492e6c289fc7 100644
--- a/test/CorrelatorTest/CMakeLists.txt
+++ b/test/CorrelatorTest/CMakeLists.txt
@@ -18,10 +18,10 @@ install(TARGETS CorrelatorTest RUNTIME DESTINATION bin)
 # R: outerRepeatCount
 # t: nrSamplesPerChannel must be a multiple of (128 / nrBits)
 # V: verifyOutput
-set(ARGS0 -b  4 -c 1 -n  1 -N 32 -r 1 -R 1 -t 32)
-set(ARGS1 -b  8 -c 1 -n  1 -N 48 -r 1 -R 1 -t 16)
-set(ARGS2 -b 16 -c 1 -n  1 -N 64 -r 1 -R 1 -t  8)
-set(ARGS3 -b 16 -c 2 -n  3 -N 32 -r 4 -R 5 -t 64)
+set(ARGS0 -I   i4 -c 1 -n  1 -N 32 -r 1 -R 1 -t 32)
+set(ARGS1 -I   i8 -c 1 -n  1 -N 48 -r 1 -R 1 -t 16)
+set(ARGS2 -I fp16 -c 1 -n  1 -N 64 -r 1 -R 1 -t  8)
+set(ARGS3 -I fp16 -c 2 -n  3 -N 32 -r 4 -R 5 -t 64)
 
 foreach(idx RANGE 3)
   add_test(NAME CorrelatorTest${idx} COMMAND CorrelatorTest ${ARGS${idx}})
@@ -41,6 +41,6 @@ foreach(idx
 647   653   659   661   673   677   683   691   701   709   719   727   733
 739   743   751   757   761)
   add_test(NAME CorrelatorTest-nrReceivers-${idx}
-           COMMAND CorrelatorTest -b 16 -c 1 -n ${idx} -N 32 -r 1 -R 1 -t 8
+           COMMAND CorrelatorTest -I fp16 -c 1 -n ${idx} -N 32 -r 1 -R 1 -t 8
   )
 endforeach()
diff --git a/test/CorrelatorTest/CorrelatorTest.cc b/test/CorrelatorTest/CorrelatorTest.cc
index 09c52dcac3b6fbbef6fd18801ed79e5586311358..302623fa106323b06c20682a18c32c064d4eff91 100644
--- a/test/CorrelatorTest/CorrelatorTest.cc
+++ b/test/CorrelatorTest/CorrelatorTest.cc
@@ -20,7 +20,7 @@ CorrelatorTest::CorrelatorTest(const Options &options)
   UnitTest(options.deviceNumber),
   options(options),
   hasIntegratedMemory(device.getAttribute(CU_DEVICE_ATTRIBUTE_INTEGRATED)),
-  correlator(device, options.nrBits, options.nrReceivers, options.nrChannels, options.nrSamplesPerChannel, options.nrPolarizations, options.nrReceiversPerBlock)
+  correlator(device, options.inputFormat, options.nrReceivers, options.nrChannels, options.nrSamplesPerChannel, options.nrPolarizations, options.nrReceiversPerBlock)
 {
 #if defined MEASURE_POWER
   Record start(*powerSensor), stop(*powerSensor);
@@ -30,14 +30,20 @@ CorrelatorTest::CorrelatorTest(const Options &options)
 
   start.enqueue(stream);
 
-  switch (options.nrBits) {
-    case  4 : doTest<complex_int4_t, std::complex<int32_t>>();
+  switch (options.inputFormat) {
+    case tcc::Format::i4   : doTest<complex_int4_t, std::complex<int32_t>>();
 	      break;
 
-    case  8 : doTest<std::complex<int8_t>, std::complex<int32_t>>();
+    case tcc::Format::i8   : doTest<std::complex<int8_t>, std::complex<int32_t>>();
 	      break;
 
-    case 16 : doTest<std::complex<__half>, std::complex<float>>();
+    case tcc::Format::e4m3 : doTest<std::complex<__nv_fp8_e4m3>, std::complex<float>>();
+	      break;
+
+    case tcc::Format::e5m2 : doTest<std::complex<__nv_fp8_e5m2>, std::complex<float>>();
+	      break;
+
+    case tcc::Format::fp16 : doTest<std::complex<__half>, std::complex<float>>();
 	      break;
   }
 
@@ -140,8 +146,8 @@ template<typename SampleType> void CorrelatorTest::setTestPattern(const multi_ar
   unsigned recv0   = options.nrReceivers > 174 ? 174 : options.nrReceivers / 3;
   unsigned recv1   = options.nrReceivers > 418 ? 418 : options.nrReceivers / 2;
 
-  samples[channel][time / options.nrTimesPerBlock][recv0][POL_X][time % options.nrTimesPerBlock] = SampleType(2.0, 3.0);
-  samples[channel][time / options.nrTimesPerBlock][recv1][POL_X][time % options.nrTimesPerBlock] = SampleType(4.0, 5.0);
+  samples[channel][time / options.nrTimesPerBlock][recv0][POL_X][time % options.nrTimesPerBlock] = SampleType((typename SampleType::value_type) 2.0, (typename SampleType::value_type) 3.0);
+  samples[channel][time / options.nrTimesPerBlock][recv1][POL_X][time % options.nrTimesPerBlock] = SampleType((typename SampleType::value_type) 4.0, (typename SampleType::value_type) 5.0);
 #else
   SampleType randomValues[7777]; // use a limited set of random numbers to save time
 
@@ -182,7 +188,7 @@ template<typename SampleType, typename VisibilityType> void CorrelatorTest::veri
 		  SampleType sampleY = ref[recvY][polY][minor_time];
 		  SampleType sampleX = ref[recvX][polX][minor_time];
 
-		  sum[baseline][polY][polX] += VisibilityType(sampleY.real(), sampleY.imag()) * conj(VisibilityType(sampleX.real(), sampleX.imag()));
+		  sum[baseline][polY][polX] += VisibilityType((typename VisibilityType::value_type) sampleY.real(), (typename VisibilityType::value_type) sampleY.imag()) * conj(VisibilityType((typename VisibilityType::value_type) sampleX.real(), (typename VisibilityType::value_type) sampleX.imag()));
 		}
       }
 
diff --git a/test/CorrelatorTest/CorrelatorTest.h b/test/CorrelatorTest/CorrelatorTest.h
index 5ff3189144e70797dd6c54a279dccf0eb1fb33bc..e9944bb4730306b8ce098938bc706be73007e2cb 100644
--- a/test/CorrelatorTest/CorrelatorTest.h
+++ b/test/CorrelatorTest/CorrelatorTest.h
@@ -7,6 +7,7 @@
 #include "libtcc/Correlator.h"
 #include "util/multi_array.h"
 
+#include <cuda_fp8.h>
 #include <cuda_fp16.h>
 
 
@@ -41,11 +42,12 @@ template<> std::complex<int8_t> CorrelatorTest::randomValue<std::complex<int8_t>
 }
 
 
-template<> std::complex<__half> CorrelatorTest::randomValue<std::complex<__half>>()
+template<typename SampleType> SampleType CorrelatorTest::randomValue()
 {
-  return std::complex<__half>(drand48() - .5, drand48() - .5);
+  return SampleType((typename SampleType::value_type) (drand48() - .5), (typename SampleType::value_type) (drand48() - .5));
 }
 
+
 template <typename VisibilityType> bool CorrelatorTest::approximates(const VisibilityType &a, const VisibilityType &b) const
 {
   return a == b;
diff --git a/test/CorrelatorTest/Options.cc b/test/CorrelatorTest/Options.cc
index 4f49757060cbb595b750da39096b2e409c56fbdb..706e60b82e03abe3aeb231210b08971c748adb5b 100644
--- a/test/CorrelatorTest/Options.cc
+++ b/test/CorrelatorTest/Options.cc
@@ -8,12 +8,11 @@
 
 Options::Options(int argc, char *argv[])
 :
-  nrBits(8),
+  inputFormat(tcc::Format::i8),
   nrChannels(480),
   nrReceivers(576),
   nrReceiversPerBlock(64),
   nrSamplesPerChannel(3072),
-  nrTimesPerBlock(128 / nrBits),
   innerRepeatCount(1), outerRepeatCount(1),
   deviceNumber(0),
   verifyOutput(true),
@@ -21,14 +20,11 @@ Options::Options(int argc, char *argv[])
 {
   opterr = 0;
 
-  for (int opt; (opt = getopt(argc, argv, "a:b:c:d:hn:N:r:R:t:V:")) >= 0;)
+  for (int opt; (opt = getopt(argc, argv, "a:c:d:hI:n:N:r:R:t:V:")) >= 0;)
     switch (opt) {
       case 'a' : add = atoi(optarg);
 		 break;
 
-      case 'b' : nrBits = atoi(optarg);
-		 break;
-
       case 'c' : nrChannels = atoi(optarg);
 		 break;
 
@@ -38,6 +34,9 @@ Options::Options(int argc, char *argv[])
       case 'h' : std::cout << usage(argv[0]) << std::endl;
 		 exit(0);
 
+      case 'I' : inputFormat = toFormat(optarg);
+                 break;
+ 
       case 'n' : nrReceivers = atoi(optarg);
 		 break;
 
@@ -59,9 +58,6 @@ Options::Options(int argc, char *argv[])
       default  : throw Error(usage(argv[0]));
     }
 
-  if (nrBits != 4 && nrBits != 8 && nrBits != 16)
-    throw Error("nrBits must be 4, 8, or 16");
-
   if (nrChannels == 0)
     throw Error("nrChannels must be > 0");
 
@@ -74,16 +70,41 @@ Options::Options(int argc, char *argv[])
   if (nrSamplesPerChannel == 0)
     throw Error("nrSamplesPerChannel must be > 0");
 
-  nrTimesPerBlock = 128 / nrBits;
+  nrTimesPerBlock = (unsigned []) {8, 16, 16, 16, 32}[inputFormat];
 
   if (nrSamplesPerChannel % nrTimesPerBlock != 0)
     throw Error("nrSamplesPerChannel must be a multiple of " + std::to_string(nrTimesPerBlock));
 }
 
 
+tcc::Format Options::toFormat(const std::string &format)
+{
+  /* if (format == "fp32" || format == "float")
+    return tcc::Format::fp32;
+  else */ if (format == "fp16" || format == "half" || format == "float16")
+    return tcc::Format::fp16;
+  /* else if (format == "bfloat16")
+    return tcc::Format::bf16; */
+  else if (format == "fp8" || format == "e4m3")
+    return tcc::Format::e4m3;
+  else if (format == "e5m2")
+    return tcc::Format::e5m2;
+  /* else if (format == "i32" || format == "int" || format == "int32_t")
+    return tcc::Format::i32;
+  else if (format == "i16" || format == "short" || format == "int16_t")
+    return tcc::Format::i16; */
+  else if (format == "i8" || format == "char" || format == "int8_t")
+    return tcc::Format::i8;
+  else if (format == "i4" || format == "int4_t")
+    return tcc::Format::i4;
+  else
+    throw Error("format \"" + format + "\" not recognized");
+}
+
+
 std::string Options::usage(const std::string &execName)
 {
-  return "usage: " + execName + " [-b nrBits] [-c nrChannels] [-n nrReceivers] [-N nrReceiversPerBlock] [-r innerRepeatCount] [-R outerRepeatCount] [-t nrSamplesPerChannel] [-V verifyOutput]";
+  return "usage: " + execName + " [-I i4|i8|e4m3|e5m2|fp16] [-c nrChannels] [-n nrReceivers] [-N nrReceiversPerBlock] [-r innerRepeatCount] [-R outerRepeatCount] [-t nrSamplesPerChannel] [-V verifyOutput]";
 }
 
 
diff --git a/test/CorrelatorTest/Options.h b/test/CorrelatorTest/Options.h
index 8d34ac53c42e773171bb217ce4c91e6fc47f0a99..116a1ed4c7f92b301020a58a8227259e31747f40 100644
--- a/test/CorrelatorTest/Options.h
+++ b/test/CorrelatorTest/Options.h
@@ -4,6 +4,8 @@
 #include <exception>
 #include <string>
 
+#include "libtcc/Correlator.h"
+
 
 class Options
 {
@@ -26,7 +28,7 @@ class Options
 
     unsigned nrBaselines() const { return nrReceivers * (nrReceivers + 1) / 2; }
 
-    unsigned nrBits;
+    tcc::Format inputFormat;
     unsigned nrChannels;
     unsigned nrReceivers;
     unsigned nrReceiversPerBlock;
@@ -40,6 +42,7 @@ class Options
     static const unsigned nrPolarizations = 2;
 
   private:
+    static tcc::Format toFormat(const std::string &);
     static std::string usage(const std::string &execName);
 };
 
diff --git a/test/SimpleExample/SimpleExample.cu b/test/SimpleExample/SimpleExample.cu
index 6ad4ed52b7d9016677c9c42b6a9414b60d4a65ac..c52ccbea992b71fcd43324a46fd751eed2becf55 100644
--- a/test/SimpleExample/SimpleExample.cu
+++ b/test/SimpleExample/SimpleExample.cu
@@ -32,12 +32,15 @@ inline void checkCudaCall(cudaError_t error)
 #if NR_BITS == 4
 typedef complex_int4_t	      Sample;
 typedef std::complex<int32_t> Visibility;
+constexpr tcc::Format inputFormat = tcc::Format::i4;
 #elif NR_BITS == 8
 typedef std::complex<int8_t>  Sample;
 typedef std::complex<int32_t> Visibility;
+constexpr tcc::Format inputFormat = tcc::Format::i8;
 #elif NR_BITS == 16
 typedef std::complex<__half>  Sample;
 typedef std::complex<float>   Visibility;
+constexpr tcc::Format inputFormat = tcc::Format::fp16;
 #endif
 
 typedef Sample Samples[NR_CHANNELS][NR_SAMPLES_PER_CHANNEL / NR_TIMES_PER_BLOCK][NR_RECEIVERS][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK];
@@ -51,7 +54,7 @@ int main()
     checkCudaCall(cudaSetDevice(0)); // combine the CUDA runtime API and CUDA driver API
     checkCudaCall(cudaFree(0));
 
-    tcc::Correlator correlator(cu::Device(0), NR_BITS, NR_RECEIVERS, NR_CHANNELS, NR_SAMPLES_PER_CHANNEL, NR_POLARIZATIONS, NR_RECEIVERS_PER_BLOCK);
+    tcc::Correlator correlator(cu::Device(0), inputFormat, NR_RECEIVERS, NR_CHANNELS, NR_SAMPLES_PER_CHANNEL, NR_POLARIZATIONS, NR_RECEIVERS_PER_BLOCK);
 
     cudaStream_t stream;
     Samples *samples;