Skip to content
Snippets Groups Projects
Commit 7628fc36 authored by Andre Offringa's avatar Andre Offringa
Browse files

Merge branch 'padded-convolution' into 'master'

Add function for padded convolutions

See merge request !169
parents b3517f2e d3d201c1
No related branches found
No related tags found
1 merge request!169Add function for padded convolutions
Pipeline #108681 passed
#include "convolution.h"
#include <aocommon/image.h>
#ifndef SCHAAPCOMMON_MATH_PADDED_CONVOLUTION_H_
#define SCHAAPCOMMON_MATH_PADDED_CONVOLUTION_H_
namespace schaapcommon::math {
/**
* Performs an in-place padded image convolution. The convolution kernel (psf)
* is appropriately transformed so that its centre pixel is the centre of the
* kernel. This overload requires two scratch images that have been allocated
* with the padded dimensions.
*
* This operation is sometimes referred to as correlation, because the kernel is
* not flipped as is formally done when convolving. The operation performed by
* this function is typically used to convolve an interferometric model image
* with the PSF.
*
* This overload allows specifying the scratch images, to minimize allocations
* for use-cases that perform multiple padded convolutions. Both scratch images
* should be allocated with the padded size on input.
*/
inline void PaddedConvolution(aocommon::Image& image,
const aocommon::Image& psf,
aocommon::Image& scratch_a,
aocommon::Image& scratch_b, size_t padded_width,
size_t padded_height) {
assert(image.Width() == psf.Width());
assert(image.Height() == psf.Height());
assert(image.Size() > 0);
assert(scratch_a.Width() * scratch_a.Height() >=
padded_width * padded_height);
assert(scratch_b.Width() * scratch_b.Height() >=
padded_width * padded_height);
assert(padded_width >= image.Width());
assert(padded_height >= image.Height());
using aocommon::Image;
// scratch_a = padded psf
Image::Untrim(scratch_a.Data(), padded_width, padded_height, psf.Data(),
psf.Width(), psf.Height());
// scratch_b = prepared padded psf
PrepareConvolutionKernel(scratch_b.Data(), scratch_a.Data(), padded_width,
padded_height);
// scratch_a = padded image
Image::Untrim(scratch_a.Data(), padded_width, padded_height, image.Data(),
image.Width(), image.Height());
// Convolve and store in scratch_a
Convolve(scratch_a.Data(), scratch_b.Data(), padded_width, padded_height);
Image::Trim(image.Data(), image.Width(), image.Height(), scratch_a.Data(),
padded_width, padded_height);
}
/**
* Convenience overload of the above PaddedConvolution() function that does
* not require scratch images. The required scratch images are allocated
* inside the function. See other overload for help.
*/
inline void PaddedConvolution(aocommon::Image& image,
const aocommon::Image& psf, size_t padded_width,
size_t padded_height) {
aocommon::Image scratch_a(padded_width, padded_height);
aocommon::Image scratch_b(padded_width, padded_height);
PaddedConvolution(image, psf, scratch_a, scratch_b, padded_width,
padded_height);
}
} // namespace schaapcommon::math
#endif
......@@ -3,5 +3,11 @@
include(${SCHAAPCOMMON_SOURCE_DIR}/cmake/unittest.cmake)
add_unittest(math runtests.cc tconvolution.cc tdrawgaussian.cc tresampler.cc
trestoreimage.cc)
add_unittest(
math
runtests.cc
tconvolution.cc
tdrawgaussian.cc
tpaddedconvolution.cc
tresampler.cc
trestoreimage.cc)
#include <boost/test/unit_test.hpp>
#include <aocommon/image.h>
#include "paddedconvolution.h"
using aocommon::Image;
namespace schaapcommon::math {
namespace {
void Check(size_t width, size_t height) {
// Use a padding factor of 3, so that both padded even sizes remain even
// and padded odd sizes remain odd.
constexpr size_t kPadding = 3;
Image image(width, height, 0.0f);
const Image zero(width, height, 0.0f);
PaddedConvolution(image, zero, width * kPadding, height * kPadding);
BOOST_CHECK_LT(image.RMS(), 1e-5);
Image delta(width, height, 0.0f);
delta.Value(width / 2, height / 2) = 1.0f;
PaddedConvolution(image, delta, width * kPadding, height * kPadding);
BOOST_CHECK_LT(image.RMS(), 1e-5);
for (size_t i = 0; i != image.Size(); ++i) image[i] = i + 3;
PaddedConvolution(image, delta, width * kPadding, height * kPadding);
for (size_t i = 0; i != image.Size(); ++i)
BOOST_CHECK_CLOSE_FRACTION(image[i], i + 3, 1e-6);
delta *= 2.0f;
PaddedConvolution(image, delta, width * kPadding, height * kPadding);
for (size_t i = 0; i != image.Size(); ++i)
BOOST_CHECK_CLOSE_FRACTION(image[i], (i + 3) * 2, 1e-5);
}
} // namespace
BOOST_AUTO_TEST_SUITE(padded_convolution)
BOOST_AUTO_TEST_CASE(simple_even) { Check(10, 8); }
BOOST_AUTO_TEST_CASE(simple_odd) { Check(9, 11); }
BOOST_AUTO_TEST_CASE(full_example) {
Image image(8, 6, 0.0f);
image.Value(1, 3) = 3.0;
image.Value(4, 5) = 1.5;
image.Value(6, 5) = -1.0;
Image psf(8, 6, 0.0f);
psf.Value(4, 3) = 0.5;
psf.Value(3, 3) = 3.0;
psf.Value(5, 3) = 7.0;
psf.Value(4, 2) = 10.5;
psf.Value(4, 4) = 37.0;
PaddedConvolution(image, psf, 20, 30);
const Image reference(
8, 6,
{
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, // row 0
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, // row 1
0.0f, 31.5f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, // row 2
9.0f, 1.5f, 21.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, // row 3
0.0f, 111.0f, 0.0f, 0.0f, 15.75f, 0.0f, -10.5f, 0.0f, // row 4
0.0f, 0.0f, 0.0f, 4.5f, 0.75f, 7.5f, -0.5f, -7.0f // row 5
});
for (size_t i = 0; i != reference.Size(); ++i) {
if (reference[i] == 0.0f)
BOOST_CHECK_LT(std::fabs(image[i]), 1e-5);
else
BOOST_CHECK_CLOSE_FRACTION(reference[i], image[i], 1e-5);
}
}
BOOST_AUTO_TEST_SUITE_END()
} // namespace schaapcommon::math
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment