Skip to content
Snippets Groups Projects
Commit b3e005c7 authored by Wouter Klijn's avatar Wouter Klijn
Browse files

Task #3139: Add of unittest for 70% of the code

parent 459cc15a
No related branches found
No related tags found
No related merge requests found
import numpy import numpy
import cmath
# Untested copu pasta of jon swinbanks code # Untested copu pasta of jon swinbanks code
class ComplexArray(object): class ComplexArray(object):
......
...@@ -9,17 +9,26 @@ import os ...@@ -9,17 +9,26 @@ import os
import shutil import shutil
import sys import sys
import tempfile import tempfile
import numpy
from lofarpipe.support.lofarnode import LOFARnodeTCP from lofarpipe.support.lofarnode import LOFARnodeTCP
from lofarpipe.support.pipelinelogging import CatchLog4CPlus from lofarpipe.support.pipelinelogging import CatchLog4CPlus
from lofarpipe.support.pipelinelogging import log_time from lofarpipe.support.pipelinelogging import log_time
from lofarpipe.support.utilities import read_initscript, create_directory from lofarpipe.support.utilities import read_initscript, create_directory
from lofarpipe.support.utilities import catch_segfaults from lofarpipe.support.utilities import catch_segfaults
from lofarpipe.support.lofarexceptions import PipelineRecipeFailed
from lofarpipe.recipes.helpers.WritableParmDB import WritableParmDB, list_stations
from lofarpipe.recipes.helpers.ComplexArray import ComplexArray, RealImagArray, AmplPhaseArray
class ParmExportCal(LOFARnodeTCP): class ParmExportCal(LOFARnodeTCP):
<<<<<<< .mine
def run(self, infile, outfile, executable, initscript, sigma):
=======
def run(self, infile, outfile, executable, initscript): def run(self, infile, outfile, executable, initscript):
>>>>>>> .r21050
# Time execution of this job # Time execution of this job
with log_time(self.logger): with log_time(self.logger):
if os.path.exists(infile): if os.path.exists(infile):
...@@ -60,7 +69,142 @@ class ParmExportCal(LOFARnodeTCP): ...@@ -60,7 +69,142 @@ class ParmExportCal(LOFARnodeTCP):
finally: finally:
shutil.rmtree(temp_dir) shutil.rmtree(temp_dir)
<<<<<<< .mine
#From here new parmdb implementation!!
self._filter_stations_parmdb(infile, outfile)
=======
return 1 #return 1 to allow rerunning of this script return 1 #return 1 to allow rerunning of this script
>>>>>>> .r21050
return 1
def _filter_stations_parmdb(self, infile, outfile):
# Create copy of the input file
# delete target location
shutil.rmtree(outfile)
self.logger.debug("cleared target path for filtered parmdb: \n {0}".format(
outfile))
# copy
shutil.copytree(infile, outfile)
self.logger.debug("Copied raw parmdb to target locations: \n {0}".format(
infile))
# Create a local WritableParmDB
parmdb = WritetableParmdb(outfile)
#get all stations in the parmdb
stations = list_stations(parmdb)
for station in stations:
self.logger.debug("Processing station {0}".format(station))
# till here implemented
polarization_data, type_pair = \
self._read_polarisation_data_and_type_from_db(parmdb, station)
corected_data = self._swap_outliers_with_median(polarization_data,
type_pair, sigma)
self._write_corrected_data(parmdb, station,
polarization_data, corected_data)
def _read_polarisation_data_and_type_from_db(self, parmdb, station):
all_matching_names = parmdb.getNames("Gain:*:*:*:{0}".format(station))
# get the polarisation_data eg: 1:1
# This is based on the 1 trough 3th entry in the parmdb name entry
pols = set(":".join(x[1:3]) for x in (x.split(":") for x in names))
# Get the im or re name, eg: real. Sort for we need a known order
type_pair = sorted([x[3] for x in (x.split(":") for x in names)])
#Check if the retrieved types are valid
sorted_valid_type_pairs = [sorted(RealImagArray.keys),
sorted(AmplPhaseArray.keys)]
if not type_pair in sorted_valid_type_pairs:
self.logger.error("The parsed parmdb contained an invalid array_type:")
self.logger.error("{0}".format(type_pair))
self.logger.error("valid data pairs are: {0}".format(
sorted_valid_type_pairs))
raise PipelineRecipeFailed(
"Invalid data type retrieved from parmdb: {0}".format(
type_pair))
polarisation_data = dict()
#for all polarisation_data in the parmdb (2 times 2)
for polarization in pols:
data = []
#for the two types
for key in type_pair:
query = "Gain:{0}:{1}:{2}".format(polarization, key, station)
#append the retrieved data (resulting in dict to arrays
data.append(parmdb.getValuesGrid(query)[query])
polarisation_data[polarization] = data
#return the raw data and the type of the data
return polarisation_data, type_pair
def _swap_outliers_with_median(self, polarization_data, type_pair, sigma):
corrected_polarization_data = dict()
for pol, data in polarization_data.iteritems():
# Convert the raw data to the correct complex array type
complex_array = self._convert_data_to_ComplexArray(data, type_pair)
# get the data as amplitude from the amplitude array, skip last entry
amplitudes = complex_array.amp[:-1]
# calculate the statistics
median = numpy.median(amplitudes)
stddev = numpy.std(amplitudes)
# Swap outliers with median version of the data
corrected = numpy.where(
numpy.abs(amplitudes - median) > sigma * stddev,
median,
amplitudes
)
# assign the corect data back to the complex_array
complex_array.amp = numpy.concatenate((corrected, complex_array.amp[-1:]))
# collect all corrected data
corrected_polarization_data[pol] = complex_array
return corrected_polarization_data
def _convert_data_to_ComplexArray(self, data, type_pair):
if sorted(type_pair) == sorted(RealImagArray.keys):
# The type_pair is in alphabetical order: Imag on index 0
complex_array = RealImagArray(data[1]["values"], data[0]["values"])
elif sorted(type_pair) == sorted(AmplPhaseArray.keys):
complex_array = AmplPhaseArray(data[0]["values"], data[1]["values"])
else:
self.logger.error("Incorrect data type pair provided: {0}".format(
type_pair))
raise PipelineRecipeFailed(
"Invalid data type retrieved from parmdb: {0}".format(type_pair))
return complex_array
def _write_corrected_data(self, parmdb, station, polarization_data,
corected_data):
for pol, data in polarization_data.iteritems():
if not pol in corected_data:
error_message = "Requested polarisation type is unknown:" \
"{0} \n valid polarisations: {1}".format(pol, corected_data.keys())
self.logger.error(error_message)
raise PipelineRecipeFailed(error_message)
corrected_data = corected_data[pol]
#get the "complex" converted data from the complex array
for component, value in corrected_data.writeable.iteritems():
#Collect all the data needed to write an array
name = "Gain:{0}:{1}:{2}".format(pol, component, station)
freqscale = data[0]['freqs'][0]
freqstep = data[0]['freqwidths'][0]
timescale = data[0]['times'][0]
timestep = data[0]['timewidths'][0]
#call the write function on the parmdb
parmdb.setValues(name, value, freqscale, freqstep, timescale,
timestep)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -4,11 +4,15 @@ import unittest ...@@ -4,11 +4,15 @@ import unittest
import tempfile import tempfile
import sys import sys
import shutil import shutil
import numpy
from argparse import ArgumentTypeError from argparse import ArgumentTypeError
from lofarpipe.support.utilities import create_directory #@UnresolvedImport from lofarpipe.support.utilities import create_directory #@UnresolvedImport
from lofarpipe.support.lofarexceptions import PipelineRecipeFailed
from lofarpipe.recipes.nodes.parmexportcal import ParmExportCal from lofarpipe.recipes.nodes.parmexportcal import ParmExportCal
from lofarpipe.recipes.helpers.ComplexArray import ComplexArray, RealImagArray, AmplPhaseArray
from lofarpipe.recipes.helpers.WritableParmDB import WritableParmDB
#import from fixtures: #import from fixtures:
from logger import logger from logger import logger
...@@ -16,10 +20,9 @@ class ParmExportCalWrapper(ParmExportCal): ...@@ -16,10 +20,9 @@ class ParmExportCalWrapper(ParmExportCal):
""" """
The test wrapper allows overwriting of function with muck functionality The test wrapper allows overwriting of function with muck functionality
""" """
def __init__(self, name): def __init__(self):
""" """
""" """
super(ParmExportCalWrapper, self).__init__(name)
self.logger = logger() self.logger = logger()
class ParmExportCalTest(unittest.TestCase): class ParmExportCalTest(unittest.TestCase):
...@@ -29,16 +32,156 @@ class ParmExportCalTest(unittest.TestCase): ...@@ -29,16 +32,156 @@ class ParmExportCalTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.tempDir = tempfile.mkdtemp() self.tempDir = tempfile.mkdtemp()
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tempDir) shutil.rmtree(self.tempDir)
def test_convert_data_to_ComplexArray_real_imag(self): def test_convert_data_to_ComplexArray_real_imag(self):
data = [[0], [1]] data = [{"values": [1]}, {"values": [1]}]
type_pair = ["Image", "Real"] # Order is alphabetical type_pair = ["Imag", "Real"] # Order is alphabetical
parmExportCal = ParmExportCalWrapper() parmExportCal = ParmExportCalWrapper()
complex_array = parmExportCal._convert_data_to_ComplexArray(data, type_pair) complex_array = parmExportCal._convert_data_to_ComplexArray(data, type_pair)
list_of_names = list_stations("test") goal_array = RealImagArray([1], [1])
goal_set = ["name1", "name2", "name3", "name4", "station1"] self.assertTrue(complex_array.real == goal_array.real)
self.assertTrue(list_of_names == goal_set, "{0} != {1}".format( self.assertTrue(complex_array.imag == goal_array.imag)
list_of_names, goal_set))
def test_convert_data_to_ComplexArray_amp_phase(self):
data = [{"values": [1]}, {"values": [1]}]
type_pair = ["Ampl", "Phase"] # Order is alphabetical
parmExportCal = ParmExportCalWrapper()
complex_array = parmExportCal._convert_data_to_ComplexArray(data, type_pair)
goal_array = AmplPhaseArray([1], [1])
self.assertTrue(complex_array.amp == goal_array.amp)
self.assertTrue(complex_array.phase == goal_array.phase)
def test_convert_data_to_ComplexArray_incorrect_pair(self):
data = [{"values": [1]}, {"values": [1]}]
type_pair = ["spam", "spam"] # Order is alphabetical
parmExportCal = ParmExportCalWrapper()
self.assertRaises(PipelineRecipeFailed,
parmExportCal._convert_data_to_ComplexArray,
data, type_pair)
def test_write_corrected_data(self):
# define input data
name = "test"
station = "station"
parmExportCal = ParmExportCalWrapper()
input_polarization_data = {"pol1":[{'freqs':[11],
'freqwidths':[12],
'times':[13],
'timewidths':[14]}]}
input_corected_data = {"pol1":RealImagArray([[1], [1]], [[2], [2]]),
"pol22":RealImagArray([[3], [3]], [[4], [4]])}
# This object will be taken from the fixture: it is a recorder muck
parmdb = WritableParmDB("parmdb")
# call function
parmExportCal = ParmExportCalWrapper()
parmExportCal._write_corrected_data(parmdb, station,
input_polarization_data, input_corected_data)
# test output: (the calls to parmdb)
# there is one polarization, containing a single complex array
# when writing this should result in, 1 times 2 function calls
# first delete the REAL entry
expected = ['deleteValues', ['Gain:pol1:Real:station']]
self.assertTrue(parmdb.called_functions_and_parameters[0] == expected,
"result({0}) != expected({1})".format(
parmdb.called_functions_and_parameters[0], expected))
# then the new values should be added, with the correct values
expected = ['addValues', ['Gain:pol1:Real:station',
numpy.array([[1.], [1.]],),
11, 11 + 12, 13, 13 + 2 * 14, False]] #stat + steps*size
# Now scan the argument array: for numpy use special compare function
for left, right in zip(parmdb.called_functions_and_parameters[1][1],
expected[1]):
error_message = "\nresult({0}) != \nexpected({1}) \n"\
"-> {2} != {3}".format(
parmdb.called_functions_and_parameters[1], expected,
left, right)
try:
if not left == right:
self.assertTrue(False, error_message)
except ValueError:
if not numpy.array_equal(left, right):
self.assertTrue(False, error_message)
# now delete the imag entry: Rememder these are on the 2nd and 3th array
# position
expected = ['deleteValues', ['Gain:pol1:Imag:station']]
self.assertTrue(parmdb.called_functions_and_parameters[2] == expected,
"result({0}) != expected({1})".format(
parmdb.called_functions_and_parameters[2], expected))
# then the new values should be added, with the correct values
expected = ['addValues', ['Gain:pol1:Imag:station',
numpy.array([[2.], [2.]],),
11, 11 + 12, 13, 13 + 2 * 14, False]] #stat + steps*size
# Now scan the argument array: for numpy use special compare function
for left, right in zip(parmdb.called_functions_and_parameters[3][1],
expected[1]):
error_message = "\nresult({0}) != \nexpected({1}) \n"\
"-> {2} != {3}".format(
parmdb.called_functions_and_parameters[3], expected,
left, right)
try:
if not left == right:
self.assertTrue(False, error_message)
except ValueError:
if not numpy.array_equal(left, right):
self.assertTrue(False, error_message)
def test_write_corrected_data_does_not_contain_pol(self):
name = "test"
station = "station"
parmExportCal = ParmExportCalWrapper()
input_polarization_data = {"unknownPolarisation":[{'freqs':[11],
'freqwidths':[12],
'times':[13],
'timewidths':[14]}]}
input_corected_data = {"pol1":RealImagArray([[1], [1]], [[2], [2]]),
"pol2":RealImagArray([[3], [3]], [[4], [4]])}
# This object will be taken from the fixture: it is a recorder muck
parmdb = WritableParmDB("parmdb")
# call function
parmExportCal = ParmExportCalWrapper()
self.assertRaises(PipelineRecipeFailed,
parmExportCal._write_corrected_data,
parmdb, station,
input_polarization_data, input_corected_data)
def test_swap_outliers_with_median(self):
data = {"pol1":[{"values": [1., 1., 1., 1., 100., 100.]},
{"values": [1., 1., 1., 1., 100., 100.]}]
}
type_pair = ["Imag", "Real"] # Order is alphabetical
# omit the last entry do swap the 5th entry with the median (1)
goal_filtered_array = numpy.array([1., 1., 1., 1., 1., 100.])
parmExportCal = ParmExportCalWrapper()
corrected_polarisation = \
parmExportCal._swap_outliers_with_median(data, type_pair, 2.0)
#incredibly rough and incorrect float comparison of the values in the
for left, right in zip(corrected_polarisation['pol1'].real, goal_filtered_array):
message = "Comparison of float values in the array did not" \
"result in about the same value: {0}"
if not int(left) == int(right):
self.assertTrue(False, message.format(
"int value not the same: "
"{0} != {1}".format(int(left), int(right))))
precision = 1000
if not int(left * precision) == int(right * precision):
self.assertTrue(False, message.format(
"value not the same within current precision: "
"{0} != {1}".format(int(left * precision), int(right * precision))))
import re
import copy
class parmdb(object):
"""
Much parmdb interface:
Allows basic checking of called function and parameters
"""
def __init__ (self, dbname, create = True, names = None):
self._basename = dbname
if not names == None:
self.names = names
else:
self.names = ["1:1:Real:name1",
"1:1:Real:name2",
"1:1:Real:name3",
"1:1:Real:name4",
"Gain:1:2:Real:station1"]
self.called_functions_and_parameters = []
def getNames(self, parmnamepattern = ''):
if parmnamepattern == '':
return self.names
#convert the pattern to regexp
listed_pattern = parmnamepattern.split("*")
parmnamepattern = ".*{0}.*".format(".*".join(listed_pattern))
#create regexp matcher!
prog = re.compile(parmnamepattern)
# only return regexp matches the pattern in the string
return [s for s in self.names if prog.match(s)]
def deleteValues(self, *args):
self.called_functions_and_parameters.append(['deleteValues',
[arg for arg in args]])
def addValues(self, *args):
self.called_functions_and_parameters.append(
['addValues', [arg for arg in args]])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment