diff --git a/libtcc/Correlator.cc b/libtcc/Correlator.cc index a4e5b7d68207908ea4af8f483029c0e81fb9b733..856ea1069645c0990e79bd67f2a21c51151b39d1 100644 --- a/libtcc/Correlator.cc +++ b/libtcc/Correlator.cc @@ -35,10 +35,11 @@ Correlator::Correlator(unsigned nrBits, unsigned nrChannels, unsigned nrSamplesPerChannel, unsigned nrPolarizations, - unsigned nrReceiversPerBlock + unsigned nrReceiversPerBlock, + const std::string &customStoreVisibility ) : - correlatorModule(compileModule(nrBits, nrReceivers, nrChannels, nrSamplesPerChannel, nrPolarizations, nrReceiversPerBlock)), + correlatorModule(compileModule(nrBits, nrReceivers, nrChannels, nrSamplesPerChannel, nrPolarizations, nrReceiversPerBlock, customStoreVisibility)), correlatorKernel(correlatorModule, nrBits, nrReceivers, nrChannels, nrSamplesPerChannel, nrPolarizations, nrReceiversPerBlock) { } @@ -49,7 +50,8 @@ cu::Module Correlator::compileModule(unsigned nrBits, unsigned nrChannels, unsigned nrSamplesPerChannel, unsigned nrPolarizations, - unsigned nrReceiversPerBlock + unsigned nrReceiversPerBlock, + const std::string &customStoreVisibility ) { cu::Device device(cu::Context::getCurrent().getDevice()); @@ -69,6 +71,9 @@ cu::Module Correlator::compileModule(unsigned nrBits, "-DNR_RECEIVERS_PER_BLOCK=" + std::to_string(nrReceiversPerBlock), }; + if (!customStoreVisibility.empty()) + options.push_back("-DCUSTOM_STORE_VISIBILITY=" + customStoreVisibility); + //std::for_each(options.begin(), options.end(), [] (const std::string &e) { std::cout << e << ' '; }); std::cout << std::endl; #if 0 diff --git a/libtcc/Correlator.h b/libtcc/Correlator.h index 7b33fa2f3fdffda90843ed22efdb1ab7c8b3638c..5b9cac6744a1d4f194634e967f30dbf67fe10aaa 100644 --- a/libtcc/Correlator.h +++ b/libtcc/Correlator.h @@ -5,6 +5,8 @@ #include "util/cu.h" #include "util/nvrtc.h" +#include <string> + namespace tcc { class Correlator { @@ -14,7 +16,8 @@ namespace tcc { unsigned nrChannels, unsigned nrSamplesPerChannel, unsigned nrPolarizations = 2, - unsigned nrReceiversPerBlock = 64 + unsigned nrReceiversPerBlock = 64, + const std::string &customStoreVisibility = "" ); // throw (cu::Error, nvrtc::Error) void launchAsync(cu::Stream &, cu::DeviceMemory &visibilities, cu::DeviceMemory &samples); // throw (cu::Error) @@ -29,7 +32,8 @@ namespace tcc { unsigned nrChannels, unsigned nrSamplesPerChannel, unsigned nrPolarizations, - unsigned nrReceiversPerBlock + unsigned nrReceiversPerBlock, + const std::string &customStoreVisibility ); cu::Module correlatorModule; diff --git a/libtcc/TCCorrelator.cu b/libtcc/TCCorrelator.cu index 0836edcbeb3e997616f1b1dc35414dea79a86276..928c32055e8900559eda3ff8c0551860cd066003 100644 --- a/libtcc/TCCorrelator.cu +++ b/libtcc/TCCorrelator.cu @@ -39,14 +39,21 @@ using namespace nvcuda::wmma; #if NR_BITS == 4 -typedef char Samples[NR_CHANNELS][NR_SAMPLES_PER_CHANNEL / NR_TIMES_PER_BLOCK][NR_RECEIVERS][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK]; -typedef int2 Visibilities[NR_CHANNELS][NR_BASELINES][NR_POLARIZATIONS][NR_POLARIZATIONS]; +typedef char Sample; +typedef int2 Visibility; #elif NR_BITS == 8 -typedef char2 Samples[NR_CHANNELS][NR_SAMPLES_PER_CHANNEL / NR_TIMES_PER_BLOCK][NR_RECEIVERS][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK]; -typedef int2 Visibilities[NR_CHANNELS][NR_BASELINES][NR_POLARIZATIONS][NR_POLARIZATIONS]; +typedef char2 Sample; +typedef int2 Visibility; #elif NR_BITS == 16 -typedef __half2 Samples[NR_CHANNELS][NR_SAMPLES_PER_CHANNEL / NR_TIMES_PER_BLOCK][NR_RECEIVERS][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK]; -typedef float2 Visibilities[NR_CHANNELS][NR_BASELINES][NR_POLARIZATIONS][NR_POLARIZATIONS]; +typedef __half2 Sample; +typedef float2 Visibility; +#endif + + +typedef Sample Samples[NR_CHANNELS][NR_SAMPLES_PER_CHANNEL / NR_TIMES_PER_BLOCK][NR_RECEIVERS][NR_POLARIZATIONS][NR_TIMES_PER_BLOCK]; + +#if !defined CUSTOM_STORE_VISIBILITY +typedef Visibility Visibilities[NR_CHANNELS][NR_BASELINES][NR_POLARIZATIONS][NR_POLARIZATIONS]; #endif @@ -194,10 +201,22 @@ __device__ inline float2 make_complex(float real, float imag) } +#if defined CUSTOM_STORE_VISIBILITY +CUSTOM_STORE_VISIBILITY +#else + +template <typename T> __device__ inline void storeVisibility(Visibilities visibilities, unsigned channel, unsigned baseline, unsigned polY, unsigned polX, T visibility) +{ + visibilities[channel][baseline][polY][polX] = visibility; +} + +#endif + + template <typename T> __device__ inline void storeVisibility(Visibilities visibilities, unsigned channel, unsigned baseline, unsigned recvY, unsigned recvX, unsigned tcY, unsigned tcX, unsigned polY, unsigned polX, bool skipCheckY, bool skipCheckX, T sumR, T sumI) { if ((skipCheckX || recvX + tcX <= recvY + tcY) && (skipCheckY || recvY + tcY < NR_RECEIVERS)) - visibilities[channel][baseline + tcY * recvY + tcY * (tcY + 1) / 2 + tcX][polY][polX] = make_complex(sumR, sumI); + storeVisibility(visibilities, channel, baseline + tcY * recvY + tcY * (tcY + 1) / 2 + tcX, polY, polX, make_complex(sumR, sumI)); }