Skip to content
Snippets Groups Projects
Commit 049a43a3 authored by John Romein's avatar John Romein
Browse files

Use m16n8k32 tensor ops with inline ptx in 4-bit mode, yielding 45% better performance.

parent dc033d5a
No related branches found
No related tags found
No related merge requests found
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#define NR_TIMES_PER_BLOCK (128 / (NR_BITS)) #define NR_TIMES_PER_BLOCK (128 / (NR_BITS))
#define NR_RECEIVERS_PER_TCM_X ((NR_BITS) == 4 ? 2 : 4) #define NR_RECEIVERS_PER_TCM_X ((NR_BITS) == 4 ? 2 : 4)
#define NR_RECEIVERS_PER_TCM_Y ((NR_BITS) == 4 ? 4 : 8) #define NR_RECEIVERS_PER_TCM_Y 8
#define COMPLEX 2 #define COMPLEX 2
...@@ -36,6 +36,64 @@ ...@@ -36,6 +36,64 @@
#define MIN(A,B) ((A)<(B)?(A):(B)) #define MIN(A,B) ((A)<(B)?(A):(B))
inline __device__ unsigned laneid()
{
#if 0
unsigned laneid;
asm ("mov.u32 %0, %%laneid;" : "=r" (laneid));
return laneid;
#else
return threadIdx.x;
#endif
}
namespace nvcuda {
namespace wmma {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 730
template<> class fragment<matrix_a, 16, 8, 64, experimental::precision::s4, row_major> : public __frag_base<experimental::precision::s4, 32, 4> {};
template<> class fragment<matrix_b, 16, 8, 64, experimental::precision::s4, col_major> : public __frag_base<experimental::precision::s4, 16, 2> {};
template<> class fragment<accumulator, 16, 8, 64, int> : public __frag_base<int, 4> {};
inline __device__ void mma_sync(fragment<accumulator, 16, 8, 64, int>& d,
const fragment<matrix_a, 16, 8, 64, experimental::precision::s4, row_major>& a,
const fragment<matrix_b, 16, 8, 64, experimental::precision::s4, col_major>& b,
const fragment<accumulator, 16, 8, 64, int>& c)
{
asm ("mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" :
"=r" (d.x[0]), "=r" (d.x[1]), "=r" (d.x[2]), "=r" (d.x[3]) :
"r" (a.x[0]), "r" (a.x[1]), "r" (a.x[2]), "r" (a.x[3]),
"r" (b.x[0]), "r" (b.x[1]),
"r" (c.x[0]), "r" (c.x[1]), "r" (c.x[2]), "r" (c.x[3])
);
}
inline __device__ void load_matrix_sync(fragment<matrix_a, 16, 8, 64, experimental::precision::s4, row_major> &a, const void *p, unsigned ldm)
{
a.x[0] = ((const int *) p)[ldm / 8 * (laneid() / 4 ) + laneid() % 4 ];
a.x[1] = ((const int *) p)[ldm / 8 * (laneid() / 4 + 8) + laneid() % 4 ];
a.x[2] = ((const int *) p)[ldm / 8 * (laneid() / 4 ) + laneid() % 4 + 4];
a.x[3] = ((const int *) p)[ldm / 8 * (laneid() / 4 + 8) + laneid() % 4 + 4];
}
inline __device__ void load_matrix_sync(fragment<matrix_b, 16, 8, 64, experimental::precision::s4, col_major> &b, const void *p, unsigned ldm)
{
b.x[0] = ((const int *) p)[ldm / 8 * (laneid() / 4) + laneid() % 4 ];
b.x[1] = ((const int *) p)[ldm / 8 * (laneid() / 4) + laneid() % 4 + 4];
}
inline __device__ void store_matrix_sync(int *p, const fragment<accumulator, 16, 8, 64, int>& d, unsigned ldm, layout_t layout)
{
// FIXME: only row-major supported
((int2 *) p)[ldm / 2 * (laneid() / 4 ) + laneid() % 4] = make_int2(d.x[0], d.x[1]);
((int2 *) p)[ldm / 2 * (laneid() / 4 + 8) + laneid() % 4] = make_int2(d.x[2], d.x[3]);
}
#endif
}
}
using namespace nvcuda::wmma; using namespace nvcuda::wmma;
#if NR_BITS == 4 #if NR_BITS == 4
...@@ -58,9 +116,9 @@ typedef Visibility Visibilities[NR_CHANNELS][NR_BASELINES][NR_POLARIZATIONS][NR_ ...@@ -58,9 +116,9 @@ typedef Visibility Visibilities[NR_CHANNELS][NR_BASELINES][NR_POLARIZATIONS][NR_
#if NR_BITS == 4 #if NR_BITS == 4
typedef fragment<matrix_a, 8, 8, 32, experimental::precision::s4, row_major> Afrag; typedef fragment<matrix_a, 16, 8, 64, experimental::precision::s4, row_major> Afrag;
typedef fragment<matrix_b, 8, 8, 32, experimental::precision::s4, col_major> Bfrag; typedef fragment<matrix_b, 16, 8, 64, experimental::precision::s4, col_major> Bfrag;
typedef fragment<accumulator, 8, 8, 32, int> Sum; typedef fragment<accumulator, 16, 8, 64, int> Sum;
#elif NR_BITS == 8 #elif NR_BITS == 8
typedef fragment<matrix_a, 16, 16, 16, signed char, row_major> Afrag; typedef fragment<matrix_a, 16, 16, 16, signed char, row_major> Afrag;
typedef fragment<matrix_b, 16, 16, 16, signed char, col_major> Bfrag; typedef fragment<matrix_b, 16, 16, 16, signed char, col_major> Bfrag;
...@@ -72,13 +130,7 @@ typedef fragment<accumulator, 16, 16, 16, float> Sum ...@@ -72,13 +130,7 @@ typedef fragment<accumulator, 16, 16, 16, float> Sum
#endif #endif
#if NR_BITS == 4 typedef Visibility ScratchSpace[NR_RECEIVERS_PER_TCM_Y][NR_POLARIZATIONS][NR_RECEIVERS_PER_TCM_X][NR_POLARIZATIONS];
typedef int2 ScratchSpace[4][NR_POLARIZATIONS][2][NR_POLARIZATIONS];
#elif NR_BITS == 8
typedef int2 ScratchSpace[8][NR_POLARIZATIONS][4][NR_POLARIZATIONS];
#elif NR_BITS == 16
typedef float2 ScratchSpace[8][NR_POLARIZATIONS][4][NR_POLARIZATIONS];
#endif
__device__ inline int conj_perm(int v) __device__ inline int conj_perm(int v)
...@@ -223,24 +275,23 @@ template <typename T> __device__ inline void storeVisibility(Visibilities visibi ...@@ -223,24 +275,23 @@ template <typename T> __device__ inline void storeVisibility(Visibilities visibi
__device__ inline void storeVisibilities(Visibilities visibilities, unsigned channel, unsigned firstReceiverY, unsigned firstReceiverX, unsigned recvYoffset, unsigned recvXoffset, unsigned y, unsigned x, bool skipCheckY, bool skipCheckX, const Sum &sum, ScratchSpace scratchSpace[], unsigned warp) __device__ inline void storeVisibilities(Visibilities visibilities, unsigned channel, unsigned firstReceiverY, unsigned firstReceiverX, unsigned recvYoffset, unsigned recvXoffset, unsigned y, unsigned x, bool skipCheckY, bool skipCheckX, const Sum &sum, ScratchSpace scratchSpace[], unsigned warp)
{ {
#if defined PORTABLE #if defined PORTABLE
store_matrix_sync(&scratchSpace[warp][0][0][0][0].x, sum, NR_BITS == 4 ? 8 : 16, mem_row_major); store_matrix_sync(&scratchSpace[warp][0][0][0][0].x, sum, NR_RECEIVERS_PER_TCM_X * NR_POLARIZATIONS * COMPLEX, mem_row_major);
__syncwarp(); __syncwarp();
#if 0 #if 0
if (threadIdx.x == 0) if (threadIdx.x == 0)
for (unsigned _y = 0; _y < 8; _y ++) for (unsigned _y = 0; _y < NR_RECEIVERS_PER_TCM_Y; _y ++)
for (unsigned pol_y = 0; pol_y < NR_POLARIZATIONS; pol_y ++) for (unsigned pol_y = 0; pol_y < NR_POLARIZATIONS; pol_y ++)
for (unsigned _x = 0; _x < 4; _x ++) for (unsigned _x = 0; _x < NR_RECEIVERS_PER_TCM_X; _x ++)
for (unsigned pol_x = 0; pol_x < NR_POLARIZATIONS; pol_x ++) for (unsigned pol_x = 0; pol_x < NR_POLARIZATIONS; pol_x ++)
if (scratchSpace[warp][_y][pol_y][_x][pol_x],x != 0 || scratchSpace[warp][_y][pol_y][_x][pol_x].y != 0) if (scratchSpace[warp][_y][pol_y][_x][pol_x].x != 0 || scratchSpace[warp][_y][pol_y][_x][pol_x].y != 0)
printf("firstY=%u firstX=%u warp=%u y=%u x=%u _y=%u pol_y=%u _x=%u pol_x=%u val=(%f,%f)\n", firstReceiverY, firstReceiverX, warp, y, x, _y, pol_y, _x, pol_x, scratchSpace[warp][_y][pol_y][_x][pol_x].x, scratchSpace[warp][_y][pol_y][_x][pol_x].y); printf("firstY=%u firstX=%u warp=%u y=%u x=%u _y=%u pol_y=%u _x=%u pol_x=%u val=(%f,%f)\n", firstReceiverY, firstReceiverX, warp, y, x, _y, pol_y, _x, pol_x, (float) scratchSpace[warp][_y][pol_y][_x][pol_x].x, (float) scratchSpace[warp][_y][pol_y][_x][pol_x].y);
#endif #endif
#if NR_BITS == 4 #if NR_BITS == 4
unsigned _y = threadIdx.x >> 3; unsigned _y = threadIdx.x >> 2;
unsigned _x = (threadIdx.x >> 2) & 1; unsigned _x = (threadIdx.x >> 1) & 1;
unsigned polY = (threadIdx.x >> 1) & 1; unsigned polY = threadIdx.x & 1;
unsigned polX = threadIdx.x & 1;
#elif NR_BITS == 8 || NR_BITS == 16 #elif NR_BITS == 8 || NR_BITS == 16
unsigned _y = threadIdx.x >> 2; unsigned _y = threadIdx.x >> 2;
unsigned _x = threadIdx.x & 3; unsigned _x = threadIdx.x & 3;
...@@ -252,6 +303,7 @@ __device__ inline void storeVisibilities(Visibilities visibilities, unsigned cha ...@@ -252,6 +303,7 @@ __device__ inline void storeVisibilities(Visibilities visibilities, unsigned cha
if ((skipCheckX || recvX <= recvY) && (skipCheckY || recvY < NR_RECEIVERS)) if ((skipCheckX || recvX <= recvY) && (skipCheckY || recvY < NR_RECEIVERS))
#if NR_BITS == 4 #if NR_BITS == 4
for (unsigned polX = 0; polX < NR_POLARIZATIONS; polX ++)
visibilities[channel][baseline][polY][polX] = scratchSpace[warp][_y][polY][_x][polX]; visibilities[channel][baseline][polY][polX] = scratchSpace[warp][_y][polY][_x][polX];
#elif NR_BITS == 8 || NR_BITS == 16 #elif NR_BITS == 8 || NR_BITS == 16
for (unsigned polY = 0; polY < NR_POLARIZATIONS; polY ++) for (unsigned polY = 0; polY < NR_POLARIZATIONS; polY ++)
...@@ -282,7 +334,9 @@ __device__ inline void storeVisibilities(Visibilities visibilities, unsigned cha ...@@ -282,7 +334,9 @@ __device__ inline void storeVisibilities(Visibilities visibilities, unsigned cha
storeVisibility(visibilities, channel, baseline, recvY, recvX, 0, 0, polY, polX, skipCheckY, skipCheckX, sum.x[0], sum.x[1]); storeVisibility(visibilities, channel, baseline, recvY, recvX, 0, 0, polY, polX, skipCheckY, skipCheckX, sum.x[0], sum.x[1]);
#if NR_BITS == 8 || NR_BITS == 16 #if NR_BITS == 8 || NR_BITS == 16
storeVisibility(visibilities, channel, baseline, recvY, recvX, 0, 2, polY, polX, skipCheckY, skipCheckX, sum.x[4], sum.x[5]); storeVisibility(visibilities, channel, baseline, recvY, recvX, 0, 2, polY, polX, skipCheckY, skipCheckX, sum.x[4], sum.x[5]);
#endif
storeVisibility(visibilities, channel, baseline, recvY, recvX, 4, 0, polY, polX, skipCheckY, skipCheckX, sum.x[2], sum.x[3]); storeVisibility(visibilities, channel, baseline, recvY, recvX, 4, 0, polY, polX, skipCheckY, skipCheckX, sum.x[2], sum.x[3]);
#if NR_BITS == 8 || NR_BITS == 16
storeVisibility(visibilities, channel, baseline, recvY, recvX, 4, 2, polY, polX, skipCheckY, skipCheckX, sum.x[6], sum.x[7]); storeVisibility(visibilities, channel, baseline, recvY, recvX, 4, 2, polY, polX, skipCheckY, skipCheckX, sum.x[6], sum.x[7]);
#endif #endif
#endif #endif
...@@ -296,8 +350,8 @@ __device__ inline void storeVisibilities(Visibilities visibilities, unsigned cha ...@@ -296,8 +350,8 @@ __device__ inline void storeVisibilities(Visibilities visibilities, unsigned cha
template <bool fullTriangle> __device__ void doCorrelateTriangle(Visibilities visibilities, const Samples samples, unsigned firstReceiver, unsigned warp, unsigned tid, SharedData<>::Bsamples &bSamples, ScratchSpace scratchSpace[NR_WARPS]) template <bool fullTriangle> __device__ void doCorrelateTriangle(Visibilities visibilities, const Samples samples, unsigned firstReceiver, unsigned warp, unsigned tid, SharedData<>::Bsamples &bSamples, ScratchSpace scratchSpace[NR_WARPS])
{ {
const unsigned nrFragmentsX = NR_BITS == 4 ? 12 : 6; const unsigned nrFragmentsX = 24 / NR_RECEIVERS_PER_TCM_X;
const unsigned nrFragmentsY = nrFragmentsX / 2; const unsigned nrFragmentsY = 24 / NR_RECEIVERS_PER_TCM_Y;
Sum sum[nrFragmentsX * nrFragmentsY]; Sum sum[nrFragmentsX * nrFragmentsY];
for (auto &s : sum) for (auto &s : sum)
...@@ -368,31 +422,31 @@ template <bool fullTriangle> __device__ void doCorrelateTriangle(Visibilities vi ...@@ -368,31 +422,31 @@ template <bool fullTriangle> __device__ void doCorrelateTriangle(Visibilities vi
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (unsigned minorTime = 0; minorTime < NR_TIMES_PER_BLOCK; minorTime += ((NR_BITS) == 4 ? 16 : 8)) { for (unsigned minorTime = 0; minorTime < NR_TIMES_PER_BLOCK; minorTime += ((NR_BITS) == 4 ? 32 : 8)) {
Afrag aFrag[nrFragmentsY]; Afrag aFrag;
Bfrag bFrag[nrFragmentsX]; Bfrag bFrag[nrFragmentsX];
if (warp != 0) { if (warp != 0) {
for (unsigned y = 0; y < nrFragmentsY; y ++)
load_matrix_sync(aFrag[y], &bSamples[buffer][recvYoffset + NR_RECEIVERS_PER_TCM_Y * y][0][0][minorTime][0], sizeof(bSamples[0][0][0]) * 8 / NR_BITS);
for (unsigned x = 0; x < nrFragmentsX; x ++) for (unsigned x = 0; x < nrFragmentsX; x ++)
load_matrix_sync(bFrag[x], &bSamples[buffer][recvXoffset + NR_RECEIVERS_PER_TCM_X * x][0][0][minorTime][0], sizeof(bSamples[0][0][0][0]) * 8 / NR_BITS); load_matrix_sync(bFrag[x], &bSamples[buffer][recvXoffset + NR_RECEIVERS_PER_TCM_X * x][0][0][minorTime][0], sizeof(bSamples[0][0][0][0]) * 8 / NR_BITS);
for (unsigned y = 0, i = 0; y < nrFragmentsY; y ++) for (unsigned y = 0, i = 0; y < nrFragmentsY; y ++) {
load_matrix_sync(aFrag, &bSamples[buffer][recvYoffset + NR_RECEIVERS_PER_TCM_Y * y][0][0][minorTime][0], sizeof(bSamples[0][0][0]) * 8 / NR_BITS);
for (unsigned x = 0; x < nrFragmentsX; x ++, i ++) for (unsigned x = 0; x < nrFragmentsX; x ++, i ++)
mma_sync(sum[i], aFrag[y], bFrag[x], sum[i]); mma_sync(sum[i], aFrag, bFrag[x], sum[i]);
}
} else { } else {
for (unsigned z = 0, i = 0; z < 3; z ++) { for (unsigned z = 0, i = 0; z < 3; z ++) {
for (unsigned y = 0; y < (NR_BITS == 4 ? 4 : 2); y ++) for (unsigned x = 0; x < nrFragmentsX; x ++)
load_matrix_sync(aFrag[y], &bSamples[buffer][/*recvYoffset*/ 24 * z + NR_RECEIVERS_PER_TCM_Y * y][0][0][minorTime][0], sizeof(bSamples[0][0][0]) * 8 / NR_BITS);
for (unsigned x = 0; x < (NR_BITS == 4 ? 8 : 4); x ++)
load_matrix_sync(bFrag[x], &bSamples[buffer][/*recvXoffset*/ 24 * z + NR_RECEIVERS_PER_TCM_X * x][0][0][minorTime][0], sizeof(bSamples[0][0][0][0]) * 8 / NR_BITS); load_matrix_sync(bFrag[x], &bSamples[buffer][/*recvXoffset*/ 24 * z + NR_RECEIVERS_PER_TCM_X * x][0][0][minorTime][0], sizeof(bSamples[0][0][0][0]) * 8 / NR_BITS);
for (unsigned y = 0; y < (NR_BITS == 4 ? 4 : 2); y ++) for (unsigned y = 0; y < 2; y ++) {
for (unsigned x = 0; x < 2 + 2 * y; x ++, i ++) load_matrix_sync(aFrag, &bSamples[buffer][/*recvYoffset*/ 24 * z + NR_RECEIVERS_PER_TCM_Y * y][0][0][minorTime][0], sizeof(bSamples[0][0][0]) * 8 / NR_BITS);
mma_sync(sum[i], aFrag[y], bFrag[x], sum[i]);
for (unsigned x = 0; x < (NR_BITS == 4 ? 4 : 2) * (y + 1); x ++, i ++)
mma_sync(sum[i], aFrag, bFrag[x], sum[i]);
}
} }
} }
} }
...@@ -405,12 +459,12 @@ template <bool fullTriangle> __device__ void doCorrelateTriangle(Visibilities vi ...@@ -405,12 +459,12 @@ template <bool fullTriangle> __device__ void doCorrelateTriangle(Visibilities vi
if (warp != 0) if (warp != 0)
for (unsigned y = 0, i = 0; y < nrFragmentsY; y ++) for (unsigned y = 0, i = 0; y < nrFragmentsY; y ++)
for (unsigned x = 0; x < nrFragmentsX; x ++, i ++) for (unsigned x = 0; x < nrFragmentsX; x ++, i ++)
storeVisibilities(visibilities, channel, firstReceiver, firstReceiver, recvYoffset, recvXoffset, y, x, fullTriangle, x < 2 * y + (NR_BITS == 4 ? 8 : 4), sum[i], scratchSpace, warp); storeVisibilities(visibilities, channel, firstReceiver, firstReceiver, recvYoffset, recvXoffset, y, x, fullTriangle, y > 0 || x < (NR_BITS == 4 ? 8 : 4), sum[i], scratchSpace, warp);
else else
for (unsigned z = 0, i = 0; z < 3; z ++) for (unsigned z = 0, i = 0; z < 3; z ++)
for (unsigned y = 0; y < (NR_BITS == 4 ? 4 : 2); y ++) for (unsigned y = 0; y < 2; y ++)
for (unsigned x = 0; x < 2 * y + 2; x ++, i ++) for (unsigned x = 0; x < (NR_BITS == 4 ? 4 : 2) * (y + 1); x ++, i ++)
storeVisibilities(visibilities, channel, firstReceiver, firstReceiver, 24 * z, 24 * z, y, x, fullTriangle, x < 2 * y, sum[i], scratchSpace, warp); storeVisibilities(visibilities, channel, firstReceiver, firstReceiver, 24 * z, 24 * z, y, x, fullTriangle, x < (NR_BITS == 4 ? 4 : 2) * y, sum[i], scratchSpace, warp);
} }
#endif #endif
...@@ -523,19 +577,19 @@ template <unsigned nrFragmentsY, bool skipLoadYcheck, bool skipLoadXcheck, bool ...@@ -523,19 +577,19 @@ template <unsigned nrFragmentsY, bool skipLoadYcheck, bool skipLoadXcheck, bool
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (unsigned minorTime = 0; minorTime < NR_TIMES_PER_BLOCK; minorTime += ((NR_BITS) == 4 ? 16 : 8)) { for (unsigned minorTime = 0; minorTime < NR_TIMES_PER_BLOCK; minorTime += ((NR_BITS) == 4 ? 32 : 8)) {
Afrag aFrag[nrFragmentsY]; Afrag aFrag;
Bfrag bFrag[nrFragmentsX]; Bfrag bFrag[nrFragmentsX];
for (unsigned y = 0; y < nrFragmentsY; y ++)
load_matrix_sync(aFrag[y], &aSamples[buffer][recvYoffset + NR_RECEIVERS_PER_TCM_Y * y][0][minorTime][0], sizeof(aSamples[0][0][0]) * 8 / NR_BITS);
for (unsigned x = 0; x < nrFragmentsX; x ++) for (unsigned x = 0; x < nrFragmentsX; x ++)
load_matrix_sync(bFrag[x], &bSamples[buffer][recvXoffset + NR_RECEIVERS_PER_TCM_X * x][0][0][minorTime][0], sizeof(bSamples[0][0][0][0]) * 8 / NR_BITS); load_matrix_sync(bFrag[x], &bSamples[buffer][recvXoffset + NR_RECEIVERS_PER_TCM_X * x][0][0][minorTime][0], sizeof(bSamples[0][0][0][0]) * 8 / NR_BITS);
for (unsigned y = 0; y < nrFragmentsY; y ++) for (unsigned y = 0; y < nrFragmentsY; y ++) {
load_matrix_sync(aFrag, &aSamples[buffer][recvYoffset + NR_RECEIVERS_PER_TCM_Y * y][0][minorTime][0], sizeof(aSamples[0][0][0]) * 8 / NR_BITS);
for (unsigned x = 0; x < nrFragmentsX; x ++) for (unsigned x = 0; x < nrFragmentsX; x ++)
mma_sync(sum[y][x], aFrag[y], bFrag[x], sum[y][x]); mma_sync(sum[y][x], aFrag, bFrag[x], sum[y][x]);
}
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment