From 7b1a8341cdf44d5fbb6a37fcf9dbd7a757a23239 Mon Sep 17 00:00:00 2001
From: Bram Veenboer <veenboer@astron.nl>
Date: Fri, 11 Aug 2023 14:33:53 +0000
Subject: [PATCH] PADRE-5: Add GPU implementation of DDECal
 IterativeDiagonalSolver

---
 .gitmodules                                   |   1 -
 CMakeLists.txt                                |  64 ++-
 ddecal/Settings.cc                            |   1 +
 ddecal/Settings.h                             |   1 +
 ddecal/SolverFactory.cc                       |  20 +
 .../IterativeDiagonalSolverCuda.cc            | 458 ++++++++++++++++++
 .../IterativeDiagonalSolverCuda.h             | 135 ++++++
 ddecal/gain_solvers/SolveData.h               |   2 +
 ddecal/gain_solvers/SolverBase.h              |   7 +
 ddecal/gain_solvers/kernels/Common.h          |  17 +
 ddecal/gain_solvers/kernels/Complex.h         |  63 +++
 .../gain_solvers/kernels/IterativeDiagonal.cu | 292 +++++++++++
 .../gain_solvers/kernels/IterativeDiagonal.h  |  39 ++
 .../gain_solvers/kernels/MatrixComplex2x2.h   | 116 +++++
 ddecal/test/unit/tSolvers.cc                  |  40 +-
 docs/schemas/DDECal.yml                       |   8 +-
 scripts/run-format.sh                         |   3 +
 17 files changed, 1253 insertions(+), 14 deletions(-)
 create mode 100644 ddecal/gain_solvers/IterativeDiagonalSolverCuda.cc
 create mode 100644 ddecal/gain_solvers/IterativeDiagonalSolverCuda.h
 create mode 100644 ddecal/gain_solvers/kernels/Common.h
 create mode 100644 ddecal/gain_solvers/kernels/Complex.h
 create mode 100644 ddecal/gain_solvers/kernels/IterativeDiagonal.cu
 create mode 100644 ddecal/gain_solvers/kernels/IterativeDiagonal.h
 create mode 100644 ddecal/gain_solvers/kernels/MatrixComplex2x2.h

diff --git a/.gitmodules b/.gitmodules
index fbd993e15..2d1be4e94 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -10,4 +10,3 @@
 [submodule "external/schaap-packaging"]
 	path = external/schaap-packaging
 	url = https://git.astron.nl/RD/schaap-packaging.git
-	
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d371ee283..169363fae 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -21,7 +21,32 @@ else()
   message(FATAL_ERROR "Failed to parse DP3_VERSION='${DP3_VERSION}'")
 endif()
 
-project(DP3 VERSION ${DP3_VERSION})
+option(BUILD_WITH_CUDA "Build with CUDA support" FALSE)
+if(BUILD_WITH_CUDA)
+  project(
+    DP3
+    VERSION ${DP3_VERSION}
+    LANGUAGES CUDA C CXX)
+  set(CUDA_PROPAGATE_HOST_FLAGS FALSE)
+  set(CMAKE_CUDA_ARCHITECTURES
+      "70"
+      CACHE STRING "Specify GPU architecture(s) to compile for")
+  add_definitions(-DHAVE_CUDA)
+  find_package(CUDAToolkit REQUIRED)
+
+  # Necessary to find the cuda.h file in iterativediagonalsolver as a result of resolving SolverFactory
+  include_directories(${CUDAToolkit_INCLUDE_DIRS})
+
+  include(FetchContent)
+  FetchContent_Declare(
+    cudawrappers
+    GIT_REPOSITORY https://github.com/nlesc-recruit/cudawrappers.git
+    GIT_TAG main)
+  FetchContent_MakeAvailable(cudawrappers)
+
+else()
+  project(DP3 VERSION ${DP3_VERSION})
+endif()
 
 include(CheckCXXCompilerFlag)
 include(FetchContent)
@@ -391,6 +416,36 @@ add_library(
   ${DDE_ARMADILLO_FILES})
 target_link_libraries(DDECal xsimd xtensor)
 
+if(BUILD_WITH_CUDA)
+  target_link_libraries(DDECal cudawrappers::cu)
+
+  add_library(
+    CudaSolvers SHARED ddecal/gain_solvers/IterativeDiagonalSolverCuda.cc
+                       ddecal/gain_solvers/kernels/IterativeDiagonal.cu)
+
+  target_compile_options(
+    CudaSolvers
+    PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:
+            --use_fast_math
+            -Xcompiler
+            -fPIC
+            -shared
+            -dc
+            >)
+
+  target_link_libraries(CudaSolvers PUBLIC cudawrappers::cu CUDA::nvToolsExt
+                                           CUDA::cudart_static xsimd xtensor)
+  set_target_properties(
+    CudaSolvers
+    PROPERTIES CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES}
+               CUDA_RESOLVE_DEVICE_SYMBOLS ON
+               CUDA_SEPARABLE_COMPILATION ON
+               POSITION_INDEPENDENT_CODE ON
+               RUNTIME_OUTPUT_DIRECTORY "${CMAKE_SOURCE_DIR}/build")
+
+  install(TARGETS CudaSolvers)
+endif()
+
 # 'Base' files and steps are only included in DP3.
 add_library(
   DP3_OBJ OBJECT
@@ -539,6 +594,10 @@ set(DP3_LIBRARIES
     Threads::Threads
     pybind11::embed)
 
+if(BUILD_WITH_CUDA)
+  list(APPEND DP3_LIBRARIES CudaSolvers)
+endif()
+
 # If libdirac is found, use it
 if(LIBDIRAC_FOUND)
   if(HAVE_CUDA)
@@ -778,6 +837,9 @@ if(BUILD_TESTING)
 
   add_executable(unittests ${TEST_FILENAMES})
   target_include_directories(unittests PRIVATE "${CMAKE_SOURCE_DIR}")
+  if(BUILD_WITH_CUDA)
+    set_target_properties(unittests PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON)
+  endif()
   target_link_libraries(unittests LIBDP3 ${DP3_LIBRARIES} xtensor)
   add_dependencies(unittests schaapcommon)
 
diff --git a/ddecal/Settings.cc b/ddecal/Settings.cc
index a7c6310be..808226e64 100644
--- a/ddecal/Settings.cc
+++ b/ddecal/Settings.cc
@@ -119,6 +119,7 @@ Settings::Settings(const common::ParameterSet& _parset,
       lbfgs_minibatches((solver_algorithm == SolverAlgorithm::kLBFGS)
                             ? GetUint("solverlbfgs.minibatches", 1)
                             : 1),
+      use_gpu(GetBool("usegpu", 0)),
 
       // Column reader settings
       model_data_columns(ReadModelDataColumns()),
diff --git a/ddecal/Settings.h b/ddecal/Settings.h
index ab618d412..b4253cbdf 100644
--- a/ddecal/Settings.h
+++ b/ddecal/Settings.h
@@ -135,6 +135,7 @@ struct Settings {
   const size_t lbfgs_history_size;
   // LBFGS minibatches
   const size_t lbfgs_minibatches;
+  const bool use_gpu;
 
   const std::vector<std::string> model_data_columns;
   const std::vector<std::string> reuse_model_data;
diff --git a/ddecal/SolverFactory.cc b/ddecal/SolverFactory.cc
index 9e7560e5f..67b01707b 100644
--- a/ddecal/SolverFactory.cc
+++ b/ddecal/SolverFactory.cc
@@ -12,6 +12,9 @@
 #include "gain_solvers/LBFGSSolver.h"
 #include "gain_solvers/HybridSolver.h"
 #include "gain_solvers/IterativeDiagonalSolver.h"
+#if defined(HAVE_CUDA)
+#include "gain_solvers/IterativeDiagonalSolverCuda.h"
+#endif
 #include "gain_solvers/IterativeFullJonesSolver.h"
 #include "gain_solvers/IterativeScalarSolver.h"
 #include "gain_solvers/ScalarSolver.h"
@@ -56,6 +59,23 @@ std::unique_ptr<SolverBase> CreateScalarSolver(SolverAlgorithm algorithm,
 
 std::unique_ptr<SolverBase> CreateDiagonalSolver(SolverAlgorithm algorithm,
                                                  const Settings& settings) {
+#if defined(HAVE_CUDA)
+  if (settings.use_gpu) {
+    switch (algorithm) {
+      case SolverAlgorithm::kDirectionIterative:
+        return std::make_unique<IterativeDiagonalSolverCuda>();
+      default:
+        throw std::runtime_error(
+            "usegpu=true, but no GPU implementation for solver algorithm is "
+            "available.");
+    }
+  }
+#else
+  if (settings.use_gpu) {
+    throw std::runtime_error(
+        "usegpu=true, but DP3 is built without CUDA support.");
+  }
+#endif
   switch (algorithm) {
     case SolverAlgorithm::kDirectionIterative:
       return std::make_unique<IterativeDiagonalSolver>();
diff --git a/ddecal/gain_solvers/IterativeDiagonalSolverCuda.cc b/ddecal/gain_solvers/IterativeDiagonalSolverCuda.cc
new file mode 100644
index 000000000..5922e7d78
--- /dev/null
+++ b/ddecal/gain_solvers/IterativeDiagonalSolverCuda.cc
@@ -0,0 +1,458 @@
+// Copyright (C) 2023 ASTRON (Netherlands Institute for Radio Astronomy)
+// SPDX-License-Identifier: GPL-3.0-or-later
+
+#include "IterativeDiagonalSolverCuda.h"
+
+#include <algorithm>
+#include <iostream>
+#include <vector>
+
+#include <cuda_runtime.h>
+#include <nvToolsExt.h>
+
+#include <aocommon/matrix2x2.h>
+#include <aocommon/matrix2x2diag.h>
+
+#include "kernels/IterativeDiagonal.h"
+
+using aocommon::MC2x2F;
+using aocommon::MC2x2FDiag;
+
+namespace {
+
+size_t SizeOfModel(size_t n_directions, size_t n_visibilities) {
+  return n_directions * n_visibilities * sizeof(MC2x2F);
+}
+
+size_t SizeOfResidual(size_t n_visibilities) {
+  return n_visibilities * sizeof(MC2x2F);
+}
+
+size_t SizeOfSolutions(size_t n_visibilities) {
+  return n_visibilities * sizeof(std::complex<double>);
+}
+
+size_t SizeOfAntennaPairs(size_t n_visibilities) {
+  return n_visibilities * 2 * sizeof(uint32_t);
+}
+
+size_t SizeOfSolutionMap(size_t n_directions, size_t n_visibilities) {
+  return n_directions * n_visibilities * sizeof(uint32_t);
+}
+
+size_t SizeOfNextSolutions(size_t n_visibilities) {
+  return n_visibilities * sizeof(std::complex<double>);
+}
+
+size_t SizeOfNumerator(size_t n_antennas, size_t n_direction_solutions) {
+  return n_antennas * n_direction_solutions * sizeof(MC2x2FDiag);
+}
+
+size_t SizeOfDenominator(size_t n_antennas, size_t n_direction_solutions) {
+  return n_antennas * n_direction_solutions * 2 * sizeof(float);
+}
+
+void SolveDirection(
+    const dp3::ddecal::SolveData::ChannelBlockData& channel_block_data,
+    cu::Stream& stream, size_t n_antennas, size_t n_solutions, size_t direction,
+    cu::DeviceMemory& device_residual_in,
+    cu::DeviceMemory& device_residual_temp,
+    cu::DeviceMemory& device_solution_map, cu::DeviceMemory& device_solutions,
+    cu::DeviceMemory& device_model, cu::DeviceMemory& device_next_solutions,
+    cu::DeviceMemory& device_antenna_pairs, cu::DeviceMemory& device_numerator,
+    cu::DeviceMemory& device_denominator) {
+  // Calculate this equation, given ant a:
+  //
+  //          sum_b data_ab * solutions_b * model_ab^*
+  // sol_a =  ----------------------------------------
+  //             sum_b norm(model_ab * solutions_b)
+  const size_t n_direction_solutions =
+      channel_block_data.NSolutionsForDirection(direction);
+  const size_t n_visibilities = channel_block_data.NVisibilities();
+
+  // Initialize values to 0
+  device_numerator.zero(SizeOfNumerator(n_antennas, n_direction_solutions),
+                        stream);
+  device_denominator.zero(SizeOfDenominator(n_antennas, n_direction_solutions),
+                          stream);
+
+  stream.memcpyDtoDAsync(device_residual_temp, device_residual_in,
+                         SizeOfResidual(n_visibilities));
+
+  LaunchSolveDirectionKernel(
+      stream, n_visibilities, n_direction_solutions, n_solutions, direction,
+      device_antenna_pairs, device_solution_map, device_solutions, device_model,
+      device_residual_in, device_residual_temp, device_numerator,
+      device_denominator);
+
+  LaunchSolveNextSolutionKernel(
+      stream, n_antennas, n_visibilities, n_direction_solutions, n_solutions,
+      direction, device_antenna_pairs, device_solution_map,
+      device_next_solutions, device_numerator, device_denominator);
+}
+
+void PerformIteration(
+    bool phase_only, double step_size,
+    const dp3::ddecal::SolveData::ChannelBlockData& channel_block_data,
+    cu::Stream& stream, size_t n_antennas, size_t n_solutions,
+    size_t n_directions, cu::DeviceMemory& device_solution_map,
+    cu::DeviceMemory& device_solutions, cu::DeviceMemory& device_next_solutions,
+    cu::DeviceMemory& device_residual, cu::DeviceMemory& device_residual_temp,
+    cu::DeviceMemory& device_model, cu::DeviceMemory& device_antenna_pairs,
+    cu::DeviceMemory& device_numerator, cu::DeviceMemory& device_denominator) {
+  const size_t n_visibilities = channel_block_data.NVisibilities();
+
+  // Subtract all directions with their current solutions
+  // In-place: residual -> residual
+  LaunchSubtractKernel(stream, n_directions, n_visibilities, n_solutions,
+                       device_antenna_pairs, device_solution_map,
+                       device_solutions, device_model, device_residual);
+
+  for (size_t direction = 0; direction != n_directions; direction++) {
+    // Be aware that we purposely still use the subtraction with 'old'
+    // solutions, because the new solutions have not been constrained yet. Add
+    // this direction back before solving
+
+    // Out-of-place: residual -> residual_temp
+    SolveDirection(channel_block_data, stream, n_antennas, n_solutions,
+                   direction, device_residual, device_residual_temp,
+                   device_solution_map, device_solutions, device_model,
+                   device_next_solutions, device_antenna_pairs,
+                   device_numerator, device_denominator);
+  }
+
+  LaunchStepKernel(stream, n_visibilities, device_solutions,
+                   device_next_solutions, phase_only, step_size);
+}
+
+std::tuple<size_t, size_t, size_t> ComputeArrayDimensions(
+    const dp3::ddecal::SolveData& data) {
+  size_t max_n_direction_solutions = 0;
+  size_t max_n_visibilities = 0;
+  size_t max_n_directions = 0;
+
+  for (size_t ch_block = 0; ch_block < data.NChannelBlocks(); ch_block++) {
+    const dp3::ddecal::SolveData::ChannelBlockData& channel_block_data =
+        data.ChannelBlock(ch_block);
+    max_n_visibilities =
+        std::max(max_n_visibilities, channel_block_data.NVisibilities());
+    max_n_directions =
+        std::max(max_n_directions, channel_block_data.NDirections());
+    for (size_t direction = 0; direction < channel_block_data.NDirections();
+         direction++) {
+      max_n_direction_solutions =
+          std::max(max_n_direction_solutions,
+                   static_cast<size_t>(
+                       channel_block_data.NSolutionsForDirection(direction)));
+    }
+  }
+
+  return std::make_tuple(max_n_direction_solutions, max_n_visibilities,
+                         max_n_directions);
+}
+}  // namespace
+
+namespace dp3 {
+namespace ddecal {
+
+IterativeDiagonalSolverCuda::IterativeDiagonalSolverCuda() {
+  cu::init();
+  device_ = std::make_unique<cu::Device>(0);
+  context_ = std::make_unique<cu::Context>(0, *device_);
+  context_->setCurrent();
+  execute_stream_ = std::make_unique<cu::Stream>();
+  host_to_device_stream_ = std::make_unique<cu::Stream>();
+  device_to_host_stream_ = std::make_unique<cu::Stream>();
+}
+
+void IterativeDiagonalSolverCuda::AllocateGPUBuffers(const SolveData& data) {
+  size_t max_n_direction_solutions = 0;
+  size_t max_n_visibilities = 0;
+  size_t max_n_directions = 0;
+  std::tie(max_n_direction_solutions, max_n_visibilities, max_n_directions) =
+      ComputeArrayDimensions(data);
+
+  gpu_buffers_.numerator = std::make_unique<cu::DeviceMemory>(
+      SizeOfNumerator(NAntennas(), max_n_direction_solutions));
+  gpu_buffers_.denominator = std::make_unique<cu::DeviceMemory>(
+      SizeOfDenominator(NAntennas(), max_n_direction_solutions));
+  // Allocating two buffers allows double buffering.
+  for (size_t i = 0; i < 2; i++) {
+    gpu_buffers_.antenna_pairs.emplace_back(
+        SizeOfAntennaPairs(max_n_visibilities));
+    gpu_buffers_.solution_map.emplace_back(
+        SizeOfSolutionMap(max_n_directions, max_n_visibilities));
+    gpu_buffers_.solutions.emplace_back(SizeOfSolutions(max_n_visibilities));
+    gpu_buffers_.next_solutions.emplace_back(
+        SizeOfNextSolutions(max_n_visibilities));
+    gpu_buffers_.model.emplace_back(
+        SizeOfModel(max_n_directions, max_n_visibilities));
+  }
+
+  // We need two buffers for residual like above to facilitate double-buffering,
+  // the third buffer is used for the per-direction add/subtract.
+  for (size_t i = 0; i < 3; i++) {
+    gpu_buffers_.residual.emplace_back(SizeOfResidual(max_n_visibilities));
+  }
+}
+
+void IterativeDiagonalSolverCuda::AllocateHostBuffers(const SolveData& data) {
+  host_buffers_.next_solutions =
+      std::make_unique<cu::HostMemory>(SizeOfNextSolutions(NVisibilities()));
+  for (size_t ch_block = 0; ch_block < NChannelBlocks(); ch_block++) {
+    const SolveData::ChannelBlockData& channel_block_data =
+        data.ChannelBlock(ch_block);
+    const size_t n_directions = channel_block_data.NDirections();
+    const size_t n_visibilities = channel_block_data.NVisibilities();
+    host_buffers_.model.emplace_back(SizeOfModel(n_directions, n_visibilities));
+    host_buffers_.residual.emplace_back(SizeOfResidual(n_visibilities));
+    host_buffers_.solutions.emplace_back(SizeOfSolutions(n_visibilities));
+    host_buffers_.antenna_pairs.emplace_back(
+        SizeOfAntennaPairs(n_visibilities));
+    host_buffers_.solution_map.emplace_back(
+        SizeOfSolutionMap(n_directions, n_visibilities));
+    uint32_t* antenna_pairs =
+        static_cast<uint32_t*>(host_buffers_.antenna_pairs[ch_block]);
+    for (size_t visibility_index = 0; visibility_index < n_visibilities;
+         visibility_index++) {
+      antenna_pairs[visibility_index * 2 + 0] =
+          channel_block_data.Antenna1Index(visibility_index);
+      antenna_pairs[visibility_index * 2 + 1] =
+          channel_block_data.Antenna2Index(visibility_index);
+    }
+  }
+}
+
+void IterativeDiagonalSolverCuda::CopyHostToHost(
+    size_t ch_block, bool first_iteration, const SolveData& data,
+    const std::vector<std::complex<double>>& solutions, cu::Stream& stream) {
+  const SolveData::ChannelBlockData& channel_block_data =
+      data.ChannelBlock(ch_block);
+  const size_t n_directions = channel_block_data.NDirections();
+  const size_t n_visibilities = channel_block_data.NVisibilities();
+  cu::HostMemory& host_model = host_buffers_.model[ch_block];
+  cu::HostMemory& host_solutions = host_buffers_.solutions[ch_block];
+  stream.memcpyHtoHAsync(host_model, &channel_block_data.ModelVisibility(0, 0),
+                         SizeOfModel(n_directions, n_visibilities));
+  stream.memcpyHtoHAsync(host_solutions, solutions.data(),
+                         SizeOfSolutions(n_visibilities));
+  if (first_iteration) {
+    cu::HostMemory& host_residual = host_buffers_.residual[ch_block];
+    cu::HostMemory& host_solution_map = host_buffers_.solution_map[ch_block];
+    stream.memcpyHtoHAsync(host_residual, &channel_block_data.Visibility(0),
+                           SizeOfResidual(n_visibilities));
+    stream.memcpyHtoHAsync(host_solution_map,
+                           channel_block_data.SolutionMapData(),
+                           SizeOfSolutionMap(n_directions, n_visibilities));
+  }
+}
+
+void IterativeDiagonalSolverCuda::CopyHostToDevice(size_t ch_block,
+                                                   size_t buffer_id,
+                                                   cu::Stream& stream,
+                                                   cu::Event& event,
+                                                   const SolveData& data) {
+  const dp3::ddecal::SolveData::ChannelBlockData& channel_block_data =
+      data.ChannelBlock(ch_block);
+
+  const size_t n_directions = channel_block_data.NDirections();
+  const size_t n_visibilities = channel_block_data.NVisibilities();
+
+  cu::HostMemory& host_solution_map = host_buffers_.solution_map[ch_block];
+  cu::HostMemory& host_antenna_pairs = host_buffers_.antenna_pairs[ch_block];
+  cu::HostMemory& host_model = host_buffers_.model[ch_block];
+  cu::HostMemory& host_residual = host_buffers_.residual[ch_block];
+  cu::HostMemory& host_solutions = host_buffers_.solutions[ch_block];
+  cu::DeviceMemory& device_solution_map = gpu_buffers_.solution_map[buffer_id];
+  cu::DeviceMemory& device_antenna_pairs =
+      gpu_buffers_.antenna_pairs[buffer_id];
+  cu::DeviceMemory& device_model = gpu_buffers_.model[buffer_id];
+  cu::DeviceMemory& device_residual = gpu_buffers_.residual[buffer_id];
+  cu::DeviceMemory& device_solutions = gpu_buffers_.solutions[buffer_id];
+
+  stream.memcpyHtoDAsync(device_solution_map, host_solution_map,
+                         SizeOfSolutionMap(n_directions, n_visibilities));
+  stream.memcpyHtoDAsync(device_model, host_model,
+                         SizeOfModel(n_directions, n_visibilities));
+  stream.memcpyHtoDAsync(device_residual, host_residual,
+                         SizeOfResidual(n_visibilities));
+  stream.memcpyHtoDAsync(device_antenna_pairs, host_antenna_pairs,
+                         SizeOfAntennaPairs(n_visibilities));
+  stream.memcpyHtoDAsync(device_solutions, host_solutions,
+                         SizeOfSolutions(n_visibilities));
+
+  stream.record(event);
+}
+
+void IterativeDiagonalSolverCuda::PostProcessing(
+    size_t& iteration, double time, bool has_previously_converged,
+    bool& has_converged, bool& constraints_satisfied, bool& done,
+    SolverBase::SolveResult& result,
+    std::vector<std::vector<std::complex<double>>>& solutions,
+    SolutionSpan& next_solutions, std::vector<double>& step_magnitudes,
+    std::ostream* stat_stream) {
+  constraints_satisfied =
+      ApplyConstraints(iteration, time, has_previously_converged, result,
+                       next_solutions, stat_stream);
+
+  double avg_squared_diff;
+  has_converged =
+      AssignSolutions(solutions, next_solutions, !constraints_satisfied,
+                      avg_squared_diff, step_magnitudes);
+  iteration++;
+
+  has_previously_converged = has_converged || has_previously_converged;
+
+  done = ReachedStoppingCriterion(iteration, has_converged,
+                                  constraints_satisfied, step_magnitudes);
+}
+
+IterativeDiagonalSolver::SolveResult IterativeDiagonalSolverCuda::Solve(
+    const SolveData& data,
+    std::vector<std::vector<std::complex<double>>>& solutions, double time,
+    std::ostream* stat_stream) {
+  PrepareConstraints();
+
+  const bool phase_only = GetPhaseOnly();
+  const double step_size = GetStepSize();
+
+  SolveResult result;
+
+  /*
+   * Allocate buffers
+   */
+  if (!buffers_initialized_) {
+    AllocateHostBuffers(data);
+    AllocateGPUBuffers(data);
+    buffers_initialized_ = true;
+  }
+
+  const std::array<size_t, 4> next_solutions_shape = {
+      NChannelBlocks(), NAntennas(), NSolutions(), NSolutionPolarizations()};
+  std::complex<double>* next_solutions_ptr = *(host_buffers_.next_solutions);
+  SolutionSpan next_solutions =
+      aocommon::xt::CreateSpan(next_solutions_ptr, next_solutions_shape);
+
+  /*
+   * Allocate events for each channel block
+   */
+  std::vector<cu::Event> input_copied_events(NChannelBlocks());
+  std::vector<cu::Event> compute_finished_events(NChannelBlocks());
+  std::vector<cu::Event> output_copied_events(NChannelBlocks());
+
+  /*
+   * Start iterating
+   */
+  size_t iteration = 0;
+  bool has_converged = false;
+  bool has_previously_converged = false;
+  bool constraints_satisfied = false;
+  bool done = false;
+
+  std::vector<double> step_magnitudes;
+  step_magnitudes.reserve(GetMaxIterations());
+
+  do {
+    MakeSolutionsFinite2Pol(solutions);
+
+    nvtxRangeId_t nvts_range_gpu = nvtxRangeStart("GPU");
+
+    for (size_t ch_block = 0; ch_block < NChannelBlocks(); ch_block++) {
+      const SolveData::ChannelBlockData& channel_block_data =
+          data.ChannelBlock(ch_block);
+      const int buffer_id = ch_block % 2;
+
+      // Copy input data for first channel block
+      if (ch_block == 0) {
+        CopyHostToHost(ch_block, iteration == 0, data, solutions[ch_block],
+                       *host_to_device_stream_);
+
+        CopyHostToDevice(ch_block, buffer_id, *host_to_device_stream_,
+                         input_copied_events[0], data);
+      }
+
+      // As soon as input_copied_events[0] is triggered, the input data is
+      // copied to the GPU and the host buffers could theoretically be reused.
+      // However, since the size of these buffers may differ, every channel
+      // block has its own set of host buffers anyway.
+      // Before starting kernel execution for the current channel block (on a
+      // different stream), the copy of data for the next channel block (if any)
+      // is scheduled using a second set of GPU buffers.
+      if (ch_block < NChannelBlocks() - 1) {
+        CopyHostToHost(ch_block + 1, iteration == 0, data,
+                       solutions[ch_block + 1], *host_to_device_stream_);
+
+        // Since the computation of channel block <n> and <n + 2> share the same
+        // set of GPU buffers, wait for the compute_finished event to be
+        // triggered before overwriting their contents.
+        if (ch_block > 1) {
+          host_to_device_stream_->wait(compute_finished_events[ch_block - 2]);
+        }
+
+        CopyHostToDevice(ch_block + 1, (ch_block + 1) % 2,
+                         *host_to_device_stream_,
+                         input_copied_events[ch_block + 1], data);
+      }
+
+      // Wait for input of the current channel block to be copied
+      execute_stream_->wait(input_copied_events[ch_block]);
+
+      // Wait for output buffer to be free
+      if (ch_block > 1) {
+        execute_stream_->wait(output_copied_events[ch_block - 2]);
+      }
+
+      // Start iteration (dtod copies and kernel execution only)
+      PerformIteration(phase_only, step_size, channel_block_data,
+                       *execute_stream_, NAntennas(), NSolutions(),
+                       NDirections(), gpu_buffers_.solution_map[buffer_id],
+                       gpu_buffers_.solutions[buffer_id],
+                       gpu_buffers_.next_solutions[buffer_id],
+                       gpu_buffers_.residual[buffer_id],
+                       gpu_buffers_.residual[2], gpu_buffers_.model[buffer_id],
+                       gpu_buffers_.antenna_pairs[buffer_id],
+                       *gpu_buffers_.numerator, *gpu_buffers_.denominator);
+
+      execute_stream_->record(compute_finished_events[ch_block]);
+
+      // Wait for the computation to finish
+      device_to_host_stream_->wait(compute_finished_events[ch_block]);
+
+      // Copy next solutions back to host
+      const size_t n_visibilities = next_solutions.shape(1) *
+                                    next_solutions.shape(2) *
+                                    next_solutions.shape(3);
+      device_to_host_stream_->memcpyDtoHAsync(
+          &next_solutions(ch_block, 0, 0, 0),
+          gpu_buffers_.next_solutions[buffer_id],
+          SizeOfNextSolutions(n_visibilities));
+
+      // Record that the output is copied
+      device_to_host_stream_->record(output_copied_events[ch_block]);
+    }  // end for ch_block
+
+    // Wait for next solutions to be copied
+    device_to_host_stream_->synchronize();
+
+    nvtxRangeEnd(nvts_range_gpu);
+
+    // CPU-only postprocessing
+    nvtxRangeId_t nvtx_range_cpu = nvtxRangeStart("CPU");
+    PostProcessing(iteration, time, has_previously_converged, has_converged,
+                   constraints_satisfied, done, result, solutions,
+                   next_solutions, step_magnitudes, stat_stream);
+    nvtxRangeEnd(nvtx_range_cpu);
+  } while (!done);
+
+  // When we have not converged yet, we set the nr of iterations to the max+1,
+  // so that non-converged iterations can be distinguished from converged ones.
+  if (has_converged && constraints_satisfied) {
+    result.iterations = iteration;
+  } else {
+    result.iterations = iteration + 1;
+  }
+  return result;
+}
+
+}  // namespace ddecal
+}  // namespace dp3
diff --git a/ddecal/gain_solvers/IterativeDiagonalSolverCuda.h b/ddecal/gain_solvers/IterativeDiagonalSolverCuda.h
new file mode 100644
index 000000000..48015f534
--- /dev/null
+++ b/ddecal/gain_solvers/IterativeDiagonalSolverCuda.h
@@ -0,0 +1,135 @@
+// Copyright (C) 2023 ASTRON (Netherlands Institute for Radio Astronomy)
+// SPDX-License-Identifier: GPL-3.0-or-later
+
+#ifndef DDECAL_GAIN_SOLVERS_ITERATIVE_DIAGONAL_SOLVER_CUDA_H_
+#define DDECAL_GAIN_SOLVERS_ITERATIVE_DIAGONAL_SOLVER_CUDA_H_
+
+#include <vector>
+
+#include <cudawrappers/cu.hpp>
+
+#include "IterativeDiagonalSolver.h"
+#include "SolverBase.h"
+#include "SolveData.h"
+#include "../../common/Timer.h"
+
+namespace dp3 {
+namespace ddecal {
+
+class IterativeDiagonalSolverCuda final : public SolverBase {
+ public:
+  IterativeDiagonalSolverCuda();
+  SolveResult Solve(const SolveData& data,
+                    std::vector<std::vector<DComplex>>& solutions, double time,
+                    std::ostream* stat_stream) override;
+
+  size_t NSolutionPolarizations() const override { return 2; }
+
+  bool SupportsDdSolutionIntervals() const override { return true; }
+
+ private:
+  void AllocateGPUBuffers(const SolveData& data);
+
+  void AllocateHostBuffers(const SolveData& data);
+
+  void CopyHostToHost(size_t ch_block, bool first_iteration,
+                      const SolveData& data,
+                      const std::vector<DComplex>& solutions,
+                      cu::Stream& stream);
+
+  void CopyHostToDevice(size_t ch_block, size_t buffer_id, cu::Stream& stream,
+                        cu::Event& event, const SolveData& data);
+
+  void PostProcessing(size_t& iteration, double time,
+                      bool has_previously_converged, bool& has_converged,
+                      bool& constraints_satisfied, bool& done,
+                      SolverBase::SolveResult& result,
+                      std::vector<std::vector<DComplex>>& solutions,
+                      SolutionSpan& next_solutions,
+                      std::vector<double>& step_magnitudes,
+                      std::ostream* stat_stream);
+
+  /// If this variable is false, gpu_buffers_ and host_buffers are not
+  /// initialized
+  bool buffers_initialized_ = false;
+
+  std::unique_ptr<cu::Device> device_;
+  std::unique_ptr<cu::Context> context_;
+  std::unique_ptr<cu::Stream> execute_stream_;
+  std::unique_ptr<cu::Stream> host_to_device_stream_;
+  std::unique_ptr<cu::Stream> device_to_host_stream_;
+
+  /**
+   * GPUBuffers hold the GPU memory used in ::Solve()
+   *
+   * The GPU memory is of type cu::DeviceMemory. This is a wrapper around a
+   * plain CUdeviceptr, provided by the cudawrappers library.
+   *
+   * To facilitate double-buffering, most of the memory is allocated twice (in
+   * ::AllocateGPUBuffers) and stored in a vector. There are three exceptions:
+   *  - residual: three buffers are used
+   *  - numerator: a single buffer is used
+   *  - denominator: a single buffer is used
+   *
+   *
+   * For each element in the struct, the comment
+   * "<x>[a][b], y" denotes:
+   *   - x: the number of elements in the vector (if used)
+   *   - a: length of the first dimension
+   *   - b: length of the second dimension (if any)
+   *   - y: data type
+   *
+   * For instance for antenna_pairs, <2> denotes
+   * that the vector has two elements. [n_antennas][2] denotes that every
+   * element is a 2D array with dimensions (n_antennas, 2). Finally, uint32_t
+   * denotes the data type.
+   */
+  struct GPUBuffers {
+    // <2>[n_antennas][2], uint32_t
+    std::vector<cu::DeviceMemory> antenna_pairs;
+    // <2>[n_directions][n_visibilities], uint32_t
+    std::vector<cu::DeviceMemory> solution_map;
+    // <2>[n_visibilities], DComplex
+    std::vector<cu::DeviceMemory> solutions;
+    // <2>[n_visibilities], DComplex
+    std::vector<cu::DeviceMemory> next_solutions;
+    // <2>[n_directions][n_visibilities], MC2x2F
+    std::vector<cu::DeviceMemory> model;
+    // <3>[n_visibilities], MC2x2F
+    std::vector<cu::DeviceMemory> residual;
+    // [n_antennas][n_directions], MC2x2FDiag
+    std::unique_ptr<cu::DeviceMemory> numerator;
+    // [n_antennas][n_directions_solutions], float
+    std::unique_ptr<cu::DeviceMemory> denominator;
+  } gpu_buffers_;
+
+  /**
+   * HostBuffers hold the host memory used in ::Solve()
+   *
+   * The host memory is of type cu::HostMemory. This is a wrapper around a
+   * plain void*, provided by the cudawrappers library.
+   *
+   * These buffers contain a copy of data that is elsewhere in host memory.
+   * Using an extra host-to-cuda-host-memory copy and then a
+   * cuda-host-memory-to-gpu copy is faster than a direct host-to-gpu copy.
+   */
+  struct HostBuffers {
+    // <n_channelblocks>[n_directions][n_visibilities], MC2x2F
+    std::vector<cu::HostMemory> model;
+    // <n_channelblocks>[n_visibilities], MC2x2F
+    std::vector<cu::HostMemory> residual;
+    // <n_channelblocks>[n_visibilities], DComplex
+    std::vector<cu::HostMemory> solutions;
+    // [n_channelblocks][n_antennas][n_polarizations], DComplex
+    std::unique_ptr<cu::HostMemory> next_solutions;
+    // <n_channelblocks>[n_visibilities], std::pair<uin32_t, uint32_t>
+    std::vector<cu::HostMemory> antenna_pairs;
+    // <n_channelblocks>[n_directions][n_visibilities], uint32_t
+    std::vector<cu::HostMemory> solution_map;
+  } host_buffers_;
+};
+
+}  // namespace ddecal
+}  // namespace dp3
+
+#endif  // DDECAL_GAIN_SOLVERS_ITERATIVE_DIAGONAL_SOLVER_CUDA_H_
diff --git a/ddecal/gain_solvers/SolveData.h b/ddecal/gain_solvers/SolveData.h
index e07d87adf..6b0f068e8 100644
--- a/ddecal/gain_solvers/SolveData.h
+++ b/ddecal/gain_solvers/SolveData.h
@@ -63,6 +63,8 @@ class SolveData {
       return solution_map_(direction_index, visibility_index);
     }
 
+    const uint32_t* SolutionMapData() const { return solution_map_.data(); }
+
     const aocommon::MC2x2F& Visibility(size_t index) const {
       return data_[index];
     }
diff --git a/ddecal/gain_solvers/SolverBase.h b/ddecal/gain_solvers/SolverBase.h
index 44e11b219..06360083f 100644
--- a/ddecal/gain_solvers/SolverBase.h
+++ b/ddecal/gain_solvers/SolverBase.h
@@ -270,6 +270,13 @@ class SolverBase {
    * intervals.
    */
   size_t NSolutions() const { return n_solutions_; }
+  /**
+   * Total number of visibilities over all channel blocks
+   */
+  size_t NVisibilities() const {
+    return NChannelBlocks() * NAntennas() * NSolutions() *
+           NSolutionPolarizations();
+  }
 
   /**
    * Create an LLSSolver with the given matrix dimensions.
diff --git a/ddecal/gain_solvers/kernels/Common.h b/ddecal/gain_solvers/kernels/Common.h
new file mode 100644
index 000000000..e0484a65e
--- /dev/null
+++ b/ddecal/gain_solvers/kernels/Common.h
@@ -0,0 +1,17 @@
+// Copyright (C) 2023 ASTRON (Netherlands Institute for Radio Astronomy)
+// SPDX-License-Identifier: GPL-3.0-or-later
+
+#ifndef DP3_DDECAL_GAIN_SOLVERS_KERNELS_COMMON_H_
+#define DP3_DDECAL_GAIN_SOLVERS_KERNELS_COMMON_H_
+
+#include <cudawrappers/cu.hpp>
+
+/// This helper function is needed because the Launch*Kernel functions receive
+/// cu::DeviceMemory references, while the GPU kernels require the actual
+/// pointer type instead.
+template <typename T>
+T* Cast(cu::DeviceMemory& m) {
+  return reinterpret_cast<T*>(static_cast<CUdeviceptr>(m));
+}
+
+#endif  // DP3_DDECAL_GAIN_SOLVERS_KERNELS_COMMON_H_
\ No newline at end of file
diff --git a/ddecal/gain_solvers/kernels/Complex.h b/ddecal/gain_solvers/kernels/Complex.h
new file mode 100644
index 000000000..ad7795e96
--- /dev/null
+++ b/ddecal/gain_solvers/kernels/Complex.h
@@ -0,0 +1,63 @@
+// Copyright (C) 2023 ASTRON (Netherlands Institute for Radio Astronomy)
+// SPDX-License-Identifier: GPL-3.0-or-later
+
+#ifndef DP3_DDECAL_GAIN_SOLVERS_KERNELS_COMPLEX_H_
+#define DP3_DDECAL_GAIN_SOLVERS_KERNELS_COMPLEX_H_
+
+#include <cuComplex.h>
+
+__host__ __device__ static __inline__ cuDoubleComplex make_cuDoubleComplex(
+    const cuFloatComplex& a) {
+  return make_cuDoubleComplex(a.x, a.y);
+}
+
+/*
+ * Taken the below utility functions from
+ * https://forums.developer.nvidia.com/t/additional-cucomplex-functions-cucnorm-cucsqrt-cucexp-and-some-complex-double-functions/36892
+ */
+__host__ __device__ static __inline__ cuDoubleComplex cuCadd(cuDoubleComplex x,
+                                                             double y) {
+  return make_cuDoubleComplex(cuCreal(x) + y, cuCimag(x));
+}
+__host__ __device__ static __inline__ cuDoubleComplex cuCdiv(cuDoubleComplex x,
+                                                             double y) {
+  return make_cuDoubleComplex(cuCreal(x) / y, cuCimag(x) / y);
+}
+__host__ __device__ static __inline__ cuDoubleComplex cuCmul(cuDoubleComplex x,
+                                                             double y) {
+  return make_cuDoubleComplex(cuCreal(x) * y, cuCimag(x) * y);
+}
+__host__ __device__ static __inline__ cuDoubleComplex cuCsub(cuDoubleComplex x,
+                                                             double y) {
+  return make_cuDoubleComplex(cuCreal(x) - y, cuCimag(x));
+}
+
+__host__ __device__ static __inline__ cuDoubleComplex cuCexp(
+    cuDoubleComplex x) {
+  double factor = exp(x.x);
+  return make_cuDoubleComplex(factor * cos(x.y), factor * sin(x.y));
+}
+
+/*
+ * Cuda complex implementation of std::arg
+ * https://en.cppreference.com/w/cpp/numeric/complex/arg
+ */
+__device__ static __inline__ float cuCarg(const cuDoubleComplex& z) {
+  return atan2(cuCimag(z), cuCreal(z));
+}
+
+/*
+ * Cuda complex implementation of std::polar
+ * https://en.cppreference.com/w/cpp/numeric/complex/polar
+ */
+__device__ static __inline__ cuDoubleComplex cuCpolar(const double r,
+                                                      const double z) {
+  return make_cuDoubleComplex(r * cos(z), r * sin(z));
+}
+
+template <typename T>
+__device__ static __inline__ double cuNorm(const T& a) {
+  return a.x * a.x + a.y * a.y;
+}
+
+#endif  // DP3_DDECAL_GAIN_SOLVERS_KERNELS_COMPLEX_H_
\ No newline at end of file
diff --git a/ddecal/gain_solvers/kernels/IterativeDiagonal.cu b/ddecal/gain_solvers/kernels/IterativeDiagonal.cu
new file mode 100644
index 000000000..dbb790d9b
--- /dev/null
+++ b/ddecal/gain_solvers/kernels/IterativeDiagonal.cu
@@ -0,0 +1,292 @@
+// Copyright (C) 2023 ASTRON (Netherlands Institute for Radio Astronomy)
+// SPDX-License-Identifier: GPL-3.0-or-later
+
+#include "IterativeDiagonal.h"
+
+#include <cuComplex.h>
+#include <math_constants.h>
+
+#include "Common.h"
+#include "Complex.h"
+#include "MatrixComplex2x2.h"
+
+#define BLOCK_SIZE 128
+
+template <bool Add>
+__device__ void AddOrSubtract(size_t vis_index, size_t n_solutions,
+                              const unsigned int* antenna_pairs,
+                              const unsigned int* solution_map,
+                              const cuDoubleComplex* solutions,
+                              const cuM2x2FloatComplex* model,
+                              const cuM2x2FloatComplex* residual_in,
+                              cuM2x2FloatComplex* residual_out) {
+  const uint32_t antenna_1 = antenna_pairs[vis_index * 2 + 0];
+  const uint32_t antenna_2 = antenna_pairs[vis_index * 2 + 1];
+  const size_t solution_index = solution_map[vis_index];
+  const cuDoubleComplex* solution_1 =
+      &solutions[(antenna_1 * n_solutions + solution_index) * 2];
+  const cuDoubleComplex* solution_2 =
+      &solutions[(antenna_2 * n_solutions + solution_index) * 2];
+
+  const cuFloatComplex solution_1_0 = cuComplexDoubleToFloat(solution_1[0]);
+  const cuFloatComplex solution_1_1 = cuComplexDoubleToFloat(solution_1[1]);
+  const cuFloatComplex solution_2_0_conj =
+      cuComplexDoubleToFloat(cuConj(solution_2[0]));
+  const cuFloatComplex solution_2_1_conj =
+      cuComplexDoubleToFloat(cuConj(solution_2[1]));
+
+  const cuM2x2FloatComplex contribution(
+      cuCmulf(cuCmulf(solution_1_0, model[vis_index][0]), solution_2_0_conj),
+      cuCmulf(cuCmulf(solution_1_0, model[vis_index][1]), solution_2_1_conj),
+      cuCmulf(cuCmulf(solution_1_1, model[vis_index][2]), solution_2_0_conj),
+      cuCmulf(cuCmulf(solution_1_1, model[vis_index][3]), solution_2_1_conj));
+
+  if (Add) {
+    residual_out[vis_index] = residual_in[vis_index] + contribution;
+  } else {
+    residual_out[vis_index] = residual_in[vis_index] - contribution;
+  }
+}
+
+__device__ void SolveDirection(size_t vis_index, size_t n_visibilities,
+                               size_t n_direction_solutions, size_t n_solutions,
+                               const unsigned int* antenna_pairs,
+                               const unsigned int* solution_map,
+                               const cuDoubleComplex* solutions,
+                               const cuM2x2FloatComplex* model,
+                               const cuM2x2FloatComplex* residual,
+                               cuFloatComplex* numerator, float* denominator) {
+  // Load correct variables to compute on.
+  const size_t antenna_1 = antenna_pairs[vis_index * 2];
+  const size_t antenna_2 = antenna_pairs[vis_index * 2 + 1];
+  const size_t solution_index = solution_map[vis_index];
+
+  const cuDoubleComplex* solution_antenna_1 =
+      &solutions[(antenna_1 * n_solutions + solution_index) * 2];
+  const cuDoubleComplex* solution_antenna_2 =
+      &solutions[(antenna_2 * n_solutions + solution_index) * 2];
+
+  const size_t rel_solution_index = solution_index - solution_map[0];
+
+  // Calculate the contribution of this baseline for both antennas
+  // For antenna2,
+  // data_ba = data_ab^H, etc., therefore, numerator and denominator
+  // become:
+  // - num = data_ab^H * solutions_a * model_ab
+  // - den = norm(model_ab^H * solutions_a)
+  for (size_t i = 0; i < 2; i++) {
+    const size_t antenna = antenna_pairs[vis_index * 2 + i];
+
+    cuM2x2FloatComplex result;
+    cuM2x2FloatComplex changed_model;
+
+    if (i == 0) {
+      const cuM2x2FloatComplexDiagonal solution(
+          make_cuFloatComplex(solution_antenna_2[0].x, solution_antenna_2[0].y),
+          make_cuFloatComplex(solution_antenna_2[1].x,
+                              solution_antenna_2[1].y));
+      changed_model = solution * cuConj(model[vis_index]);
+      result = residual[vis_index] * changed_model;
+    } else {
+      const cuM2x2FloatComplexDiagonal solution(
+          make_cuFloatComplex(solution_antenna_1[0].x, solution_antenna_1[0].y),
+          make_cuFloatComplex(solution_antenna_1[1].x,
+                              solution_antenna_1[1].y));
+      changed_model = solution * model[vis_index];
+      result = cuConj(residual[vis_index]) * changed_model;
+    }
+
+    const size_t full_solution_index =
+        antenna * n_direction_solutions + rel_solution_index;
+
+    // Atomic reduction into global memory
+    atomicAdd(&numerator[full_solution_index * 2 + 0].x, result[0].x);
+    atomicAdd(&numerator[full_solution_index * 2 + 0].y, result[0].y);
+    atomicAdd(&numerator[full_solution_index * 2 + 1].x, result[3].x);
+    atomicAdd(&numerator[full_solution_index * 2 + 1].y, result[3].y);
+
+    atomicAdd(&denominator[full_solution_index * 2],
+              cuNorm(changed_model[0]) + cuNorm(changed_model[2]));
+    atomicAdd(&denominator[full_solution_index * 2 + 1],
+              cuNorm(changed_model[1]) + cuNorm(changed_model[3]));
+  }
+}
+
+__global__ void SolveDirectionKernel(
+    size_t n_visibilities, size_t n_direction_solutions, size_t n_solutions,
+    const unsigned int* antenna_pairs, const unsigned int* solution_map,
+    const cuDoubleComplex* solutions, const cuM2x2FloatComplex* model,
+    const cuM2x2FloatComplex* residual_in, cuM2x2FloatComplex* residual_temp,
+    cuFloatComplex* numerator, float* denominator) {
+  const size_t vis_index = blockIdx.x * blockDim.x + threadIdx.x;
+
+  if (vis_index >= n_visibilities) {
+    return;
+  }
+
+  AddOrSubtract<true>(vis_index, n_solutions, antenna_pairs, solution_map,
+                      solutions, model, residual_in, residual_temp);
+
+  SolveDirection(vis_index, n_visibilities, n_direction_solutions, n_solutions,
+                 antenna_pairs, solution_map, solutions, model, residual_temp,
+                 numerator, denominator);
+}
+
+void LaunchSolveDirectionKernel(
+    cudaStream_t stream, size_t n_visibilities, size_t n_direction_solutions,
+    size_t n_solutions, size_t direction, cu::DeviceMemory& antenna_pairs,
+    cu::DeviceMemory& solution_map, cu::DeviceMemory& solutions,
+    cu::DeviceMemory& model, cu::DeviceMemory& residual_in,
+    cu::DeviceMemory& residual_temp, cu::DeviceMemory& numerator,
+    cu::DeviceMemory& denominator) {
+  const size_t block_dim = BLOCK_SIZE;
+  const size_t grid_dim = (n_visibilities + block_dim) / block_dim;
+
+  const size_t direction_offset = direction * n_visibilities;
+  const unsigned int* solution_map_direction =
+      Cast<const unsigned int>(solution_map) + direction_offset;
+  const cuM2x2FloatComplex* model_direction =
+      Cast<const cuM2x2FloatComplex>(model) + direction_offset;
+  SolveDirectionKernel<<<grid_dim, block_dim, 0, stream>>>(
+      n_visibilities, n_direction_solutions, n_solutions,
+      Cast<const unsigned int>(antenna_pairs), solution_map_direction,
+      Cast<const cuDoubleComplex>(solutions), model_direction,
+      Cast<const cuM2x2FloatComplex>(residual_in),
+      Cast<cuM2x2FloatComplex>(residual_temp), Cast<cuFloatComplex>(numerator),
+      Cast<float>(denominator));
+}
+
+__global__ void SubtractKernel(size_t n_directions, size_t n_visibilities,
+                               size_t n_solutions,
+                               const unsigned int* antenna_pairs,
+                               const unsigned int* solution_map,
+                               const cuDoubleComplex* solutions,
+                               const cuFloatComplex* model,
+                               cuM2x2FloatComplex* residual) {
+  const size_t vis_index = blockIdx.x * blockDim.x + threadIdx.x;
+
+  if (vis_index >= n_visibilities) {
+    return;
+  }
+
+  for (size_t direction = 0; direction < n_directions; direction++) {
+    const size_t direction_offset = direction * n_visibilities;
+    const unsigned int* solution_map_direction =
+        solution_map + direction_offset;
+    const cuFloatComplex* model_direction = model + (4 * direction_offset);
+    AddOrSubtract<false>(
+        vis_index, n_solutions, antenna_pairs, solution_map_direction,
+        solutions, reinterpret_cast<const cuM2x2FloatComplex*>(model_direction),
+        residual, residual);  // in-place
+  }
+}
+
+void LaunchSubtractKernel(cudaStream_t stream, size_t n_directions,
+                          size_t n_visibilities, size_t n_solutions,
+                          cu::DeviceMemory& antenna_pairs,
+                          cu::DeviceMemory& solution_map,
+                          cu::DeviceMemory& solutions, cu::DeviceMemory& model,
+                          cu::DeviceMemory& residual) {
+  const size_t block_dim = BLOCK_SIZE;
+  const size_t grid_dim = (n_visibilities + block_dim) / block_dim;
+
+  SubtractKernel<<<grid_dim, block_dim, 0, stream>>>(
+      n_directions, n_visibilities, n_solutions,
+      Cast<const unsigned int>(antenna_pairs),
+      Cast<const unsigned int>(solution_map),
+      Cast<const cuDoubleComplex>(solutions), Cast<const cuFloatComplex>(model),
+      Cast<cuM2x2FloatComplex>(residual));
+}
+
+__global__ void SolveNextSolutionKernel(unsigned int n_antennas,
+                                        unsigned int n_direction_solutions,
+                                        const unsigned int n_solutions,
+                                        const unsigned int* solution_map,
+                                        const cuFloatComplex* numerator,
+                                        const float* denominator,
+                                        cuDoubleComplex* next_solutions) {
+  const size_t antenna = blockIdx.x * blockDim.x + threadIdx.x;
+
+  if (antenna >= n_antennas) {
+    return;
+  }
+
+  for (size_t relative_solution = 0; relative_solution < n_direction_solutions;
+       relative_solution++) {
+    const size_t solution_index = relative_solution + solution_map[0];
+    cuDoubleComplex* destination =
+        &next_solutions[(antenna * n_solutions + solution_index) * 2];
+    const size_t index = antenna * n_direction_solutions + relative_solution;
+
+    for (size_t pol = 0; pol < 2; pol++) {
+      if (denominator[index * 2 + pol] == 0.0) {
+        destination[pol] = {CUDART_NAN, CUDART_NAN};
+      } else {
+        // The CPU code performs this compuation in double-precision,
+        // however single-precision also seems sufficiently accurate.
+        destination[pol] = {
+            numerator[index * 2 + pol].x / denominator[index * 2 + pol],
+            numerator[index * 2 + pol].y / denominator[index * 2 + pol]};
+      }
+    }
+  }
+}
+
+void LaunchSolveNextSolutionKernel(
+    cudaStream_t stream, size_t n_antennas, size_t n_visibilities,
+    size_t n_direction_solutions, size_t n_solutions, size_t direction,
+    cu::DeviceMemory& antenna_pairs, cu::DeviceMemory& solution_map,
+    cu::DeviceMemory& next_solutions, cu::DeviceMemory& numerator,
+    cu::DeviceMemory& denominator) {
+  const size_t block_dim = BLOCK_SIZE;
+  const size_t grid_dim = (n_antennas + block_dim) / block_dim;
+
+  const size_t direction_offset = direction * n_visibilities;
+  const unsigned int* solution_map_direction =
+      Cast<const unsigned int>(solution_map) + direction_offset;
+  SolveNextSolutionKernel<<<grid_dim, block_dim, 0, stream>>>(
+      n_antennas, n_direction_solutions, n_solutions, solution_map_direction,
+      Cast<const cuFloatComplex>(numerator), Cast<const float>(denominator),
+      Cast<cuDoubleComplex>(next_solutions));
+}
+
+__global__ void StepKernel(const size_t n_visibilities,
+                           const cuDoubleComplex* solutions,
+                           cuDoubleComplex* next_solutions, bool phase_only,
+                           double step_size) {
+  const size_t vis_index = blockIdx.x * blockDim.x + threadIdx.x;
+
+  if (vis_index >= n_visibilities) {
+    return;
+  }
+
+  if (phase_only) {
+    // In phase only mode, a step is made along the complex circle,
+    // towards the shortest direction.
+    double phase_from = cuCarg(solutions[vis_index]);
+    double distance = cuCarg(next_solutions[vis_index]) - phase_from;
+    if (distance > CUDART_PI)
+      distance = distance - 2.0 * CUDART_PI;
+    else if (distance < -CUDART_PI)
+      distance = distance + 2.0 * CUDART_PI;
+
+    next_solutions[vis_index] =
+        cuCpolar(1.0, phase_from + step_size * distance);
+  } else {
+    next_solutions[vis_index] =
+        cuCadd(cuCmul(solutions[vis_index], (1.0 - step_size)),
+               cuCmul(next_solutions[vis_index], step_size));
+  }
+}
+
+void LaunchStepKernel(cudaStream_t stream, size_t n_visibilities,
+                      cu::DeviceMemory& solutions,
+                      cu::DeviceMemory& next_solutions, bool phase_only,
+                      double step_size) {
+  const size_t block_dim = BLOCK_SIZE;
+  const size_t grid_dim = (n_visibilities + block_dim) / block_dim;
+
+  StepKernel<<<grid_dim, block_dim, 0, stream>>>(
+      n_visibilities, Cast<const cuDoubleComplex>(solutions),
+      Cast<cuDoubleComplex>(next_solutions), phase_only, step_size);
+}
diff --git a/ddecal/gain_solvers/kernels/IterativeDiagonal.h b/ddecal/gain_solvers/kernels/IterativeDiagonal.h
new file mode 100644
index 000000000..188f59327
--- /dev/null
+++ b/ddecal/gain_solvers/kernels/IterativeDiagonal.h
@@ -0,0 +1,39 @@
+// Copyright (C) 2023 ASTRON (Netherlands Institute for Radio Astronomy)
+// SPDX-License-Identifier: GPL-3.0-or-later
+
+#ifndef DP3_DDECAL_GAIN_SOLVERS_KERNELS_ITERATIVEDIAGONAL_H_
+#define DP3_DDECAL_GAIN_SOLVERS_KERNELS_ITERATIVEDIAGONAL_H_
+
+#include <complex>
+#include <cuda_runtime.h>
+
+#include <cudawrappers/cu.hpp>
+
+void LaunchSubtractKernel(cudaStream_t stream, size_t n_directions,
+                          size_t n_visibilities, size_t n_solutions,
+                          cu::DeviceMemory& antenna_pairs,
+                          cu::DeviceMemory& solution_map,
+                          cu::DeviceMemory& solutions, cu::DeviceMemory& model,
+                          cu::DeviceMemory& residual);
+
+void LaunchSolveNextSolutionKernel(
+    cudaStream_t stream, size_t n_antennas, size_t n_visibilities,
+    size_t n_direction_solutions, size_t n_solutions, size_t direction,
+    cu::DeviceMemory& antenna_pairs, cu::DeviceMemory& solution_map,
+    cu::DeviceMemory& next_solutions, cu::DeviceMemory& numerator,
+    cu::DeviceMemory& denominator);
+
+void LaunchSolveDirectionKernel(
+    cudaStream_t stream, size_t n_visibilities, size_t n_direction_solutions,
+    size_t n_solutions, size_t direction, cu::DeviceMemory& antenna_pairs,
+    cu::DeviceMemory& solution_map, cu::DeviceMemory& solutions,
+    cu::DeviceMemory& model, cu::DeviceMemory& residual_in,
+    cu::DeviceMemory& residual_temp, cu::DeviceMemory& numerator,
+    cu::DeviceMemory& denominator);
+
+void LaunchStepKernel(cudaStream_t stream, size_t n_visibilities,
+                      cu::DeviceMemory& solutions,
+                      cu::DeviceMemory& next_solutions, bool phase_only,
+                      double step_size);
+
+#endif  // DP3_DDECAL_GAIN_SOLVERS_KERNELS_ITERATIVEDIAGONAL_H_
\ No newline at end of file
diff --git a/ddecal/gain_solvers/kernels/MatrixComplex2x2.h b/ddecal/gain_solvers/kernels/MatrixComplex2x2.h
new file mode 100644
index 000000000..bccfe5e3d
--- /dev/null
+++ b/ddecal/gain_solvers/kernels/MatrixComplex2x2.h
@@ -0,0 +1,116 @@
+// Copyright (C) 2023 ASTRON (Netherlands Institute for Radio Astronomy)
+// SPDX-License-Identifier: GPL-3.0-or-later
+
+#ifndef DP3_DDECAL_GAIN_SOLVERS_KERNELS_MATRIXCOMPLEX2X2_H_
+#define DP3_DDECAL_GAIN_SOLVERS_KERNELS_MATRIXCOMPLEX2X2_H_
+
+#include <cuComplex.h>
+
+template <typename T>
+struct cuM2x2 {
+ public:
+  __device__ cuM2x2() {
+    data[0] = {0};
+    data[1] = {0};
+    data[2] = {0};
+    data[3] = {0};
+  }
+  __device__ cuM2x2(T a, T b, T c, T d) {
+    data[0] = a;
+    data[1] = b;
+    data[2] = c;
+    data[3] = d;
+  }
+
+  inline __device__ const T& operator[](int i) const { return data[i]; }
+  inline __device__ T& operator[](int i) { return data[i]; }
+
+ private:
+  T data[4];
+};
+
+template <typename T>
+struct cuM2x2Diagonal {
+ public:
+  __device__ cuM2x2Diagonal() {
+    data[0] = {0};
+    data[1] = {0};
+  }
+  __device__ cuM2x2Diagonal(T a, T b) {
+    data[0] = a;
+    data[1] = b;
+  }
+
+  inline __device__ const T& operator[](int i) const { return data[i]; }
+  inline __device__ T& operator[](int i) { return data[i]; }
+
+ private:
+  T data[2];
+};
+
+using cuM2x2FloatComplex = cuM2x2<cuFloatComplex>;
+using cuM2x2DoubleComplex = cuM2x2<cuDoubleComplex>;
+
+inline __device__ cuM2x2FloatComplex operator+(const cuM2x2FloatComplex& a,
+                                               const cuM2x2FloatComplex& b) {
+  return cuM2x2FloatComplex(cuCaddf(a[0], b[0]), cuCaddf(a[1], b[1]),
+                            cuCaddf(a[2], b[2]), cuCaddf(a[3], b[3]));
+}
+
+inline __device__ cuM2x2FloatComplex operator-(const cuM2x2FloatComplex& a,
+                                               const cuM2x2FloatComplex& b) {
+  return cuM2x2FloatComplex(cuCsubf(a[0], b[0]), cuCsubf(a[1], b[1]),
+                            cuCsubf(a[2], b[2]), cuCsubf(a[3], b[3]));
+}
+
+inline __device__ cuM2x2FloatComplex operator*(const cuM2x2FloatComplex& a,
+                                               const cuM2x2FloatComplex& b) {
+  return cuM2x2FloatComplex(cuCaddf(cuCmulf(a[0], b[0]), cuCmulf(a[1], b[2])),
+                            cuCaddf(cuCmulf(a[0], b[1]), cuCmulf(a[1], b[3])),
+                            cuCaddf(cuCmulf(a[2], b[0]), cuCmulf(a[3], b[2])),
+                            cuCaddf(cuCmulf(a[2], b[1]), cuCmulf(a[3], b[3])));
+}
+
+inline __device__ cuM2x2DoubleComplex operator*(const cuM2x2DoubleComplex& a,
+                                                const cuM2x2DoubleComplex& b) {
+  return cuM2x2DoubleComplex(cuCadd(cuCmul(a[0], b[0]), cuCmul(a[1], b[2])),
+                             cuCadd(cuCmul(a[0], b[1]), cuCmul(a[1], b[3])),
+                             cuCadd(cuCmul(a[2], b[0]), cuCmul(a[3], b[2])),
+                             cuCadd(cuCmul(a[2], b[1]), cuCmul(a[3], b[3])));
+}
+
+inline __device__ cuM2x2FloatComplex cuConj(const cuM2x2FloatComplex& x) {
+  return cuM2x2FloatComplex(cuConjf(x[0]), cuConjf(x[2]), cuConjf(x[1]),
+                            cuConjf(x[3]));
+}
+
+inline __device__ cuM2x2DoubleComplex cuConj(const cuM2x2DoubleComplex x) {
+  return cuM2x2DoubleComplex(cuConj(x[0]), cuConj(x[2]), cuConj(x[1]),
+                             cuConj(x[3]));
+}
+
+inline __device__ cuM2x2DoubleComplex
+make_cuM2x2ComplexDouble(const cuM2x2FloatComplex& x) {
+  return cuM2x2DoubleComplex(
+      make_cuDoubleComplex(x[0]), make_cuDoubleComplex(x[1]),
+      make_cuDoubleComplex(x[2]), make_cuDoubleComplex(x[3])
+
+  );
+}
+
+using cuM2x2FloatComplexDiagonal = cuM2x2Diagonal<cuFloatComplex>;
+using cuM2x2DoubleComplexDiagonal = cuM2x2Diagonal<cuDoubleComplex>;
+
+inline __device__ cuM2x2FloatComplex
+operator*(const cuM2x2FloatComplexDiagonal& a, const cuM2x2FloatComplex& b) {
+  return cuM2x2FloatComplex(cuCmulf(a[0], b[0]), cuCmulf(a[0], b[1]),
+                            cuCmulf(a[1], b[2]), cuCmulf(a[1], b[3]));
+}
+
+inline __device__ cuM2x2DoubleComplex
+operator*(const cuM2x2DoubleComplexDiagonal& a, const cuM2x2DoubleComplex& b) {
+  return cuM2x2DoubleComplex(cuCmul(a[0], b[0]), cuCmul(a[0], b[1]),
+                             cuCmul(a[1], b[2]), cuCmul(a[1], b[3]));
+}
+
+#endif  // DP3_DDECAL_GAIN_SOLVERS_KERNELS_MATRIXCOMPLEX2X2_H_
diff --git a/ddecal/test/unit/tSolvers.cc b/ddecal/test/unit/tSolvers.cc
index 1bb69e90c..d03a9b3c1 100644
--- a/ddecal/test/unit/tSolvers.cc
+++ b/ddecal/test/unit/tSolvers.cc
@@ -7,6 +7,9 @@
 #include "../../gain_solvers/FullJonesSolver.h"
 #include "../../gain_solvers/HybridSolver.h"
 #include "../../gain_solvers/IterativeDiagonalSolver.h"
+#if defined(HAVE_CUDA)
+#include "../../gain_solvers/IterativeDiagonalSolverCuda.h"
+#endif
 #include "../../gain_solvers/IterativeFullJonesSolver.h"
 #include "../../gain_solvers/IterativeScalarSolver.h"
 #include "../../gain_solvers/ScalarSolver.h"
@@ -141,29 +144,44 @@ BOOST_FIXTURE_TEST_CASE(hybrid, SolverTester,
   BOOST_CHECK_EQUAL(result.iterations, kMaxIterations + 1);
 }
 
-BOOST_FIXTURE_TEST_CASE(iterative_diagonal, SolverTester,
-                        *boost::unit_test::label("slow")) {
-  SetDiagonalSolutions(false);
-  dp3::ddecal::IterativeDiagonalSolver solver;
-  InitializeSolver(solver);
+inline void TestIterativeDiagonal(dp3::ddecal::SolverBase& solver) {
+  SolverTester solver_tester;
+  solver_tester.SetDiagonalSolutions(false);
+  solver_tester.InitializeSolver(solver);
 
   BOOST_CHECK_EQUAL(solver.NSolutionPolarizations(), 2u);
   BOOST_REQUIRE_EQUAL(solver.ConstraintSolvers().size(), 1u);
   BOOST_CHECK_EQUAL(solver.ConstraintSolvers()[0], &solver);
 
-  const dp3::ddecal::BdaSolverBuffer& solver_buffer = FillBDAData();
-  const SolveData data(solver_buffer, kNChannelBlocks, kNDirections, kNAntennas,
-                       Antennas1(), Antennas2());
+  const dp3::ddecal::BdaSolverBuffer& solver_buffer =
+      solver_tester.FillBDAData();
+  const SolveData data(solver_buffer, SolverTester::kNChannelBlocks,
+                       SolverTester::kNDirections, SolverTester::kNAntennas,
+                       solver_tester.Antennas1(), solver_tester.Antennas2());
 
   dp3::ddecal::SolverBase::SolveResult result =
-      solver.Solve(data, GetSolverSolutions(), 0.0, nullptr);
+      solver.Solve(data, solver_tester.GetSolverSolutions(), 0.0, nullptr);
 
-  CheckDiagonalResults(1.0E-2);
+  solver_tester.CheckDiagonalResults(1.0E-2);
   // The iterative solver solves the requested accuracy within the max
   // iterations so just check if the nr of iterations is <= max+1.
-  BOOST_CHECK_LE(result.iterations, kMaxIterations + 1);
+  BOOST_CHECK_LE(result.iterations, SolverTester::kMaxIterations + 1);
+}
+
+BOOST_FIXTURE_TEST_CASE(iterative_diagonal, SolverTester,
+                        *boost::unit_test::label("slow")) {
+  dp3::ddecal::IterativeDiagonalSolver solver;
+  TestIterativeDiagonal(solver);
 }
 
+#if defined(HAVE_CUDA)
+BOOST_FIXTURE_TEST_CASE(iterative_diagonal_cuda, SolverTester,
+                        *boost::unit_test::label("slow")) {
+  dp3::ddecal::IterativeDiagonalSolverCuda solver;
+  TestIterativeDiagonal(solver);
+}
+#endif
+
 BOOST_FIXTURE_TEST_CASE(iterative_diagonal_dd_intervals, SolverTester,
                         *boost::unit_test::label("slow")) {
   SetDiagonalSolutions(true);
diff --git a/docs/schemas/DDECal.yml b/docs/schemas/DDECal.yml
index 4e453cfcc..0c6a86315 100644
--- a/docs/schemas/DDECal.yml
+++ b/docs/schemas/DDECal.yml
@@ -212,4 +212,10 @@ inputs:
   storebuffer:
     default: false
     type: bool
-    doc: Setting this to true will store the solution of DDECal into the buffer, allowing the usage of this solution in a later step. For example, a pipeline with  DDECal -> OneApplyCal  would be able to apply solutions to the data without requiring an intermediate format to be stored to disk. Note that it currently only works for single-direction solutions.
\ No newline at end of file
+    doc: Setting this to true will store the solution of DDECal into the buffer, allowing the usage of this solution in a later step. For example, a pipeline with  DDECal -> OneApplyCal  would be able to apply solutions to the data without requiring an intermediate format to be stored to disk. Note that it currently only works for single-direction solutions.
+  usegpu:
+    default: false
+    type: bool
+    doc: >-
+      Use GPU solver. This is an experimental feature only available for the iterative
+      diagonal solver and requires DP3 to be built with BUILD_WITH_CUDA=1 `.`
\ No newline at end of file
diff --git a/scripts/run-format.sh b/scripts/run-format.sh
index 3bf6a2697..81a842519 100755
--- a/scripts/run-format.sh
+++ b/scripts/run-format.sh
@@ -12,6 +12,9 @@ SOURCE_DIR=$(dirname "$0")/..
 #relative to SOURCE_DIR.
 EXCLUDE_DIRS=(external build CMake)
 
+#The patterns of the C++ source files, which clang-format should format.
+CXX_SOURCES=(*.cc *.h *.cu *.cuh)
+
 # clang-format version 12 and 14 produce slightly different results for
 # pythondp3/parameterset.cc, so hard-code this to 14.
 CLANG_FORMAT_BINARY=clang-format-14
-- 
GitLab