Skip to content
Snippets Groups Projects
Commit 10bfd447 authored by Maaijke Mevius's avatar Maaijke Mevius
Browse files

Merge branch 'numba_implementation' into 'main'

Try numba implementation

See merge request !5
parents 20b37162 09f0e0ef
Branches
No related tags found
1 merge request!5Try numba implementation
Pipeline #39049 failed
...@@ -2,3 +2,5 @@ h5py ...@@ -2,3 +2,5 @@ h5py
numpy numpy
astropy astropy
matplotlib matplotlib
numba
scipy
\ No newline at end of file
import numpy
import numpy as np import numpy as np
from scipy.ndimage import median_filter,gaussian_filter1d
from scipy.ndimage.filters import uniform_filter1d
from scipy.interpolate import interp1d
from astropy.convolution import interpolate_replace_nans from astropy.convolution import interpolate_replace_nans
from scipy.interpolate import interp1d
from scipy.ndimage import median_filter
from scipy.ndimage.filters import uniform_filter1d
def model_flux(calibrator, frequency): def model_flux(calibrator, frequency):
...@@ -58,6 +59,7 @@ def model_flux(calibrator, frequency): ...@@ -58,6 +59,7 @@ def model_flux(calibrator, frequency):
flux_model = 10 ** flux_model # because at first the flux is in log10 flux_model = 10 ** flux_model # because at first the flux is in log10
return flux_model return flux_model
def filter_rfi(data, nsigma=5, ntpts=50, nchans=10): def filter_rfi(data, nsigma=5, ntpts=50, nchans=10):
''' '''
RFI mitigation strategy for dynamic spectra: RFI mitigation strategy for dynamic spectra:
...@@ -73,7 +75,6 @@ def filter_rfi(data,nsigma=5,ntpts=50,nchans=10): ...@@ -73,7 +75,6 @@ def filter_rfi(data,nsigma=5,ntpts=50,nchans=10):
Output: An array with flags Output: An array with flags
''' '''
ntimepts, nfreqs = data.shape ntimepts, nfreqs = data.shape
# flatten = median_filter(data,(ntpts,1),mode='constant',cval=1) # flatten = median_filter(data,(ntpts,1),mode='constant',cval=1)
...@@ -94,7 +95,8 @@ def filter_rfi(data,nsigma=5,ntpts=50,nchans=10): ...@@ -94,7 +95,8 @@ def filter_rfi(data,nsigma=5,ntpts=50,nchans=10):
return maskdata return maskdata
def apply_bandpass(data, freqs, freqaxis=0, timeaxis=1, target = "Cas A", sample_rate= 180, filter_length = 300, rfi_filter_length =10, flag_value=10, replace_value=1): def apply_bandpass(data, freqs, freqaxis=0, timeaxis=1, target="Cas A", sample_rate=180, filter_length=300,
rfi_filter_length=10, flag_value=10, replace_value=1):
'''apply a bandpass correction to the data, scaling the data to the nominal flux of the calibrator, also flag outliers and replace flagged data with replace_value '''apply a bandpass correction to the data, scaling the data to the nominal flux of the calibrator, also flag outliers and replace flagged data with replace_value
Input: Input:
data: the numpy array with the data data: the numpy array with the data
...@@ -118,7 +120,8 @@ def apply_bandpass(data, freqs, freqaxis=0, timeaxis=1, target = "Cas A", sample ...@@ -118,7 +120,8 @@ def apply_bandpass(data, freqs, freqaxis=0, timeaxis=1, target = "Cas A", sample
data = np.swapaxes(data, 0, timeaxis) data = np.swapaxes(data, 0, timeaxis)
# medfilter = gaussian_filter1d(data[::sample_rate], axis=0 , sigma=filter_length) #or use median_filter? # medfilter = gaussian_filter1d(data[::sample_rate], axis=0 , sigma=filter_length) #or use median_filter?
medfilter = median_filter(data[::sample_rate], size=(filter_length,)+(1,)*(len(data.shape)-1)) #or use gaussian_filter? medfilter = median_filter(data[::sample_rate],
size=(filter_length,) + (1,) * (len(data.shape) - 1)) # or use gaussian_filter?
xnew = np.arange(data.shape[0]) xnew = np.arange(data.shape[0])
x = xnew[::sample_rate] x = xnew[::sample_rate]
f = interp1d(x, medfilter, axis=0, fill_value='extrapolate') # interpolate to all skipped samples f = interp1d(x, medfilter, axis=0, fill_value='extrapolate') # interpolate to all skipped samples
...@@ -141,6 +144,7 @@ def apply_bandpass(data, freqs, freqaxis=0, timeaxis=1, target = "Cas A", sample ...@@ -141,6 +144,7 @@ def apply_bandpass(data, freqs, freqaxis=0, timeaxis=1, target = "Cas A", sample
return np.swapaxes(tmpdata, 0, timeaxis), np.swapaxes(flags, 0, timeaxis), flux, np.swapaxes(bandpass, 0, timeaxis) return np.swapaxes(tmpdata, 0, timeaxis), np.swapaxes(flags, 0, timeaxis), flux, np.swapaxes(bandpass, 0, timeaxis)
def getS4_fast(data, window_size, skip_sample_time=65, axis=1, has_nan=False): def getS4_fast(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
'''Calculate S4 value for data along axis, Could be fast if it works with numba '''Calculate S4 value for data along axis, Could be fast if it works with numba
Input: Input:
...@@ -165,8 +169,10 @@ def getS4_fast(data, window_size, skip_sample_time=65, axis=1, has_nan=False): ...@@ -165,8 +169,10 @@ def getS4_fast(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
if has_nan: if has_nan:
for i in range(idx.shape[0]): for i in range(idx.shape[0]):
for j in range(window_size): for j in range(window_size):
avgsqdata = np.sum(tmpdata[idx[i]]**2,axis=0)/(window_size - np.sum(np.isnan(tmpdata[idx[i]]),axis=0)) avgsqdata = np.sum(tmpdata[idx[i]] ** 2, axis=0) / (
avgdatasq = (np.sum(tmpdata[idx[i]],axis=0)/(window_size - np.sum(np.isnan(tmpdata[idx[i]]),axis=0)))**2 window_size - np.sum(np.isnan(tmpdata[idx[i]]), axis=0))
avgdatasq = (np.sum(tmpdata[idx[i]], axis=0) / (
window_size - np.sum(np.isnan(tmpdata[idx[i]]), axis=0))) ** 2
S4[i] = np.sqrt((avgsqdata - avgdatasq) / avgdatasq) S4[i] = np.sqrt((avgsqdata - avgdatasq) / avgdatasq)
else: else:
for i in range(idx.shape[0]): for i in range(idx.shape[0]):
...@@ -184,6 +190,7 @@ def window_stdv_mean(arr, window): ...@@ -184,6 +190,7 @@ def window_stdv_mean(arr, window):
c2 = uniform_filter1d(arr * arr, window, mode='constant', axis=0) c2 = uniform_filter1d(arr * arr, window, mode='constant', axis=0)
return (np.abs(c2 - c1 * c1) ** .5)[window // 2:-window // 2 + 1], c1[window // 2:-window // 2 + 1] return (np.abs(c2 - c1 * c1) ** .5)[window // 2:-window // 2 + 1], c1[window // 2:-window // 2 + 1]
def getS4(data, window_size, skip_sample_time=65, axis=1, has_nan=False): def getS4(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
'''Calculate S4 value for data along axis '''Calculate S4 value for data along axis
Input: Input:
...@@ -202,6 +209,91 @@ def getS4(data, window_size, skip_sample_time=65, axis=1, has_nan=False): ...@@ -202,6 +209,91 @@ def getS4(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
S4 = stddata[::skip_sample_time] / avgdata[::skip_sample_time] S4 = stddata[::skip_sample_time] / avgdata[::skip_sample_time]
return np.swapaxes(S4, 0, axis) return np.swapaxes(S4, 0, axis)
from numba import guvectorize
@guvectorize(['void(float32[:], intp[:], float32[:])'], '(n),() -> (n)')
def move_mean(a, window_arr, out):
window_width = window_arr[0]
asum = 0.0
count = 0
for i in range(window_width):
asum += a[i]
count += 1
out[i] = asum / count
for i in range(window_width, len(a)):
asum += a[i] - a[i - window_width]
out[i] = asum / count
@guvectorize(['void(float32[:], intp[:], float32[:])'], '(n),() -> (n)')
def move_mean_sq(a, window_arr, out):
window_width = window_arr[0]
asum = 0.0
count = 0
for i in range(window_width):
asum += numpy.power(a[i], 2)
count += 1
out[i] = asum / count
for i in range(window_width, len(a)):
asum += numpy.power(a[i], 2) - numpy.power(a[i - window_width], 2)
out[i] = asum / count
def window_stdv_mean_numba(arr: numpy.ndarray, window):
arr = numpy.swapaxes(arr, 0, 1)
start_index = window // 2
mean_arr = move_mean(arr, window)
std_arr = numpy.sqrt(numpy.abs(move_mean_sq(arr, window) - numpy.power(mean_arr, 2)))
mean_arr = np.swapaxes(mean_arr, 0, 1)
std_arr = np.swapaxes(std_arr, 0, 1)
return std_arr[start_index: -start_index + 1], mean_arr[start_index: -start_index + 1]
def getS4Numba(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
'''Calculate S4 value for data along axis
Input:
data: numpy array with data (maximum resolution but RFI flagged)
window_size: int: the window to calculate S4, typically 60s and 180s
skip_sample_time: int: only calculate S4 every skip_sample_time (in samples, typically 1s)
has_nan: boolean, if True use the much slower nanmean function to ignore nans
stepsize: int, size of step through the data to reduce memory usage
output:
numpy array with S4 values
new times axis array
'''
# make sure axis to sample = 0-th axis:
tmpdata = np.swapaxes(data, 0, axis)
stddata, avgdata = window_stdv_mean_numba(tmpdata, window_size)
S4 = stddata[::skip_sample_time] / avgdata[::skip_sample_time]
return np.swapaxes(S4, 0, axis)
def getS4Naive(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
'''Calculate S4 value for data along axis
Input:
data: numpy array with data (maximum resolution but RFI flagged)
window_size: int: the window to calculate S4, typically 60s and 180s
skip_sample_time: int: only calculate S4 every skip_sample_time (in samples, typically 1s)
has_nan: boolean, if True use the much slower nanmean function to ignore nans
stepsize: int, size of step through the data to reduce memorey usage
output:
numpy array with S4 values
new times axis array
'''
# make sure axis to sample = 0-th axis:
tmpdata = np.swapaxes(data, 0, axis).astype(np.float32).copy()
ntimes, n_freq = tmpdata.shape
mean_arr = np.zeros((ntimes, n_freq), dtype=np.float32)
std_arr = np.zeros((ntimes, n_freq), dtype=np.float32)
naive_running_std(tmpdata, window_size, mean_arr, std_arr)
S4 = std_arr[::skip_sample_time] / mean_arr[::skip_sample_time]
return np.swapaxes(S4, 0, axis)
def getS4_slow(data, window_size, skip_sample_time=65, axis=1, has_nan=False): def getS4_slow(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
'''Calculate S4 value for data along axis '''Calculate S4 value for data along axis
Input: Input:
...@@ -227,6 +319,7 @@ def getS4_slow(data, window_size, skip_sample_time=65, axis=1, has_nan=False): ...@@ -227,6 +319,7 @@ def getS4_slow(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
S4 = np.sqrt(np.abs(avgsqdata - avgdatasq) / avgdatasq) S4 = np.sqrt(np.abs(avgsqdata - avgdatasq) / avgdatasq)
return np.swapaxes(S4, 0, axis) return np.swapaxes(S4, 0, axis)
def getS4_medium(data, window_size, skip_sample_time=65, axis=1, has_nan=False): def getS4_medium(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
'''Calculate S4 value for data along axis '''Calculate S4 value for data along axis
Input: Input:
...@@ -284,7 +377,3 @@ def getMedian(data,Nsamples,axis=0, flags=None, bandpass = None, has_nan=False): ...@@ -284,7 +377,3 @@ def getMedian(data,Nsamples,axis=0, flags=None, bandpass = None, has_nan=False):
bandpass = np.average(bandpass[:cutoff].reshape((-1, Nsamples) + avgdata.shape[1:]), axis=1) bandpass = np.average(bandpass[:cutoff].reshape((-1, Nsamples) + avgdata.shape[1:]), axis=1)
return np.swapaxes(avgdata, 0, axis), np.swapaxes(flags, 0, axis), np.swapaxes(bandpass, 0, axis) # swap back return np.swapaxes(avgdata, 0, axis), np.swapaxes(flags, 0, axis), np.swapaxes(bandpass, 0, axis) # swap back
import os
import unittest
from glob import glob
import numpy
import numpy.random
from scintillation.Calibrationlib import apply_bandpass, getS4, getS4Numba
from scintillation.averaging import open_dataset, extract_metadata, decode_str
basepath = os.path.dirname(__file__)
test_datasets = glob(os.path.join(basepath, 'data', '*.h5'))
def is_test_data_present():
return len(test_datasets) > 0
def get_filtered_data_from_dataset(dataset):
metadata = extract_metadata(dataset)
dynspec, *_ = metadata.keys()
data_array = dataset[dynspec]['DATA'][:, :, 0]
frequency = dataset[dynspec]['COORDINATES']['SPECTRAL'].attrs['AXIS_VALUE_WORLD']
time_delta, *_ = decode_str(dataset[dynspec]['COORDINATES']['TIME'].attrs['INCREMENT'])
subset = data_array[0:10000, :]
start_frequency, end_frequency = frequency[0] / 1.e6, frequency[-1] / 1.e6
frequency_axis = numpy.linspace(start_frequency, end_frequency, data_array.shape[1])
averaging_window_in_samples = int(numpy.ceil(1 / time_delta))
averaging_window_in_seconds = averaging_window_in_samples * time_delta
sample_rate = int(averaging_window_in_samples * 3. / averaging_window_in_seconds)
S4_60s_window_in_samples = int(60. / time_delta)
filtered_data, flags, flux, bandpass = apply_bandpass(subset, frequency_axis,
freqaxis=1, timeaxis=0, target=metadata[dynspec]["TARGET"],
sample_rate=sample_rate,
# sample every 3 seconds
filter_length=600, # window size 30 minutes
rfi_filter_length=averaging_window_in_samples // 2,
# window size in time to prevent flagging scintillation
flag_value=8, replace_value=1)
return filtered_data, S4_60s_window_in_samples, averaging_window_in_samples
class TestS4Generation(unittest.TestCase):
@unittest.skipUnless(is_test_data_present(), 'missing test data')
def test_reference_implementation_succeed(self):
dataset = open_dataset(test_datasets[0])
filtered_data, S4_60s_window_in_samples, averaging_window_in_samples = get_filtered_data_from_dataset(dataset)
print(averaging_window_in_samples, S4_60s_window_in_samples, filtered_data.shape)
for i in range(10):
_ = getS4(filtered_data,
window_size=S4_60s_window_in_samples, # 60 seconds
skip_sample_time=averaging_window_in_samples,
# create S4 every averaging time
axis=0, has_nan=False)
@unittest.skipUnless(is_test_data_present(), 'missing test data')
def test_numba_implementation_succeed(self):
dataset = open_dataset(test_datasets[0])
filtered_data, S4_60s_window_in_samples, averaging_window_in_samples = get_filtered_data_from_dataset(dataset)
print(averaging_window_in_samples, S4_60s_window_in_samples, filtered_data.shape)
for i in range(10):
_ = getS4Numba(filtered_data,
window_size=S4_60s_window_in_samples, # 60 seconds
skip_sample_time=averaging_window_in_samples,
# create S4 every averaging time
axis=0, has_nan=False)
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment