diff --git a/rapthor/scripts/filter_skymodel.py b/rapthor/scripts/filter_skymodel.py
index 17b86ed27b73c373336c31894c51e4ece85b7b23..f7a57b6c57f9b51c40ba62ea1508acbb1af0081c 100755
--- a/rapthor/scripts/filter_skymodel.py
+++ b/rapthor/scripts/filter_skymodel.py
@@ -18,6 +18,8 @@ from rapthor.lib.observation import Observation
 from scipy.interpolate import interp1d
 import subprocess
 import sys
+import shutil
+import tempfile
 
 
 def calc_theoretical_noise(mslist, w_factor=1.5):
@@ -205,6 +207,7 @@ def main(input_image, input_skymodel_pb, output_root, vertices_file, beamMS,
         if os.path.exists(tmpdir):
             os.environ["TMPDIR"] = tmpdir
             break
+    temp_ms_dir = tempfile.mkdtemp()  # used for storing a copy of the beam MS file
 
     # Run PyBDSF to make a mask for grouping
     if use_adaptive_threshold:
@@ -239,7 +242,7 @@ def main(input_image, input_skymodel_pb, output_root, vertices_file, beamMS,
                              atrous_do=True, atrous_jmax=3, rms_map=True, quiet=True)
 
     # Collect some diagnostic numbers for later reporting. Note: we ensure all
-    # numbers are float, as, e.g., np.float32 is not supported by json.dump()
+    # non-integer numbers are float, as, e.g., np.float32 is not supported by json.dump()
     theoretical_rms, unflagged_fraction = calc_theoretical_noise(beamMS)  # Jy/beam
     min_rms = float(np.min(img.rms_arr))  # Jy/beam
     max_rms = float(np.max(img.rms_arr))  # Jy/beam
@@ -293,9 +296,9 @@ def main(input_image, input_skymodel_pb, output_root, vertices_file, beamMS,
         hdu[0].data = data
         hdu.writeto(maskfile, overwrite=True)
 
-        # Now filter the sky model using the mask made above
+        # Select the best MS for the beam attenuation and copy it to TMPDIR,
+        # as this is likely to have faster I/O (important for EveryBeam use)
         if len(beamMS) > 1:
-            # Select the best MS for the beam attenuation
             ms_times = []
             for ms in beamMS:
                 tab = pt.table(ms, ack=False)
@@ -306,14 +309,19 @@ def main(input_image, input_skymodel_pb, output_root, vertices_file, beamMS,
             beam_ind = ms_times.index(mid_time)
         else:
             beam_ind = 0
+        beam_ms = os.path.join(temp_ms_dir, os.path.basename(beamMS[beam_ind]))
+        shutil.copytree(beamMS[beam_ind], beam_ms)
+
+        # Load the sky model with the associated beam MS
         try:
-            s_in = lsmtool.load(input_skymodel_pb, beamMS=beamMS[beam_ind])
+            s_in = lsmtool.load(input_skymodel_pb, beamMS=beam_ms)
         except astropy.io.ascii.InconsistentTableError:
             emptysky = True
+
+        # If bright sources were peeled before imaging, add them back
         if input_bright_skymodel_pb is not None:
             try:
-                # If bright sources were peeled before imaging, add them back
-                s_bright = lsmtool.load(input_bright_skymodel_pb, beamMS=beamMS[beam_ind])
+                s_bright = lsmtool.load(input_bright_skymodel_pb)
 
                 # Rename the bright sources, removing the '_sector_*' added previously
                 # (otherwise the '_sector_*' text will be added every iteration,
@@ -327,9 +335,11 @@ def main(input_image, input_skymodel_pb, output_root, vertices_file, beamMS,
                     emptysky = False
             except astropy.io.ascii.InconsistentTableError:
                 pass
+
+        # Do final filtering and write out the sky models
         if not emptysky:
-            # Keep only those sources with positive flux densities
             if remove_negative:
+                # Keep only those sources with positive flux densities
                 s_in.select('I > 0.0')
             if s_in and filter_by_mask:
                 # Keep only those sources in PyBDSF masked regions
@@ -413,9 +423,10 @@ def main(input_image, input_skymodel_pb, output_root, vertices_file, beamMS,
         with open(output_root+'.true_sky.txt', 'w') as f:
             f.writelines(dummylines)
 
-    # Set the TMPDIR env var back to its original value
+    # Set the TMPDIR env var back to its original value and clean up
     if old_tmpdir is not None:
         os.environ["TMPDIR"] = old_tmpdir
+    misc.delete_directory(temp_ms_dir)
 
 
 if __name__ == '__main__':