Skip to content
Snippets Groups Projects
Select Git revision
  • 580da8f50231bf0c4db5c7a7889e62f1e440de59
  • main default protected
2 results

define_tasks_from_archive.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    test_s4_computation.py 3.46 KiB
    import os
    import unittest
    from glob import glob
    
    import numpy
    import numpy.random
    from scintillation.Calibrationlib import apply_bandpass, getS4, getS4Numba
    from scintillation.averaging import open_dataset, extract_metadata, decode_str
    
    basepath = os.path.dirname(__file__)
    test_datasets = glob(os.path.join(basepath, 'data', '*.h5'))
    
    
    def is_test_data_present():
        return len(test_datasets) > 0
    
    
    def get_filtered_data_from_dataset(dataset):
        metadata = extract_metadata(dataset)
        dynspec, *_ = metadata.keys()
        data_array = dataset[dynspec]['DATA'][:, :, 0]
        frequency = dataset[dynspec]['COORDINATES']['SPECTRAL'].attrs['AXIS_VALUE_WORLD']
        time_delta, *_ = decode_str(dataset[dynspec]['COORDINATES']['TIME'].attrs['INCREMENT'])
    
        subset = data_array[0:10000, :]
        start_frequency, end_frequency = frequency[0] / 1.e6, frequency[-1] / 1.e6
    
        frequency_axis = numpy.linspace(start_frequency, end_frequency, data_array.shape[1])
    
        averaging_window_in_samples = int(numpy.ceil(1 / time_delta))
        averaging_window_in_seconds = averaging_window_in_samples * time_delta
        sample_rate = int(averaging_window_in_samples * 3. / averaging_window_in_seconds)
        S4_60s_window_in_samples = int(60. / time_delta)
    
        filtered_data, flags, flux, bandpass = apply_bandpass(subset, frequency_axis,
                                                              freqaxis=1, timeaxis=0, target=metadata[dynspec]["TARGET"],
                                                              sample_rate=sample_rate,
                                                              # sample every 3 seconds
                                                              filter_length=600,  # window size 30 minutes
                                                              rfi_filter_length=averaging_window_in_samples // 2,
                                                              # window size in time to prevent flagging scintillation
                                                              flag_value=8, replace_value=1)
        return filtered_data, S4_60s_window_in_samples, averaging_window_in_samples
    
    
    class TestS4Generation(unittest.TestCase):
        @unittest.skipUnless(is_test_data_present(), 'missing test data')
        def test_reference_implementation_succeed(self):
            dataset = open_dataset(test_datasets[0])
            filtered_data, S4_60s_window_in_samples, averaging_window_in_samples = get_filtered_data_from_dataset(dataset)
            print(averaging_window_in_samples, S4_60s_window_in_samples, filtered_data.shape)
            for i in range(10):
                _ = getS4(filtered_data,
                      window_size=S4_60s_window_in_samples,  # 60 seconds
                      skip_sample_time=averaging_window_in_samples,
                      # create S4 every averaging time
                      axis=0, has_nan=False)
    
        @unittest.skipUnless(is_test_data_present(), 'missing test data')
        def test_numba_implementation_succeed(self):
            dataset = open_dataset(test_datasets[0])
            filtered_data, S4_60s_window_in_samples, averaging_window_in_samples = get_filtered_data_from_dataset(dataset)
            print(averaging_window_in_samples, S4_60s_window_in_samples, filtered_data.shape)
            for i in range(10):
                _ = getS4Numba(filtered_data,
                      window_size=S4_60s_window_in_samples,  # 60 seconds
                      skip_sample_time=averaging_window_in_samples,
                      # create S4 every averaging time
                      axis=0, has_nan=False)
    
    
    
    
    if __name__ == '__main__':
        unittest.main()