diff --git a/include/schaapcommon/math/convolution.h b/include/schaapcommon/math/convolution.h index d88add1e8afb727defb537a1e9b05331ef96f525..96d7f6336b7ed991cac3de675e58e77e2ab6c452 100644 --- a/include/schaapcommon/math/convolution.h +++ b/include/schaapcommon/math/convolution.h @@ -5,14 +5,17 @@ #define SCHAAPCOMMON_FFT_CONVOLUTION_H_ #include <cstring> - +#include <vector> namespace schaapcommon::math { /** * @brief Make the FFTW float planner thread safe. */ void MakeFftwfPlannerThreadSafe(); +void PlanMultiFFTW(std::vector<size_t> height, std::vector<size_t> width); +void ClearFFTWCache(); +void CleanWisdom(); /** * Convolve an image with a smaller kernel. No preparation of either image is * needed. @@ -54,7 +57,7 @@ void PrepareConvolutionKernel(float* dest, const float* source, * called at a sufficiently high level before calling this function. */ void Convolve(float* image, const float* kernel, size_t image_width, - size_t image_height); + size_t image_height, bool use_cache = false); } // namespace schaapcommon::math #endif diff --git a/src/math/convolution.cc b/src/math/convolution.cc index 2a6e02716e46e3a95d9c41f7c25c1b3e9212cea0..3a6025a484ec1bea5a672aa65f29cb38b7d0443f 100644 --- a/src/math/convolution.cc +++ b/src/math/convolution.cc @@ -8,17 +8,148 @@ #include <aocommon/uvector.h> #include <aocommon/staticfor.h> - +#include <map> #include <fftw3.h> +#include <mutex> +#include <shared_mutex> #include "compositefft.h" #include <iostream> +namespace { +enum FFTWType { C2C_FORWARD, C2C_BACKARD, C2R, R2C }; + +static std::shared_mutex fftwplanner_rw_mutex_; + +void manual_destroy(fftwf_plan plan) { + if (plan != nullptr) fftwf_destroy_plan(plan); +} + +class SmartFFTPlan + : public std::unique_ptr<fftwf_plan_s, decltype(&manual_destroy)> { + public: + SmartFFTPlan() + : std::unique_ptr<fftwf_plan_s, decltype(&manual_destroy)>( + nullptr, manual_destroy) {} + explicit SmartFFTPlan(fftwf_plan plan) + : std::unique_ptr<fftwf_plan_s, decltype(&manual_destroy)>( + plan, manual_destroy) {} + SmartFFTPlan(const SmartFFTPlan&) = delete; + SmartFFTPlan& operator=(const SmartFFTPlan&) = delete; + SmartFFTPlan(SmartFFTPlan&&) = default; + SmartFFTPlan& operator=(SmartFFTPlan&&) = default; + + explicit operator fftwf_plan() const { return get(); } +}; +SmartFFTPlan kInvalidPlan = SmartFFTPlan(nullptr); + +SmartFFTPlan MakePlan(FFTWType type, size_t size) { + switch (type) { + case FFTWType::R2C: + return SmartFFTPlan( + fftwf_plan_dft_r2c_1d(size, nullptr, nullptr, FFTW_ESTIMATE)); + case FFTWType::C2R: + return SmartFFTPlan( + fftwf_plan_dft_c2r_1d(size, nullptr, nullptr, FFTW_ESTIMATE)); + case FFTWType::C2C_FORWARD: + return SmartFFTPlan(fftwf_plan_dft_1d(size, nullptr, nullptr, + FFTW_FORWARD, FFTW_ESTIMATE)); + case FFTWType::C2C_BACKARD: + return SmartFFTPlan(fftwf_plan_dft_1d(size, nullptr, nullptr, + FFTW_BACKWARD, FFTW_ESTIMATE)); + default: + return std::move(kInvalidPlan); + } +} + +struct cmpByPair { + bool operator()(const std::pair<FFTWType, size_t>& a, + const std::pair<FFTWType, size_t>& b) const { + return a.first < b.first || (a.first == b.first && a.second < b.second); + } +}; + +class FTTWPlannerCache { + public: + template <FFTWType T> + void PlanMulti(std::vector<size_t> sizes) { + std::unique_lock<std::shared_mutex> lock_r(fftwplanner_rw_mutex_); + + for (size_t size : sizes) { + fftwplanner_cached_plans_[FFTW_KEY(T, size)] = MakePlan(T, size); + } + } + + void ClearCache() { + std::unique_lock<std::shared_mutex> lock_rw(fftwplanner_rw_mutex_); + + fftwplanner_cached_plans_.clear(); + } + + template <FFTWType T> + fftwf_plan GetOrPlan(size_t size) { + std::shared_lock<std::shared_mutex> read_lock(fftwplanner_rw_mutex_); + + fftwf_plan plan = SearchPlan<T>(size); + if (plan != nullptr) { + return plan; + } + + read_lock.unlock(); + read_lock.release(); + std::unique_lock<std::shared_mutex> lock(fftwplanner_rw_mutex_); + + fftwplanner_cached_plans_.emplace(FFTW_KEY(T, size), + std::move(MakePlan(T, size))); + return static_cast<fftwf_plan>( + fftwplanner_cached_plans_[FFTW_KEY(T, size)]); + } + + private: + static constexpr std::pair<FFTWType, size_t> FFTW_KEY(FFTWType type, + size_t size) { + return std::make_pair(type, size); + } + + template <FFTWType T> + fftwf_plan SearchPlan(size_t size) { + std::pair<FFTWType, size_t> key = std::make_pair(T, size); + if (auto const& match = fftwplanner_cached_plans_.find(key); + match != fftwplanner_cached_plans_.end()) { + return static_cast<fftwf_plan>(match->second); + } + return nullptr; + } + + std::map<std::pair<FFTWType, size_t>, SmartFFTPlan, cmpByPair> + fftwplanner_cached_plans_; +}; + +static FTTWPlannerCache* fftwplanner_cache_ = nullptr; + +FTTWPlannerCache* GetFFTWCache() { + if (fftwplanner_cache_ == nullptr) { + std::unique_lock<std::shared_mutex> lock(fftwplanner_rw_mutex_); + fftwplanner_cache_ = new FTTWPlannerCache(); + } + return fftwplanner_cache_; +} +} // namespace + namespace schaapcommon::math { +void CleanWisdom() { fftwf_cleanup(); }; void MakeFftwfPlannerThreadSafe() { fftwf_make_planner_thread_safe(); } +void PlanMultiFFTW(std::vector<size_t> height, std::vector<size_t> width) { + GetFFTWCache()->PlanMulti<FFTWType::C2C_FORWARD>(height); + GetFFTWCache()->PlanMulti<FFTWType::C2C_BACKARD>(height); + GetFFTWCache()->PlanMulti<FFTWType::C2R>(width); + GetFFTWCache()->PlanMulti<FFTWType::R2C>(width); +} + +void ClearFFTWCache() { GetFFTWCache()->ClearCache(); } void ResizeAndConvolve(float* image, size_t image_width, size_t image_height, const float* kernel, size_t kernel_size) { aocommon::UVector<float> scaled_kernel(image_width * image_height, 0.0); @@ -116,23 +247,38 @@ void PrepareConvolutionKernel(float* dest, const float* source, } void Convolve(float* image, const float* kernel, size_t image_width, - size_t image_height) { + size_t image_height, bool use_cache) { const size_t image_size = image_width * image_height; const size_t complex_width = image_width / 2 + 1; const size_t complex_size = complex_width * image_height; float* temp_data = fftwf_alloc_real(image_size); fftwf_complex* fft_image_data = fftwf_alloc_complex(complex_size); fftwf_complex* fft_kernel_data = fftwf_alloc_complex(complex_size); + fftwf_plan plan_r2c = nullptr, plan_c2c_forward = nullptr, + plan_c2c_backward = nullptr, plan_c2r = nullptr; - fftwf_plan plan_r2c = - fftwf_plan_dft_r2c_1d(image_width, nullptr, nullptr, FFTW_ESTIMATE); - fftwf_plan plan_c2c_forward = fftwf_plan_dft_1d( - image_height, nullptr, nullptr, FFTW_FORWARD, FFTW_ESTIMATE); - fftwf_plan plan_c2c_backward = fftwf_plan_dft_1d( - image_height, nullptr, nullptr, FFTW_BACKWARD, FFTW_ESTIMATE); - fftwf_plan plan_c2r = - fftwf_plan_dft_c2r_1d(image_width, nullptr, nullptr, FFTW_ESTIMATE); + if (use_cache) { + plan_r2c = GetFFTWCache()->GetOrPlan<FFTWType::R2C>(image_width); + plan_c2c_forward = + GetFFTWCache()->GetOrPlan<FFTWType::C2C_FORWARD>(image_height); + plan_c2c_backward = + GetFFTWCache()->GetOrPlan<FFTWType::C2C_BACKARD>(image_height); + plan_c2r = GetFFTWCache()->GetOrPlan<FFTWType::C2R>(image_width); + } else { + plan_r2c = + fftwf_plan_dft_r2c_1d(image_width, nullptr, nullptr, FFTW_ESTIMATE); + plan_c2c_forward = fftwf_plan_dft_1d(image_height, nullptr, nullptr, + FFTW_FORWARD, FFTW_ESTIMATE); + plan_c2c_backward = fftwf_plan_dft_1d(image_height, nullptr, nullptr, + FFTW_BACKWARD, FFTW_ESTIMATE); + plan_c2r = + fftwf_plan_dft_c2r_1d(image_width, nullptr, nullptr, FFTW_ESTIMATE); + } + if (plan_c2c_backward == nullptr || plan_c2r == nullptr || + plan_c2c_forward == nullptr || plan_r2c == nullptr) { + std::runtime_error("the fuck!"); + } aocommon::StaticFor<size_t> loop; FftR2CComposite(plan_r2c, plan_c2c_forward, image_height, image_width, image, @@ -159,11 +305,6 @@ void Convolve(float* image, const float* kernel, size_t image_width, fftwf_free(fft_image_data); fftwf_free(fft_kernel_data); fftwf_free(temp_data); - - fftwf_destroy_plan(plan_r2c); - fftwf_destroy_plan(plan_c2c_forward); - fftwf_destroy_plan(plan_c2c_backward); - fftwf_destroy_plan(plan_c2r); } } // namespace schaapcommon::math diff --git a/src/math/test/tconvolution.cc b/src/math/test/tconvolution.cc index 288d0ec0052ce340b00185703151ea6ae020f466..c7df9f0dee16724cc2f6c0f10c025f8d5d3afead 100644 --- a/src/math/test/tconvolution.cc +++ b/src/math/test/tconvolution.cc @@ -8,6 +8,9 @@ #include <aocommon/image.h> #include <aocommon/threadpool.h> #include <iostream> +#include <chrono> +#include <vector> +#include <aocommon/recursivefor.h> namespace { constexpr size_t kWidth = 4; @@ -58,6 +61,7 @@ BOOST_AUTO_TEST_CASE(prepare_small_kernel) { BOOST_AUTO_TEST_CASE(convolve) { aocommon::ThreadPool::GetInstance().SetNThreads(kThreadCount); + schaapcommon::math::MakeFftwfPlannerThreadSafe(); const float dirac_scale = 0.5; aocommon::Image image(kWidth, kHeight); for (size_t i = 0; i != image.Size(); ++i) { @@ -68,7 +72,6 @@ BOOST_AUTO_TEST_CASE(convolve) { kernel[kWidth * kHeight / 2 + kWidth / 2] = dirac_scale * 1.0f; - schaapcommon::math::MakeFftwfPlannerThreadSafe(); BOOST_CHECK_THROW( schaapcommon::math::ResizeAndConvolve(image.Data(), kWidth, kHeight, kernel.Data(), kWidth * 2), @@ -84,4 +87,110 @@ BOOST_AUTO_TEST_CASE(convolve) { } } +BOOST_AUTO_TEST_CASE(convolve_threads) { + aocommon::ThreadPool::GetInstance().SetNThreads(kThreadCount); + schaapcommon::math::MakeFftwfPlannerThreadSafe(); + size_t Width = 2096; + size_t Height = 1096; + size_t kNImages = 1000; + + // Preparation + aocommon::RecursiveFor loop; + const float dirac_scale = 0.5; + std::vector<aocommon::Image> images(kNImages * kThreadCount); + std::vector<aocommon::Image> kernels(kNImages * kThreadCount); + for (size_t i = 0; i != images.size(); ++i) { + images[i] = aocommon::Image(Width / (i / 10 + 1), Height / (i / 10 + 1)); + kernels[i] = aocommon::Image(Width / (i / 10 + 1), Height / (i / 10 + 1)); + for (size_t j = 0; j != images[i].Size(); ++j) { + images[i][j] = j; + + kernels[i][kernels[i].Size() / 2 + kernels[i].Width() / 2] = + dirac_scale * 1.0f; + } + } + + std::vector<size_t> height(images.size()); + std::vector<size_t> width(images.size()); + for (size_t i = 0; i != images.size(); ++i) { + height[i] = images[i].Height(); + width[i] = images[i].Width(); + } + // Preparation end + + // Test with cache but not precaching + auto start = std::chrono::steady_clock::now(); + loop.Run(0, images.size(), [&](size_t k, size_t thread_id) { + schaapcommon::math::Convolve(images[k].Data(), kernels[k].Data(), + images[k].Width(), images[k].Height(), true); + }); + + auto end = std::chrono::steady_clock::now(); + + std::cout << "Elapsed time without precaching: " + << std::chrono::duration_cast<std::chrono::milliseconds>(end - + start) + .count() / + 1.e3 + << " s" << std::endl; + + schaapcommon::math::ClearFFTWCache(); + schaapcommon::math::CleanWisdom(); + start = std::chrono::steady_clock::now(); + + schaapcommon::math::PlanMultiFFTW(height, width); + end = std::chrono::steady_clock::now(); + std::cout << "Planning time is " + << std::chrono::duration_cast<std::chrono::milliseconds>(end - + start) + .count() / + 1.e3 + << " s" << std::endl; + // Test with cache with precaching + schaapcommon::math::CleanWisdom(); + start = std::chrono::steady_clock::now(); + loop.Run(0, images.size(), [&](size_t k, size_t thread_id) { + schaapcommon::math::Convolve(images[k].Data(), kernels[k].Data(), width[k], + height[k], true); + }); + + end = std::chrono::steady_clock::now(); + + std::cout << "Elapsed time with cache: " + << std::chrono::duration_cast<std::chrono::milliseconds>(end - + start) + .count() / + 1.e3 + << " s" << std::endl; + schaapcommon::math::CleanWisdom(); + start = std::chrono::steady_clock::now(); + loop.Run(0, images.size(), [&](size_t k, size_t thread_id) { + schaapcommon::math::Convolve(images[k].Data(), kernels[k].Data(), width[k], + height[k], false); + }); + end = std::chrono::steady_clock::now(); + + std::cout << "Elapsed time without cache: " + << std::chrono::duration_cast<std::chrono::milliseconds>(end - + start) + .count() / + 1.e3 + << " s" << std::endl; + + start = std::chrono::steady_clock::now(); + for (size_t k = 0; k < images.size(); k++) { + schaapcommon::math::Convolve(images[k].Data(), kernels[k].Data(), width[k], + height[k], false); + } + + end = std::chrono::steady_clock::now(); + + std::cout << "Elapsed time without nestedfor: " + << std::chrono::duration_cast<std::chrono::milliseconds>(end - + start) + .count() / + 1.e3 + << " s" << std::endl; +} + BOOST_AUTO_TEST_SUITE_END()