Skip to content
Snippets Groups Projects
Commit a177d13a authored by Andre Offringa's avatar Andre Offringa
Browse files

Bug 1491: aofrequencyfilter is now multithreaded

parent d700a480
No related branches found
No related tags found
No related merge requests found
......@@ -27,7 +27,14 @@ class System
public:
static long TotalMemory()
{
return casa::HostInfo::memoryTotal()*1024;
return casa::HostInfo::memoryTotal()*1024;
}
static unsigned ProcessorCount()
{
unsigned cpus = casa::HostInfo::numCPUs();
if(cpus == 0) cpus = 1;
return cpus;
}
};
......
#include <iostream>
#include <string>
#include <deque>
#include <ms/MeasurementSets/MeasurementSet.h>
#include <ms/MeasurementSets/MSColumns.h>
#include <ms/MeasurementSets/MSTable.h>
#include <boost/thread.hpp>
#include <tables/Tables/TableIter.h>
#include <AOFlagger/rfi/thresholdtools.h>
#include <AOFlagger/imaging/uvimager.h>
#include <AOFlagger/msio/system.h>
using namespace std;
......@@ -24,6 +28,171 @@ double uvDist(double u, double v, double firstFrequency, double lastFrequency)
return sqrt(ud * ud + vd * vd) / UVImager::SpeedOfLight();
}
// This represents data that is constant for all threads
struct SetInfo
{
unsigned bandCount;
unsigned frequencyCount;
unsigned polarizationCount;
};
// This represents the data that is specific for a single thread
struct TaskInfo
{
TaskInfo() : length(0), convolutionSize(0), table(0), dataColumn(0), realData(0), imagData(0) { }
TaskInfo(const TaskInfo &source) :
length(source.length),
convolutionSize(source.convolutionSize),
table(source.table),
dataColumn(source.dataColumn),
realData(source.realData),
imagData(source.imagData)
{
}
void operator=(const TaskInfo &source)
{
length = source.length;
convolutionSize = source.convolutionSize;
table = source.table;
dataColumn = source.dataColumn;
realData = source.realData;
imagData = source.imagData;
}
unsigned length, convolutionSize;
casa::Table *table;
casa::ArrayColumn<casa::Complex> *dataColumn;
float
**realData,
**imagData;
};
void performAndWriteConvolution(const SetInfo &set, TaskInfo task, boost::mutex &mutex)
{
// Convolve the data
for(unsigned p=0;p<set.polarizationCount;++p)
{
ThresholdTools::OneDimensionalSincConvolution(task.realData[p], task.length, task.convolutionSize);
ThresholdTools::OneDimensionalSincConvolution(task.imagData[p], task.length, task.convolutionSize);
}
boost::mutex::scoped_lock lock(mutex);
// Copy data back to tables
for(unsigned i=0;i<set.bandCount;++i)
{
casa::Array<casa::Complex> dataArray = (*task.dataColumn)(i);
casa::Array<casa::Complex>::iterator dataIterator = dataArray.begin();
unsigned index = i * set.frequencyCount;
for(unsigned f=0;f<set.frequencyCount;++f)
{
for(unsigned p=0;p<set.polarizationCount;++p)
{
*dataIterator = casa::Complex(task.realData[p][index], task.imagData[p][index]);
++dataIterator;
}
++index;
}
task.dataColumn->basePut(i, dataArray);
}
lock.unlock();
// Free memory
for(unsigned p=0;p<set.polarizationCount;++p)
{
delete[] task.realData[p];
delete[] task.imagData[p];
}
delete[] task.realData;
delete[] task.imagData;
delete task.dataColumn;
delete task.table;
}
struct ThreadFunction
{
void operator()();
class ThreadControl *threadControl;
int number;
};
class ThreadControl
{
public:
ThreadControl(unsigned threadCount, const SetInfo &setInfo)
: _setInfo(setInfo), _threadCount(threadCount), _isFinishing(false)
{
for(unsigned i=0;i<threadCount;++i)
{
ThreadFunction function;
function.number = i;
function.threadControl = this;
_threadGroup.create_thread(function);
}
}
void PushTask(const TaskInfo &taskInfo)
{
boost::mutex::scoped_lock lock(_mutex);
while(_tasks.size() > _threadCount * 10)
{
_queueFullCondition.wait(lock);
}
_tasks.push_back(taskInfo);
_dataAvailableCondition.notify_one();
}
bool WaitForTask(TaskInfo &taskInfo)
{
boost::mutex::scoped_lock lock(_mutex);
while(_tasks.empty() && !_isFinishing)
{
_dataAvailableCondition.wait(lock);
}
if(_isFinishing)
return false;
else
{
taskInfo = _tasks.front();
_tasks.pop_front();
_queueFullCondition.notify_one();
return true;
}
}
void Finish()
{
boost::mutex::scoped_lock lock(_mutex);
_isFinishing = true;
lock.unlock();
_dataAvailableCondition.notify_all();
_threadGroup.join_all();
}
const struct SetInfo &SetInfo() const { return _setInfo; }
boost::mutex &WriteMutex() { return _writeMutex; }
private:
const struct SetInfo _setInfo;
boost::thread_group _threadGroup;
unsigned _threadCount;
bool _isFinishing;
boost::mutex _mutex;
boost::condition_variable _dataAvailableCondition;
boost::condition_variable _queueFullCondition;
std::deque<TaskInfo> _tasks;
boost::mutex _writeMutex;
};
void ThreadFunction::operator()()
{
cout << "Thread " << number << " started\n";
TaskInfo task;
bool hasTask = threadControl->WaitForTask(task);
while(hasTask)
{
performAndWriteConvolution(threadControl->SetInfo(), task, threadControl->WriteMutex());
hasTask = threadControl->WaitForTask(task);
}
cout << "Thread " << number << " finished\n";
}
int main(int argc, char *argv[])
{
if(argc != 3)
......@@ -36,13 +205,14 @@ int main(int argc, char *argv[])
cout << "Fringe size: " << fringeSize << '\n';
casa::MeasurementSet ms(msFilename);
casa::Table table(msFilename, casa::Table::Update);
casa::Table mainTable(msFilename, casa::Table::Update);
class SetInfo setInfo;
//count number of polarizations
casa::Table polTable = ms.polarization();
casa::ROArrayColumn<int> corTypeColumn(polTable, "CORR_TYPE");
const unsigned polarizationCount = corTypeColumn(0).shape()[0];
cout << "Number of polarizations: " << polarizationCount << '\n';
setInfo.polarizationCount = corTypeColumn(0).shape()[0];
cout << "Number of polarizations: " << setInfo.polarizationCount << '\n';
// Find lowest and highest frequency and check order
double lowestFrequency = 0.0, highestFrequency = 0.0;
......@@ -65,18 +235,23 @@ int main(int argc, char *argv[])
++frequencyIterator;
}
}
setInfo.bandCount = spectralWindowTable.nrow();
cout
<< "Number of bands: " << spectralWindowTable.nrow()
<< "Number of bands: " << setInfo.bandCount
<< " (" << round(lowestFrequency/1e6) << " MHz - "
<< round(highestFrequency/1e6) << " MHz)\n";
const unsigned frequencyCount =
casa::ROArrayColumn<casa::Complex>(table, "DATA")(0).shape()[1];
cout << "Channels per band: " << frequencyCount << '\n';
setInfo.frequencyCount =
casa::ROArrayColumn<casa::Complex>(mainTable, "DATA")(0).shape()[1];
cout << "Channels per band: " << setInfo.frequencyCount << '\n';
const unsigned long totalIterations =
table.nrow() / spectralWindowTable.nrow();
mainTable.nrow() / spectralWindowTable.nrow();
cout << "Total iterations: " << totalIterations << '\n';
unsigned processorCount = System::ProcessorCount();
cout << "CPUs: " << processorCount << '\n';
ThreadControl threads(processorCount, setInfo);
// Create the sorted table and iterate over it
casa::Block<casa::String> names(4);
......@@ -85,102 +260,84 @@ int main(int argc, char *argv[])
names[2] = "ANTENNA2";
names[3] = "DATA_DESC_ID";
cout << "Sorting...\n";
casa::Table sortab = table.sort(names);
casa::Table sortab = mainTable.sort(names);
cout << "Iterating...\n";
unsigned long iterSteps = 0;
names.resize(3, true, true);
casa::TableIterator iter (sortab, names, casa::TableIterator::Ascending, casa::TableIterator::NoSort);
double maxFringeChannels = 0.0, minFringeChannels = 1e100;
while (! iter.pastEnd()) {
casa::Table table = iter.table();
TaskInfo task;
task.table = new casa::Table(iter.table());
casa::ROScalarColumn<int> antenna1Column =
casa::ROScalarColumn<int>(table, "ANTENNA1");
casa::ROScalarColumn<int>(*task.table, "ANTENNA1");
casa::ROScalarColumn<int> antenna2Column =
casa::ROScalarColumn<int>(table, "ANTENNA2");
casa::ROScalarColumn<int>(*task.table, "ANTENNA2");
// Skip autocorrelations
const int antenna1 = antenna1Column(0), antenna2 = antenna2Column(0);
if(antenna1 != antenna2)
if(antenna1 == antenna2)
{
delete task.table;
} else
{
casa::ArrayColumn<casa::Complex> dataColumn =
casa::ArrayColumn<casa::Complex>(table, "DATA");
task.dataColumn = new casa::ArrayColumn<casa::Complex>(*task.table, "DATA");
casa::ROArrayColumn<double> uvwColumn =
casa::ROArrayColumn<double>(table, "UVW");
const casa::IPosition &dataShape = dataColumn.shape(0);
if(dataShape[1] != frequencyCount) {
casa::ROArrayColumn<double>(*task.table, "UVW");
// Check number of channels & bands
const casa::IPosition &dataShape = task.dataColumn->shape(0);
if(dataShape[1] != setInfo.frequencyCount) {
std::cerr << "ERROR: bands do not have equal number of channels!\n";
abort();
}
const unsigned bandCount = table.nrow();
const unsigned length = bandCount * frequencyCount;
// Allocate memory for putting the data of all channels in an array, for each polarization
float
*realData[polarizationCount],
*imagData[polarizationCount];
for(unsigned p=0;p<polarizationCount;++p)
{
realData[p] = new float[length];
imagData[p] = new float[length];
}
// Copy data from tables in arrays
for(unsigned i=0;i<bandCount;++i)
{
casa::Array<casa::Complex> dataArray = dataColumn(i);
casa::Array<casa::Complex>::const_iterator dataIterator = dataArray.begin();
unsigned index = i * frequencyCount;
for(unsigned f=0;f<frequencyCount;++f)
{
for(unsigned p=0;p<polarizationCount;++p)
{
realData[p][index] = (*dataIterator).real();
imagData[p][index] = (*dataIterator).imag();
}
++index;
}
if(task.table->nrow() != setInfo.bandCount) {
std::cerr << "ERROR: inconsistent band information in specific correlation/time step\n"
" (rows in table iterator's table: " << task.table->nrow() <<
", in set: " << setInfo.bandCount << ")\n";
abort();
}
// Convolve the data
casa::Array<double>::const_iterator uvwIterator = uvwColumn(0).begin();
// Retrieve uv info and calculate the convolution size
casa::Array<double> uvwArray = uvwColumn(0);
casa::Array<double>::const_iterator uvwIterator = uvwArray.begin();
const double u = *uvwIterator;
++uvwIterator;
const double v = *uvwIterator;
const double convSize = fringeSize * (double) length / uvDist(u, v, lowestFrequency, highestFrequency);
if(convSize > maxFringeChannels) maxFringeChannels = convSize;
if(convSize < minFringeChannels) minFringeChannels = convSize;
for(unsigned p=0;p<polarizationCount;++p)
task.length = setInfo.bandCount * setInfo.frequencyCount;
task.convolutionSize = fringeSize * (double) task.length / uvDist(u, v, lowestFrequency, highestFrequency);
if(task.convolutionSize > maxFringeChannels) maxFringeChannels = task.convolutionSize;
if(task.convolutionSize < minFringeChannels) minFringeChannels = task.convolutionSize;
// Allocate memory for putting the data of all channels in an array, for each polarization
task.realData = new float*[setInfo.polarizationCount];
task.imagData = new float*[setInfo.polarizationCount];
for(unsigned p=0;p<setInfo.polarizationCount;++p)
{
ThresholdTools::OneDimensionalSincConvolution(realData[p], length, convSize);
ThresholdTools::OneDimensionalSincConvolution(imagData[p], length, convSize);
task.realData[p] = new float[task.length];
task.imagData[p] = new float[task.length];
}
// Copy data back to tables
for(unsigned i=0;i<bandCount;++i)
boost::mutex::scoped_lock lock(threads.WriteMutex());
// Copy data from tables in arrays
for(unsigned i=0;i<setInfo.bandCount;++i)
{
casa::Array<casa::Complex> dataArray = dataColumn(i);
casa::Array<casa::Complex>::iterator dataIterator = dataArray.begin();
unsigned index = i * frequencyCount;
for(unsigned f=0;f<frequencyCount;++f)
casa::Array<casa::Complex> dataArray = (*task.dataColumn)(i);
casa::Array<casa::Complex>::const_iterator dataIterator = dataArray.begin();
unsigned index = i * setInfo.frequencyCount;
for(unsigned f=0;f<setInfo.frequencyCount;++f)
{
for(unsigned p=0;p<polarizationCount;++p)
for(unsigned p=0;p<setInfo.polarizationCount;++p)
{
(*dataIterator).real() = realData[p][index];
(*dataIterator).imag() = imagData[p][index];
task.realData[p][index] = (*dataIterator).real();
task.imagData[p][index] = (*dataIterator).imag();
}
++index;
}
dataColumn.basePut(i, dataArray);
}
// Free memory
for(unsigned p=0;p<polarizationCount;++p)
{
delete[] realData[p];
delete[] imagData[p];
}
lock.unlock();
threads.PushTask(task);
}
iter.next();
......@@ -190,6 +347,8 @@ int main(int argc, char *argv[])
if((iterSteps * 10UL) % totalIterations < 10UL)
cout << (iterSteps*100/totalIterations) << '%' << flush;
}
cout << '\n';
threads.Finish();
cout
<< "Done. " << iterSteps << " steps taken.\n"
<< "Maximum filtering fringe size = " << maxFringeChannels << " channels, "
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment