diff --git a/CMakeLists.txt b/CMakeLists.txt index c9a903ca8b4006b76b536913d779103a60b161c8..a5a3055a8171e63cfa67faf34b43ae65a18c9e03 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,7 +106,9 @@ if("${isSystemDir}" STREQUAL "-1") set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib") endif("${isSystemDir}" STREQUAL "-1") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror=vla -DNDEBUG -O3 -std=c++11") +#set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror=vla -DNDEBUG -O3 -std=c++11") +#set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror=vla -O3 -g -std=c++11") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror=vla -fsanitize=address -O3 -g -std=c++11") if(CMAKE_BUILD_TYPE STREQUAL "Debug") message(STATUS "Debug build selected: setting linking flag --no-undefined") @@ -221,6 +223,7 @@ set(MSIO_FILES msio/pngfile.cpp msio/rspreader.cpp msio/singlebaselinefile.cpp + msio/singlemandmemorybaselinereader.cpp msio/spatialtimeloader.cpp) set(STRUCTURES_FILES @@ -272,6 +275,7 @@ set(IMAGESETS_FILES imagesets/imageset.cpp imagesets/indexableset.cpp imagesets/msimageset.cpp + imagesets/multibandmsimageset.cpp imagesets/parmimageset.cpp imagesets/pngreader.cpp imagesets/rfibaselineset.cpp diff --git a/aoluarunner/baselineiterator.cpp b/aoluarunner/baselineiterator.cpp index 695ff69d3908526ff67346643ff6f38427cdbd9f..520f6126c892c713e4dfc445710e7348b8682269 100644 --- a/aoluarunner/baselineiterator.cpp +++ b/aoluarunner/baselineiterator.cpp @@ -14,6 +14,7 @@ #include "../imagesets/fitsimageset.h" #include "../imagesets/imageset.h" #include "../imagesets/msimageset.h" +#include "../imagesets/multibandmsimageset.h" // TODO(RAP-310) Remove #include "../imagesets/filterbankset.h" #include "../imagesets/qualitystatimageset.h" #include "../imagesets/rfibaselineset.h" @@ -42,7 +43,9 @@ void BaselineIterator::Run(imagesets::ImageSet& imageSet, LuaThreadGroup& lua, _imageSet = &imageSet; _threadCount = _options.CalculateThreadCount(); - _writeThread.reset(new WriteThread(imageSet, _threadCount, _ioMutex)); + // TODO(RAP-310) Remove test + if (!dynamic_cast<imagesets::MultiBandMsImageSet*>(_imageSet)) + _writeThread.reset(new WriteThread(imageSet, _threadCount, _ioMutex)); _globalScriptData = &scriptData; imagesets::MSImageSet* msImageSet = @@ -231,7 +234,8 @@ void BaselineIterator::ProcessingThread::operator()() { _parent._lua->Execute(_threadIndex, data, baseline->MetaData(), scriptData, executeFunctionName); - _parent._writeThread->SaveFlags(data, baseline->Index()); + if (_parent._writeThread) // TODO(RAP-310) Remove test + _parent._writeThread->SaveFlags(data, baseline->Index()); baseline = _parent.GetNextBaseline(); _parent.IncBaselineProgress(); diff --git a/aoluarunner/options.h b/aoluarunner/options.h index ccc3f9a03f1ab500085d9db74e53a867244b7dab..8da4ab4c14703192af342ea4f088ee7997e54590 100644 --- a/aoluarunner/options.h +++ b/aoluarunner/options.h @@ -54,6 +54,7 @@ struct Options { BaselineIntegration baselineIntegration; size_t chunkSize; boost::optional<bool> combineSPWs; + boost::optional<bool> concatenateFrequency; std::string dataColumn; std::string executeFilename; std::string executeFunctionName; @@ -86,6 +87,8 @@ struct Options { if (other.baselineSelection) baselineSelection = other.baselineSelection; if (other.chunkSize) chunkSize = other.chunkSize; if (other.combineSPWs) combineSPWs = other.combineSPWs; + if (other.concatenateFrequency) + concatenateFrequency = other.concatenateFrequency; if (!other.dataColumn.empty()) dataColumn = other.dataColumn; if (!other.executeFilename.empty()) executeFilename = other.executeFilename; if (!other.executeFunctionName.empty()) @@ -114,6 +117,7 @@ struct Options { baselineIntegration == rhs.baselineIntegration && baselineSelection == rhs.baselineSelection && chunkSize == rhs.chunkSize && combineSPWs == rhs.combineSPWs && + concatenateFrequency == rhs.concatenateFrequency && dataColumn == rhs.dataColumn && executeFilename == rhs.executeFilename && executeFunctionName == rhs.executeFunctionName && diff --git a/aoluarunner/runner.cpp b/aoluarunner/runner.cpp index 7b442dd461292c59e3c46778dd37fc4ee8a0c314..cbd71d3ffab95a4310bdcc67a21f735e91ac55ef 100644 --- a/aoluarunner/runner.cpp +++ b/aoluarunner/runner.cpp @@ -13,11 +13,14 @@ #include "../imagesets/joinedspwset.h" #include "../imagesets/msimageset.h" #include "../imagesets/msoptions.h" +#include "../imagesets/multibandmsimageset.h" #include "../util/logger.h" #include <boost/date_time/posix_time/posix_time.hpp> +#include <boost/make_unique.hpp> +#include <algorithm> #include <fstream> using namespace imagesets; @@ -110,14 +113,43 @@ void Runner::loadStrategy( } } +static std::vector<std::string> FilterProcessedFiles( + const std::vector<std::string>& ms_names) { + std::vector<std::string> result; + std::copy_if(ms_names.begin(), ms_names.end(), std::back_inserter(result), + [](const std::string& ms_name) { + MSMetaData set(ms_name); + if (!set.HasAOFlaggerHistory()) { + return true; + } + Logger::Info + << "Skipping " << ms_name + << ",\n" + "because the set contains AOFlagger history and " + "-skip-flagged was given.\n"; + return false; + }); + return result; +} + void Runner::run(const Options& options) { Logger::SetVerbosity(options.logVerbosity.value_or(Logger::NormalVerbosity)); size_t threadCount = options.CalculateThreadCount(); Logger::Debug << "Number of threads: " << options.threadCount << "\n"; - for (const std::string& filename : _cmdLineOptions.filenames) { - processFile(options, filename, threadCount); + const std::vector<std::string>& ms_files = + options.skipFlagged ? FilterProcessedFiles(_cmdLineOptions.filenames) + : _cmdLineOptions.filenames; + + if (_cmdLineOptions.concatenateFrequency && ms_files.size() > 1) { + // Only use the multi-band image set when there at least 2 files. + // Else just use the simpler code. + processFiles(options, std::move(ms_files), threadCount); + } else { + for (const std::string& filename : ms_files) { + processFile(options, filename, threadCount); + } } } @@ -198,45 +230,53 @@ void Runner::processFile(const Options& options, const std::string& filename, boost::posix_time::microsec_clock::local_time()) << '\n'; - bool skip = false; - if (options.skipFlagged) { - MSMetaData set(filename); - if (set.HasAOFlaggerHistory()) { - skip = true; - Logger::Info << "Skipping " << filename - << ",\n" - "because the set contains AOFlagger history and " - "-skip-flagged was given.\n"; - } - } + ScriptData scriptData; + FileOptions fileOptions; + fileOptions.intervalStart = options.startTimestep; + fileOptions.intervalEnd = options.endTimestep; + fileOptions.filename = filename; + bool isMS = false; + while (fileOptions.intervalIndex < fileOptions.nIntervals) { + std::unique_ptr<ImageSet> imageSet = + initializeImageSet(options, fileOptions); + isMS = dynamic_cast<MSImageSet*>(imageSet.get()) != nullptr; - if (!skip) { - ScriptData scriptData; - FileOptions fileOptions; - fileOptions.intervalStart = options.startTimestep; - fileOptions.intervalEnd = options.endTimestep; - fileOptions.filename = filename; - bool isMS = false; - while (fileOptions.intervalIndex < fileOptions.nIntervals) { - std::unique_ptr<ImageSet> imageSet = - initializeImageSet(options, fileOptions); - isMS = dynamic_cast<MSImageSet*>(imageSet.get()) != nullptr; + LuaThreadGroup lua(threadCount); - LuaThreadGroup lua(threadCount); + loadStrategy(lua, options, imageSet); - loadStrategy(lua, options, imageSet); + std::mutex ioMutex; + BaselineIterator blIterator(&ioMutex, options); + blIterator.Run(*imageSet, lua, scriptData); - std::mutex ioMutex; - BaselineIterator blIterator(&ioMutex, options); - blIterator.Run(*imageSet, lua, scriptData); + ++fileOptions.intervalIndex; + } - ++fileOptions.intervalIndex; - } + if (isMS) writeHistory(options, filename); - if (isMS) writeHistory(options, filename); + finishStatistics(filename, scriptData, isMS); +} - finishStatistics(filename, scriptData, isMS); - } +void Runner::processFiles(const Options& options, + std::vector<std::string> ms_names, + size_t num_threads) { + Logger::Info << "Starting strategy on " + << to_simple_string( + boost::posix_time::microsec_clock::local_time()) + << '\n'; + + ScriptData scriptData; + std::unique_ptr<ImageSet> imageSet = + boost::make_unique<imagesets::MultiBandMsImageSet>( + std::move(ms_names), + options.readMode.value_or(BaselineIOMode::AutoReadMode), num_threads); + + LuaThreadGroup thread_pool(num_threads); + loadStrategy(thread_pool, options, imageSet); + + std::mutex ioMutex; + BaselineIterator blIterator(&ioMutex, options); + blIterator.Run(*imageSet, thread_pool, scriptData); } void Runner::writeHistory(const Options& options, const std::string& filename) { diff --git a/aoluarunner/runner.h b/aoluarunner/runner.h index f9ef19067e5ece9a2960897d04bb972c257bb1ca..f5c220ed3b111514da3c8b4699fcd0a5ea51900c 100644 --- a/aoluarunner/runner.h +++ b/aoluarunner/runner.h @@ -27,6 +27,8 @@ class Runner { void processFile(const Options& options, const std::string& filename, size_t threadCount); + void processFiles(const Options& options, std::vector<std::string> filenames, + size_t threadCount); std::unique_ptr<imagesets::ImageSet> initializeImageSet( const Options& options, FileOptions& fileOptions); void writeHistory(const Options& options, const std::string& filename); diff --git a/applications/aoflagger.cpp b/applications/aoflagger.cpp index a0f8a4f4d466878ad5ef004be209a4a52c8abce2..6795ada16cb574973331e4615ea0a07f01f20fa9 100644 --- a/applications/aoflagger.cpp +++ b/applications/aoflagger.cpp @@ -161,6 +161,8 @@ int main(int argc, char** argv) { NumberList::ParseIntList(argv[parameterIndex], options.fields); } else if (flag == "combine-spws") { options.combineSPWs = true; + } else if (flag == "concatenate-frequency") { + options.concatenateFrequency = true; } else if (flag == "preamble") { ++parameterIndex; options.preamble.emplace_back(argv[parameterIndex]); diff --git a/imagesets/indexableset.h b/imagesets/indexableset.h index 48036087c9aa2cb22a02e340e7380b3a869713a2..04e289a5efaf8bae044b1bdb96927066550bf564 100644 --- a/imagesets/indexableset.h +++ b/imagesets/indexableset.h @@ -28,7 +28,7 @@ class IndexableSet : public ImageSet { size_t sequenceId) const = 0; virtual FieldInfo GetFieldInfo(unsigned fieldIndex) const = 0; - virtual std::string TelescopeName() final override; + virtual std::string TelescopeName() override; /** * Finds the longest or shortest baseline in the same band/sequence as the diff --git a/imagesets/multibandmsimageset.cpp b/imagesets/multibandmsimageset.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c6ebe4fb1687a621d2e63cee11a0c89a87bc7bdf --- /dev/null +++ b/imagesets/multibandmsimageset.cpp @@ -0,0 +1,177 @@ + +#include "multibandmsimageset.h" + +#include "../util/logger.h" +#include "../util/stopwatch.h" + +#include "aocommon/parallelfor.h" + +#include <boost/make_unique.hpp> + +namespace imagesets { + +MultiBandMsImageSet::MultiBandMsImageSet(std::vector<std::string> ms_names, + BaselineIOMode io_mode, + size_t thread_count) + : ms_names_(std::move(ms_names)), io_mode_(io_mode) { + for (const auto& ms_name : ms_names_) + readers_.emplace_back( + boost::make_unique<SingleBandMemoryBaselineReader>(ms_name)); + + ReadData(thread_count); + ProcessMetaData(); +} + +static std::vector<BandInfo> CombineBands( + const std::vector<std::reference_wrapper<MSMetaData>>& meta_data) { + std::vector<BandInfo> result; + std::for_each(meta_data.begin(), meta_data.end(), + [&result](const std::reference_wrapper<MSMetaData> element) { + // There's only one band. + result.emplace_back(element.get().GetBandInfo(0)); + }); + return result; +} + +static const std::vector<std::pair<size_t, size_t>>& GetBaselines( + std::reference_wrapper<MSMetaData> meta_data) { + return meta_data.get().GetBaselines(); +} + +static const std::set<double>& GetObservationTimes( + std::reference_wrapper<MSMetaData> meta_data) { + return meta_data.get().GetObservationTimes(); +} + +static const std::vector<std::set<double>>& GetObservationTimesPerSequence( + std::reference_wrapper<MSMetaData> meta_data) { + return meta_data.get().GetObservationTimesPerSequence(); +} + +static const std::vector<AntennaInfo>& GetAntennae( + std::reference_wrapper<MSMetaData> meta_data) { + return meta_data.get().GetAntennas(); +} + +static const std::vector<FieldInfo>& GetFields( + std::reference_wrapper<MSMetaData> meta_data) { + return meta_data.get().GetFields(); +} + +static const std::vector<MSMetaData::Sequence>& GetSequences( + std::reference_wrapper<MSMetaData> meta_data) { + return meta_data.get().GetSequences(); +} + +template <class T, class F> +static void ValidateEqual( + const T& lhs, + std::vector<std::reference_wrapper<MSMetaData>>::const_iterator first, + std::vector<std::reference_wrapper<MSMetaData>>::const_iterator last, + F&& functor) { + if (!std::all_of(first, last, + [&](const std::reference_wrapper<MSMetaData>& element) { + return lhs == functor(element); + })) + throw std::runtime_error("The MS not compatible"); +} + +template <class F> +static auto ExtractField( + const std::vector<std::reference_wrapper<MSMetaData>>& meta_data, F functor) + -> typename std::remove_reference< + typename std::remove_cv<decltype(functor(meta_data[0]))>::type>::type { + assert(!meta_data.empty()); + auto result = functor(meta_data[0]); + ValidateEqual(result, meta_data.begin() + 1, meta_data.end(), + std::move(functor)); + return result; +} + +void MultiBandMsImageSet::ReadData(size_t thread_count) { + Stopwatch watch(true); + aocommon::ParallelFor<size_t> executor(thread_count); + executor.Run(0, readers_.size(), + [&](size_t i, size_t) { readers_[i]->Read(); }); + Logger::Debug << "Reading took " << watch.ToString() << ".\n"; +} + +// Returns the metadata of the readers and initialized their main tables. +static std::vector<std::reference_wrapper<MSMetaData>> GetInitializedMetaData( + std::vector<std::unique_ptr<SingleBandMemoryBaselineReader>>::iterator + first, + std::vector<std::unique_ptr<SingleBandMemoryBaselineReader>>::iterator + last) { + std::vector<std::reference_wrapper<MSMetaData>> result; + std::transform(first, last, std::back_inserter(result), + [](std::unique_ptr<SingleBandMemoryBaselineReader>& reader) { + MSMetaData& meta_data{reader->MetaData()}; + meta_data.initializeMainTableData(); + return std::reference_wrapper<MSMetaData>{meta_data}; + }); + return result; +} + +void MultiBandMsImageSet::ProcessMetaData() { + const std::vector<std::reference_wrapper<MSMetaData>> meta_data = + GetInitializedMetaData(readers_.begin(), readers_.end()); + + // These fields are only validated. + ExtractField(meta_data, GetBaselines); + ExtractField(meta_data, GetObservationTimes); + + // These fields are validated and cached. + observation_times_per_sequence_ = + ExtractField(meta_data, GetObservationTimesPerSequence); + antennae_ = ExtractField(meta_data, GetAntennae); + fields_ = ExtractField(meta_data, GetFields); + sequences_ = ExtractField(meta_data, GetSequences); + bands_ = CombineBands(meta_data); +} + +size_t MultiBandMsImageSet::findBaselineIndex(size_t antenna1, size_t antenna2, + size_t band, + size_t sequenceId) const { + // TODO This is a linear search... + size_t index = 0; + for (const auto& sequence : sequences_) { + const bool antennaMatch = + (sequence.antenna1 == antenna1 && sequence.antenna2 == antenna2) || + (sequence.antenna1 == antenna2 && sequence.antenna2 == antenna1); + + if (antennaMatch && sequence.sequenceId == sequenceId) { + return index + band * sequences_.size(); + } + ++index; + } + + return not_found; +} + +void MultiBandMsImageSet::PerformReadRequests(ProgressListener&) { + if (!data_.empty()) + throw std::runtime_error( + "ReadRequest() called, but a previous read request was not " + "completely processed by calling GetNextRequested()."); + + for (const auto& index : read_requests_) { + const auto& sequence = sequences_[GetSequenceId(index)]; + SingleBandMemoryBaselineReader& reader = *readers_[GetBandIndex(index)]; + data_.emplace_back(reader.GetData(sequence, index)); + } + + read_requests_.clear(); +} + +std::unique_ptr<BaselineData> MultiBandMsImageSet::GetNextRequested() { + std::unique_ptr<BaselineData> result = std::move(data_.front()); + data_.pop_front(); + + if (result->Data().IsEmpty()) + throw std::runtime_error( + "Calling GetNextRequested(), but requests were not read with " + "LoadRequests."); + return result; +} + +} // namespace imagesets diff --git a/imagesets/multibandmsimageset.h b/imagesets/multibandmsimageset.h new file mode 100644 index 0000000000000000000000000000000000000000..fe9c4983b600f74f58a1b446293038cc6cb3f398 --- /dev/null +++ b/imagesets/multibandmsimageset.h @@ -0,0 +1,187 @@ +#ifndef MULTIBANDMSIMAGESET_H +#define MULTIBANDMSIMAGESET_H + +#include "../../msio/singlemandmemorybaselinereader.h" + +#include "indexableset.h" + +#include <deque> +#include <string> +#include <vector> + +namespace imagesets { + +struct MetaData; + +/** + * The multiband images combine multiple single band measurement sets. + * + * @pre Each image contains a different bands of the same measurement. + */ +class MultiBandMsImageSet final : public IndexableSet { + public: + MultiBandMsImageSet(std::vector<std::string> names, BaselineIOMode io_mode, + size_t thread_count); + MultiBandMsImageSet(const MultiBandMsImageSet &) = default; + + ~MultiBandMsImageSet() = default; + +#if 0 + size_t StartTimeIndex(const ImageSetIndex &index) const { return 0; } + + size_t EndTimeIndex(const ImageSetIndex &index) const; +#endif + std::unique_ptr<ImageSet> Clone() override { + throw std::runtime_error("Not available"); + } + + size_t Size() const override { return sequences_.size() * BandCount(); } + + std::string Description(const ImageSetIndex &index) const override { + // TODO Find a better description... + return "Multiband set"; + } // XXX + std::string Name() const override { + // TODO Find a better name + return "Multiband set"; + } + std::vector<std::string> Files() const override { return ms_names_; } + + void AddReadRequest(const ImageSetIndex &index) override { + read_requests_.emplace_back(index); + } + void PerformReadRequests(class ProgressListener &progress) override; + std::unique_ptr<BaselineData> GetNextRequested() override; + + void AddWriteFlagsTask(const ImageSetIndex &index, + std::vector<Mask2DCPtr> &flags) override { + throw std::runtime_error("Not implemented"); + } + void PerformWriteFlagsTask() override { + throw std::runtime_error("Not implemented"); + } + + void Initialize() override { + // TODO should this do somoething + /* Do nothing.*/ + } + + void PerformWriteDataTask(const ImageSetIndex &index, + std::vector<Image2DCPtr> realImages, + std::vector<Image2DCPtr> imaginaryImages) override { + throw std::runtime_error("Not implemented"); + } + + BaselineReaderPtr Reader() const override { + throw std::runtime_error("Not available"); + } + + std::string TelescopeName() override { + // TODO Should we cache this value? + auto ms = readers_[0]->OpenMS(); + return MSMetaData::GetTelescopeName(ms); + } + + size_t GetAntenna1(const ImageSetIndex &index) const override { + return sequences_[GetSequenceIndex(index)].antenna1; + } + size_t GetAntenna2(const ImageSetIndex &index) const override { + return sequences_[GetSequenceIndex(index)].antenna2; + } + size_t GetBand(const ImageSetIndex &index) const override { + return GetBandIndex(index); + } + size_t GetField(const ImageSetIndex &index) const override { + return sequences_[GetSequenceIndex(index)].fieldId; + } + size_t GetSequenceId(const ImageSetIndex &index) const override { + return sequences_[GetSequenceIndex(index)].sequenceId; + } + + boost::optional<ImageSetIndex> Index(size_t antenna1, size_t antenna2, + size_t bandIndex, + size_t sequenceId) const override { + size_t value = findBaselineIndex(antenna1, antenna2, bandIndex, sequenceId); + if (value == not_found) + return boost::optional<ImageSetIndex>(); + else + return ImageSetIndex(Size(), value); + } + /* + const std::string &DataColumnName() const { return _dataColumnName; } + void SetDataColumnName(const std::string &name) { + if (_reader != 0) + throw std::runtime_error( + "Trying to set data column after creating the reader!"); + _dataColumnName = name; + } + */ + size_t BandCount() const override { return bands_.size(); } + AntennaInfo GetAntennaInfo(unsigned antennaIndex) const override { + return antennae_[antennaIndex]; + } + BandInfo GetBandInfo(unsigned bandIndex) const override { + return bands_[bandIndex]; + } + + size_t SequenceCount() const override { + return observation_times_per_sequence_.size() * BandCount(); + } + + size_t AntennaCount() const override { return antennae_.size(); } + FieldInfo GetFieldInfo(unsigned fieldIndex) const override { + return fields_[fieldIndex]; + } +#if 0 + std::vector<double> ObservationTimesVector(const ImageSetIndex &index); + size_t FieldCount() const { return _fieldCount; } + void SetReadFlags(bool readFlags) { _readFlags = readFlags; } + void SetReadUVW(bool readUVW) { _readUVW = readUVW; } + const std::vector<MSMetaData::Sequence> &Sequences() const { + return _sequences; + } + + void SetInterval(boost::optional<size_t> start, boost::optional<size_t> end) { + _intervalStart = start; + _intervalEnd = end; + if (start) _metaData.SetIntervalStart(start.get()); + if (end) _metaData.SetIntervalEnd(end.get()); + } +#endif + private: + size_t GetSequenceIndex(const ImageSetIndex &index) const { + return index.Value() % sequences_.size(); + } + + size_t GetBandIndex(const ImageSetIndex &index) const { + return index.Value() / sequences_.size(); + } + + const static size_t not_found = std::numeric_limits<size_t>::max(); + size_t findBaselineIndex(size_t antenna1, size_t antenna2, size_t band, + size_t sequenceId) const; + + void ReadData(size_t thread_count); + void ProcessMetaData(); + + std::vector<std::string> ms_names_; + BaselineIOMode io_mode_; + + // XXX other readers use a vector and pop_front. + std::deque<std::unique_ptr<BaselineData>> data_; + std::vector<ImageSetIndex> read_requests_; + + // A copy of the meta data + std::vector<AntennaInfo> antennae_; + std::vector<FieldInfo> fields_; + std::vector<BandInfo> bands_; + std::vector<MSMetaData::Sequence> sequences_; + // Only used for the sequence count, can this be removed? + std::vector<std::set<double>> observation_times_per_sequence_; + + std::vector<std::unique_ptr<SingleBandMemoryBaselineReader>> readers_; +}; + +} // namespace imagesets + +#endif // MULTIBANDMSIMAGESET_H diff --git a/msio/memorybaselinereader.cpp b/msio/memorybaselinereader.cpp index d41cb5cd97be0d1f1396cefcf4b783bbd760e8fe..b1a85dce57300cc0e34c13401aff70863d197c1b 100644 --- a/msio/memorybaselinereader.cpp +++ b/msio/memorybaselinereader.cpp @@ -69,8 +69,8 @@ void MemoryBaselineReader::readSet(ProgressListener& progress) { size_t antennaCount = MetaData().AntennaCount(), polarizationCount = Polarizations().size(), - bandCount = MetaData().BandCount(), - sequenceCount = MetaData().SequenceCount(), intStart = IntervalStart(), + bandCount = MetaData().BandCount(), // XXX + sequenceCount = MetaData().SequenceCount(), intStart = IntervalStart(), intEnd = IntervalEnd(); std::vector<size_t> dataDescIdToSpw; @@ -120,7 +120,7 @@ void MemoryBaselineReader::readSet(ProgressListener& progress) { std::unique_ptr<Result>& result = baselineCube[spwFieldIndex][ant1][ant2]; if (result == nullptr) { const size_t timeStepCount = ObservationTimes(sequenceId).size(); - const size_t nFreq = MetaData().FrequencyCount(spw); + const size_t nFreq = MetaData().FrequencyCount(spw); // XXX result.reset(new Result()); for (size_t p = 0; p != polarizationCount; ++p) { result->_realImages.emplace_back( @@ -184,7 +184,7 @@ void MemoryBaselineReader::readSet(ProgressListener& progress) { // Move elements from matrix into the baseline map. for (size_t s = 0; s != sequenceCount; ++s) { - for (size_t b = 0; b != bandCount; ++b) { + for (size_t b = 0; b != bandCount; ++b) { // XXX size_t fbIndex = s * bandCount + b; for (size_t a1 = 0; a1 != antennaCount; ++a1) { for (size_t a2 = a1; a2 != antennaCount; ++a2) { diff --git a/msio/singlemandmemorybaselinereader.cpp b/msio/singlemandmemorybaselinereader.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8bf3f68118683e8970ac649fa9e61c5ecd7d2c31 --- /dev/null +++ b/msio/singlemandmemorybaselinereader.cpp @@ -0,0 +1,227 @@ +#include "singlemandmemorybaselinereader.h" + +#include "../util/logger.h" +#include "../util/progresslistener.h" +#include "../structures/timefrequencydata.h" +#include "../structures/timefrequencymetadata.h" +#include "msselection.h" + +#include <casacore/tables/Tables/ArrayColumn.h> +#include <casacore/tables/Tables/ScalarColumn.h> + +#include <boost/make_unique.hpp> + +TimeFrequencyData SingleBandMemoryBaselineReader::GetData(Result& baseline) { + // Need a reference to BaselineReader::_polarizations to avoid + // having a reference to a temporary in the TimeFrequencyData. + std::vector<aocommon::PolarizationEnum>& polarizations = + const_cast<std::vector<aocommon::PolarizationEnum>&>(Polarizations()); + TimeFrequencyData data{polarizations.data(), polarizations.size(), + baseline._realImages.data(), + baseline._imaginaryImages.data()}; + data.SetIndividualPolarizationMasks(baseline._flags.data()); + + return data; +} + +TimeFrequencyMetaDataCPtr SingleBandMemoryBaselineReader::GetMetaData( + const MSMetaData::Sequence& sequence, Result& baseline) { + auto result = boost::make_unique<TimeFrequencyMetaData>(); + + MSMetaData& meta_data = MetaData(); + result->SetAntenna1(meta_data.GetAntennaInfo(sequence.antenna1)); + result->SetAntenna2(meta_data.GetAntennaInfo(sequence.antenna2)); + result->SetBand(meta_data.GetBandInfo(0)); + result->SetField(meta_data.GetFieldInfo(sequence.fieldId)); + const std::set<double>& observation_times = + meta_data.GetObservationTimesSet(sequence.sequenceId); + result->SetObservationTimes( + std::vector<double>(observation_times.begin(), observation_times.end())); + result->SetUVW(baseline._uvw); + + return TimeFrequencyMetaDataCPtr{result.release()}; +} + +// XXX move from WSClean to aocommon + +#if __cplusplus < 201703L +namespace detail { + +inline void Stream(std::stringstream&) {} + +template <class T, class... Args> +void Stream(std::stringstream& sstr, const T& value, const Args&... args) { + sstr << value; + Stream(sstr, args...); +} +} // namespace detail +#endif + +/** + * Helper function to throw a @c std::runtime_error. + * + * The function concatenates the @a args to a string and uses that message as + * error message for the exception. + */ +template <class... Args> +[[noreturn]] void ThrowRuntimeError(const Args&... args) { + std::stringstream sstr; +#if __cplusplus > 201402L + ((sstr << args), ...); +#else + detail::Stream(sstr, args...); +#endif + throw std::runtime_error(sstr.str()); +} + +std::unique_ptr<imagesets::BaselineData> +SingleBandMemoryBaselineReader::GetData(const MSMetaData::Sequence& sequence, + const imagesets::ImageSetIndex& index) { + assert(sequence.antenna1 < MetaData().AntennaCount()); + assert(sequence.antenna2 < MetaData().AntennaCount()); + assert(sequence.sequenceId < MetaData().SequenceCount()); + + size_t antenna_1 = sequence.antenna1; + size_t antenna_2 = sequence.antenna2; + if (antenna_1 > antenna_2) std::swap(antenna_1, antenna_2); + + const std::unique_ptr<Result>& baseline = + baselines_[sequence.sequenceId][antenna_1][antenna_2]; + + if (!baseline) + ThrowRuntimeError( + "Exception in PerformReadRequests(): requested baseline is not " + "available in measurement set (antenna1=", + sequence.antenna1, ", antenna2=", sequence.antenna2, + ", sequenceId=", sequence.sequenceId, ")"); + + return boost::make_unique<imagesets::BaselineData>( + GetData(*baseline), GetMetaData(sequence, *baseline), index); +} + +template <class R> +static R GetColumn(casacore::MeasurementSet& ms, + const std::string& column_name) { + return R(ms, column_name); +} +template <class R> +static R GetColumn(casacore::MeasurementSet& ms, + casacore::MSMainEnums::PredefinedColumns column_name) { + return GetColumn<R>(ms, casacore::MeasurementSet::columnName(column_name)); +} + +struct Columns { + explicit Columns(casacore::MeasurementSet& ms, + const std::string& data_column_name) + : antenna_1(GetColumn<decltype(antenna_1)>( + ms, casacore::MSMainEnums::ANTENNA1)), + antenna_2(GetColumn<decltype(antenna_2)>( + ms, casacore::MSMainEnums::ANTENNA2)), + data_desc_id(GetColumn<decltype(data_desc_id)>( + ms, casacore::MSMainEnums::DATA_DESC_ID)), + data(GetColumn<decltype(data)>(ms, data_column_name)), + flag(GetColumn<decltype(flag)>(ms, casacore::MSMainEnums::FLAG)), + uvw(GetColumn<decltype(uvw)>(ms, casacore::MSMainEnums::UVW)) {} + + casacore::ScalarColumn<int> antenna_1; + casacore::ScalarColumn<int> antenna_2; + casacore::ScalarColumn<int> data_desc_id; + casacore::ArrayColumn<casacore::Complex> data; + casacore::ArrayColumn<bool> flag; + casacore::ArrayColumn<double> uvw; +}; + +void SingleBandMemoryBaselineReader::InitialiseBaseLines() { + baselines_.resize(MetaData().SequenceCount()); + + const size_t num_antennae = MetaData().AntennaCount(); + for (auto& matrix : baselines_) { + matrix.resize(num_antennae); + for (size_t a1 = 0; a1 != num_antennae; ++a1) { + matrix[a1].resize(num_antennae); + for (size_t a2 = 0; a2 != num_antennae; ++a2) matrix[a1][a2] = 0; + } + } +} + +void SingleBandMemoryBaselineReader::Read() { + initializeMeta(); + if (MetaData().BandCount() != 1) + throw std::runtime_error( + "Only Measurement Sets with one band are supported."); + + Logger::Debug << "Reading the data (interval={" << IntervalStart() << "..." + << IntervalEnd() << "})...\n"; + + casacore::MeasurementSet ms(OpenMS()); + Columns columns{ms, DataColumnName()}; + + std::vector<size_t> data_desc_to_band; + MetaData().GetDataDescToBandVector(data_desc_to_band); + InitialiseBaseLines(); + + const BandInfo& band_info = MetaData().GetBandInfo(0); + const size_t num_frequencies = band_info.channels.size(); + const size_t num_polarizations = Polarizations().size(); + + DummyProgressListener progress; + MSSelection msSelection(ms, ObservationTimesPerSequence(), progress); + msSelection.Process([&](size_t row_index, size_t sequenceId, + size_t timeIndexInSequence) { + size_t antenna_1 = columns.antenna_1(row_index); + size_t antenna_2 = columns.antenna_2(row_index); + if (antenna_1 > antenna_2) std::swap(antenna_1, antenna_2); + + const size_t spw = data_desc_to_band[columns.data_desc_id(row_index)]; + const size_t spw_index = spw + sequenceId; + std::unique_ptr<Result>& result = + baselines_[spw_index][antenna_1][antenna_2]; + + if (result == nullptr) { + const size_t num_time_steps = ObservationTimes(sequenceId).size(); + result.reset(new Result()); + for (size_t p = 0; p != num_polarizations; ++p) { + result->_realImages.emplace_back( + Image2D::CreateZeroImagePtr(num_time_steps, num_frequencies)); + result->_imaginaryImages.emplace_back( + Image2D::CreateZeroImagePtr(num_time_steps, num_frequencies)); + result->_flags.emplace_back( + Mask2D::CreateSetMaskPtr<true>(num_time_steps, num_frequencies)); + } + result->_bandInfo = band_info; + result->_uvw.resize(num_time_steps); + } + + casacore::Array<double> uvw = columns.uvw.get(row_index); + double* uvw_ptr = uvw.data(); + result->_uvw[timeIndexInSequence] = UVW(uvw_ptr[0], uvw_ptr[1], uvw_ptr[2]); + + auto data = columns.data.get(row_index); + casacore::Array<bool> flagArray = columns.flag.get(row_index); + for (size_t p = 0; p != num_polarizations; ++p) { + Image2D& real = *result->_realImages[p]; + Image2D& imag = *result->_imaginaryImages[p]; + Mask2D& mask = *result->_flags[p]; + const size_t image_stride = real.Stride(); + const size_t mask_stride = mask.Stride(); + num_t* real_out_ptr = real.ValuePtr(timeIndexInSequence, 0); + num_t* imag_out_ptr = imag.ValuePtr(timeIndexInSequence, 0); + bool* flag_out_ptr = mask.ValuePtr(timeIndexInSequence, 0); + + auto data_ptr = data.cbegin() + p; + auto flag_ptr = flagArray.cbegin() + p; + for (size_t ch = 0; ch != num_frequencies; ++ch) { + *real_out_ptr = data_ptr->real(); + *imag_out_ptr = data_ptr->imag(); + *flag_out_ptr = *flag_ptr; + + real_out_ptr += image_stride; + imag_out_ptr += image_stride; + flag_out_ptr += mask_stride; + + data_ptr += num_polarizations; + flag_ptr += num_polarizations; + } + } + }); +} diff --git a/msio/singlemandmemorybaselinereader.h b/msio/singlemandmemorybaselinereader.h new file mode 100644 index 0000000000000000000000000000000000000000..12ae1dfff14ea937ccdf14a896e875d0acdc3069 --- /dev/null +++ b/msio/singlemandmemorybaselinereader.h @@ -0,0 +1,78 @@ +#ifndef SINGLE_BAND_MEMORY_BASELINE_READER_H +#define SINGLE_BAND_MEMORY_BASELINE_READER_H + +#include "baselinereader.h" + +#include "../imagesets/imageset.h" + +/** + * Single band memory based baseline reader. + * + * This reads all data in a memory structure. The class offers a different + * interface to extract the data, @ref GetData returns the requested data. This + * differs from the other since the controlling @ref MultiBandMsImageSet needs + * to retrieve its data from multiple readers. Using the normal request method + * doesn't work to well in this case. + * + * @pre Each image contains one band. + */ +class SingleBandMemoryBaselineReader final : public BaselineReader { + public: + explicit SingleBandMemoryBaselineReader(const std::string& msFile) + : BaselineReader(msFile) {} + + /** + * Reads the data. + * + * This could be called from the constructor, but is separated. + * After construction the @ref MultiBandMsImageSet executes this function on + * several instances of this class in parallel. + */ + void Read(); + + void PerformReadRequests(ProgressListener&) override { + throw std::runtime_error( + "The full mem reader can not write data back to file: use the indirect " + "reader"); + }; + + void PerformFlagWriteRequests() override { + throw std::runtime_error( + "The full mem reader can not write data back to file: use the indirect " + "reader"); + } + + void PerformDataWriteTask(std::vector<Image2DCPtr>, std::vector<Image2DCPtr>, + size_t, size_t, size_t, size_t) override { + throw std::runtime_error( + "The full mem reader can not write data back to file: use the indirect " + "reader"); + } + + size_t GetMinRecommendedBufferSize(size_t) override { return 1; } + size_t GetMaxRecommendedBufferSize(size_t) override { return 2; } + + /** + * Returns the requested data. + * + * @pre Read has been called. + */ + std::unique_ptr<imagesets::BaselineData> GetData( + const MSMetaData::Sequence& sequence, + const imagesets::ImageSetIndex& index); + + private: + TimeFrequencyData GetData(Result& baseline); + TimeFrequencyMetaDataCPtr GetMetaData(const MSMetaData::Sequence& sequence, + Result& baseline); + + using MatrixElement = std::unique_ptr<Result>; + using MatrixRow = std::vector<MatrixElement>; + using BaselineMatrix = std::vector<MatrixRow>; + using BaselineCube = std::vector<BaselineMatrix>; + + BaselineCube baselines_; + + void InitialiseBaseLines(); +}; +#endif // SINGLE_BAND_MEMORY_BASELINE_READER_H