diff --git a/requirements.txt b/requirements.txt
index 6beb006c14cee3c3a4b845f1138324bf70b39168..d6929c0011d179e6f215668f6dde60fb64c8f938 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,6 @@
 h5py
 numpy
 astropy
-matplotlib
\ No newline at end of file
+matplotlib
+numba
+scipy
\ No newline at end of file
diff --git a/scintillation/Calibrationlib.py b/scintillation/Calibrationlib.py
index 0c9974047edc7d48046baa9b8b5f9aef776c2cae..1e55130fe161f551c064773acd5ce1d75dd07990 100644
--- a/scintillation/Calibrationlib.py
+++ b/scintillation/Calibrationlib.py
@@ -1,8 +1,9 @@
+import numpy
 import numpy as np
-from scipy.ndimage import median_filter,gaussian_filter1d
-from scipy.ndimage.filters import uniform_filter1d
-from scipy.interpolate import interp1d
 from astropy.convolution import interpolate_replace_nans
+from scipy.interpolate import interp1d
+from scipy.ndimage import median_filter
+from scipy.ndimage.filters import uniform_filter1d
 
 
 def model_flux(calibrator, frequency):
@@ -17,48 +18,49 @@ def model_flux(calibrator, frequency):
     '''
     parameters = []
 
-    Cal_dict = {'J0133-3629':[1.0440,-0.662,-0.225],
-                '3C48': [1.3253,-0.7553,-0.1914,0.0498],
-                'For A': [2.218,-0.661],
-                'ForA': [2.218,-0.661],
-                '3C123':[1.8017,-0.7884,-0.1035,-0.0248,0.0090],
-                'J0444-2809':[0.9710,-0.894,-0.118],
-                '3C138':[1.0088,-0.4981,-0.155,-0.010,0.022,],
-                'Pic A':[1.9380,-0.7470,-0.074],
-                'Tau A':[2.9516,-0.217,-0.047,-0.067],
-                'PicA':[1.9380,-0.7470,-0.074],
-                'TauA':[2.9516,-0.217,-0.047,-0.067],
-                '3C147':[1.4516,-0.6961,-0.201,0.064,-0.046,0.029],
-                '3C196':[1.2872,-0.8530,-0.153,-0.0200,0.0201],
-                'Hyd A':[1.7795,-0.9176,-0.084,-0.0139,0.030],
-                'Vir A':[2.4466,-0.8116,-0.048],
-                'HydA':[1.7795,-0.9176,-0.084,-0.0139,0.030],
-                'VirA':[2.4466,-0.8116,-0.048],
-                '3C286':[1.2481 ,-0.4507 ,-0.1798 ,0.0357 ],
-                '3C295':[1.4701,-0.7658,-0.2780,-0.0347,0.0399],
-                'Her A':[1.8298,-1.0247,-0.0951],
-                'HerA':[1.8298,-1.0247,-0.0951],
-                '3C353':[1.8627,-0.6938,-0.100,-0.032],
-                '3C380':[1.2320,-0.791,0.095,0.098,-0.18,-0.16],
-                'Cyg A':[3.3498,-1.0022,-0.225,0.023,0.043],
-                'CygA':[3.3498,-1.0022,-0.225,0.023,0.043],
-                '3C444':[3.3498,-1.0022,-0.22,0.023,0.043],
-                'Cas A':[3.3584,-0.7518,-0.035,-0.071],
-                'CasA':[3.3584,-0.7518,-0.035,-0.071]}
+    Cal_dict = {'J0133-3629': [1.0440, -0.662, -0.225],
+                '3C48': [1.3253, -0.7553, -0.1914, 0.0498],
+                'For A': [2.218, -0.661],
+                'ForA': [2.218, -0.661],
+                '3C123': [1.8017, -0.7884, -0.1035, -0.0248, 0.0090],
+                'J0444-2809': [0.9710, -0.894, -0.118],
+                '3C138': [1.0088, -0.4981, -0.155, -0.010, 0.022, ],
+                'Pic A': [1.9380, -0.7470, -0.074],
+                'Tau A': [2.9516, -0.217, -0.047, -0.067],
+                'PicA': [1.9380, -0.7470, -0.074],
+                'TauA': [2.9516, -0.217, -0.047, -0.067],
+                '3C147': [1.4516, -0.6961, -0.201, 0.064, -0.046, 0.029],
+                '3C196': [1.2872, -0.8530, -0.153, -0.0200, 0.0201],
+                'Hyd A': [1.7795, -0.9176, -0.084, -0.0139, 0.030],
+                'Vir A': [2.4466, -0.8116, -0.048],
+                'HydA': [1.7795, -0.9176, -0.084, -0.0139, 0.030],
+                'VirA': [2.4466, -0.8116, -0.048],
+                '3C286': [1.2481, -0.4507, -0.1798, 0.0357],
+                '3C295': [1.4701, -0.7658, -0.2780, -0.0347, 0.0399],
+                'Her A': [1.8298, -1.0247, -0.0951],
+                'HerA': [1.8298, -1.0247, -0.0951],
+                '3C353': [1.8627, -0.6938, -0.100, -0.032],
+                '3C380': [1.2320, -0.791, 0.095, 0.098, -0.18, -0.16],
+                'Cyg A': [3.3498, -1.0022, -0.225, 0.023, 0.043],
+                'CygA': [3.3498, -1.0022, -0.225, 0.023, 0.043],
+                '3C444': [3.3498, -1.0022, -0.22, 0.023, 0.043],
+                'Cas A': [3.3584, -0.7518, -0.035, -0.071],
+                'CasA': [3.3584, -0.7518, -0.035, -0.071]}
     if calibrator in Cal_dict.keys():
         parameters = Cal_dict[calibrator]
     else:
-        parameters = [1.,0.]
-        #raise ValueError(calibrator, "is not in the calibrators list")
-        
+        parameters = [1., 0.]
+        # raise ValueError(calibrator, "is not in the calibrators list")
+
     flux_model = 0
-    freqs = frequency*1.e-3# convert from MHz to GHz
-    for j,p in enumerate(parameters):
-        flux_model += p*np.log10(freqs)**j
-    flux_model = 10**flux_model # because at first the flux is in log10
+    freqs = frequency * 1.e-3  # convert from MHz to GHz
+    for j, p in enumerate(parameters):
+        flux_model += p * np.log10(freqs) ** j
+    flux_model = 10 ** flux_model  # because at first the flux is in log10
     return flux_model
 
-def filter_rfi(data,nsigma=5,ntpts=50,nchans=10):
+
+def filter_rfi(data, nsigma=5, ntpts=50, nchans=10):
     '''
     RFI mitigation strategy for dynamic spectra:
     - median filter the data, in frequency dimension only
@@ -73,28 +75,28 @@ def filter_rfi(data,nsigma=5,ntpts=50,nchans=10):
     Output: An array with flags
     '''
 
+    ntimepts, nfreqs = data.shape
 
-    ntimepts,nfreqs = data.shape
-
-    #flatten = median_filter(data,(ntpts,1),mode='constant',cval=1)
-    #faster to median filter every Nsamples and interpolate
-    cutoff = ntimepts - ntimepts%ntpts
-    medfilter = np.median(data[:cutoff].reshape((-1,ntpts,nfreqs)),axis=1)
+    # flatten = median_filter(data,(ntpts,1),mode='constant',cval=1)
+    # faster to median filter every Nsamples and interpolate
+    cutoff = ntimepts - ntimepts % ntpts
+    medfilter = np.median(data[:cutoff].reshape((-1, ntpts, nfreqs)), axis=1)
     xnew = np.arange(ntimepts)
-    x = xnew [:cutoff][::ntpts]
-    f = interp1d(x,medfilter,axis=0,fill_value='extrapolate')  #interpolate to all skipped samples
-    flatten =  f(xnew)
-    flatten = median_filter(flatten,(1,nchans),mode='constant',cval=1)
-    flatdata = data/flatten
-    nanmed =  np.nanmedian(flatdata)
+    x = xnew[:cutoff][::ntpts]
+    f = interp1d(x, medfilter, axis=0, fill_value='extrapolate')  # interpolate to all skipped samples
+    flatten = f(xnew)
+    flatten = median_filter(flatten, (1, nchans), mode='constant', cval=1)
+    flatdata = data / flatten
+    nanmed = np.nanmedian(flatdata)
     diff = abs(flatdata - nanmed)
     sd = np.nanmedian(diff)  # Median absolute deviation
-    maskdata = np.where(diff>sd*nsigma,1,0)
-        
+    maskdata = np.where(diff > sd * nsigma, 1, 0)
+
     return maskdata
 
 
-def apply_bandpass(data, freqs, freqaxis=0, timeaxis=1, target = "Cas A", sample_rate= 180, filter_length = 300, rfi_filter_length =10, flag_value=10, replace_value=1):
+def apply_bandpass(data, freqs, freqaxis=0, timeaxis=1, target="Cas A", sample_rate=180, filter_length=300,
+                   rfi_filter_length=10, flag_value=10, replace_value=1):
     '''apply a bandpass correction to the data, scaling the data to the nominal flux of the calibrator, also flag outliers and replace flagged data with replace_value
     Input:
     data: the numpy array with the data
@@ -113,33 +115,35 @@ def apply_bandpass(data, freqs, freqaxis=0, timeaxis=1, target = "Cas A", sample
     array of flux values
     
     '''
-    
+
     flux = model_flux(target, freqs)
-    data = np.swapaxes(data,0,timeaxis)
+    data = np.swapaxes(data, 0, timeaxis)
 
-    #medfilter = gaussian_filter1d(data[::sample_rate], axis=0 , sigma=filter_length) #or use median_filter?
-    medfilter = median_filter(data[::sample_rate], size=(filter_length,)+(1,)*(len(data.shape)-1)) #or use gaussian_filter?
+    # medfilter = gaussian_filter1d(data[::sample_rate], axis=0 , sigma=filter_length) #or use median_filter?
+    medfilter = median_filter(data[::sample_rate],
+                              size=(filter_length,) + (1,) * (len(data.shape) - 1))  # or use gaussian_filter?
     xnew = np.arange(data.shape[0])
-    x = xnew [::sample_rate]
-    f = interp1d(x,medfilter,axis=0,fill_value='extrapolate')  #interpolate to all skipped samples
+    x = xnew[::sample_rate]
+    f = interp1d(x, medfilter, axis=0, fill_value='extrapolate')  # interpolate to all skipped samples
     bandpass = f(xnew)
-    newshape = [1,]*len(data.shape)
+    newshape = [1, ] * len(data.shape)
     if freqaxis == 0:
         newshape[timeaxis] = flux.shape[0]
     else:
         newshape[freqaxis] = flux.shape[0]
-    bandpass[bandpass==0]=1
-    data/=bandpass
-    flags = filter_rfi(data,nsigma=flag_value,ntpts=rfi_filter_length,nchans= 10)
+    bandpass[bandpass == 0] = 1
+    data /= bandpass
+    flags = filter_rfi(data, nsigma=flag_value, ntpts=rfi_filter_length, nchans=10)
     tmpdata = np.copy(data)
-    if replace_value=="interpolate":
-            tmpdata [ flags>0] = np.nan
-            tmpdata = interpolate_replace_nans(tmpdata,np.ones((21,21)))
+    if replace_value == "interpolate":
+        tmpdata[flags > 0] = np.nan
+        tmpdata = interpolate_replace_nans(tmpdata, np.ones((21, 21)))
     else:
-        
-       tmpdata [flags>0] = replace_value
 
-    return np.swapaxes(tmpdata,0,timeaxis),np.swapaxes(flags,0,timeaxis),flux,np.swapaxes(bandpass,0,timeaxis)
+        tmpdata[flags > 0] = replace_value
+
+    return np.swapaxes(tmpdata, 0, timeaxis), np.swapaxes(flags, 0, timeaxis), flux, np.swapaxes(bandpass, 0, timeaxis)
+
 
 def getS4_fast(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
     '''Calculate S4 value for data along axis, Could be fast if it works with numba
@@ -154,35 +158,38 @@ def getS4_fast(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
     new times axis array
     '''
     # make sure axis to sample = 0-th axis:
-    tmpdata = np.swapaxes(data,0,axis)
+    tmpdata = np.swapaxes(data, 0, axis)
     ntimes = tmpdata.shape[0]
     indices = np.arange(window_size)
 
-    idx_step = np.arange(0,ntimes-window_size,skip_sample_time)
-    
-    idx = idx_step.reshape((-1,1)) + indices.reshape((1,-1))
-    S4=np.zeros((idx.shape[0],)+tmpdata.shape[1:],dtype=tmpdata.dtype)
+    idx_step = np.arange(0, ntimes - window_size, skip_sample_time)
+
+    idx = idx_step.reshape((-1, 1)) + indices.reshape((1, -1))
+    S4 = np.zeros((idx.shape[0],) + tmpdata.shape[1:], dtype=tmpdata.dtype)
     if has_nan:
         for i in range(idx.shape[0]):
             for j in range(window_size):
-                avgsqdata = np.sum(tmpdata[idx[i]]**2,axis=0)/(window_size - np.sum(np.isnan(tmpdata[idx[i]]),axis=0))
-                avgdatasq = (np.sum(tmpdata[idx[i]],axis=0)/(window_size - np.sum(np.isnan(tmpdata[idx[i]]),axis=0)))**2
-            S4[i] = np.sqrt((avgsqdata-avgdatasq)/avgdatasq)
+                avgsqdata = np.sum(tmpdata[idx[i]] ** 2, axis=0) / (
+                        window_size - np.sum(np.isnan(tmpdata[idx[i]]), axis=0))
+                avgdatasq = (np.sum(tmpdata[idx[i]], axis=0) / (
+                        window_size - np.sum(np.isnan(tmpdata[idx[i]]), axis=0))) ** 2
+            S4[i] = np.sqrt((avgsqdata - avgdatasq) / avgdatasq)
     else:
         for i in range(idx.shape[0]):
-            avgsqdata = np.sum(tmpdata[idx[i]]**2,axis=0)/window_size
-            avgdatasq = np.sum(tmpdata[idx[i]],axis=0)**2/window_size**2
-            S4[i] = np.sqrt((avgsqdata-avgdatasq)/avgdatasq)
-    
-    return np.swapaxes(S4,0,axis)
+            avgsqdata = np.sum(tmpdata[idx[i]] ** 2, axis=0) / window_size
+            avgdatasq = np.sum(tmpdata[idx[i]], axis=0) ** 2 / window_size ** 2
+            S4[i] = np.sqrt((avgsqdata - avgdatasq) / avgdatasq)
+
+    return np.swapaxes(S4, 0, axis)
 
 
 def window_stdv_mean(arr, window):
-    #Cool trick: you can compute the standard deviation 
-    #given just the sum of squared values and the sum of values in the window.   
-    c1 = uniform_filter1d(arr, window, mode='constant',axis=0)
-    c2 = uniform_filter1d(arr*arr, window, mode='constant',axis=0)
-    return (np.abs(c2 - c1*c1)**.5)[window//2:-window//2+1],c1[window//2:-window//2+1]
+    # Cool trick: you can compute the standard deviation
+    # given just the sum of squared values and the sum of values in the window.
+    c1 = uniform_filter1d(arr, window, mode='constant', axis=0)
+    c2 = uniform_filter1d(arr * arr, window, mode='constant', axis=0)
+    return (np.abs(c2 - c1 * c1) ** .5)[window // 2:-window // 2 + 1], c1[window // 2:-window // 2 + 1]
+
 
 def getS4(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
     '''Calculate S4 value for data along axis
@@ -197,10 +204,95 @@ def getS4(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
     new times axis array
     '''
     # make sure axis to sample = 0-th axis:
-    tmpdata = np.swapaxes(data,0,axis)
-    stddata,avgdata = window_stdv_mean(tmpdata,window_size)
-    S4 = stddata[::skip_sample_time]/avgdata[::skip_sample_time]
-    return np.swapaxes(S4,0,axis)
+    tmpdata = np.swapaxes(data, 0, axis)
+    stddata, avgdata = window_stdv_mean(tmpdata, window_size)
+    S4 = stddata[::skip_sample_time] / avgdata[::skip_sample_time]
+    return np.swapaxes(S4, 0, axis)
+
+
+from numba import guvectorize
+
+
+@guvectorize(['void(float32[:], intp[:], float32[:])'], '(n),() -> (n)')
+def move_mean(a, window_arr, out):
+    window_width = window_arr[0]
+    asum = 0.0
+    count = 0
+    for i in range(window_width):
+        asum += a[i]
+        count += 1
+        out[i] = asum / count
+    for i in range(window_width, len(a)):
+        asum += a[i] - a[i - window_width]
+        out[i] = asum / count
+
+
+@guvectorize(['void(float32[:], intp[:], float32[:])'], '(n),() -> (n)')
+def move_mean_sq(a, window_arr, out):
+    window_width = window_arr[0]
+    asum = 0.0
+    count = 0
+    for i in range(window_width):
+        asum += numpy.power(a[i], 2)
+        count += 1
+        out[i] = asum / count
+    for i in range(window_width, len(a)):
+        asum += numpy.power(a[i], 2) - numpy.power(a[i - window_width], 2)
+        out[i] = asum / count
+
+
+def window_stdv_mean_numba(arr: numpy.ndarray, window):
+    arr = numpy.swapaxes(arr, 0, 1)
+    start_index = window // 2
+    mean_arr = move_mean(arr, window)
+    std_arr = numpy.sqrt(numpy.abs(move_mean_sq(arr, window) - numpy.power(mean_arr, 2)))
+
+    mean_arr = np.swapaxes(mean_arr, 0, 1)
+    std_arr = np.swapaxes(std_arr, 0, 1)
+
+    return std_arr[start_index: -start_index + 1], mean_arr[start_index: -start_index + 1]
+
+
+def getS4Numba(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
+    '''Calculate S4 value for data along axis
+    Input:
+    data: numpy array with data (maximum resolution but RFI flagged)
+    window_size: int: the window to calculate S4, typically 60s and 180s
+    skip_sample_time: int: only calculate S4 every skip_sample_time (in samples, typically 1s)
+    has_nan: boolean, if True use the much slower nanmean function to ignore nans
+    stepsize: int, size of step through the data to reduce memory usage
+    output:
+    numpy array with S4 values
+    new times axis array
+    '''
+    # make sure axis to sample = 0-th axis:
+    tmpdata = np.swapaxes(data, 0, axis)
+    stddata, avgdata = window_stdv_mean_numba(tmpdata, window_size)
+    S4 = stddata[::skip_sample_time] / avgdata[::skip_sample_time]
+    return np.swapaxes(S4, 0, axis)
+
+
+def getS4Naive(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
+    '''Calculate S4 value for data along axis
+    Input:
+    data: numpy array with data (maximum resolution but RFI flagged)
+    window_size: int: the window to calculate S4, typically 60s and 180s
+    skip_sample_time: int: only calculate S4 every skip_sample_time (in samples, typically 1s)
+    has_nan: boolean, if True use the much slower nanmean function to ignore nans
+    stepsize: int, size of step through the data to reduce memorey usage
+    output:
+    numpy array with S4 values
+    new times axis array
+    '''
+    # make sure axis to sample = 0-th axis:
+    tmpdata = np.swapaxes(data, 0, axis).astype(np.float32).copy()
+    ntimes, n_freq = tmpdata.shape
+    mean_arr = np.zeros((ntimes, n_freq), dtype=np.float32)
+    std_arr = np.zeros((ntimes, n_freq), dtype=np.float32)
+    naive_running_std(tmpdata, window_size, mean_arr, std_arr)
+    S4 = std_arr[::skip_sample_time] / mean_arr[::skip_sample_time]
+    return np.swapaxes(S4, 0, axis)
+
 
 def getS4_slow(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
     '''Calculate S4 value for data along axis
@@ -215,18 +307,19 @@ def getS4_slow(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
     new times axis array
     '''
     # make sure axis to sample = 0-th axis:
-    tmpdata = np.swapaxes(data,0,axis)
+    tmpdata = np.swapaxes(data, 0, axis)
     ntimes = tmpdata.shape[0]
-    slides = sliding_window_view(tmpdata,window_size,axis=0)[::skip_sample_time]
+    slides = sliding_window_view(tmpdata, window_size, axis=0)[::skip_sample_time]
     if has_nan:
-        avgsqdata = np.nanmean(slides**2,axis=-1)
-        avgdatasq = np.nanmean(slides,axis=-1)**2
+        avgsqdata = np.nanmean(slides ** 2, axis=-1)
+        avgdatasq = np.nanmean(slides, axis=-1) ** 2
     else:
-        avgsqdata = np.mean(slides**2,axis=-1)
-        avgdatasq = np.mean(slides,axis=-1)**2
-    S4 = np.sqrt(np.abs(avgsqdata-avgdatasq)/avgdatasq)
-    return np.swapaxes(S4,0,axis)
- 
+        avgsqdata = np.mean(slides ** 2, axis=-1)
+        avgdatasq = np.mean(slides, axis=-1) ** 2
+    S4 = np.sqrt(np.abs(avgsqdata - avgdatasq) / avgdatasq)
+    return np.swapaxes(S4, 0, axis)
+
+
 def getS4_medium(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
     '''Calculate S4 value for data along axis
     Input:
@@ -240,51 +333,47 @@ def getS4_medium(data, window_size, skip_sample_time=65, axis=1, has_nan=False):
     new times axis array
     '''
     # make sure axis to sample = 0-th axis:
-    tmpdata = np.swapaxes(data,0,axis)
+    tmpdata = np.swapaxes(data, 0, axis)
     ntimes = tmpdata.shape[0]
     indices = np.arange(window_size)
 
-    idx_step = np.arange(0,ntimes-window_size,skip_sample_time)
-    
-    idx = idx_step[:,np.newaxis]+indices[np.newaxis]
-    
-    S4=[]
+    idx_step = np.arange(0, ntimes - window_size, skip_sample_time)
+
+    idx = idx_step[:, np.newaxis] + indices[np.newaxis]
+
+    S4 = []
     if has_nan:
         for i in range(idx.shape[0]):
-            avgsqdata = np.nanmean(tmpdata[idx[i]]**2,axis=0)
-            avgdatasq = np.nanmean(tmpdata[idx[i]],axis=0)**2
-            S4.append(np.sqrt(np.abs(avgsqdata-avgdatasq)/avgdatasq))
+            avgsqdata = np.nanmean(tmpdata[idx[i]] ** 2, axis=0)
+            avgdatasq = np.nanmean(tmpdata[idx[i]], axis=0) ** 2
+            S4.append(np.sqrt(np.abs(avgsqdata - avgdatasq) / avgdatasq))
     else:
         for i in range(idx.shape[0]):
-            avgsqdata = np.mean(tmpdata[idx[i]]**2,axis=0)
-            avgdatasq = np.mean(tmpdata[idx[i]],axis=0)**2
-            S4.append(np.sqrt(np.abs(avgsqdata-avgdatasq)/avgdatasq))
-    
+            avgsqdata = np.mean(tmpdata[idx[i]] ** 2, axis=0)
+            avgdatasq = np.mean(tmpdata[idx[i]], axis=0) ** 2
+            S4.append(np.sqrt(np.abs(avgsqdata - avgdatasq) / avgdatasq))
+
     S4 = np.array(S4)
-    return np.swapaxes(S4,0,axis)
+    return np.swapaxes(S4, 0, axis)
 
 
-def getMedian(data,Nsamples,axis=0, flags=None, bandpass = None, has_nan=False):
+def getMedian(data, Nsamples, axis=0, flags=None, bandpass=None, has_nan=False):
     '''average the data using median over Nsamples samples, ignore last samples <Nsamples. If has_nan, use the much slower nanmedian function'''
-    #flags is None or has same shape as data
-    #make sure average axis is first
-    tmpdata = np.swapaxes(data,0,axis)
-    cutoff = (tmpdata.shape[0]//Nsamples)*Nsamples
-    
-    tmpdata = tmpdata[:cutoff].reshape((-1,Nsamples)+tmpdata.shape[1:])
+    # flags is None or has same shape as data
+    # make sure average axis is first
+    tmpdata = np.swapaxes(data, 0, axis)
+    cutoff = (tmpdata.shape[0] // Nsamples) * Nsamples
+
+    tmpdata = tmpdata[:cutoff].reshape((-1, Nsamples) + tmpdata.shape[1:])
     if has_nan:
-        avgdata = np.nanmedian(tmpdata,axis=1)
+        avgdata = np.nanmedian(tmpdata, axis=1)
     else:
-        avgdata = np.median(tmpdata,axis=1)
+        avgdata = np.median(tmpdata, axis=1)
     if not flags is None:
-        flags = np.swapaxes(flags,0,axis).astype(float)
-        flags = np.average(flags[:cutoff].reshape((-1,Nsamples)+avgdata.shape[1:]),axis=1)
+        flags = np.swapaxes(flags, 0, axis).astype(float)
+        flags = np.average(flags[:cutoff].reshape((-1, Nsamples) + avgdata.shape[1:]), axis=1)
     if not bandpass is None:
-        bandpass = np.swapaxes(bandpass,0,axis).astype(float)
-        bandpass = np.average(bandpass[:cutoff].reshape((-1,Nsamples)+avgdata.shape[1:]),axis=1)
-        
-    return np.swapaxes(avgdata,0,axis),np.swapaxes(flags,0,axis),np.swapaxes(bandpass,0,axis)  #swap back
-        
-
-
+        bandpass = np.swapaxes(bandpass, 0, axis).astype(float)
+        bandpass = np.average(bandpass[:cutoff].reshape((-1, Nsamples) + avgdata.shape[1:]), axis=1)
 
+    return np.swapaxes(avgdata, 0, axis), np.swapaxes(flags, 0, axis), np.swapaxes(bandpass, 0, axis)  # swap back
diff --git a/tests/test_s4_computation.py b/tests/test_s4_computation.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3e8fc2bf0d9adb81f6d88bc8fa609d3c4d2144f
--- /dev/null
+++ b/tests/test_s4_computation.py
@@ -0,0 +1,75 @@
+import os
+import unittest
+from glob import glob
+
+import numpy
+import numpy.random
+from scintillation.Calibrationlib import apply_bandpass, getS4, getS4Numba
+from scintillation.averaging import open_dataset, extract_metadata, decode_str
+
+basepath = os.path.dirname(__file__)
+test_datasets = glob(os.path.join(basepath, 'data', '*.h5'))
+
+
+def is_test_data_present():
+    return len(test_datasets) > 0
+
+
+def get_filtered_data_from_dataset(dataset):
+    metadata = extract_metadata(dataset)
+    dynspec, *_ = metadata.keys()
+    data_array = dataset[dynspec]['DATA'][:, :, 0]
+    frequency = dataset[dynspec]['COORDINATES']['SPECTRAL'].attrs['AXIS_VALUE_WORLD']
+    time_delta, *_ = decode_str(dataset[dynspec]['COORDINATES']['TIME'].attrs['INCREMENT'])
+
+    subset = data_array[0:10000, :]
+    start_frequency, end_frequency = frequency[0] / 1.e6, frequency[-1] / 1.e6
+
+    frequency_axis = numpy.linspace(start_frequency, end_frequency, data_array.shape[1])
+
+    averaging_window_in_samples = int(numpy.ceil(1 / time_delta))
+    averaging_window_in_seconds = averaging_window_in_samples * time_delta
+    sample_rate = int(averaging_window_in_samples * 3. / averaging_window_in_seconds)
+    S4_60s_window_in_samples = int(60. / time_delta)
+
+    filtered_data, flags, flux, bandpass = apply_bandpass(subset, frequency_axis,
+                                                          freqaxis=1, timeaxis=0, target=metadata[dynspec]["TARGET"],
+                                                          sample_rate=sample_rate,
+                                                          # sample every 3 seconds
+                                                          filter_length=600,  # window size 30 minutes
+                                                          rfi_filter_length=averaging_window_in_samples // 2,
+                                                          # window size in time to prevent flagging scintillation
+                                                          flag_value=8, replace_value=1)
+    return filtered_data, S4_60s_window_in_samples, averaging_window_in_samples
+
+
+class TestS4Generation(unittest.TestCase):
+    @unittest.skipUnless(is_test_data_present(), 'missing test data')
+    def test_reference_implementation_succeed(self):
+        dataset = open_dataset(test_datasets[0])
+        filtered_data, S4_60s_window_in_samples, averaging_window_in_samples = get_filtered_data_from_dataset(dataset)
+        print(averaging_window_in_samples, S4_60s_window_in_samples, filtered_data.shape)
+        for i in range(10):
+            _ = getS4(filtered_data,
+                  window_size=S4_60s_window_in_samples,  # 60 seconds
+                  skip_sample_time=averaging_window_in_samples,
+                  # create S4 every averaging time
+                  axis=0, has_nan=False)
+
+    @unittest.skipUnless(is_test_data_present(), 'missing test data')
+    def test_numba_implementation_succeed(self):
+        dataset = open_dataset(test_datasets[0])
+        filtered_data, S4_60s_window_in_samples, averaging_window_in_samples = get_filtered_data_from_dataset(dataset)
+        print(averaging_window_in_samples, S4_60s_window_in_samples, filtered_data.shape)
+        for i in range(10):
+            _ = getS4Numba(filtered_data,
+                  window_size=S4_60s_window_in_samples,  # 60 seconds
+                  skip_sample_time=averaging_window_in_samples,
+                  # create S4 every averaging time
+                  axis=0, has_nan=False)
+
+
+
+
+if __name__ == '__main__':
+    unittest.main()