Select Git revision
test_clk.py
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
test_s4_computation.py 3.46 KiB
import os
import unittest
from glob import glob
import numpy
import numpy.random
from scintillation.Calibrationlib import apply_bandpass, getS4, getS4Numba
from scintillation.averaging import open_dataset, extract_metadata, decode_str
basepath = os.path.dirname(__file__)
test_datasets = glob(os.path.join(basepath, 'data', '*.h5'))
def is_test_data_present():
return len(test_datasets) > 0
def get_filtered_data_from_dataset(dataset):
metadata = extract_metadata(dataset)
dynspec, *_ = metadata.keys()
data_array = dataset[dynspec]['DATA'][:, :, 0]
frequency = dataset[dynspec]['COORDINATES']['SPECTRAL'].attrs['AXIS_VALUE_WORLD']
time_delta, *_ = decode_str(dataset[dynspec]['COORDINATES']['TIME'].attrs['INCREMENT'])
subset = data_array[0:10000, :]
start_frequency, end_frequency = frequency[0] / 1.e6, frequency[-1] / 1.e6
frequency_axis = numpy.linspace(start_frequency, end_frequency, data_array.shape[1])
averaging_window_in_samples = int(numpy.ceil(1 / time_delta))
averaging_window_in_seconds = averaging_window_in_samples * time_delta
sample_rate = int(averaging_window_in_samples * 3. / averaging_window_in_seconds)
S4_60s_window_in_samples = int(60. / time_delta)
filtered_data, flags, flux, bandpass = apply_bandpass(subset, frequency_axis,
freqaxis=1, timeaxis=0, target=metadata[dynspec]["TARGET"],
sample_rate=sample_rate,
# sample every 3 seconds
filter_length=600, # window size 30 minutes
rfi_filter_length=averaging_window_in_samples // 2,
# window size in time to prevent flagging scintillation
flag_value=8, replace_value=1)
return filtered_data, S4_60s_window_in_samples, averaging_window_in_samples
class TestS4Generation(unittest.TestCase):
@unittest.skipUnless(is_test_data_present(), 'missing test data')
def test_reference_implementation_succeed(self):
dataset = open_dataset(test_datasets[0])
filtered_data, S4_60s_window_in_samples, averaging_window_in_samples = get_filtered_data_from_dataset(dataset)
print(averaging_window_in_samples, S4_60s_window_in_samples, filtered_data.shape)
for i in range(10):
_ = getS4(filtered_data,
window_size=S4_60s_window_in_samples, # 60 seconds
skip_sample_time=averaging_window_in_samples,
# create S4 every averaging time
axis=0, has_nan=False)
@unittest.skipUnless(is_test_data_present(), 'missing test data')
def test_numba_implementation_succeed(self):
dataset = open_dataset(test_datasets[0])
filtered_data, S4_60s_window_in_samples, averaging_window_in_samples = get_filtered_data_from_dataset(dataset)
print(averaging_window_in_samples, S4_60s_window_in_samples, filtered_data.shape)
for i in range(10):
_ = getS4Numba(filtered_data,
window_size=S4_60s_window_in_samples, # 60 seconds
skip_sample_time=averaging_window_in_samples,
# create S4 every averaging time
axis=0, has_nan=False)
if __name__ == '__main__':
unittest.main()