diff --git a/README.md b/README.md index a1ee014fe75b6d980bdec002227e6d7fb0660a59..6ecad7da010efb0083ecae0b21bd49235ff89139 100644 --- a/README.md +++ b/README.md @@ -32,36 +32,49 @@ how to use the CUDA driver API (wrappers); and OpenCL program. `test/CorrelatorTest/CorrelatorTest.cc` is a much more versatile, robust (and complex) example than `test/SimpleExample/SimpleExample.cu`. -Input and output data types are defined as follows: +The TCC accepts the following input data types: +- half precision floating point (a.k.a. fp16), starting from Volta (sm\_70) +- e4m3 and e5m2 (a.k.a. fp8), starting from Hopper (sm\_90) +- 8-bit integers (i8), starting from the Jetson Xavier (sm\_72) +- 4-bit integers (i4), only natively supported on Ampere and Ada ``` -#if NR_BITS == 4 +#if INPUT_FORMAT == FORMAT_I4 +#define NR_TIMES_PER_BLOCK 32 typedef complex_int4_t Sample; typedef std::complex<int32_t> Visibility; -#elif NR_BITS == 8 +#elif INPUT_FORMAT == FORMAT_I8 +#define NR_TIMES_PER_BLOCK 16 typedef std::complex<int8_t> Sample; typedef std::complex<int32_t> Visibility; -#elif NR_BITS == 16 +#elif INPUT_FORMAT == FORMAT_E4M3 +#define NR_TIMES_PER_BLOCK 16 +typedef std::complex<__nv_fp8_e4m3> Sample; +typedef std::complex<float> Visibility; +#elif INPUT_FORMAT == FORMAT_E5M2 +#define NR_TIMES_PER_BLOCK 16 +typedef std::complex<__nv_fp8_e5m2> Sample; +typedef std::complex<float> Visibility; +#elif INPUT_FORMAT == FORMAT_FP16 +#define NR_TIMES_PER_BLOCK 8 typedef std::complex<__half> Sample; typedef std::complex<float> Visibility; #endif -#define NR_TIMES_PER_BLOCK (128 / NR_BITS) typedef Sample Samples[NR_CHANNELS][NR_SAMPLES_PER_CHANNEL / NR_TIMES_PER_BLOCK][NR_RECEIVERS][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK]; typedef Visibility Visibilities[NR_CHANNELS][NR_BASELINES][NR_POLARIZATIONS][NR_POLARIZATIONS]; ``` -Note that in 4-bit and 8-bit mode, the input samples may not contain -8 or -128 +Note that with FORMAT\_I4 and FORMAT\_I8, the input samples may not contain -8 or -128 respectively, as these values cannot be conjugated properly. The input data type (`Samples`) is a weird format, but this seemed to be the only format that yields good performance (tensor cores are very unforgiving). Limitations: - `NR_POLARIZATIONS` must be 2 -- `NR_BITS` must be 4, 8, or 16 -- the amount of samples over which is integrated) must be a multiple of 128 / `NR_BITS` - (i.e., 32, 16, or 8 for 4-bit, 8-bit, or 16-bit input, respectively). +- the amount of samples over which is integrated must be a multiple of +NR\_TIMES\_PER\_BLOCK. ## Building, testing, and installation Clone the repository: diff --git a/libtcc/Correlator.cc b/libtcc/Correlator.cc index b95fe813132dc0c55ddf3e09e7888eb1da0df1f6..9d189d22ceb474d82858fe2d13277daec955aecb 100644 --- a/libtcc/Correlator.cc +++ b/libtcc/Correlator.cc @@ -33,7 +33,7 @@ std::string Correlator::findNVRTCincludePath() const Correlator::Correlator(const cu::Device &device, - unsigned nrBits, + Format inputFormat, unsigned nrReceivers, unsigned nrChannels, unsigned nrSamplesPerChannel, @@ -46,8 +46,37 @@ Correlator::Correlator(const cu::Device &device, return 10 * device.getAttribute<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>() + device.getAttribute<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(); } ()), nrReceiversPerBlock(nrReceiversPerBlock != 0 ? nrReceiversPerBlock : defaultNrReceiversPerBlock(nrReceivers)), - correlatorModule(compileModule(nrBits, nrReceivers, nrChannels, nrSamplesPerChannel, nrPolarizations, this->nrReceiversPerBlock, customStoreVisibility)), - correlatorKernel(correlatorModule, nrBits, nrReceivers, nrChannels, nrSamplesPerChannel, nrPolarizations, this->nrReceiversPerBlock) + correlatorModule(compileModule(inputFormat, nrReceivers, nrChannels, nrSamplesPerChannel, nrPolarizations, this->nrReceiversPerBlock, customStoreVisibility)), + correlatorKernel(correlatorModule, nrReceivers, nrChannels, nrSamplesPerChannel, nrPolarizations, this->nrReceiversPerBlock) +{ +} + + +Correlator::Correlator(const cu::Device &device, + unsigned nrBits, + unsigned nrReceivers, + unsigned nrChannels, + unsigned nrSamplesPerChannel, + unsigned nrPolarizations, + unsigned nrReceiversPerBlock, + const std::string &customStoreVisibility + ) +: + Correlator(device, + [&] () -> Format { + switch (nrBits) { + case 4 : return Format::i4; + case 8 : return Format::i8; + case 16 : return Format::fp16; + default : throw std::invalid_argument("nrBits should be 4, 8, or 16"); + } + } (), + nrReceivers, + nrChannels, + nrSamplesPerChannel, + nrPolarizations, + nrReceiversPerBlock, + customStoreVisibility) { } @@ -64,7 +93,7 @@ unsigned Correlator::defaultNrReceiversPerBlock(unsigned nrReceivers) const } -cu::Module Correlator::compileModule(unsigned nrBits, +cu::Module Correlator::compileModule(Format inputFormat, unsigned nrReceivers, unsigned nrChannels, unsigned nrSamplesPerChannel, @@ -79,12 +108,17 @@ cu::Module Correlator::compileModule(unsigned nrBits, "-std=c++11", "-arch=compute_" + std::to_string(capability), "-lineinfo", - "-DNR_BITS=" + std::to_string(nrBits), + "-DINPUT_FORMAT=" + std::to_string(inputFormat), "-DNR_RECEIVERS=" + std::to_string(nrReceivers), "-DNR_CHANNELS=" + std::to_string(nrChannels), "-DNR_SAMPLES_PER_CHANNEL=" + std::to_string(nrSamplesPerChannel), "-DNR_POLARIZATIONS=" + std::to_string(nrPolarizations), "-DNR_RECEIVERS_PER_BLOCK=" + std::to_string(nrReceiversPerBlock), + "-DFORMAT_FP16=" + std::to_string(fp16), + "-DFORMAT_E4M3=" + std::to_string(e4m3), + "-DFORMAT_E5M2=" + std::to_string(e5m2), + "-DFORMAT_I8=" + std::to_string(i8), + "-DFORMAT_I4=" + std::to_string(i4), }; if (!customStoreVisibility.empty()) diff --git a/libtcc/Correlator.h b/libtcc/Correlator.h index b618baf73ec7ff3e2f0a8dc381b3177d57ccea65..5f705340e8a0679098f06d01972170c73c70b124 100644 --- a/libtcc/Correlator.h +++ b/libtcc/Correlator.h @@ -9,9 +9,21 @@ #include "libtcc/CorrelatorKernel.h" namespace tcc { + enum Format { fp16, e4m3, e5m2, i8, i4 }; + class Correlator { public: Correlator(const cu::Device &, + Format inputFormat, + unsigned nrReceivers, + unsigned nrChannels, + unsigned nrSamplesPerChannel, + unsigned nrPolarizations = 2, + unsigned nrReceiversPerBlock = 0, // 0: use a heuristic value that should work well + const std::string &customStoreVisibility = "" + ); // throw (cu::Error, nvrtc::Error) + + [[deprecated]] Correlator(const cu::Device &, unsigned nrBits, unsigned nrReceivers, unsigned nrChannels, @@ -29,7 +41,7 @@ namespace tcc { private: std::string findNVRTCincludePath() const; unsigned defaultNrReceiversPerBlock(unsigned nrReceivers) const; - cu::Module compileModule(unsigned nrBits, + cu::Module compileModule(Format inputFormat, unsigned nrReceivers, unsigned nrChannels, unsigned nrSamplesPerChannel, diff --git a/libtcc/CorrelatorKernel.cc b/libtcc/CorrelatorKernel.cc index 1571f6e13e46dfcdcfe460f5b4e76667333fc6c7..4229dd449876890186299feb89f24436dc34a267 100644 --- a/libtcc/CorrelatorKernel.cc +++ b/libtcc/CorrelatorKernel.cc @@ -4,7 +4,6 @@ namespace tcc { CorrelatorKernel::CorrelatorKernel(cu::Module &module, - unsigned nrBits, unsigned nrReceivers, unsigned nrChannels, unsigned nrSamplesPerChannel, @@ -13,7 +12,6 @@ CorrelatorKernel::CorrelatorKernel(cu::Module &module, ) : Kernel(module, "correlate"), - nrBits(nrBits), nrReceivers(nrReceivers), nrChannels(nrChannels), nrSamplesPerChannel(nrSamplesPerChannel), diff --git a/libtcc/CorrelatorKernel.h b/libtcc/CorrelatorKernel.h index d9c0a5425381d83500a898733e1aa9fdd729df7e..3186921f360974c3e346185645f5be67e7e1650c 100644 --- a/libtcc/CorrelatorKernel.h +++ b/libtcc/CorrelatorKernel.h @@ -9,7 +9,6 @@ namespace tcc { { public: CorrelatorKernel(cu::Module &module, - unsigned nrBits, unsigned nrReceivers, unsigned nrChannels, unsigned nrSamplesPerChannel, @@ -22,7 +21,6 @@ namespace tcc { virtual uint64_t FLOPS() const; private: - const unsigned nrBits; const unsigned nrReceivers; const unsigned nrChannels; const unsigned nrSamplesPerChannel; diff --git a/libtcc/kernel/TCCorrelator.cu b/libtcc/kernel/TCCorrelator.cu index 8caf22e6b80354fc82ae1059ae3e5dc8da063e42..2a88364515ad9187d187e3c540e1c19f4380aa22 100644 --- a/libtcc/kernel/TCCorrelator.cu +++ b/libtcc/kernel/TCCorrelator.cu @@ -7,18 +7,41 @@ #endif #include <mma.h> +#include <cuda_fp8.h> #define NR_BASELINES (NR_RECEIVERS * (NR_RECEIVERS + 1) / 2) #define ALIGN(A,N) (((A)+(N)-1)/(N)*(N)) -#define NR_TIMES_PER_BLOCK (128 / (NR_BITS)) -#define NR_RECEIVERS_PER_TCM_X ((NR_BITS) == 4 ? 2 : 4) +#if INPUT_FORMAT == FORMAT_FP16 +#define NR_BITS 16 +#define NR_TIMES_PER_BLOCK 8 +#define NR_RECEIVERS_PER_TCM_X 4 +#define MINIMUM_ARCHITECTURE 700 +#elif INPUT_FORMAT == FORMAT_E4M3 || INPUT_FORMAT == FORMAT_E5M2 +#define NR_BITS 8 +#define NR_TIMES_PER_BLOCK 16 +#define NR_RECEIVERS_PER_TCM_X 2 +#define MINIMUM_ARCHITECTURE 900 +#elif INPUT_FORMAT == FORMAT_I8 +#define NR_BITS 8 +#define NR_TIMES_PER_BLOCK 16 +#define NR_RECEIVERS_PER_TCM_X 4 +#define MINIMUM_ARCHITECTURE 720 +#elif INPUT_FORMAT == FORMAT_I4 +#define NR_BITS 4 +#define NR_TIMES_PER_BLOCK 32 +#define NR_RECEIVERS_PER_TCM_X 2 +#define MINIMUM_ARCHITECTURE 750 +#else +#error Unsupported input format +#endif + #define NR_RECEIVERS_PER_TCM_Y 8 #define NR_RECEIVERS_PER_BLOCK_X (NR_RECEIVERS_PER_BLOCK == 64 ? 32 : NR_RECEIVERS_PER_BLOCK) #define COMPLEX 2 -#if __CUDA_ARCH__ < (NR_BITS == 4 ? 730 : NR_BITS == 8 ? 720 : NR_BITS == 16 ? 700 : 0) +#if __CUDA_ARCH__ < MINIMUM_ARCHITECTURE #error this architecture has no suitable tensor cores #endif @@ -52,7 +75,7 @@ inline __device__ unsigned laneid() namespace nvcuda { namespace wmma { -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 730 +#if INPUT_FORMAT == FORMAT_I4 template<> class fragment<matrix_a, 16, 8, 64, experimental::precision::s4, row_major> : public __frag_base<experimental::precision::s4, 32, 4> {}; template<> class fragment<matrix_b, 16, 8, 64, experimental::precision::s4, col_major> : public __frag_base<experimental::precision::s4, 16, 2> {}; template<> class fragment<accumulator, 16, 8, 64, int> : public __frag_base<int, 4> {}; @@ -91,21 +114,80 @@ namespace nvcuda { ((int2 *) p)[ldm / 2 * (laneid() / 4 + 8) + laneid() % 4] = make_int2(d.x[2], d.x[3]); } #endif + +#if INPUT_FORMAT == FORMAT_E4M3 || INPUT_FORMAT == FORMAT_E5M2 + template<typename T> class fragment<matrix_a, 16, 8, 32, T, row_major> : public __frag_base<int, 4> {}; + template<typename T> class fragment<matrix_b, 16, 8, 32, T, col_major> : public __frag_base<int, 2> {}; + template<> class fragment<accumulator, 16, 8, 32, float> : public __frag_base<float, 4> {}; + + inline __device__ void mma_sync(fragment<accumulator, 16, 8, 32, float>& d, + const fragment<matrix_a, 16, 8, 32, __nv_fp8_e4m3, row_major>& a, + const fragment<matrix_b, 16, 8, 32, __nv_fp8_e4m3, col_major>& b, + const fragment<accumulator, 16, 8, 32, float>& c) + { + asm ("mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" : + "=f" (d.x[0]), "=f" (d.x[1]), "=f" (d.x[2]), "=f" (d.x[3]) : + "r" (a.x[0]), "r" (a.x[1]), "r" (a.x[2]), "r" (a.x[3]), + "r" (b.x[0]), "r" (b.x[1]), + "f" (c.x[0]), "f" (c.x[1]), "f" (c.x[2]), "f" (c.x[3]) + ); + } + + inline __device__ void mma_sync(fragment<accumulator, 16, 8, 32, float>& d, + const fragment<matrix_a, 16, 8, 32, __nv_fp8_e5m2, row_major>& a, + const fragment<matrix_b, 16, 8, 32, __nv_fp8_e5m2, col_major>& b, + const fragment<accumulator, 16, 8, 32, float>& c) + { + asm ("mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" : + "=f" (d.x[0]), "=f" (d.x[1]), "=f" (d.x[2]), "=f" (d.x[3]) : + "r" (a.x[0]), "r" (a.x[1]), "r" (a.x[2]), "r" (a.x[3]), + "r" (b.x[0]), "r" (b.x[1]), + "f" (c.x[0]), "f" (c.x[1]), "f" (c.x[2]), "f" (c.x[3]) + ); + } + + template <typename T> inline __device__ void load_matrix_sync(fragment<matrix_a, 16, 8, 32, T, row_major> &a, const void *p, unsigned ldm) + { + a.x[0] = ((const int *) p)[ldm / 4 * (laneid() / 4 ) + laneid() % 4 ]; + a.x[1] = ((const int *) p)[ldm / 4 * (laneid() / 4 + 8) + laneid() % 4 ]; + a.x[2] = ((const int *) p)[ldm / 4 * (laneid() / 4 ) + laneid() % 4 + 4]; + a.x[3] = ((const int *) p)[ldm / 4 * (laneid() / 4 + 8) + laneid() % 4 + 4]; + } + + template <typename T> inline __device__ void load_matrix_sync(fragment<matrix_b, 16, 8, 32, T, col_major> &b, const void *p, unsigned ldm) + { + b.x[0] = ((const int *) p)[ldm / 4 * (laneid() / 4) + laneid() % 4 ]; + b.x[1] = ((const int *) p)[ldm / 4 * (laneid() / 4) + laneid() % 4 + 4]; + } + + inline __device__ void store_matrix_sync(float *p, const fragment<accumulator, 16, 8, 32, float>& d, unsigned ldm, layout_t layout) + { + // FIXME: only row-major supported + ((float2 *) p)[ldm / 2 * (laneid() / 4 ) + laneid() % 4] = make_float2(d.x[0], d.x[1]); + ((float2 *) p)[ldm / 2 * (laneid() / 4 + 8) + laneid() % 4] = make_float2(d.x[2], d.x[3]); + } +#endif } } using namespace nvcuda::wmma; -#if NR_BITS == 4 -typedef char Sample; -typedef int2 Visibility; -#elif NR_BITS == 8 -typedef char2 Sample; -typedef int2 Visibility; -#elif NR_BITS == 16 -typedef __half2 Sample; -typedef float2 Visibility; +#if INPUT_FORMAT == FORMAT_I4 +typedef char Sample; +typedef int2 Visibility; +#elif INPUT_FORMAT == FORMAT_I8 +typedef char2 Sample; +typedef int2 Visibility; +#elif INPUT_FORMAT == FORMAT_E4M3 +typedef __nv_fp8x2_e4m3 Sample; +typedef float2 Visibility; +#elif INPUT_FORMAT == FORMAT_E5M2 +typedef __nv_fp8x2_e5m2 Sample; +typedef float2 Visibility; +#elif INPUT_FORMAT == FORMAT_FP16 +typedef __half2 Sample; +typedef float2 Visibility; #endif @@ -123,18 +205,26 @@ typedef Visibility Visibilities[NR_CHANNELS][NR_BASELINES][NR_POLARIZATIONS][NR_ #endif -#if NR_BITS == 4 +#if INPUT_FORMAT == FORMAT_I4 typedef fragment<matrix_a, 16, 8, 64, experimental::precision::s4, row_major> Afrag; typedef fragment<matrix_b, 16, 8, 64, experimental::precision::s4, col_major> Bfrag; typedef fragment<accumulator, 16, 8, 64, int> Sum; -#elif NR_BITS == 8 -typedef fragment<matrix_a, 16, 16, 16, signed char, row_major> Afrag; -typedef fragment<matrix_b, 16, 16, 16, signed char, col_major> Bfrag; -typedef fragment<accumulator, 16, 16, 16, int> Sum; -#elif NR_BITS == 16 -typedef fragment<matrix_a, 16, 16, 16, __half, row_major> Afrag; -typedef fragment<matrix_b, 16, 16, 16, __half, col_major> Bfrag; -typedef fragment<accumulator, 16, 16, 16, float> Sum; +#elif INPUT_FORMAT == FORMAT_I8 +typedef fragment<matrix_a, 16, 16, 16, signed char, row_major> Afrag; +typedef fragment<matrix_b, 16, 16, 16, signed char, col_major> Bfrag; +typedef fragment<accumulator, 16, 16, 16, int> Sum; +#elif INPUT_FORMAT == FORMAT_E4M3 +typedef fragment<matrix_a, 16, 8, 32, __nv_fp8_e4m3, row_major> Afrag; +typedef fragment<matrix_b, 16, 8, 32, __nv_fp8_e4m3, col_major> Bfrag; +typedef fragment<accumulator, 16, 8, 32, float> Sum; +#elif INPUT_FORMAT == FORMAT_E5M2 +typedef fragment<matrix_a, 16, 8, 32, __nv_fp8_e5m2, row_major> Afrag; +typedef fragment<matrix_b, 16, 8, 32, __nv_fp8_e5m2, col_major> Bfrag; +typedef fragment<accumulator, 16, 8, 32, float> Sum; +#elif INPUT_FORMAT == FORMAT_FP16 +typedef fragment<matrix_a, 16, 16, 16, __half, row_major> Afrag; +typedef fragment<matrix_b, 16, 16, 16, __half, col_major> Bfrag; +typedef fragment<accumulator, 16, 16, 16, float> Sum; #endif @@ -143,13 +233,15 @@ typedef Visibility ScratchSpace[NR_RECEIVERS_PER_TCM_Y][NR_POLARIZATIONS][NR_REC __device__ inline int conj_perm(int v) { -#if NR_BITS == 4 +#if INPUT_FORMAT == FORMAT_I4 //return ((v & 0x0F0F0F0F) << 4) | (__vnegss4(v >> 4) & 0x0F0F0F0F); return ((v & 0x0F0F0F0F) << 4) | ((0xF0F0F0F0 - ((v >> 4) & 0x0F0F0F0F)) & 0x0F0F0F0F); -#elif NR_BITS == 8 +#elif INPUT_FORMAT == FORMAT_I8 //return __byte_perm(v, __vnegss4(v), 0x2705); return __byte_perm(v, 0x00FF00FF - (v & 0xFF00FF00), 0x2705); -#elif NR_BITS == 16 +#elif INPUT_FORMAT == FORMAT_E4M3 || INPUT_FORMAT == FORMAT_E5M2 + return __byte_perm(v ^ 0x80008000, v, 0x2301); +#elif INPUT_FORMAT == FORMAT_FP16 return __byte_perm(v ^ 0x80000000, v, 0x1032); #endif } @@ -178,15 +270,21 @@ __device__ inline int4 conj_perm(int4 v) template <unsigned nrReceiversPerBlock = NR_RECEIVERS_PER_BLOCK> struct SharedData { -#if NR_BITS == 4 - typedef char Asamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK][1]; - typedef char Bsamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][COMPLEX][NR_TIMES_PER_BLOCK + 16][1]; -#elif NR_BITS == 8 - typedef signed char Asamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK][COMPLEX]; - typedef signed char Bsamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][COMPLEX][NR_TIMES_PER_BLOCK + 8][COMPLEX]; -#elif NR_BITS == 16 - typedef __half Asamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK][COMPLEX]; - typedef __half Bsamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][COMPLEX][NR_TIMES_PER_BLOCK + 4][COMPLEX]; +#if INPUT_FORMAT == FORMAT_I4 + typedef char Asamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK][1]; + typedef char Bsamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][COMPLEX][NR_TIMES_PER_BLOCK + 16][1]; +#elif INPUT_FORMAT == FORMAT_I8 + typedef signed char Asamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK][COMPLEX]; + typedef signed char Bsamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][COMPLEX][NR_TIMES_PER_BLOCK + 8][COMPLEX]; +#elif INPUT_FORMAT == FORMAT_E4M3 + typedef __nv_fp8_e4m3 Asamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK][COMPLEX]; + typedef __nv_fp8_e4m3 Bsamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][COMPLEX][NR_TIMES_PER_BLOCK + 8][COMPLEX]; +#elif INPUT_FORMAT == FORMAT_E5M2 + typedef __nv_fp8_e5m2 Asamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK][COMPLEX]; + typedef __nv_fp8_e5m2 Bsamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][COMPLEX][NR_TIMES_PER_BLOCK + 8][COMPLEX]; +#elif INPUT_FORMAT == FORMAT_FP16 + typedef __half Asamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK][COMPLEX]; + typedef __half Bsamples[NR_SHARED_BUFFERS][nrReceiversPerBlock][NR_POLARIZATIONS][COMPLEX][NR_TIMES_PER_BLOCK + 4][COMPLEX]; #endif }; @@ -301,11 +399,11 @@ template <bool add>__device__ inline void storeVisibilities(Visibilities visibil printf("firstY=%u firstX=%u warp=%u y=%u x=%u _y=%u pol_y=%u _x=%u pol_x=%u val=(%f,%f)\n", firstReceiverY, firstReceiverX, warp, y, x, _y, pol_y, _x, pol_x, (float) scratchSpace[warp][_y][pol_y][_x][pol_x].x, (float) scratchSpace[warp][_y][pol_y][_x][pol_x].y); #endif -#if NR_BITS == 4 +#if INPUT_FORMAT == FORMAT_I4 || INPUT_FORMAT == FORMAT_E4M3 || INPUT_FORMAT == FORMAT_E5M2 unsigned _y = threadIdx.x >> 2; unsigned _x = (threadIdx.x >> 1) & 1; unsigned polY = threadIdx.x & 1; -#elif NR_BITS == 8 || NR_BITS == 16 +#elif INPUT_FORMAT == FORMAT_I8 || INPUT_FORMAT == FORMAT_FP16 unsigned _y = threadIdx.x >> 2; unsigned _x = threadIdx.x & 3; #endif @@ -315,16 +413,16 @@ template <bool add>__device__ inline void storeVisibilities(Visibilities visibil unsigned baseline = (recvX * (recvX + 1) / 2) + recvY; if ((skipCheckY || recvY <= recvX) && (skipCheckX || recvX < NR_RECEIVERS)) -#if NR_BITS == 4 +#if INPUT_FORMAT == FORMAT_I4 || INPUT_FORMAT == FORMAT_E4M3 || INPUT_FORMAT == FORMAT_E5M2 for (unsigned polX = 0; polX < NR_POLARIZATIONS; polX ++) visibilities[channel][baseline][polY][polX] = scratchSpace[warp][_y][polY][_x][polX]; -#elif NR_BITS == 8 || NR_BITS == 16 +#elif INPUT_FORMAT == FORMAT_I8 || INPUT_FORMAT == FORMAT_FP16 for (unsigned polY = 0; polY < NR_POLARIZATIONS; polY ++) for (unsigned polX = 0; polX < NR_POLARIZATIONS; polX ++) visibilities[channel][baseline][polY][polX] = scratchSpace[warp][_y][polY][_x][polX]; #endif #else -#if __CUDA_ARCH__ == 700 || (__CUDA_ARCH__ == 720 && NR_BITS == 16) +#if __CUDA_ARCH__ == 700 || (__CUDA_ARCH__ == 720 && INPUT_FORMAT == FORMAT_FP16) unsigned polY = threadIdx.x & 1; unsigned polX = (threadIdx.x >> 1) & 1; unsigned recvY = firstReceiverY + NR_RECEIVERS_PER_TCM_Y * y + ((threadIdx.x >> 3) & 2) + (threadIdx.x & 4); @@ -334,18 +432,18 @@ template <bool add>__device__ inline void storeVisibilities(Visibilities visibil storeVisibility<add>(visibilities, channel, recvY + 0, recvX + 1, polY, polX, skipCheckY, skipCheckX, sum.x[4], sum.x[5]); storeVisibility<add>(visibilities, channel, recvY + 1, recvX + 0, polY, polX, skipCheckY, skipCheckX, sum.x[2], sum.x[3]); storeVisibility<add>(visibilities, channel, recvY + 1, recvX + 1, polY, polX, skipCheckY, skipCheckX, sum.x[6], sum.x[7]); -#elif (__CUDA_ARCH__ == 720 && NR_BITS == 8) || __CUDA_ARCH__ == 750 || __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1200 +#elif (__CUDA_ARCH__ == 720 && INPUT_FORMAT == FORMAT_I8) || __CUDA_ARCH__ == 750 || __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1200 unsigned polY = (threadIdx.x >> 2) & 1; unsigned polX = threadIdx.x & 1; unsigned recvY = firstReceiverY + NR_RECEIVERS_PER_TCM_Y * y + ((threadIdx.x >> 3) & 3); unsigned recvX = firstReceiverX + NR_RECEIVERS_PER_TCM_X * x + ((threadIdx.x >> 1) & 1); storeVisibility<add>(visibilities, channel, recvY + 0, recvX + 0, polY, polX, skipCheckY, skipCheckX, sum.x[0], sum.x[1]); -#if NR_BITS == 8 || NR_BITS == 16 +#if INPUT_FORMAT == FORMAT_I8 || INPUT_FORMAT == FORMAT_FP16 storeVisibility<add>(visibilities, channel, recvY + 0, recvX + 2, polY, polX, skipCheckY, skipCheckX, sum.x[4], sum.x[5]); #endif storeVisibility<add>(visibilities, channel, recvY + 4, recvX + 0, polY, polX, skipCheckY, skipCheckX, sum.x[2], sum.x[3]); -#if NR_BITS == 8 || NR_BITS == 16 +#if INPUT_FORMAT == FORMAT_I8 || INPUT_FORMAT == FORMAT_FP16 storeVisibility<add>(visibilities, channel, recvY + 4, recvX + 2, polY, polX, skipCheckY, skipCheckX, sum.x[6], sum.x[7]); #endif #endif @@ -430,8 +528,10 @@ template <bool add, bool fullTriangle> __device__ void doCorrelateTriangle(Visib __syncthreads(); + constexpr unsigned minorTimeIncrement = INPUT_FORMAT == FORMAT_I8 ? 8 : NR_TIMES_PER_BLOCK; + #pragma unroll - for (unsigned minorTime = 0; minorTime < NR_TIMES_PER_BLOCK; minorTime += ((NR_BITS) == 4 ? 32 : 8)) { + for (unsigned minorTime = 0; minorTime < NR_TIMES_PER_BLOCK; minorTime += minorTimeIncrement) { Afrag aFrag; Bfrag bFrag[nrFragmentsX]; @@ -468,7 +568,7 @@ template <bool add, bool fullTriangle> __device__ void doCorrelateTriangle(Visib if (warp != 0) for (unsigned y = 0, i = 0; y < nrFragmentsY; y ++) for (unsigned x = 0; x < nrFragmentsX; x ++, i ++) - storeVisibilities<add>(visibilities, channel, firstReceiver + recvYoffset, firstReceiver + recvXoffset, y, x, y < 2 || x > (NR_BITS == 4 ? 4 : 2), fullTriangle, sum[i], scratchSpace, warp); + storeVisibilities<add>(visibilities, channel, firstReceiver + recvYoffset, firstReceiver + recvXoffset, y, x, y < 2 || x > (INPUT_FORMAT == FORMAT_I4 || INPUT_FORMAT == FORMAT_E4M3 || INPUT_FORMAT == FORMAT_E5M2 ? 4 : 2), fullTriangle, sum[i], scratchSpace, warp); else for (unsigned z = 0, i = 0; z < 3; z ++) for (unsigned y = 0; y < 2; y ++) @@ -583,8 +683,10 @@ template <bool add, unsigned nrFragmentsY, unsigned nrFragmentsX, bool skipLoadY __syncthreads(); + constexpr unsigned minorTimeIncrement = INPUT_FORMAT == FORMAT_I8 ? 8 : NR_TIMES_PER_BLOCK; + #pragma unroll - for (unsigned minorTime = 0; minorTime < NR_TIMES_PER_BLOCK; minorTime += ((NR_BITS) == 4 ? 32 : 8)) { + for (unsigned minorTime = 0; minorTime < NR_TIMES_PER_BLOCK; minorTime += minorTimeIncrement) { Afrag aFrag; Bfrag bFrag[nrFragmentsX]; @@ -605,7 +707,7 @@ template <bool add, unsigned nrFragmentsY, unsigned nrFragmentsX, bool skipLoadY for (unsigned x = 0; x < nrFragmentsX; x ++) for (unsigned i = 0; i < sum[0][0].num_storage_elements; i ++) if (sum[y][x].x[i] != 0) -#if NR_BITS == 4 || NR_BITS == 8 +#if INPUT_FORMAT == FORMAT_I4 || INPUT_FORMAT == FORMAT_I8 printf("blockIdx=(%d,%d,%d) tid=%u y=%u x=%u i=%u v=%d\n", blockIdx.x, blockIdx.y, blockIdx.z, tid, y, x, i, sum[y][x].x[i]); #else printf("blockIdx=(%d,%d,%d) tid=%u y=%u x=%u i=%u v=%f\n", blockIdx.x, blockIdx.y, blockIdx.z, tid, y, x, i, sum[y][x].x[i]); diff --git a/scripts/load-modules.sh b/scripts/load-modules.sh index d5e2a8039d5cb25d5e2c4bc85c4267031e298bfc..039e8069e26e8ada54e35ef9e4b8c08f48f1f7ae 100644 --- a/scripts/load-modules.sh +++ b/scripts/load-modules.sh @@ -1,2 +1,2 @@ -module load spack/9.4.0 -module load cuda/12.2.1 +module load spack +module load cuda diff --git a/test/Benchmark/Benchmark.cc b/test/Benchmark/Benchmark.cc index a26c470a37f61503918f12450eee93fe781c674b..54b3792effe4e9961432f6aa5828be41d3b455b9 100644 --- a/test/Benchmark/Benchmark.cc +++ b/test/Benchmark/Benchmark.cc @@ -8,6 +8,8 @@ #include <cstring> #include <iostream> +#include <cuda_fp8.h> + #include <cudawrappers/nvrtc.hpp> #define GNU_SOURCE @@ -26,40 +28,50 @@ Benchmark::Benchmark() ep([&] () { context.setCurrent(); - for (unsigned nrBits: {8, 16, 4}) + using Format = tcc::Format; + + for (Format format : { Format::fp16, Format::e4m3, Format::i8, Format::i4 }) // e5m2 not tested separately, as it performs equal to e4m3 #pragma omp for collapse(2) schedule(dynamic) ordered for (unsigned nrReceivers = 1; nrReceivers <= 576; nrReceivers ++) for (unsigned nrReceiversPerBlock = 32; nrReceiversPerBlock <= 64; nrReceiversPerBlock += 16) - if (nrBits == 4 && (capability >= 73 && capability <= 89) || - nrBits == 8 && capability >= 72 || - nrBits == 16 && capability >= 70) - switch (nrBits) { - case 4 : doTest<complex_int4_t, std::complex<int32_t>>(4, nrReceiversPerBlock, nrReceivers); - break; + switch (format) { + case Format::i4 : if (capability >= 73 && capability <= 89) + doTest<complex_int4_t, std::complex<int32_t>>(format, nrReceiversPerBlock, nrReceivers); + + break; + + case Format::i8 : if (capability >= 72) + doTest<std::complex<int8_t>, std::complex<int32_t>>(format, nrReceiversPerBlock, nrReceivers); + + break; + + case Format::e4m3 : if (capability >= 90) + doTest<std::complex<__nv_fp8_e4m3>, std::complex<float>>(format, nrReceiversPerBlock, nrReceivers); + + break; - case 8 : doTest<std::complex<int8_t>, std::complex<int32_t>>(8, nrReceiversPerBlock, nrReceivers); - break; + case Format::fp16 : if (capability >= 70) + doTest<std::complex<__half>, std::complex<float>>(format, nrReceiversPerBlock, nrReceivers); - case 16 : doTest<std::complex<__half>, std::complex<float>>(16, nrReceiversPerBlock, nrReceivers); - break; - } + break; + } stream.synchronize(); }); } -template <typename SampleType, typename VisibilityType> void Benchmark::doTest(unsigned nrBits, unsigned nrReceiversPerBlock, unsigned nrReceivers) +template <typename SampleType, typename VisibilityType> void Benchmark::doTest(tcc::Format inputFormat, unsigned nrReceiversPerBlock, unsigned nrReceivers) { constexpr double measureTime = 3; // seconds unsigned nrChannels = 4 * device.getAttribute<CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT>(); // provide enough parallelism constexpr unsigned nrPolarizations = 2; constexpr unsigned nrSamplesPerChannel = 3072; constexpr bool addVisibilities = false; - unsigned nrTimesPerBlock = 128 / nrBits; + unsigned nrTimesPerBlock = (unsigned []) {8, 16, 16, 16, 32}[inputFormat]; unsigned nrBaselines = nrReceivers * (nrReceivers + 1) / 2; - tcc::Correlator correlator(device, nrBits, nrReceivers, nrChannels, nrSamplesPerChannel, nrPolarizations, nrReceiversPerBlock); + tcc::Correlator correlator(device, inputFormat, nrReceivers, nrChannels, nrSamplesPerChannel, nrPolarizations, nrReceiversPerBlock); unsigned repeatCount; multi_array::extent<5> samplesExtent(multi_array::extents[nrChannels][nrSamplesPerChannel / nrTimesPerBlock][nrReceivers][nrPolarizations][nrTimesPerBlock]); @@ -107,7 +119,7 @@ template <typename SampleType, typename VisibilityType> void Benchmark::doTest(u stream.synchronize(); char msg[64]; - sprintf(msg, "bits=%u recv/blk=%u recv=%u cnt=%u", nrBits, nrReceiversPerBlock, nrReceivers, repeatCount); + sprintf(msg, "%s recv/blk=%u recv=%u cnt=%u", (const char * []) {"fp16", "e4m3", "e5m2", "i8", "i4"}[inputFormat], nrReceiversPerBlock, nrReceivers, repeatCount); #pragma omp ordered report(msg, computeRecordStart, computeRecordStop, repeatCount * correlator.FLOPS()); } diff --git a/test/Benchmark/Benchmark.h b/test/Benchmark/Benchmark.h index 32bb0aae4308e8a87c2d395218866132fd5cfcca..50d38f1b48da0673d4d2e67e4f9e7cfca5c397e5 100644 --- a/test/Benchmark/Benchmark.h +++ b/test/Benchmark/Benchmark.h @@ -14,7 +14,7 @@ class Benchmark : public UnitTest Benchmark(); private: - template<typename SampleType, typename VisibilityType> void doTest(unsigned nrBits, unsigned nrReceiversPerBlock, unsigned nrReceivers); + template<typename SampleType, typename VisibilityType> void doTest(tcc::Format inputFormat, unsigned nrReceiversPerBlock, unsigned nrReceivers); template<typename SampleType> void setTestPattern(const multi_array::array_ref<SampleType, 5> &samples); template<typename SampleType> static SampleType randomValue(); @@ -37,11 +37,12 @@ template<> std::complex<int8_t> Benchmark::randomValue<std::complex<int8_t>>() } -template<> std::complex<__half> Benchmark::randomValue<std::complex<__half>>() +template<typename SampleType> SampleType Benchmark::randomValue() { - return std::complex<__half>(drand48() - .5, drand48() - .5); + return SampleType((typename SampleType::value_type) (drand48() - .5), (typename SampleType::value_type) (drand48() - .5)); } + template <typename VisibilityType> bool Benchmark::approximates(const VisibilityType &a, const VisibilityType &b) const { return a == b; diff --git a/test/Common/ComplexInt4.h b/test/Common/ComplexInt4.h index 2cb9c50fb3b56a8fcb4bcf83947d166672d77ad6..b3355d70c1edcbe6607cd8c3b7da09d3694afd26 100644 --- a/test/Common/ComplexInt4.h +++ b/test/Common/ComplexInt4.h @@ -7,6 +7,8 @@ class complex_int4_t { public: + typedef int value_type; + complex_int4_t() {} complex_int4_t(int real, int imag) { value = (imag << 4) | (real & 0xF); } complex_int4_t operator = (const complex_int4_t &other) { value = other.value; return *this; } diff --git a/test/CorrelatorTest/CMakeLists.txt b/test/CorrelatorTest/CMakeLists.txt index 737418b9eae240d9aad6d8d6ffc1afc1fe38b3e7..17d9c30012b4b1062dd95183996f492e6c289fc7 100644 --- a/test/CorrelatorTest/CMakeLists.txt +++ b/test/CorrelatorTest/CMakeLists.txt @@ -18,10 +18,10 @@ install(TARGETS CorrelatorTest RUNTIME DESTINATION bin) # R: outerRepeatCount # t: nrSamplesPerChannel must be a multiple of (128 / nrBits) # V: verifyOutput -set(ARGS0 -b 4 -c 1 -n 1 -N 32 -r 1 -R 1 -t 32) -set(ARGS1 -b 8 -c 1 -n 1 -N 48 -r 1 -R 1 -t 16) -set(ARGS2 -b 16 -c 1 -n 1 -N 64 -r 1 -R 1 -t 8) -set(ARGS3 -b 16 -c 2 -n 3 -N 32 -r 4 -R 5 -t 64) +set(ARGS0 -I i4 -c 1 -n 1 -N 32 -r 1 -R 1 -t 32) +set(ARGS1 -I i8 -c 1 -n 1 -N 48 -r 1 -R 1 -t 16) +set(ARGS2 -I fp16 -c 1 -n 1 -N 64 -r 1 -R 1 -t 8) +set(ARGS3 -I fp16 -c 2 -n 3 -N 32 -r 4 -R 5 -t 64) foreach(idx RANGE 3) add_test(NAME CorrelatorTest${idx} COMMAND CorrelatorTest ${ARGS${idx}}) @@ -41,6 +41,6 @@ foreach(idx 647 653 659 661 673 677 683 691 701 709 719 727 733 739 743 751 757 761) add_test(NAME CorrelatorTest-nrReceivers-${idx} - COMMAND CorrelatorTest -b 16 -c 1 -n ${idx} -N 32 -r 1 -R 1 -t 8 + COMMAND CorrelatorTest -I fp16 -c 1 -n ${idx} -N 32 -r 1 -R 1 -t 8 ) endforeach() diff --git a/test/CorrelatorTest/CorrelatorTest.cc b/test/CorrelatorTest/CorrelatorTest.cc index 09c52dcac3b6fbbef6fd18801ed79e5586311358..302623fa106323b06c20682a18c32c064d4eff91 100644 --- a/test/CorrelatorTest/CorrelatorTest.cc +++ b/test/CorrelatorTest/CorrelatorTest.cc @@ -20,7 +20,7 @@ CorrelatorTest::CorrelatorTest(const Options &options) UnitTest(options.deviceNumber), options(options), hasIntegratedMemory(device.getAttribute(CU_DEVICE_ATTRIBUTE_INTEGRATED)), - correlator(device, options.nrBits, options.nrReceivers, options.nrChannels, options.nrSamplesPerChannel, options.nrPolarizations, options.nrReceiversPerBlock) + correlator(device, options.inputFormat, options.nrReceivers, options.nrChannels, options.nrSamplesPerChannel, options.nrPolarizations, options.nrReceiversPerBlock) { #if defined MEASURE_POWER Record start(*powerSensor), stop(*powerSensor); @@ -30,14 +30,20 @@ CorrelatorTest::CorrelatorTest(const Options &options) start.enqueue(stream); - switch (options.nrBits) { - case 4 : doTest<complex_int4_t, std::complex<int32_t>>(); + switch (options.inputFormat) { + case tcc::Format::i4 : doTest<complex_int4_t, std::complex<int32_t>>(); break; - case 8 : doTest<std::complex<int8_t>, std::complex<int32_t>>(); + case tcc::Format::i8 : doTest<std::complex<int8_t>, std::complex<int32_t>>(); break; - case 16 : doTest<std::complex<__half>, std::complex<float>>(); + case tcc::Format::e4m3 : doTest<std::complex<__nv_fp8_e4m3>, std::complex<float>>(); + break; + + case tcc::Format::e5m2 : doTest<std::complex<__nv_fp8_e5m2>, std::complex<float>>(); + break; + + case tcc::Format::fp16 : doTest<std::complex<__half>, std::complex<float>>(); break; } @@ -140,8 +146,8 @@ template<typename SampleType> void CorrelatorTest::setTestPattern(const multi_ar unsigned recv0 = options.nrReceivers > 174 ? 174 : options.nrReceivers / 3; unsigned recv1 = options.nrReceivers > 418 ? 418 : options.nrReceivers / 2; - samples[channel][time / options.nrTimesPerBlock][recv0][POL_X][time % options.nrTimesPerBlock] = SampleType(2.0, 3.0); - samples[channel][time / options.nrTimesPerBlock][recv1][POL_X][time % options.nrTimesPerBlock] = SampleType(4.0, 5.0); + samples[channel][time / options.nrTimesPerBlock][recv0][POL_X][time % options.nrTimesPerBlock] = SampleType((typename SampleType::value_type) 2.0, (typename SampleType::value_type) 3.0); + samples[channel][time / options.nrTimesPerBlock][recv1][POL_X][time % options.nrTimesPerBlock] = SampleType((typename SampleType::value_type) 4.0, (typename SampleType::value_type) 5.0); #else SampleType randomValues[7777]; // use a limited set of random numbers to save time @@ -182,7 +188,7 @@ template<typename SampleType, typename VisibilityType> void CorrelatorTest::veri SampleType sampleY = ref[recvY][polY][minor_time]; SampleType sampleX = ref[recvX][polX][minor_time]; - sum[baseline][polY][polX] += VisibilityType(sampleY.real(), sampleY.imag()) * conj(VisibilityType(sampleX.real(), sampleX.imag())); + sum[baseline][polY][polX] += VisibilityType((typename VisibilityType::value_type) sampleY.real(), (typename VisibilityType::value_type) sampleY.imag()) * conj(VisibilityType((typename VisibilityType::value_type) sampleX.real(), (typename VisibilityType::value_type) sampleX.imag())); } } diff --git a/test/CorrelatorTest/CorrelatorTest.h b/test/CorrelatorTest/CorrelatorTest.h index 5ff3189144e70797dd6c54a279dccf0eb1fb33bc..e9944bb4730306b8ce098938bc706be73007e2cb 100644 --- a/test/CorrelatorTest/CorrelatorTest.h +++ b/test/CorrelatorTest/CorrelatorTest.h @@ -7,6 +7,7 @@ #include "libtcc/Correlator.h" #include "util/multi_array.h" +#include <cuda_fp8.h> #include <cuda_fp16.h> @@ -41,11 +42,12 @@ template<> std::complex<int8_t> CorrelatorTest::randomValue<std::complex<int8_t> } -template<> std::complex<__half> CorrelatorTest::randomValue<std::complex<__half>>() +template<typename SampleType> SampleType CorrelatorTest::randomValue() { - return std::complex<__half>(drand48() - .5, drand48() - .5); + return SampleType((typename SampleType::value_type) (drand48() - .5), (typename SampleType::value_type) (drand48() - .5)); } + template <typename VisibilityType> bool CorrelatorTest::approximates(const VisibilityType &a, const VisibilityType &b) const { return a == b; diff --git a/test/CorrelatorTest/Options.cc b/test/CorrelatorTest/Options.cc index 4f49757060cbb595b750da39096b2e409c56fbdb..706e60b82e03abe3aeb231210b08971c748adb5b 100644 --- a/test/CorrelatorTest/Options.cc +++ b/test/CorrelatorTest/Options.cc @@ -8,12 +8,11 @@ Options::Options(int argc, char *argv[]) : - nrBits(8), + inputFormat(tcc::Format::i8), nrChannels(480), nrReceivers(576), nrReceiversPerBlock(64), nrSamplesPerChannel(3072), - nrTimesPerBlock(128 / nrBits), innerRepeatCount(1), outerRepeatCount(1), deviceNumber(0), verifyOutput(true), @@ -21,14 +20,11 @@ Options::Options(int argc, char *argv[]) { opterr = 0; - for (int opt; (opt = getopt(argc, argv, "a:b:c:d:hn:N:r:R:t:V:")) >= 0;) + for (int opt; (opt = getopt(argc, argv, "a:c:d:hI:n:N:r:R:t:V:")) >= 0;) switch (opt) { case 'a' : add = atoi(optarg); break; - case 'b' : nrBits = atoi(optarg); - break; - case 'c' : nrChannels = atoi(optarg); break; @@ -38,6 +34,9 @@ Options::Options(int argc, char *argv[]) case 'h' : std::cout << usage(argv[0]) << std::endl; exit(0); + case 'I' : inputFormat = toFormat(optarg); + break; + case 'n' : nrReceivers = atoi(optarg); break; @@ -59,9 +58,6 @@ Options::Options(int argc, char *argv[]) default : throw Error(usage(argv[0])); } - if (nrBits != 4 && nrBits != 8 && nrBits != 16) - throw Error("nrBits must be 4, 8, or 16"); - if (nrChannels == 0) throw Error("nrChannels must be > 0"); @@ -74,16 +70,41 @@ Options::Options(int argc, char *argv[]) if (nrSamplesPerChannel == 0) throw Error("nrSamplesPerChannel must be > 0"); - nrTimesPerBlock = 128 / nrBits; + nrTimesPerBlock = (unsigned []) {8, 16, 16, 16, 32}[inputFormat]; if (nrSamplesPerChannel % nrTimesPerBlock != 0) throw Error("nrSamplesPerChannel must be a multiple of " + std::to_string(nrTimesPerBlock)); } +tcc::Format Options::toFormat(const std::string &format) +{ + /* if (format == "fp32" || format == "float") + return tcc::Format::fp32; + else */ if (format == "fp16" || format == "half" || format == "float16") + return tcc::Format::fp16; + /* else if (format == "bfloat16") + return tcc::Format::bf16; */ + else if (format == "fp8" || format == "e4m3") + return tcc::Format::e4m3; + else if (format == "e5m2") + return tcc::Format::e5m2; + /* else if (format == "i32" || format == "int" || format == "int32_t") + return tcc::Format::i32; + else if (format == "i16" || format == "short" || format == "int16_t") + return tcc::Format::i16; */ + else if (format == "i8" || format == "char" || format == "int8_t") + return tcc::Format::i8; + else if (format == "i4" || format == "int4_t") + return tcc::Format::i4; + else + throw Error("format \"" + format + "\" not recognized"); +} + + std::string Options::usage(const std::string &execName) { - return "usage: " + execName + " [-b nrBits] [-c nrChannels] [-n nrReceivers] [-N nrReceiversPerBlock] [-r innerRepeatCount] [-R outerRepeatCount] [-t nrSamplesPerChannel] [-V verifyOutput]"; + return "usage: " + execName + " [-I i4|i8|e4m3|e5m2|fp16] [-c nrChannels] [-n nrReceivers] [-N nrReceiversPerBlock] [-r innerRepeatCount] [-R outerRepeatCount] [-t nrSamplesPerChannel] [-V verifyOutput]"; } diff --git a/test/CorrelatorTest/Options.h b/test/CorrelatorTest/Options.h index 8d34ac53c42e773171bb217ce4c91e6fc47f0a99..116a1ed4c7f92b301020a58a8227259e31747f40 100644 --- a/test/CorrelatorTest/Options.h +++ b/test/CorrelatorTest/Options.h @@ -4,6 +4,8 @@ #include <exception> #include <string> +#include "libtcc/Correlator.h" + class Options { @@ -26,7 +28,7 @@ class Options unsigned nrBaselines() const { return nrReceivers * (nrReceivers + 1) / 2; } - unsigned nrBits; + tcc::Format inputFormat; unsigned nrChannels; unsigned nrReceivers; unsigned nrReceiversPerBlock; @@ -40,6 +42,7 @@ class Options static const unsigned nrPolarizations = 2; private: + static tcc::Format toFormat(const std::string &); static std::string usage(const std::string &execName); }; diff --git a/test/SimpleExample/SimpleExample.cu b/test/SimpleExample/SimpleExample.cu index 6ad4ed52b7d9016677c9c42b6a9414b60d4a65ac..c52ccbea992b71fcd43324a46fd751eed2becf55 100644 --- a/test/SimpleExample/SimpleExample.cu +++ b/test/SimpleExample/SimpleExample.cu @@ -32,12 +32,15 @@ inline void checkCudaCall(cudaError_t error) #if NR_BITS == 4 typedef complex_int4_t Sample; typedef std::complex<int32_t> Visibility; +constexpr tcc::Format inputFormat = tcc::Format::i4; #elif NR_BITS == 8 typedef std::complex<int8_t> Sample; typedef std::complex<int32_t> Visibility; +constexpr tcc::Format inputFormat = tcc::Format::i8; #elif NR_BITS == 16 typedef std::complex<__half> Sample; typedef std::complex<float> Visibility; +constexpr tcc::Format inputFormat = tcc::Format::fp16; #endif typedef Sample Samples[NR_CHANNELS][NR_SAMPLES_PER_CHANNEL / NR_TIMES_PER_BLOCK][NR_RECEIVERS][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK]; @@ -51,7 +54,7 @@ int main() checkCudaCall(cudaSetDevice(0)); // combine the CUDA runtime API and CUDA driver API checkCudaCall(cudaFree(0)); - tcc::Correlator correlator(cu::Device(0), NR_BITS, NR_RECEIVERS, NR_CHANNELS, NR_SAMPLES_PER_CHANNEL, NR_POLARIZATIONS, NR_RECEIVERS_PER_BLOCK); + tcc::Correlator correlator(cu::Device(0), inputFormat, NR_RECEIVERS, NR_CHANNELS, NR_SAMPLES_PER_CHANNEL, NR_POLARIZATIONS, NR_RECEIVERS_PER_BLOCK); cudaStream_t stream; Samples *samples;