diff --git a/ddecal/gain_solvers/SolveData.cc b/ddecal/gain_solvers/SolveData.cc index 4cba429e3727b68bb49d17087d98b268f13d105d..f90535c3d22f3ebd391785cf5719b8efe6850b01 100644 --- a/ddecal/gain_solvers/SolveData.cc +++ b/ddecal/gain_solvers/SolveData.cc @@ -15,7 +15,7 @@ namespace dp3 { namespace ddecal { SolveData::SolveData(const std::vector<base::DPBuffer>& buffers, - const std::vector<std::string>& direction_keys, + const std::vector<std::string>& direction_names, size_t n_channel_blocks, size_t n_antennas, const std::vector<size_t>& n_solutions_per_direction, const std::vector<int>& antennas1, @@ -28,7 +28,8 @@ SolveData::SolveData(const std::vector<base::DPBuffer>& buffers, buffers.empty() ? 0 : buffers.front().GetData().shape(0); const size_t n_channels = buffers.empty() ? 0 : buffers.front().GetData().shape(1); - const size_t n_directions = direction_keys.size(); + const size_t n_directions = direction_names.size(); + const bool has_weights = buffers.front().GetWeights().size() != 0; // Count nr of baselines with different antennas. size_t n_baselines = 0; @@ -57,6 +58,8 @@ SolveData::SolveData(const std::vector<base::DPBuffer>& buffers, channel_begin[channel_block_index]; cb_data.Resize(n_times * n_baselines * channel_block_size, n_directions); + if (has_weights) + cb_data.ResizeWeights(n_times * n_baselines * channel_block_size); cb_data.n_solutions_ = channel_blocks_.front().n_solutions_; } @@ -74,6 +77,8 @@ SolveData::SolveData(const std::vector<base::DPBuffer>& buffers, std::vector<size_t> visibility_indices(n_channel_blocks, 0); for (size_t time_index = 0; time_index < n_times; ++time_index) { const base::DPBuffer::DataType& data = buffers[time_index].GetData(""); + const base::DPBuffer::WeightsType& weights = + buffers[time_index].GetWeights(); for (size_t baseline = 0; baseline < n_baselines_in_buffers; ++baseline) { const size_t antenna1 = antennas1[baseline]; @@ -93,10 +98,17 @@ SolveData::SolveData(const std::vector<base::DPBuffer>& buffers, cb_data.antenna_indices_[vis_index + i] = std::pair<uint32_t, uint32_t>(antenna1, antenna2); } + if (has_weights) { + for (size_t i = 0; i < channel_block_size; ++i) { + for (size_t p = 0; p != weights.shape()[2]; ++p) + cb_data.weights_(vis_index + i, p) = + weights(baseline, first_channel + i, p); + } + } for (size_t direction = 0; direction < n_directions; ++direction) { const base::DPBuffer::DataType& model_data = - buffers[time_index].GetData(direction_keys[direction]); + buffers[time_index].GetData(direction_names[direction]); const size_t n_solutions = channel_blocks_.front().n_solutions_[direction]; // Calculate the absolute index as required for solution_map_ @@ -123,7 +135,7 @@ SolveData::SolveData(const std::vector<base::DPBuffer>& buffers, SolveData::SolveData(const BdaSolverBuffer& buffer, size_t n_channel_blocks, size_t n_directions, size_t n_antennas, const std::vector<int>& antennas1, - const std::vector<int>& antennas2) + const std::vector<int>& antennas2, bool with_weights) : channel_blocks_(n_channel_blocks) { // Count nr of visibilities std::vector<size_t> counts(n_channel_blocks, 0); @@ -146,6 +158,9 @@ SolveData::SolveData(const BdaSolverBuffer& buffer, size_t n_channel_blocks, // Allocate for (size_t cb = 0; cb != n_channel_blocks; ++cb) { channel_blocks_[cb].Resize(counts[cb], n_directions); + if (with_weights) { + channel_blocks_[cb].ResizeWeights(counts[cb]); + } } // Fill @@ -173,6 +188,14 @@ SolveData::SolveData(const BdaSolverBuffer& buffer, size_t n_channel_blocks, cb_data.antenna_indices_[vis_index + i] = std::pair<uint32_t, uint32_t>(antenna1, antenna2); } + if (with_weights) { + for (size_t i = 0; i != channel_block_size; ++i) { + for (size_t p = 0; p != cb_data.weights_.shape()[2]; ++p) { + cb_data.weights_(vis_index + i, p) = + 1.0; // TODO use real vis weights + } + } + } for (size_t dir = 0; dir != n_directions; ++dir) { const BDABuffer::Row& model_data_row = diff --git a/ddecal/gain_solvers/SolveData.h b/ddecal/gain_solvers/SolveData.h index 6b0f068e8dd329eaaa272dfe3826245e4fde7067..b8c4dcbc56ec6b2e876e2af429707a3d46322167 100644 --- a/ddecal/gain_solvers/SolveData.h +++ b/ddecal/gain_solvers/SolveData.h @@ -35,6 +35,10 @@ class SolveData { n_solutions_.resize(n_directions); solution_map_.resize({n_directions, n_visibilities}); } + void ResizeWeights(size_t n_visibilities) { + constexpr size_t kNCorrelations = 4; + weights_.resize({n_visibilities, kNCorrelations}); + } size_t NDirections() const { return model_data_.shape(0); } size_t NVisibilities() const { return data_.size(); } /*** @@ -65,6 +69,7 @@ class SolveData { const uint32_t* SolutionMapData() const { return solution_map_.data(); } + const float& Weight(size_t index) const { return weights_(index, 0); } const aocommon::MC2x2F& Visibility(size_t index) const { return data_[index]; } @@ -92,6 +97,9 @@ class SolveData { void InitializeSolutionIndices(); std::vector<aocommon::MC2x2F> data_; + // weights_(i, pol) contains the weight for data_[i][pol]. The vector will + // be left empty when the algorithm does not need the weights. + xt::xtensor<float, 2> weights_; // model_data_(d, i) is the model data for direction d, element i xt::xtensor<aocommon::MC2x2F, 2> model_data_; // Element i contains the first and second antenna corresponding with @@ -139,7 +147,7 @@ class SolveData { SolveData(const BdaSolverBuffer& buffer, size_t n_channel_blocks, size_t n_directions, size_t n_antennas, const std::vector<int>& antennas1, - const std::vector<int>& antennas2); + const std::vector<int>& antennas2, bool with_weights); size_t NChannelBlocks() const { return channel_blocks_.size(); } diff --git a/ddecal/gain_solvers/SolverTools.cc b/ddecal/gain_solvers/SolverTools.cc index 9b7423b4e44e459794cb707131d57caf5a0106fb..e7225f82887002323c1282f1ed1b130dd67003a1 100644 --- a/ddecal/gain_solvers/SolverTools.cc +++ b/ddecal/gain_solvers/SolverTools.cc @@ -80,7 +80,7 @@ void AssignAndWeight( std::vector<std::unique_ptr<base::DPBuffer>>& unweighted_buffers, const std::vector<std::string>& direction_names, std::vector<base::DPBuffer>& weighted_buffers, - bool keep_unweighted_model_data) { + bool keep_unweighted_model_data, bool linear_weighting_mode) { const std::size_t n_times = unweighted_buffers.size(); assert(weighted_buffers.size() >= n_times); @@ -125,25 +125,33 @@ void AssignAndWeight( // If the flag is set, set both the data and model data to zero. // Storing the result in an xtensor (and not in an expression) ensures // that the square root is evaluated once for each weight. - const xt::xtensor<float, 3> weights_sqrt = xt::sqrt(weights); + xt::xtensor<float, 3> prepared_weights; weighted_buffer.GetData().resize(unweighted_data.shape()); - Weigh(unweighted_data, weighted_buffer.GetData(), weights_sqrt); + if (linear_weighting_mode) { + prepared_weights = weights; + weighted_buffer.GetWeights().resize(unweighted_data.shape()); + weighted_buffer.GetWeights() = prepared_weights; + } else { + prepared_weights = xt::sqrt(weights); + } + Weigh(unweighted_data, weighted_buffer.GetData(), prepared_weights); - const std::complex<float> kZeroVisibility(0.0f, 0.0f); // TODO(AST-1278): Use 'const auto' instead of 'auto' for flags_view. // Although flags_view can be const, it may result in compiler errors. auto flags_view = xt::view(flags, xt::all(), xt::all(), xt::newaxis()); + + constexpr std::complex<float> kZeroVisibility(0.0f, 0.0f); xt::masked_view(weighted_buffer.GetData(), flags_view) = kZeroVisibility; for (const std::string& name : direction_names) { if (keep_unweighted_model_data) { if (!weighted_buffer.HasData(name)) weighted_buffer.AddData(name); Weigh(unweighted_buffer.GetData(name), weighted_buffer.GetData(name), - weights_sqrt); + prepared_weights); } else { weighted_buffer.MoveData(unweighted_buffer, name, name); DPBuffer::DataType& direction_buffer = weighted_buffer.GetData(name); - Weigh(direction_buffer, direction_buffer, weights_sqrt); + Weigh(direction_buffer, direction_buffer, prepared_weights); } xt::masked_view(weighted_buffer.GetData(name), flags_view) = kZeroVisibility; diff --git a/ddecal/gain_solvers/SolverTools.h b/ddecal/gain_solvers/SolverTools.h index fbe0399ada3489f4ad8ae62855e4d73c367901b4..fffd443246634e5fb2f4851ea193252253290f8b 100644 --- a/ddecal/gain_solvers/SolverTools.h +++ b/ddecal/gain_solvers/SolverTools.h @@ -45,12 +45,18 @@ void Weigh(const base::DPBuffer::DataType& in, base::DPBuffer::DataType& out, * unweighted_buffers and create new model data buffers in weighted_buffers. * If false, avoid creating new model data buffers by moving the model data * from unweighted_buffers to weighted_buffers and weighing the data in-place. + * @param linear_weighting_mode If true, it has two effects: the data + * will be linearly weighted instead of weighting them by the sqrt of the + * weights, and the weights are stored in the buffer. Gradient descent and + * conjugate gradient-like methods require sqrt weighted data but do not need + * the weights, whereas a rank-based approache requires linear weighted data and + * needs to know the applied weights. */ void AssignAndWeight( std::vector<std::unique_ptr<base::DPBuffer>>& unweighted_buffers, const std::vector<std::string>& direction_names, std::vector<base::DPBuffer>& weighted_buffers, - bool keep_unweighted_model_data); + bool keep_unweighted_model_data, bool linear_weighting_mode); } // namespace dp3::ddecal diff --git a/ddecal/test/unit/SolverTester.cc b/ddecal/test/unit/SolverTester.cc index 6b2164694009d3898242c50a7b70957ee70476b8..e4754845dc5e45bd8af4ae8f666df47a6e1c89cf 100644 --- a/ddecal/test/unit/SolverTester.cc +++ b/ddecal/test/unit/SolverTester.cc @@ -180,7 +180,7 @@ std::vector<dp3::base::DPBuffer> SolverTester::FillDdIntervalData() { } dp3::ddecal::AssignAndWeight(unweighted_buffers, direction_names, - weighted_buffers, false); + weighted_buffers, false, false); return weighted_buffers; } diff --git a/ddecal/test/unit/tSolveData.cc b/ddecal/test/unit/tSolveData.cc index 2f3e2af94daeb56b4cf0fd8efe16537a0989edcc..25b4d3b4b212b499103dce4b08a1d83816473bbe 100644 --- a/ddecal/test/unit/tSolveData.cc +++ b/ddecal/test/unit/tSolveData.cc @@ -109,7 +109,7 @@ BOOST_AUTO_TEST_CASE(regular) { } dp3::ddecal::AssignAndWeight(unweighted_buffers, {kDirectionName}, - weighted_buffers, false); + weighted_buffers, false, false); const dp3::ddecal::SolveData data( weighted_buffers, {kDirectionName}, kNChannelBlocks, kNAntennas, @@ -197,7 +197,7 @@ BOOST_AUTO_TEST_CASE(regular_with_dd_intervals) { } dp3::ddecal::AssignAndWeight(unweighted_buffers, kDirectionNames, - weighted_buffers, false); + weighted_buffers, false, false); const dp3::ddecal::SolveData data( weighted_buffers, kDirectionNames, kNChannelBlocks, kNAntennas, @@ -273,7 +273,7 @@ BOOST_AUTO_TEST_CASE(bda) { const dp3::ddecal::SolveData solve_data(solver_buffer, kNChannelBlocks, kNDirections, kNAntennas, kAntennas1, - kAntennas2); + kAntennas2, true); BOOST_TEST_REQUIRE(solve_data.NChannelBlocks() == kNChannelBlocks); for (size_t ch_block = 0; ch_block < kNChannelBlocks; ++ch_block) { diff --git a/ddecal/test/unit/tSolverTools.cc b/ddecal/test/unit/tSolverTools.cc index 2117b16038cd39540ec4fa8a7241aa8ba8e4f689..a8edcde3c81d32e7194240c154eb5be8812d55d9 100644 --- a/ddecal/test/unit/tSolverTools.cc +++ b/ddecal/test/unit/tSolverTools.cc @@ -77,7 +77,8 @@ class AssignAndWeightFixture { void DoAssignAndWeight(const bool keep_original_model_data) { dp3::ddecal::AssignAndWeight(unweighted_buffers_, kDirectionNames, - weighted_buffers_, keep_original_model_data); + weighted_buffers_, keep_original_model_data, + false); } private: diff --git a/ddecal/test/unit/tSolvers.cc b/ddecal/test/unit/tSolvers.cc index 6d4321441123db357b397a42d7abd5a2848dfb85..48b82641ef9234eb9d95aa0c936fd504781c8138 100644 --- a/ddecal/test/unit/tSolvers.cc +++ b/ddecal/test/unit/tSolvers.cc @@ -41,7 +41,7 @@ BOOST_FIXTURE_TEST_CASE(diagonal, SolverTester, const dp3::ddecal::BdaSolverBuffer& solver_buffer = FillBDAData(); const SolveData data(solver_buffer, kNChannelBlocks, kNDirections, kNAntennas, - Antennas1(), Antennas2()); + Antennas1(), Antennas2(), false); dp3::ddecal::SolverBase::SolveResult result = solver.Solve(data, GetSolverSolutions(), 0.0, nullptr); @@ -63,7 +63,7 @@ BOOST_FIXTURE_TEST_CASE(scalar, SolverTester, const dp3::ddecal::BdaSolverBuffer& solver_buffer = FillBDAData(); const SolveData data(solver_buffer, kNChannelBlocks, kNDirections, kNAntennas, - Antennas1(), Antennas2()); + Antennas1(), Antennas2(), false); dp3::ddecal::SolverBase::SolveResult result = solver.Solve(data, GetSolverSolutions(), 0.0, nullptr); @@ -85,7 +85,7 @@ BOOST_FIXTURE_TEST_CASE(iterative_scalar, SolverTester, const dp3::ddecal::BdaSolverBuffer& solver_buffer = FillBDAData(); const SolveData data(solver_buffer, kNChannelBlocks, kNDirections, kNAntennas, - Antennas1(), Antennas2()); + Antennas1(), Antennas2(), false); dp3::ddecal::SolverBase::SolveResult result = solver.Solve(data, GetSolverSolutions(), 0.0, nullptr); @@ -135,7 +135,7 @@ BOOST_FIXTURE_TEST_CASE(hybrid, SolverTester, const dp3::ddecal::BdaSolverBuffer& solver_buffer = FillBDAData(); const SolveData data(solver_buffer, kNChannelBlocks, kNDirections, kNAntennas, - Antennas1(), Antennas2()); + Antennas1(), Antennas2(), false); dp3::ddecal::SolverBase::SolveResult result = solver.Solve(data, GetSolverSolutions(), 0.0, nullptr); @@ -157,7 +157,8 @@ inline void TestIterativeDiagonal(SolverTester& solver_tester, solver_tester.FillBDAData(); const SolveData data(solver_buffer, SolverTester::kNChannelBlocks, SolverTester::kNDirections, SolverTester::kNAntennas, - solver_tester.Antennas1(), solver_tester.Antennas2()); + solver_tester.Antennas1(), solver_tester.Antennas2(), + false); dp3::ddecal::SolverBase::SolveResult result = solver.Solve(data, solver_tester.GetSolverSolutions(), 0.0, nullptr); @@ -218,7 +219,7 @@ BOOST_FIXTURE_TEST_CASE(full_jones, SolverTester, const dp3::ddecal::BdaSolverBuffer& solver_buffer = FillBDAData(); const SolveData data(solver_buffer, kNChannelBlocks, kNDirections, kNAntennas, - Antennas1(), Antennas2()); + Antennas1(), Antennas2(), false); // The full jones test uses full matrices as solutions and copies the // diagonals into the solver solutions from the SolverTester fixture. This @@ -265,7 +266,7 @@ BOOST_FIXTURE_TEST_CASE(iterative_full_jones, SolverTester, const dp3::ddecal::BdaSolverBuffer& solver_buffer = FillBDAData(); dp3::ddecal::SolveData data(solver_buffer, kNChannelBlocks, kNDirections, - kNAntennas, Antennas1(), Antennas2()); + kNAntennas, Antennas1(), Antennas2(), false); // The full jones test uses full matrices as solutions and copies the // diagonals into the solver solutions from the SolverTester fixture. This @@ -354,7 +355,7 @@ BOOST_FIXTURE_TEST_CASE(scalar_normaleq, SolverTester, const dp3::ddecal::BdaSolverBuffer& solver_buffer = FillBDAData(); const SolveData data(solver_buffer, kNChannelBlocks, kNDirections, kNAntennas, - Antennas1(), Antennas2()); + Antennas1(), Antennas2(), false); dp3::ddecal::SolverBase::SolveResult result = solver.Solve(data, GetSolverSolutions(), 0.0, nullptr); @@ -374,7 +375,7 @@ BOOST_FIXTURE_TEST_CASE(min_iterations, SolverTester, const dp3::ddecal::BdaSolverBuffer& solver_buffer = FillBDAData(); const SolveData data(solver_buffer, kNChannelBlocks, kNDirections, kNAntennas, - Antennas1(), Antennas2()); + Antennas1(), Antennas2(), false); dp3::ddecal::SolverBase::SolveResult result = solver.Solve(data, GetSolverSolutions(), 0.0, nullptr); @@ -396,7 +397,7 @@ BOOST_FIXTURE_TEST_CASE(lbfgs_diagonal, SolverTester, const dp3::ddecal::BdaSolverBuffer& solver_buffer = FillBDAData(); const SolveData data(solver_buffer, kNChannelBlocks, kNDirections, kNAntennas, - Antennas1(), Antennas2()); + Antennas1(), Antennas2(), false); dp3::ddecal::SolverBase::SolveResult result = solver.Solve(data, GetSolverSolutions(), 0.0, nullptr); @@ -419,7 +420,7 @@ BOOST_FIXTURE_TEST_CASE(lbfgs_scalar, SolverTester, const dp3::ddecal::BdaSolverBuffer& solver_buffer = FillBDAData(); const SolveData data(solver_buffer, kNChannelBlocks, kNDirections, kNAntennas, - Antennas1(), Antennas2()); + Antennas1(), Antennas2(), false); dp3::ddecal::SolverBase::SolveResult result = solver.Solve(data, GetSolverSolutions(), 0.0, nullptr); @@ -444,7 +445,7 @@ BOOST_FIXTURE_TEST_CASE(lbfgs_full_jones, SolverTester, const dp3::ddecal::BdaSolverBuffer& solver_buffer = FillBDAData(); const SolveData data(solver_buffer, kNChannelBlocks, kNDirections, kNAntennas, - Antennas1(), Antennas2()); + Antennas1(), Antennas2(), false); // The full jones test uses full matrices as solutions and copies the // diagonals into the solver solutions from the SolverTester fixture. This diff --git a/steps/BdaDdeCal.cc b/steps/BdaDdeCal.cc index 43e5ce6ff807ecd8b5e323c1c8303af56dd7ef87..91e264ea20f0844d2725318da282f04deb132e9c 100644 --- a/steps/BdaDdeCal.cc +++ b/steps/BdaDdeCal.cc @@ -397,7 +397,7 @@ void BdaDdeCal::SolveCurrentInterval() { dp3::ddecal::SolveData data(*solver_buffer_, n_channel_blocks, patches_.size(), n_antennas, antennas1_, - antennas2_); + antennas2_, false); const int current_interval = solutions_.size(); assert(current_interval == solver_buffer_->GetCurrentInterval()); diff --git a/steps/DDECal.cc b/steps/DDECal.cc index 8125d46434579a490af7bea7ba9ab29d64744d30..b339f8d4b9d98f4f9d336990009c937ba7b11c73 100644 --- a/steps/DDECal.cc +++ b/steps/DDECal.cc @@ -409,7 +409,7 @@ void DDECal::show(std::ostream& os) const { os << "Model steps for direction " << itsDirections[i][0] << '\n'; do { step->show(os); - } while (step = step->getNextStep()); + } while (nullptr != (step = step->getNextStep())); } else { os << "Direction " << itsDirections[i][0] << " reuses data from " << itsDirectionNames[i] << ""; @@ -588,8 +588,9 @@ void DDECal::doSolve() { // The last solution interval can be smaller. std::vector<base::DPBuffer> weighted_buffers(itsInputBuffers[i].size()); + const bool linear_mode = false; ddecal::AssignAndWeight(itsInputBuffers[i], itsDirectionNames, - weighted_buffers, keep_model_data); + weighted_buffers, keep_model_data, linear_mode); InitializeSolutions(i);