diff --git a/shaders/nif_image.comp b/shaders/nif_image.comp
index b985305d89c053bbc533c78df1eec8e6c53b7903..cbbae84433edd0f25a5a5e3bab2f326466ba98bf 100644
--- a/shaders/nif_image.comp
+++ b/shaders/nif_image.comp
@@ -30,6 +30,67 @@ shared vec3 s_antenna_a[DIMS];
 shared vec3 s_antenna_b[DIMS];
 shared vec2 s_vis[DIMS];
 
+void main() {
+  const float scale = 1.0/float(nx);
+    const uvec2 global_index = (gl_WorkGroupSize.xy * gl_WorkGroupID.xy) + gl_LocalInvocationID.xy;
+    const uvec2 local_index = gl_LocalInvocationID.xy;
+    const uvec2 group_index = gl_WorkGroupID.xy;
+    const uvec2 group_size = gl_WorkGroupSize.xy;
+    const uvec2 grid_size = gl_NumWorkGroups.xy;
+    const uvec2 block_size = (group_size * grid_size);
+    const uvec2 image_size = {nx, ny};
+    const uvec2 num_blocks = image_size / block_size;
+
+    const float a = 2.0 * PI_F * push_const.frequency  / SPEED_OF_LIGHT;
+    const float norm  = 1.0 / n_correlations;
+
+    uint nr_threads = group_size.x * group_size.y;
+     
+    uint thread_id = gl_LocalInvocationID.y * group_size.x + gl_LocalInvocationID.x;
+    uint boundary = (n_correlations % nr_threads ==0) ? 1 : 0;
+
+    uint radius_id = push_const.start_radius + gl_WorkGroupID.z;
+    const float radius = radii[radius_id];
+    if (radius_id < nradii) {
+        uint skip = radius_id * nx * ny;
+        for(uint b_x = 0; b_x < num_blocks.x; b_x++){    
+            for(uint b_y = 0; b_y < num_blocks.y; b_y++) {
+                const uvec2 tile_index = {b_x, b_y};
+                const uvec2 pixel_index = tile_index * block_size + group_size * group_index + local_index;
+                float sum = 0.;
+                if((pixel_index.x < nx) && (pixel_index.y < ny)) {
+                    vec2 lm = 2.0 * (vec2(pixel_index) + 0.5)* scale - 1.0;
+                    float square = dot(lm, lm);
+                    if(square < 1.0){
+                        float n = sqrt(1.0 - square);
+                        
+                        vec3 lmn = {lm.x, lm.y, n};
+                        vec3 position =  lmn * radius;
+                        
+                        for (uint s = 0; s < n_correlations; s++) {
+                            vec3 position_a = {antenna_positions_a[s*3+0], antenna_positions_a[s*3+1], antenna_positions_a[s*3+2]};
+                            vec3 position_b = {antenna_positions_b[s*3+0], antenna_positions_b[s*3+1], antenna_positions_b[s*3+2]};
+                            
+                            vec3 diff_a = position_a - position;
+                            vec3 diff_b = position_b - position;
+                            
+                            float d_a = sqrt(dot(diff_a, diff_a));
+                            float d_b = sqrt(dot(diff_b, diff_b));
+
+                            float d_phase = a * (d_a - d_b);
+                            sum = fma(amplitudes[s], cos(d_phase + phases[s]), sum);
+                        }
+                    }
+                }
+                if((pixel_index.x < nx) && (pixel_index.y < ny)) { 
+                    image_cube[nx * pixel_index.y + pixel_index.x + skip] = fma(sum, norm, image_cube[nx * pixel_index.y + pixel_index.x + skip]); 
+                }
+            }
+        }
+    }
+}
+
+/*
 
 void main() {
 
@@ -51,7 +112,7 @@ void main() {
     uint thread_id = gl_LocalInvocationID.y * group_size.x + gl_LocalInvocationID.x;
     uint boundary = (n_correlations % nr_threads ==0) ? 1 : 0;
     uint nbulks = n_correlations / nr_threads + boundary;
-
+    
     for (int bulk_id = 0; bulk_id < nbulks; bulk_id +=1){
         uint idx = thread_id + bulk_id * nr_threads;
         if (idx < n_correlations){
@@ -68,10 +129,10 @@ void main() {
         }
         memoryBarrierShared();
         barrier();
-
         
-        uint radius_id = 0; //push_const.start_radius + gl_LocalInvocationID.z;
-        const float radius = 1.e8; // radii[radius_id];
+        uint radius_id = push_const.start_radius + gl_WorkGroupID.z;
+        const float radius = radii[radius_id];
+
         if (radius_id < nradii) {
             uint skip = radius_id * nx * ny;
             for(uint b_x = 0; b_x < num_blocks.x; b_x++){    
@@ -82,23 +143,24 @@ void main() {
                     if((pixel_index.x < nx) && (pixel_index.y < ny)) {
                         vec2 lm = 2.0 * (vec2(pixel_index) + 0.5)* scale - 1.0;
                         float square = dot(lm, lm);
-                        float n = sqrt(1.0 - square) - 1.0;
-                        
-                        vec3 lmn = {lm.x, lm.y, n};
-                        vec3 position =  lmn * radius;
-                        if (idx == thread_id) {
-                            image_cube[nx * pixel_index.y + pixel_index.x + skip] = 0.0;
-                        }
-                        for (uint s = 0; s < nr_threads; s++) {
-                            vec3 diff_a = s_antenna_a[s] - position;
-                            vec3 diff_b = s_antenna_b[s] - position;
+                        if(square < 1.0){
+                            float n = sqrt(1.0 - square) - 1.0;
                             
-                            const float d_a = sqrt(dot(diff_a, diff_a));
-                            const float d_b = sqrt(dot(diff_b, diff_b));
-
-                            const float d_phase = a * (d_a - d_b);
-                            sum = fma(s_vis[s].x, cos(d_phase + s_vis[s].y), sum);
-
+                            vec3 lmn = {lm.x, lm.y, n};
+                            vec3 position =  lmn * radius;
+                            if (idx == thread_id) {
+                                image_cube[nx * pixel_index.y + pixel_index.x + skip] = 0.0;
+                            }
+                            for (uint s = 0; s < nr_threads; s++) {
+                                vec3 diff_a = s_antenna_a[s] - position;
+                                vec3 diff_b = s_antenna_b[s] - position;
+                                
+                                const float d_a = sqrt(dot(diff_a, diff_a));
+                                const float d_b = sqrt(dot(diff_b, diff_b));
+
+                                const float d_phase = a * (d_a - d_b);
+                                sum = fma(s_vis[s].x, cos(d_phase + s_vis[s].y), sum);
+                            }
                         }
                     }
                     memoryBarrierShared();
@@ -111,3 +173,4 @@ void main() {
         }
     }
 }
+*/
\ No newline at end of file
diff --git a/src/nif.cpp b/src/nif.cpp
index 44a731d571ccf750b56157aee08f5a1c02fdfc08..0868e85c27855972152cdb4a533d6dda17d0e4e3 100644
--- a/src/nif.cpp
+++ b/src/nif.cpp
@@ -157,7 +157,7 @@ void allocate_device_memory(positions_type &antenna_positions_a,
       antenna_positions_a.data(), antenna_positions_a.size(), sizeof(float),
       kp::Tensor::TensorDataTypes::eFloat);
   device_handles.antenna_positions_b = device_handles.manager->tensor(
-      antenna_positions_a.data(), antenna_positions_a.size(), sizeof(float),
+      antenna_positions_b.data(), antenna_positions_b.size(), sizeof(float),
       kp::Tensor::TensorDataTypes::eFloat);
 
   std::vector<std::shared_ptr<kp::Tensor>> params = {
@@ -172,14 +172,14 @@ void allocate_device_memory(positions_type &antenna_positions_a,
 
 struct run_parameters {
   float frequency;
-  size_t run_id;
+  uint32_t run_id;
 };
 
 void process_visibility_slice(visibility_type &visibilities, float frequency,
                               xt::xarray<float> &radii, size_t nx, size_t ny,
                               kompute_context handles) {
   xt::adapt(handles.phase->data<float>(), visibilities.shape()) =
-      xt::angle(visibilities) * 0.f;
+      xt::angle(visibilities);
   xt::adapt(handles.amplitudes->data<float>(), visibilities.shape()) =
       xt::sqrt(xt::norm(visibilities));
   xt::adapt(handles.radii->data<float>(), radii.shape()) = radii;
@@ -198,16 +198,23 @@ void process_visibility_slice(visibility_type &visibilities, float frequency,
 
   run_parameters run_pars = {frequency, radius_id};
   std::vector<run_parameters> push_constants = {run_pars};
-  std::cout << "Processing nradii " << nradii << std::endl;
-  kp::Workgroup workgroup({16, 16, 1});
+  uint32_t n_radii_per_time = 5;
+  std::cout << "Processing nradii " << nradii << " in chunks of "
+            << n_radii_per_time << std::endl;
+
+  kp::Workgroup workgroup({16, 16, n_radii_per_time});
   auto sequence = handles.manager->sequence();
-  auto algorithm =
-      handles.manager->algorithm(params, handles.nif_image_shader, workgroup,
-                                 spec_constants, push_constants);
+  auto algorithm = handles.manager->algorithm<size_t, run_parameters>(
+      params, handles.nif_image_shader, workgroup, spec_constants,
+      push_constants);
   sequence->begin();
-  sequence->record<kp::OpTensorSyncDevice>(
-      {handles.phase, handles.amplitudes, handles.radii});
-  sequence->record<kp::OpAlgoDispatch>(algorithm, push_constants);
+
+  sequence->record<kp::OpTensorSyncDevice>(params);
+  for (uint32_t k = 0; k < nradii / n_radii_per_time; k++) {
+    sequence->record<kp::OpAlgoDispatch>(
+        algorithm,
+        std::vector<run_parameters>{{frequency, k * n_radii_per_time}});
+  }
   sequence->record<kp::OpTensorSyncLocal>({handles.image_cube});
 
   sequence->end();
@@ -319,9 +326,9 @@ void nifimager(const std::string msin, const std::string out,
 
   image_cube.resize({ntimes, nradii, npixels, npixels});
   xt::xarray<float> radii = xt::logspace<float>(3.f, 8.f, nradii);
-  radii = 1.e5;
+
   float frequency = xt::mean(frequencies, 0)();
-  for (uint t = 0; t < ntimes; t++) {
+  for (uint t = 0; t < 1; t++) {
     double time = read_visibilities_slice(msin, data, t, n_correlations);
     times.push_back(time);