Skip to content
Snippets Groups Projects
Commit 10bfd447 authored by Maaijke Mevius's avatar Maaijke Mevius
Browse files

Merge branch 'numba_implementation' into 'main'

Try numba implementation

See merge request !5
parents 20b37162 09f0e0ef
No related branches found
No related tags found
1 merge request!5Try numba implementation
Pipeline #39049 failed
......@@ -2,3 +2,5 @@ h5py
numpy
astropy
matplotlib
numba
scipy
\ No newline at end of file
import numpy
import numpy as np
from scipy.ndimage import median_filter,gaussian_filter1d
from scipy.ndimage.filters import uniform_filter1d
from scipy.interpolate import interp1d
from astropy.convolution import interpolate_replace_nans
from scipy.interpolate import interp1d
from scipy.ndimage import median_filter
from scipy.ndimage.filters import uniform_filter1d
def model_flux(calibrator, frequency):
......@@ -58,6 +59,7 @@ def model_flux(calibrator, frequency):
flux_model = 10 ** flux_model # because at first the flux is in log10
return flux_model
def filter_rfi(data, nsigma=5, ntpts=50, nchans=10):
'''
RFI mitigation strategy for dynamic spectra:
......@@ -73,7 +75,6 @@ def filter_rfi(data,nsigma=5,ntpts=50,nchans=10):
Output: An array with flags
'''
ntimepts, nfreqs = data.shape
# flatten = median_filter(data,(ntpts,1),mode='constant',cval=1)
......@@ -94,7 +95,8 @@ def filter_rfi(data,nsigma=5,ntpts=50,nchans=10):
return maskdata
def apply_bandpass(data, freqs, freqaxis=0, timeaxis=1, target = "Cas A", sample_rate= 180, filter_length = 300, rfi_filter_length =10, flag_value=10, replace_value=1):
def apply_bandpass(data, freqs, freqaxis=0, timeaxis=1, target="Cas A", sample_rate=180, filter_length=300,
rfi_filter_length=10, flag_value=10, replace_value=1):
'''apply a bandpass correction to the data, scaling the data to the nominal flux of the calibrator, also flag outliers and replace flagged data with replace_value
Input:
data: the numpy array with the data
......@@ -118,7 +120,8 @@ def apply_bandpass(data, freqs, freqaxis=0, timeaxis=1, target = "Cas A", sample
data = np.swapaxes(data, 0, timeaxis)
# medfilter = gaussian_filter1d(data[::sample_rate], axis=0 , sigma=filter_length) #or use median_filter?
medfilter = median_filter(data[::sample_rate], size=(filter_length,)+(1,)*(len(data.shape)-1)) #or use gaussian_filter?
medfilter = median_filter(data[::sample_rate],
size=(filter_length,) + (1,) * (len(data.shape) - 1)) # or use gaussian_filter?
xnew = np.arange(data.shape[0])
x = xnew[::sample_rate]
f = interp1d(x, medfilter, axis=0, fill_value='extrapolate') # interpolate to all skipped samples
......@@ -141,6 +144,7 @@ def apply_bandpass(data, freqs, freqaxis=0, timeaxis=1, target = "Cas A", sample
return np.swapaxes(tmpdata, 0, timeaxis), np.swapaxes(flags, 0, timeaxis), flux, np.swapaxes(bandpass, 0, timeaxis)
def getS4_fast(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
'''Calculate S4 value for data along axis, Could be fast if it works with numba
Input:
......@@ -165,8 +169,10 @@ def getS4_fast(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
if has_nan:
for i in range(idx.shape[0]):
for j in range(window_size):
avgsqdata = np.sum(tmpdata[idx[i]]**2,axis=0)/(window_size - np.sum(np.isnan(tmpdata[idx[i]]),axis=0))
avgdatasq = (np.sum(tmpdata[idx[i]],axis=0)/(window_size - np.sum(np.isnan(tmpdata[idx[i]]),axis=0)))**2
avgsqdata = np.sum(tmpdata[idx[i]] ** 2, axis=0) / (
window_size - np.sum(np.isnan(tmpdata[idx[i]]), axis=0))
avgdatasq = (np.sum(tmpdata[idx[i]], axis=0) / (
window_size - np.sum(np.isnan(tmpdata[idx[i]]), axis=0))) ** 2
S4[i] = np.sqrt((avgsqdata - avgdatasq) / avgdatasq)
else:
for i in range(idx.shape[0]):
......@@ -184,6 +190,7 @@ def window_stdv_mean(arr, window):
c2 = uniform_filter1d(arr * arr, window, mode='constant', axis=0)
return (np.abs(c2 - c1 * c1) ** .5)[window // 2:-window // 2 + 1], c1[window // 2:-window // 2 + 1]
def getS4(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
'''Calculate S4 value for data along axis
Input:
......@@ -202,6 +209,91 @@ def getS4(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
S4 = stddata[::skip_sample_time] / avgdata[::skip_sample_time]
return np.swapaxes(S4, 0, axis)
from numba import guvectorize
@guvectorize(['void(float32[:], intp[:], float32[:])'], '(n),() -> (n)')
def move_mean(a, window_arr, out):
window_width = window_arr[0]
asum = 0.0
count = 0
for i in range(window_width):
asum += a[i]
count += 1
out[i] = asum / count
for i in range(window_width, len(a)):
asum += a[i] - a[i - window_width]
out[i] = asum / count
@guvectorize(['void(float32[:], intp[:], float32[:])'], '(n),() -> (n)')
def move_mean_sq(a, window_arr, out):
window_width = window_arr[0]
asum = 0.0
count = 0
for i in range(window_width):
asum += numpy.power(a[i], 2)
count += 1
out[i] = asum / count
for i in range(window_width, len(a)):
asum += numpy.power(a[i], 2) - numpy.power(a[i - window_width], 2)
out[i] = asum / count
def window_stdv_mean_numba(arr: numpy.ndarray, window):
arr = numpy.swapaxes(arr, 0, 1)
start_index = window // 2
mean_arr = move_mean(arr, window)
std_arr = numpy.sqrt(numpy.abs(move_mean_sq(arr, window) - numpy.power(mean_arr, 2)))
mean_arr = np.swapaxes(mean_arr, 0, 1)
std_arr = np.swapaxes(std_arr, 0, 1)
return std_arr[start_index: -start_index + 1], mean_arr[start_index: -start_index + 1]
def getS4Numba(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
'''Calculate S4 value for data along axis
Input:
data: numpy array with data (maximum resolution but RFI flagged)
window_size: int: the window to calculate S4, typically 60s and 180s
skip_sample_time: int: only calculate S4 every skip_sample_time (in samples, typically 1s)
has_nan: boolean, if True use the much slower nanmean function to ignore nans
stepsize: int, size of step through the data to reduce memory usage
output:
numpy array with S4 values
new times axis array
'''
# make sure axis to sample = 0-th axis:
tmpdata = np.swapaxes(data, 0, axis)
stddata, avgdata = window_stdv_mean_numba(tmpdata, window_size)
S4 = stddata[::skip_sample_time] / avgdata[::skip_sample_time]
return np.swapaxes(S4, 0, axis)
def getS4Naive(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
'''Calculate S4 value for data along axis
Input:
data: numpy array with data (maximum resolution but RFI flagged)
window_size: int: the window to calculate S4, typically 60s and 180s
skip_sample_time: int: only calculate S4 every skip_sample_time (in samples, typically 1s)
has_nan: boolean, if True use the much slower nanmean function to ignore nans
stepsize: int, size of step through the data to reduce memorey usage
output:
numpy array with S4 values
new times axis array
'''
# make sure axis to sample = 0-th axis:
tmpdata = np.swapaxes(data, 0, axis).astype(np.float32).copy()
ntimes, n_freq = tmpdata.shape
mean_arr = np.zeros((ntimes, n_freq), dtype=np.float32)
std_arr = np.zeros((ntimes, n_freq), dtype=np.float32)
naive_running_std(tmpdata, window_size, mean_arr, std_arr)
S4 = std_arr[::skip_sample_time] / mean_arr[::skip_sample_time]
return np.swapaxes(S4, 0, axis)
def getS4_slow(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
'''Calculate S4 value for data along axis
Input:
......@@ -227,6 +319,7 @@ def getS4_slow(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
S4 = np.sqrt(np.abs(avgsqdata - avgdatasq) / avgdatasq)
return np.swapaxes(S4, 0, axis)
def getS4_medium(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
'''Calculate S4 value for data along axis
Input:
......@@ -284,7 +377,3 @@ def getMedian(data,Nsamples,axis=0, flags=None, bandpass = None, has_nan=False):
bandpass = np.average(bandpass[:cutoff].reshape((-1, Nsamples) + avgdata.shape[1:]), axis=1)
return np.swapaxes(avgdata, 0, axis), np.swapaxes(flags, 0, axis), np.swapaxes(bandpass, 0, axis) # swap back
import os
import unittest
from glob import glob
import numpy
import numpy.random
from scintillation.Calibrationlib import apply_bandpass, getS4, getS4Numba
from scintillation.averaging import open_dataset, extract_metadata, decode_str
basepath = os.path.dirname(__file__)
test_datasets = glob(os.path.join(basepath, 'data', '*.h5'))
def is_test_data_present():
return len(test_datasets) > 0
def get_filtered_data_from_dataset(dataset):
metadata = extract_metadata(dataset)
dynspec, *_ = metadata.keys()
data_array = dataset[dynspec]['DATA'][:, :, 0]
frequency = dataset[dynspec]['COORDINATES']['SPECTRAL'].attrs['AXIS_VALUE_WORLD']
time_delta, *_ = decode_str(dataset[dynspec]['COORDINATES']['TIME'].attrs['INCREMENT'])
subset = data_array[0:10000, :]
start_frequency, end_frequency = frequency[0] / 1.e6, frequency[-1] / 1.e6
frequency_axis = numpy.linspace(start_frequency, end_frequency, data_array.shape[1])
averaging_window_in_samples = int(numpy.ceil(1 / time_delta))
averaging_window_in_seconds = averaging_window_in_samples * time_delta
sample_rate = int(averaging_window_in_samples * 3. / averaging_window_in_seconds)
S4_60s_window_in_samples = int(60. / time_delta)
filtered_data, flags, flux, bandpass = apply_bandpass(subset, frequency_axis,
freqaxis=1, timeaxis=0, target=metadata[dynspec]["TARGET"],
sample_rate=sample_rate,
# sample every 3 seconds
filter_length=600, # window size 30 minutes
rfi_filter_length=averaging_window_in_samples // 2,
# window size in time to prevent flagging scintillation
flag_value=8, replace_value=1)
return filtered_data, S4_60s_window_in_samples, averaging_window_in_samples
class TestS4Generation(unittest.TestCase):
@unittest.skipUnless(is_test_data_present(), 'missing test data')
def test_reference_implementation_succeed(self):
dataset = open_dataset(test_datasets[0])
filtered_data, S4_60s_window_in_samples, averaging_window_in_samples = get_filtered_data_from_dataset(dataset)
print(averaging_window_in_samples, S4_60s_window_in_samples, filtered_data.shape)
for i in range(10):
_ = getS4(filtered_data,
window_size=S4_60s_window_in_samples, # 60 seconds
skip_sample_time=averaging_window_in_samples,
# create S4 every averaging time
axis=0, has_nan=False)
@unittest.skipUnless(is_test_data_present(), 'missing test data')
def test_numba_implementation_succeed(self):
dataset = open_dataset(test_datasets[0])
filtered_data, S4_60s_window_in_samples, averaging_window_in_samples = get_filtered_data_from_dataset(dataset)
print(averaging_window_in_samples, S4_60s_window_in_samples, filtered_data.shape)
for i in range(10):
_ = getS4Numba(filtered_data,
window_size=S4_60s_window_in_samples, # 60 seconds
skip_sample_time=averaging_window_in_samples,
# create S4 every averaging time
axis=0, has_nan=False)
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment