Skip to content
Snippets Groups Projects

Draft: Add working cache code for fftw

Open Mattia Mancini requested to merge add_fftw_cache into master
4 unresolved threads
+ 156
15
@@ -8,17 +8,148 @@
#include <aocommon/uvector.h>
#include <aocommon/staticfor.h>
#include <map>
#include <fftw3.h>
#include <mutex>
#include <shared_mutex>
#include "compositefft.h"
#include <iostream>
namespace {
enum FFTWType { C2C_FORWARD, C2C_BACKARD, C2R, R2C };
static std::shared_mutex fftwplanner_rw_mutex_;
void manual_destroy(fftwf_plan plan) {
if (plan != nullptr) fftwf_destroy_plan(plan);
}
class SmartFFTPlan
: public std::unique_ptr<fftwf_plan_s, decltype(&manual_destroy)> {
public:
SmartFFTPlan()
: std::unique_ptr<fftwf_plan_s, decltype(&manual_destroy)>(
nullptr, manual_destroy) {}
explicit SmartFFTPlan(fftwf_plan plan)
: std::unique_ptr<fftwf_plan_s, decltype(&manual_destroy)>(
plan, manual_destroy) {}
SmartFFTPlan(const SmartFFTPlan&) = delete;
SmartFFTPlan& operator=(const SmartFFTPlan&) = delete;
SmartFFTPlan(SmartFFTPlan&&) = default;
Please register or sign in to reply
SmartFFTPlan& operator=(SmartFFTPlan&&) = default;
explicit operator fftwf_plan() const { return get(); }
};
SmartFFTPlan kInvalidPlan = SmartFFTPlan(nullptr);
SmartFFTPlan MakePlan(FFTWType type, size_t size) {
switch (type) {
case FFTWType::R2C:
return SmartFFTPlan(
fftwf_plan_dft_r2c_1d(size, nullptr, nullptr, FFTW_ESTIMATE));
case FFTWType::C2R:
return SmartFFTPlan(
fftwf_plan_dft_c2r_1d(size, nullptr, nullptr, FFTW_ESTIMATE));
case FFTWType::C2C_FORWARD:
return SmartFFTPlan(fftwf_plan_dft_1d(size, nullptr, nullptr,
FFTW_FORWARD, FFTW_ESTIMATE));
case FFTWType::C2C_BACKARD:
return SmartFFTPlan(fftwf_plan_dft_1d(size, nullptr, nullptr,
FFTW_BACKWARD, FFTW_ESTIMATE));
default:
return std::move(kInvalidPlan);
}
}
struct cmpByPair {
bool operator()(const std::pair<FFTWType, size_t>& a,
const std::pair<FFTWType, size_t>& b) const {
return a.first < b.first || (a.first == b.first && a.second < b.second);
}
};
class FTTWPlannerCache {
public:
template <FFTWType T>
void PlanMulti(std::vector<size_t> sizes) {
std::unique_lock<std::shared_mutex> lock_r(fftwplanner_rw_mutex_);
for (size_t size : sizes) {
fftwplanner_cached_plans_[FFTW_KEY(T, size)] = MakePlan(T, size);
}
}
void ClearCache() {
std::unique_lock<std::shared_mutex> lock_rw(fftwplanner_rw_mutex_);
fftwplanner_cached_plans_.clear();
}
template <FFTWType T>
fftwf_plan GetOrPlan(size_t size) {
std::shared_lock<std::shared_mutex> read_lock(fftwplanner_rw_mutex_);
Please register or sign in to reply
fftwf_plan plan = SearchPlan<T>(size);
if (plan != nullptr) {
return plan;
}
read_lock.unlock();
read_lock.release();
std::unique_lock<std::shared_mutex> lock(fftwplanner_rw_mutex_);
fftwplanner_cached_plans_.emplace(FFTW_KEY(T, size),
std::move(MakePlan(T, size)));
return static_cast<fftwf_plan>(
fftwplanner_cached_plans_[FFTW_KEY(T, size)]);
}
private:
static constexpr std::pair<FFTWType, size_t> FFTW_KEY(FFTWType type,
size_t size) {
return std::make_pair(type, size);
}
template <FFTWType T>
fftwf_plan SearchPlan(size_t size) {
std::pair<FFTWType, size_t> key = std::make_pair(T, size);
if (auto const& match = fftwplanner_cached_plans_.find(key);
match != fftwplanner_cached_plans_.end()) {
return static_cast<fftwf_plan>(match->second);
}
return nullptr;
}
std::map<std::pair<FFTWType, size_t>, SmartFFTPlan, cmpByPair>
fftwplanner_cached_plans_;
};
static FTTWPlannerCache* fftwplanner_cache_ = nullptr;
FTTWPlannerCache* GetFFTWCache() {
if (fftwplanner_cache_ == nullptr) {
std::unique_lock<std::shared_mutex> lock(fftwplanner_rw_mutex_);
fftwplanner_cache_ = new FTTWPlannerCache();
}
return fftwplanner_cache_;
}
} // namespace
namespace schaapcommon::math {
void CleanWisdom() { fftwf_cleanup(); };
void MakeFftwfPlannerThreadSafe() { fftwf_make_planner_thread_safe(); }
void PlanMultiFFTW(std::vector<size_t> height, std::vector<size_t> width) {
GetFFTWCache()->PlanMulti<FFTWType::C2C_FORWARD>(height);
Please register or sign in to reply
GetFFTWCache()->PlanMulti<FFTWType::C2C_BACKARD>(height);
GetFFTWCache()->PlanMulti<FFTWType::C2R>(width);
GetFFTWCache()->PlanMulti<FFTWType::R2C>(width);
}
void ClearFFTWCache() { GetFFTWCache()->ClearCache(); }
void ResizeAndConvolve(float* image, size_t image_width, size_t image_height,
const float* kernel, size_t kernel_size) {
aocommon::UVector<float> scaled_kernel(image_width * image_height, 0.0);
@@ -116,23 +247,38 @@ void PrepareConvolutionKernel(float* dest, const float* source,
}
void Convolve(float* image, const float* kernel, size_t image_width,
size_t image_height) {
size_t image_height, bool use_cache) {
const size_t image_size = image_width * image_height;
const size_t complex_width = image_width / 2 + 1;
const size_t complex_size = complex_width * image_height;
float* temp_data = fftwf_alloc_real(image_size);
fftwf_complex* fft_image_data = fftwf_alloc_complex(complex_size);
fftwf_complex* fft_kernel_data = fftwf_alloc_complex(complex_size);
fftwf_plan plan_r2c = nullptr, plan_c2c_forward = nullptr,
plan_c2c_backward = nullptr, plan_c2r = nullptr;
fftwf_plan plan_r2c =
fftwf_plan_dft_r2c_1d(image_width, nullptr, nullptr, FFTW_ESTIMATE);
fftwf_plan plan_c2c_forward = fftwf_plan_dft_1d(
image_height, nullptr, nullptr, FFTW_FORWARD, FFTW_ESTIMATE);
fftwf_plan plan_c2c_backward = fftwf_plan_dft_1d(
image_height, nullptr, nullptr, FFTW_BACKWARD, FFTW_ESTIMATE);
fftwf_plan plan_c2r =
fftwf_plan_dft_c2r_1d(image_width, nullptr, nullptr, FFTW_ESTIMATE);
if (use_cache) {
plan_r2c = GetFFTWCache()->GetOrPlan<FFTWType::R2C>(image_width);
    • Since GetFFTWCache returns a singleton, could you just fetch the FTTWPlannerCache pointer once and then use it in subsequent calls? The fftwplanner_cache_ pointer accessibility is set to private, therefore, it is only deallocated in case FTTWPlannerCache is destroyed (when it goes out of scope)

Please register or sign in to reply
plan_c2c_forward =
GetFFTWCache()->GetOrPlan<FFTWType::C2C_FORWARD>(image_height);
plan_c2c_backward =
GetFFTWCache()->GetOrPlan<FFTWType::C2C_BACKARD>(image_height);
plan_c2r = GetFFTWCache()->GetOrPlan<FFTWType::C2R>(image_width);
} else {
plan_r2c =
fftwf_plan_dft_r2c_1d(image_width, nullptr, nullptr, FFTW_ESTIMATE);
plan_c2c_forward = fftwf_plan_dft_1d(image_height, nullptr, nullptr,
FFTW_FORWARD, FFTW_ESTIMATE);
plan_c2c_backward = fftwf_plan_dft_1d(image_height, nullptr, nullptr,
FFTW_BACKWARD, FFTW_ESTIMATE);
plan_c2r =
fftwf_plan_dft_c2r_1d(image_width, nullptr, nullptr, FFTW_ESTIMATE);
}
if (plan_c2c_backward == nullptr || plan_c2r == nullptr ||
plan_c2c_forward == nullptr || plan_r2c == nullptr) {
std::runtime_error("the fuck!");
}
aocommon::StaticFor<size_t> loop;
FftR2CComposite(plan_r2c, plan_c2c_forward, image_height, image_width, image,
@@ -159,11 +305,6 @@ void Convolve(float* image, const float* kernel, size_t image_width,
fftwf_free(fft_image_data);
fftwf_free(fft_kernel_data);
fftwf_free(temp_data);
fftwf_destroy_plan(plan_r2c);
fftwf_destroy_plan(plan_c2c_forward);
fftwf_destroy_plan(plan_c2c_backward);
fftwf_destroy_plan(plan_c2r);
}
} // namespace schaapcommon::math
Loading