Skip to content
Snippets Groups Projects

RAP-886 Add a Filter sub-step to FlagTransfer

Merged RAP-886 Add a Filter sub-step to FlagTransfer
All threads resolved!
Merged Mick Veldhuis requested to merge rap-886-flagtransfer-filter-substep into master
All threads resolved!
+ 46
13
@@ -68,19 +68,23 @@ FlagTransfer::FlagTransfer(const common::ParameterSet& parset,
// Initialise flags_ to match the shape of the flags stored in DPBuffer
flags_ = xt::xtensor<bool, 3>({n_baselines, n_channels, n_correlations});
// Create a filter substep and connect it to a result step
filter_step_ = std::make_shared<Filter>(parset, prefix);
result_step_ = std::make_shared<ResultStep>();
filter_step_->setNextStep(result_step_);
}
void FlagTransfer::ReadSourceMsFlags() {
casacore::Table table = ms_iterator_.table();
casacore::ArrayColumn<bool> flag_col(
table, casacore::MS::columnName(casacore::MS::FLAG));
casacore::Cube<bool> flag_data = flag_col.getColumn();
const casacore::IPosition casa_shape(3, flags_.shape()[2], flags_.shape()[1],
flags_.shape()[0]);
casacore::Cube<bool> casa_flags(casa_shape, flags_.data(), casacore::SHARE);
casa_flags = flag_data;
casa_flags = flag_col.getColumn();
}
void FlagTransfer::show(std::ostream& os) const {
@@ -89,6 +93,7 @@ void FlagTransfer::show(std::ostream& os) const {
<< " Time avg factor: " << time_averaging_factor_ << '\n'
<< " Source interval: " << time_interval_ << '\n'
<< " Target interval: " << getInfoOut().timeInterval() << '\n';
filter_step_->show(os);
}
void FlagTransfer::showTimings(std::ostream& os, double duration) const {
@@ -108,6 +113,27 @@ bool FlagTransfer::process(std::unique_ptr<DPBuffer> buffer) {
}
}
common::Fields filter_fields = filter_step_->getRequiredFields();
std::unique_ptr<DPBuffer> substep_buffer =
std::make_unique<DPBuffer>(*buffer, filter_fields);
filter_step_->process(std::move(substep_buffer));
std::unique_ptr<DPBuffer> result_buffer = result_step_->take();
std::vector<common::rownr_t> filtered_row_numbers =
result_buffer->GetRowNumbers().tovector();
const common::rownr_t first_row_number = buffer->GetRowNumbers()[0];
for (auto& row_number : filtered_row_numbers) {
row_number -= first_row_number;
}
if (filtered_row_numbers.size() != flags_.shape()[0]) {
throw std::runtime_error(
"FlagTransfer requires that the source and target MS have an equal "
"number of baselines. Note: baselines might not have been properly "
"filtered!");
}
// Fill flags if DPBuffer is empty
base::DPBuffer::FlagsType& flags = buffer->GetFlags();
if (flags.size() == 0) {
@@ -117,16 +143,22 @@ bool FlagTransfer::process(std::unique_ptr<DPBuffer> buffer) {
}
const std::size_t n_source_channels = flags_.shape()[1];
if (n_source_channels == getInfoOut().nchan()) {
const std::size_t n_source_baselines = flags_.shape()[0];
if (n_source_channels == getInfoOut().nchan() &&
n_source_baselines == getInfoOut().nbaselines()) {
flags = flags_;
} else if (n_source_channels == getInfoOut().nchan()) {
for (std::size_t source_row = 0; source_row < n_source_baselines;
++source_row) {
common::rownr_t target_row = filtered_row_numbers[source_row];
xt::view(flags, target_row, xt::all(), xt::all()) =
xt::view(flags_, source_row, xt::all(), xt::all());
}
} else {
std::size_t target_channel = 0;
const std::vector<double> target_frequencies = getInfoOut().chanFreqs(0);
for (std::size_t channel_block = 0; channel_block < n_source_channels;
++channel_block) {
auto source_flags_view =
xt::view(flags_, xt::all(), channel_block, xt::all());
const double source_channel_edge =
source_channel_upper_edges_[channel_block];
while (target_frequencies[target_channel] < source_channel_edge) {
@@ -134,8 +166,13 @@ bool FlagTransfer::process(std::unique_ptr<DPBuffer> buffer) {
break;
}
xt::view(flags, xt::all(), target_channel, xt::all()) =
source_flags_view;
for (std::size_t source_row = 0; source_row < n_source_baselines;
++source_row) {
common::rownr_t target_row = filtered_row_numbers[source_row];
xt::view(flags, target_row, target_channel, xt::all()) =
xt::view(flags_, source_row, channel_block, xt::all());
}
++target_channel;
}
}
@@ -151,11 +188,7 @@ bool FlagTransfer::process(std::unique_ptr<DPBuffer> buffer) {
void FlagTransfer::updateInfo(const base::DPInfo& info_in) {
Step::updateInfo(info_in);
if (getInfoOut().nbaselines() != flags_.shape()[0]) {
throw std::runtime_error(
"FlagTransfer requires that the source and target MS have an equal "
"number of baselines");
}
filter_step_->setInfo(info_in);
if (getInfoOut().ncorr() != flags_.shape()[2]) {
throw std::runtime_error(
Loading