Skip to content
Snippets Groups Projects
Commit d3d201c1 authored by Andre Offringa's avatar Andre Offringa
Browse files

Add function for padded convolutions

parent b3517f2e
No related branches found
No related tags found
1 merge request!169Add function for padded convolutions
Pipeline #107914 passed
#include "convolution.h"
#include <aocommon/image.h>
#ifndef SCHAAPCOMMON_MATH_PADDED_CONVOLUTION_H_
#define SCHAAPCOMMON_MATH_PADDED_CONVOLUTION_H_
namespace schaapcommon::math {
/**
* Performs an in-place padded image convolution. The convolution kernel (psf)
* is appropriately transformed so that its centre pixel is the centre of the
* kernel. This overload requires two scratch images that have been allocated
* with the padded dimensions.
*
* This operation is sometimes referred to as correlation, because the kernel is
* not flipped as is formally done when convolving. The operation performed by
* this function is typically used to convolve an interferometric model image
* with the PSF.
*
* This overload allows specifying the scratch images, to minimize allocations
* for use-cases that perform multiple padded convolutions. Both scratch images
* should be allocated with the padded size on input.
*/
inline void PaddedConvolution(aocommon::Image& image,
const aocommon::Image& psf,
aocommon::Image& scratch_a,
aocommon::Image& scratch_b, size_t padded_width,
size_t padded_height) {
assert(image.Width() == psf.Width());
assert(image.Height() == psf.Height());
assert(image.Size() > 0);
assert(scratch_a.Width() * scratch_a.Height() >=
padded_width * padded_height);
assert(scratch_b.Width() * scratch_b.Height() >=
padded_width * padded_height);
assert(padded_width >= image.Width());
assert(padded_height >= image.Height());
using aocommon::Image;
// scratch_a = padded psf
Image::Untrim(scratch_a.Data(), padded_width, padded_height, psf.Data(),
psf.Width(), psf.Height());
// scratch_b = prepared padded psf
PrepareConvolutionKernel(scratch_b.Data(), scratch_a.Data(), padded_width,
padded_height);
// scratch_a = padded image
Image::Untrim(scratch_a.Data(), padded_width, padded_height, image.Data(),
image.Width(), image.Height());
// Convolve and store in scratch_a
Convolve(scratch_a.Data(), scratch_b.Data(), padded_width, padded_height);
Image::Trim(image.Data(), image.Width(), image.Height(), scratch_a.Data(),
padded_width, padded_height);
}
/**
* Convenience overload of the above PaddedConvolution() function that does
* not require scratch images. The required scratch images are allocated
* inside the function. See other overload for help.
*/
inline void PaddedConvolution(aocommon::Image& image,
const aocommon::Image& psf, size_t padded_width,
size_t padded_height) {
aocommon::Image scratch_a(padded_width, padded_height);
aocommon::Image scratch_b(padded_width, padded_height);
PaddedConvolution(image, psf, scratch_a, scratch_b, padded_width,
padded_height);
}
} // namespace schaapcommon::math
#endif
......@@ -3,5 +3,11 @@
include(${SCHAAPCOMMON_SOURCE_DIR}/cmake/unittest.cmake)
add_unittest(math runtests.cc tconvolution.cc tdrawgaussian.cc tresampler.cc
trestoreimage.cc)
add_unittest(
math
runtests.cc
tconvolution.cc
tdrawgaussian.cc
tpaddedconvolution.cc
tresampler.cc
trestoreimage.cc)
#include <boost/test/unit_test.hpp>
#include <aocommon/image.h>
#include "paddedconvolution.h"
using aocommon::Image;
namespace schaapcommon::math {
namespace {
void Check(size_t width, size_t height) {
// Use a padding factor of 3, so that both padded even sizes remain even
// and padded odd sizes remain odd.
constexpr size_t kPadding = 3;
Image image(width, height, 0.0f);
const Image zero(width, height, 0.0f);
PaddedConvolution(image, zero, width * kPadding, height * kPadding);
BOOST_CHECK_LT(image.RMS(), 1e-5);
Image delta(width, height, 0.0f);
delta.Value(width / 2, height / 2) = 1.0f;
PaddedConvolution(image, delta, width * kPadding, height * kPadding);
BOOST_CHECK_LT(image.RMS(), 1e-5);
for (size_t i = 0; i != image.Size(); ++i) image[i] = i + 3;
PaddedConvolution(image, delta, width * kPadding, height * kPadding);
for (size_t i = 0; i != image.Size(); ++i)
BOOST_CHECK_CLOSE_FRACTION(image[i], i + 3, 1e-6);
delta *= 2.0f;
PaddedConvolution(image, delta, width * kPadding, height * kPadding);
for (size_t i = 0; i != image.Size(); ++i)
BOOST_CHECK_CLOSE_FRACTION(image[i], (i + 3) * 2, 1e-5);
}
} // namespace
BOOST_AUTO_TEST_SUITE(padded_convolution)
BOOST_AUTO_TEST_CASE(simple_even) { Check(10, 8); }
BOOST_AUTO_TEST_CASE(simple_odd) { Check(9, 11); }
BOOST_AUTO_TEST_CASE(full_example) {
Image image(8, 6, 0.0f);
image.Value(1, 3) = 3.0;
image.Value(4, 5) = 1.5;
image.Value(6, 5) = -1.0;
Image psf(8, 6, 0.0f);
psf.Value(4, 3) = 0.5;
psf.Value(3, 3) = 3.0;
psf.Value(5, 3) = 7.0;
psf.Value(4, 2) = 10.5;
psf.Value(4, 4) = 37.0;
PaddedConvolution(image, psf, 20, 30);
const Image reference(
8, 6,
{
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, // row 0
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, // row 1
0.0f, 31.5f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, // row 2
9.0f, 1.5f, 21.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, // row 3
0.0f, 111.0f, 0.0f, 0.0f, 15.75f, 0.0f, -10.5f, 0.0f, // row 4
0.0f, 0.0f, 0.0f, 4.5f, 0.75f, 7.5f, -0.5f, -7.0f // row 5
});
for (size_t i = 0; i != reference.Size(); ++i) {
if (reference[i] == 0.0f)
BOOST_CHECK_LT(std::fabs(image[i]), 1e-5);
else
BOOST_CHECK_CLOSE_FRACTION(reference[i], image[i], 1e-5);
}
}
BOOST_AUTO_TEST_SUITE_END()
} // namespace schaapcommon::math
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment