diff --git a/idg-lib/src/CUDA/common/InstanceCUDA.cpp b/idg-lib/src/CUDA/common/InstanceCUDA.cpp index 7878a75b9cc884c4abc1c946411abfdf4a9c00d9..0ff9bc0d17df0b382b6dd4d1dbc9b073571f75c2 100644 --- a/idg-lib/src/CUDA/common/InstanceCUDA.cpp +++ b/idg-lib/src/CUDA/common/InstanceCUDA.cpp @@ -24,7 +24,7 @@ using namespace idg::kernel; /* * Use custom FFT kernel */ -#define USE_CUSTOM_FFT 0 +#define USE_CUSTOM_FFT 1 namespace idg::kernel::cuda { @@ -324,9 +324,9 @@ void InstanceCUDA::load_kernels() { // Load FFT function #if USE_CUSTOM_FFT - if (cuModuleGetFunction(&function, *mModules[8], kNameFft.c_str()) == + if (cuModuleGetFunction(&function, *modules_[8], kNameFft.c_str()) == CUDA_SUCCESS) { - function_fft__.reset(new cu::Function(function)); + function_fft_.reset(new cu::Function(*context_, function)); found++; } #endif @@ -782,7 +782,7 @@ void InstanceCUDA::launch_grid_fft(cu::DeviceMemory& d_data, int batch, void InstanceCUDA::plan_subgrid_fft(size_t size, size_t nr_polarizations) { #if USE_CUSTOM_FFT if (size == 32) { - m_fft_subgrid_size = size; + fft_subgrid_size_ = size; return; } #endif @@ -844,11 +844,11 @@ void InstanceCUDA::launch_subgrid_fft(cu::DeviceMemory& d_data, (direction == FourierDomainToImageDomain) ? CUFFT_INVERSE : CUFFT_FORWARD; #if USE_CUSTOM_FFT - if (fft_subgrid_size == 32) { + if (fft_subgrid_size_ == 32) { const void* parameters[] = {&data_ptr, &data_ptr, &sign}; dim3 block(128); - dim3 grid(NR_CORRELATIONS * nr_subgrids); - executestream->launchKernel(*function_fft__, grid, block, 0, parameters); + dim3 grid(nr_polarizations * nr_subgrids); + stream_execute_->launchKernel(*function_fft_, grid, block, 0, parameters); return; } #endif diff --git a/idg-lib/src/CUDA/common/InstanceCUDA.h b/idg-lib/src/CUDA/common/InstanceCUDA.h index a9d2418110e1fbcb3b1dfcdd43666472525e856d..c4d1d252ed69a0b02691e3e4c5719926208ffebe 100644 --- a/idg-lib/src/CUDA/common/InstanceCUDA.h +++ b/idg-lib/src/CUDA/common/InstanceCUDA.h @@ -224,6 +224,7 @@ class InstanceCUDA : public KernelsInstance { std::unique_ptr<cu::Profiler> profiler_; std::unique_ptr<cu::Function> function_gridder_; std::unique_ptr<cu::Function> function_degridder_; + std::unique_ptr<cu::Function> function_fft_; std::unique_ptr<cu::Function> function_adder_; std::unique_ptr<cu::Function> function_splitter_; std::unique_ptr<cu::Function> function_scaler_;