diff --git a/CMakeLists.txt b/CMakeLists.txt
index 752bf4b6fbd336db66a7e7135f8bc999b67d74c4..90ded5d4faa2a9fd554b698a91a23f7849302f6c 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -531,6 +531,7 @@ add_library(
   steps/Step.cc
   steps/Upsample.cc
   steps/UVWFlagger.cc
+  steps/WGridderPredict.cc
   steps/ApplyBeam.cc
   steps/NullStokes.cc
   steps/SagecalPredict.cc)
diff --git a/pythondp3/steps/idgpredict.py b/pythondp3/steps/idgpredict.py
new file mode 100644
index 0000000000000000000000000000000000000000..d41f090d4da338ff7e5b1b000f073ddac3f0a699
--- /dev/null
+++ b/pythondp3/steps/idgpredict.py
@@ -0,0 +1,244 @@
+# Copyright (C) 2024 ASTRON (Netherlands Institute for Radio Astronomy)
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import dp3
+import everybeam as eb
+import numpy as np
+import idg
+import astropy.io.fits as fits
+import time
+import logging
+
+
+class IdgPredictStep(dp3.Step):
+    def __init__(self, parset, prefix):
+        super().__init__()
+        self.read_parset(parset, prefix)
+        self.dpbuffers = []
+        self.is_initialized = False
+
+    def show(self):
+        print()
+        print(self.__class__.__name__)
+        print("  usebeammodel:  ", self.usebeammodel)
+        if self.usebeammodel:
+            print("  beammode:      ", self.beammode)
+        print()
+
+    def get_required_fields(self):
+        return dp3.Fields.FLAGS | dp3.Fields.WEIGHTS | dp3.Fields.UVW
+
+    def get_provided_fields(self):
+        return dp3.Fields.DATA
+
+    def process(self, dpbuffer):
+        # Accumulate buffers
+        self.dpbuffers.append(dpbuffer)
+
+        # If we have accumulated enough data, process it
+        if len(self.dpbuffers) == self.ampl_interval:
+            self.process_buffers()
+
+            # Send processed data to the next step
+            for dpbuffer in self.dpbuffers:
+                self.get_next_step().process(dpbuffer)
+
+            # Clear accumulated data
+            self.dpbuffers = []
+
+    def finish(self):
+        # If there is any remaining data, process it
+        if len(self.dpbuffers):
+            self.process_buffers()
+            for dpbuffer in self.dpbuffers:
+                self.get_next_step().process(dpbuffer)
+            self.dpbuffers = []
+        self.get_next_step().finish()
+
+    def _update_info(self, dpinfo):
+        super()._update_info(dpinfo)
+
+    def read_parset(self, parset, prefix):
+        """
+        Read relevant information from a given parset
+
+        Parameters
+        ----------
+        parset : dp3.ParameterSet
+            ParameterSet object provided by DP3
+        prefix : str
+            Prefix to be used when reading the parset.
+        """
+
+        self.imagename = parset.get_string(prefix + "modelimage")
+        self.padding = parset.get_float(prefix + "padding", 1.2)
+        self.nr_correlations = parset.get_int(prefix + "nrcorrelations", 4)
+        self.subgrid_size = parset.get_int(prefix + "subgridsize", 32)
+
+        self.taper_support = parset.get_int(prefix + "tapersupport", 7)
+        wterm_support = parset.get_int(prefix + "wtermsupport", 5)
+        aterm_support = parset.get_int(prefix + "atermsupport", 5)
+        self.kernel_size = self.taper_support + wterm_support + aterm_support
+
+    def initialize(self):
+        self.is_initialized = True
+
+        self.shift = np.array((0.0, 0.0), dtype=np.float32)
+
+        self.nr_stations = self.info.n_antenna
+        self.nr_baselines = (self.nr_stations * (self.nr_stations - 1)) // 2
+        self.frequencies = np.array(self.info.channel_frequencies, dtype=np.float32)
+        self.nr_channels = len(self.frequencies)
+
+        self.baselines = np.zeros(shape=(self.nr_baselines, 2), dtype=np.int32)
+
+        station1 = np.array(self.info.first_antenna_indices)
+        station2 = np.array(self.info.second_antenna_indices)
+        self.auto_corr_mask = station1 != station2
+        self.baselines[:, 0] = station1[self.auto_corr_mask]
+        self.baselines[:, 1] = station2[self.auto_corr_mask]
+
+        if self.proxytype.lower() == "gpu":
+            self.proxy = idg.HybridCUDA.GenericOptimized()
+        else:
+            self.proxy = idg.CPU.Optimized()
+
+        # read image dimensions from fits header
+        h = fits.getheader(self.imagename)
+        N0 = h["NAXIS1"]
+        self.cell_size = np.deg2rad(abs(h["CDELT1"]))
+
+        # Pointing of image
+        # TODO This should be checked against the pointing in the MS
+        self.ra = np.deg2rad(h["CRVAL1"])
+        self.dec = np.deg2rad(h["CRVAL2"])
+
+        # compute padded image size
+        N = next_composite(int(N0 * self.padding))
+        self.grid_size = N
+        self.image_size = N * self.cell_size
+
+        # Initialize empty grid
+        self.grid = np.zeros(
+            shape=(self.nr_correlations, self.grid_size, self.grid_size),
+            dtype=idg.gridtype,
+        )
+
+        # Initialize taper
+        taper = idgwindow(self.subgrid_size, self.taper_support, self.padding)
+        self.taper2 = np.outer(taper, taper).astype(np.float32)
+
+        taper_ = np.fft.fftshift(np.fft.fft(np.fft.ifftshift(taper)))
+        taper_grid = np.zeros(self.grid_size, dtype=np.complex128)
+        taper_grid[
+            (self.grid_size - self.subgrid_size)
+            // 2 : (self.grid_size + self.subgrid_size)
+            // 2
+        ] = taper_ * np.exp(
+            -1j * np.linspace(-np.pi / 2, np.pi / 2, self.subgrid_size, endpoint=False)
+        )
+        taper_grid = (
+            np.fft.fftshift(np.fft.ifft(np.fft.ifftshift(taper_grid))).real
+            * self.grid_size
+            / self.subgrid_size
+        )
+        taper_grid0 = taper_grid[(N - N0) // 2 : (N + N0) // 2]
+
+        # read image data, assume Stokes I
+        d = fits.getdata(self.imagename)
+        self.grid[0, (N - N0) // 2 : (N + N0) // 2, (N - N0) // 2 : (N + N0) // 2] = d[
+            0, 0, :, :
+        ] / np.outer(taper_grid0, taper_grid0)
+        self.grid[3, (N - N0) // 2 : (N + N0) // 2, (N - N0) // 2 : (N + N0) // 2] = d[
+            0, 0, :, :
+        ] / np.outer(taper_grid0, taper_grid0)
+
+        self.proxy.set_grid(self.grid)
+        self.proxy.transform(idg.ImageDomainToFourierDomain)
+
+        self.proxy.init_cache(
+            self.subgrid_size, self.cell_size, self.w_step, self.shift
+        )
+
+    def process_buffers(self):
+        """
+        Processing the buffers. This is the central method within any class that
+        derives from dp3.Step
+        """
+
+        if not self.is_initialized:
+            self.initialize()
+
+        # Concatenate accumulated data and display just the shapes
+        visibilities = self._extract_buffer("visibilities")
+        uvw_ = self._extract_buffer("uvw", apply_autocorr_mask=False)
+        uvw = np.zeros(
+            shape=(self.nr_baselines, self.ampl_interval, 3), dtype=np.float32
+        )
+
+        uvw[..., 0] = uvw_[self.auto_corr_mask, :, 0]
+        uvw[..., 1] = -uvw_[self.auto_corr_mask, :, 1]
+        uvw[..., 2] = -uvw_[self.auto_corr_mask, :, 2]
+
+        aterms = np.ascontiguousarray(aterms.astype(idg.idgtypes.atermtype))
+
+        self.proxy.degridding(
+            self.kernel_size,
+            frequencies,
+            visibilities,
+            uvw,
+            self.baselines,
+            aterms,
+            self.aterm_offsets,
+            self.taper2,
+        )
+        
+        for idx, dpbuffer in enumerate(self.dpbuffers):
+            visibilities_out = np.array(dpbuffer.get_data(), copy=False)
+            visibilities_out[
+                :,
+                channel_block
+                * self.nr_channels_per_block : (channel_block + 1)
+                * self.nr_channels_per_block,
+                :,
+            ] = visibilities[:, idx, :, :]
+
+
+    def _extract_buffer(self, name, apply_autocorr_mask=True):
+        """
+        Extract buffer from buffered data.
+
+        Parameters
+        ----------
+        name : str
+            Should be any of ("visibilities", "weights", "flags", "uvw")
+        apply_autocorr_mask : bool, optional
+            Remove autocorrelation from returned result? Defaults to True
+
+        Returns
+        -------
+        np.ndarray
+        """
+
+        if name == "visibilities":
+            result = [
+                np.array(dpbuffer.get_data(), copy=False) for dpbuffer in self.dpbuffers
+            ]
+        elif name == "flags":
+            result = [
+                np.array(dpbuffer.get_flags(), copy=False)
+                for dpbuffer in self.dpbuffers
+            ]
+        elif name == "weights":
+            result = [
+                np.array(dpbuffer.get_weights(), copy=False)
+                for dpbuffer in self.dpbuffers
+            ]
+        elif name == "uvw":
+            result = [
+                np.array(dpbuffer.get_uvw(), copy=False) for dpbuffer in self.dpbuffers
+            ]
+        else:
+            raise ValueError("Name not recognized")
+        result = np.stack(result, axis=1)
+        return result[self.auto_corr_mask, :, :] if apply_autocorr_mask else result