From 19869257fe94fcd67f3c149600e79ada2dd3ca36 Mon Sep 17 00:00:00 2001
From: Mattia Mancini <mancini@astron.nl>
Date: Tue, 11 Sep 2018 08:38:13 +0000
Subject: [PATCH] SSB-42: guard agains memory leak

---
 .../lib/datacontainers/measurementset.py      | 49 ++++++++++++-------
 1 file changed, 30 insertions(+), 19 deletions(-)

diff --git a/CAL/CalibrationCommon/lib/datacontainers/measurementset.py b/CAL/CalibrationCommon/lib/datacontainers/measurementset.py
index 6cf95e9c5f5..485bb1ba40e 100644
--- a/CAL/CalibrationCommon/lib/datacontainers/measurementset.py
+++ b/CAL/CalibrationCommon/lib/datacontainers/measurementset.py
@@ -18,7 +18,12 @@ class MeasurementSet(object):
 
     def get_frequency(self):
         spectral_window_table = self.get_spectral_window_table()
-        return spectral_window_table.getcol('REF_FREQUENCY')
+        try:
+            reference_frequency =  spectral_window_table.getcol('REF_FREQUENCY')
+        finally:
+            spectral_window_table.close()
+
+        return reference_frequency
 
     def get_data_table(self):
         data_table = MS_Table(self.path)
@@ -38,40 +43,46 @@ class MeasurementSet(object):
 
     def get_start_end_observation(self):
         observation_table = self.get_observation_table()
-        start_time_in_seconds, end_time_in_seconds = observation_table.getcol('TIME_RANGE')
-        hour_in_seconds = 60 * 60
-        day_in_seconds = hour_in_seconds * 24
-        start_time = astrotime.Time(start_time_in_seconds/day_in_seconds, format='mjd', scale='utc')
-        end_time = astrotime.Time(start_time_in_seconds / day_in_seconds, format='mjd', scale='utc')
+
+        try:
+            start_time_in_seconds, end_time_in_seconds = observation_table.getcol('TIME_RANGE')
+            hour_in_seconds = 60 * 60
+            day_in_seconds = hour_in_seconds * 24
+            start_time = astrotime.Time(start_time_in_seconds/day_in_seconds, format='mjd', scale='utc')
+            end_time = astrotime.Time(start_time_in_seconds / day_in_seconds, format='mjd', scale='utc')
+        finally:
+            observation_table.close()
 
         return start_time.to_datetime(), end_time.to_datetime()
 
     def read_cross_correlation_per_station_names(self, reference, target):
-
         data_table = self.get_data_table()
         antennas_table = self.get_antenna_table()
-        antenna_name_id_map = {name:i for i, name in enumerate(antennas_table.getcol('NAME'))}
-        antenna1_list = data_table.getcol('ANTENNA1')
-        antenna2_list = data_table.getcol('ANTENNA2')
-        timestamp = data_table.getcol('TIME')
-        cross_correlation = data_table.getcol('DATA')
-        reference_antenna_id = antenna_name_id_map[reference]
-        target_antenna_id = antenna_name_id_map[target]
 
-        selected_data = [index for index, (a_i, a_j) in enumerate(zip(antenna1_list, antenna2_list))
-                         if a_i == reference_antenna_id and a_j == target_antenna_id]
+        try:
+            antenna_name_id_map = {name:i for i, name in enumerate(antennas_table.getcol('NAME'))}
+            antenna1_list = data_table.getcol('ANTENNA1')
+            antenna2_list = data_table.getcol('ANTENNA2')
+            timestamp = data_table.getcol('TIME')
+            cross_correlation = data_table.getcol('DATA')
+            reference_antenna_id = antenna_name_id_map[reference]
+            target_antenna_id = antenna_name_id_map[target]
 
-        return (timestamp[selected_data], cross_correlation[selected_data])
+            selected_data = [index for index, (a_i, a_j) in enumerate(zip(antenna1_list, antenna2_list))
+                             if a_i == reference_antenna_id and a_j == target_antenna_id]
+        finally:
+            data_table.close()
+            antennas_table.close()
 
 
+        return (timestamp[selected_data], cross_correlation[selected_data])
+
     def __repr__(self):
         return 'MeasurementSet(%d) located in %s for sas_id %d and sub_band_id %d' % (id(self),
                                                                                   self.name,
                                                                                   self.sas_id,
                                                                                   self.beamlet)
 
-
-
     @staticmethod
     def parse_sas_id_and_sub_band_from_ms_name(ms_name):
         if not MeasurementSet.is_a_valid_ms_name(ms_name):
-- 
GitLab