Skip to content
Snippets Groups Projects
Select Git revision
  • 913c21d8e11d5e95e525fc7a7ab8cfc8029a9530
  • master default protected
  • dither_on_off_disabled
  • yocto
  • pypcc2
  • pypcc3
  • 2020-12-07-the_only_working_copy
  • v2.1
  • v2.0
  • v1.0
  • v0.9
  • Working-RCU_ADC,ID
  • 2020-12-11-Holiday_Season_release
13 results

test_clk.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    test_s4_computation.py 3.46 KiB
    import os
    import unittest
    from glob import glob
    
    import numpy
    import numpy.random
    from scintillation.Calibrationlib import apply_bandpass, getS4, getS4Numba
    from scintillation.averaging import open_dataset, extract_metadata, decode_str
    
    basepath = os.path.dirname(__file__)
    test_datasets = glob(os.path.join(basepath, 'data', '*.h5'))
    
    
    def is_test_data_present():
        return len(test_datasets) > 0
    
    
    def get_filtered_data_from_dataset(dataset):
        metadata = extract_metadata(dataset)
        dynspec, *_ = metadata.keys()
        data_array = dataset[dynspec]['DATA'][:, :, 0]
        frequency = dataset[dynspec]['COORDINATES']['SPECTRAL'].attrs['AXIS_VALUE_WORLD']
        time_delta, *_ = decode_str(dataset[dynspec]['COORDINATES']['TIME'].attrs['INCREMENT'])
    
        subset = data_array[0:10000, :]
        start_frequency, end_frequency = frequency[0] / 1.e6, frequency[-1] / 1.e6
    
        frequency_axis = numpy.linspace(start_frequency, end_frequency, data_array.shape[1])
    
        averaging_window_in_samples = int(numpy.ceil(1 / time_delta))
        averaging_window_in_seconds = averaging_window_in_samples * time_delta
        sample_rate = int(averaging_window_in_samples * 3. / averaging_window_in_seconds)
        S4_60s_window_in_samples = int(60. / time_delta)
    
        filtered_data, flags, flux, bandpass = apply_bandpass(subset, frequency_axis,
                                                              freqaxis=1, timeaxis=0, target=metadata[dynspec]["TARGET"],
                                                              sample_rate=sample_rate,
                                                              # sample every 3 seconds
                                                              filter_length=600,  # window size 30 minutes
                                                              rfi_filter_length=averaging_window_in_samples // 2,
                                                              # window size in time to prevent flagging scintillation
                                                              flag_value=8, replace_value=1)
        return filtered_data, S4_60s_window_in_samples, averaging_window_in_samples
    
    
    class TestS4Generation(unittest.TestCase):
        @unittest.skipUnless(is_test_data_present(), 'missing test data')
        def test_reference_implementation_succeed(self):
            dataset = open_dataset(test_datasets[0])
            filtered_data, S4_60s_window_in_samples, averaging_window_in_samples = get_filtered_data_from_dataset(dataset)
            print(averaging_window_in_samples, S4_60s_window_in_samples, filtered_data.shape)
            for i in range(10):
                _ = getS4(filtered_data,
                      window_size=S4_60s_window_in_samples,  # 60 seconds
                      skip_sample_time=averaging_window_in_samples,
                      # create S4 every averaging time
                      axis=0, has_nan=False)
    
        @unittest.skipUnless(is_test_data_present(), 'missing test data')
        def test_numba_implementation_succeed(self):
            dataset = open_dataset(test_datasets[0])
            filtered_data, S4_60s_window_in_samples, averaging_window_in_samples = get_filtered_data_from_dataset(dataset)
            print(averaging_window_in_samples, S4_60s_window_in_samples, filtered_data.shape)
            for i in range(10):
                _ = getS4Numba(filtered_data,
                      window_size=S4_60s_window_in_samples,  # 60 seconds
                      skip_sample_time=averaging_window_in_samples,
                      # create S4 every averaging time
                      axis=0, has_nan=False)
    
    
    
    
    if __name__ == '__main__':
        unittest.main()