Skip to content
Snippets Groups Projects
Commit b3e005c7 authored by Wouter Klijn's avatar Wouter Klijn
Browse files

Task #3139: Add of unittest for 70% of the code

parent 459cc15a
Branches
Tags
No related merge requests found
import numpy
import cmath
# Untested copu pasta of jon swinbanks code
class ComplexArray(object):
......
......@@ -9,17 +9,26 @@ import os
import shutil
import sys
import tempfile
import numpy
from lofarpipe.support.lofarnode import LOFARnodeTCP
from lofarpipe.support.pipelinelogging import CatchLog4CPlus
from lofarpipe.support.pipelinelogging import log_time
from lofarpipe.support.utilities import read_initscript, create_directory
from lofarpipe.support.utilities import catch_segfaults
from lofarpipe.support.lofarexceptions import PipelineRecipeFailed
from lofarpipe.recipes.helpers.WritableParmDB import WritableParmDB, list_stations
from lofarpipe.recipes.helpers.ComplexArray import ComplexArray, RealImagArray, AmplPhaseArray
class ParmExportCal(LOFARnodeTCP):
<<<<<<< .mine
def run(self, infile, outfile, executable, initscript, sigma):
=======
def run(self, infile, outfile, executable, initscript):
>>>>>>> .r21050
# Time execution of this job
with log_time(self.logger):
if os.path.exists(infile):
......@@ -60,7 +69,142 @@ class ParmExportCal(LOFARnodeTCP):
finally:
shutil.rmtree(temp_dir)
<<<<<<< .mine
#From here new parmdb implementation!!
self._filter_stations_parmdb(infile, outfile)
=======
return 1 #return 1 to allow rerunning of this script
>>>>>>> .r21050
return 1
def _filter_stations_parmdb(self, infile, outfile):
# Create copy of the input file
# delete target location
shutil.rmtree(outfile)
self.logger.debug("cleared target path for filtered parmdb: \n {0}".format(
outfile))
# copy
shutil.copytree(infile, outfile)
self.logger.debug("Copied raw parmdb to target locations: \n {0}".format(
infile))
# Create a local WritableParmDB
parmdb = WritetableParmdb(outfile)
#get all stations in the parmdb
stations = list_stations(parmdb)
for station in stations:
self.logger.debug("Processing station {0}".format(station))
# till here implemented
polarization_data, type_pair = \
self._read_polarisation_data_and_type_from_db(parmdb, station)
corected_data = self._swap_outliers_with_median(polarization_data,
type_pair, sigma)
self._write_corrected_data(parmdb, station,
polarization_data, corected_data)
def _read_polarisation_data_and_type_from_db(self, parmdb, station):
all_matching_names = parmdb.getNames("Gain:*:*:*:{0}".format(station))
# get the polarisation_data eg: 1:1
# This is based on the 1 trough 3th entry in the parmdb name entry
pols = set(":".join(x[1:3]) for x in (x.split(":") for x in names))
# Get the im or re name, eg: real. Sort for we need a known order
type_pair = sorted([x[3] for x in (x.split(":") for x in names)])
#Check if the retrieved types are valid
sorted_valid_type_pairs = [sorted(RealImagArray.keys),
sorted(AmplPhaseArray.keys)]
if not type_pair in sorted_valid_type_pairs:
self.logger.error("The parsed parmdb contained an invalid array_type:")
self.logger.error("{0}".format(type_pair))
self.logger.error("valid data pairs are: {0}".format(
sorted_valid_type_pairs))
raise PipelineRecipeFailed(
"Invalid data type retrieved from parmdb: {0}".format(
type_pair))
polarisation_data = dict()
#for all polarisation_data in the parmdb (2 times 2)
for polarization in pols:
data = []
#for the two types
for key in type_pair:
query = "Gain:{0}:{1}:{2}".format(polarization, key, station)
#append the retrieved data (resulting in dict to arrays
data.append(parmdb.getValuesGrid(query)[query])
polarisation_data[polarization] = data
#return the raw data and the type of the data
return polarisation_data, type_pair
def _swap_outliers_with_median(self, polarization_data, type_pair, sigma):
corrected_polarization_data = dict()
for pol, data in polarization_data.iteritems():
# Convert the raw data to the correct complex array type
complex_array = self._convert_data_to_ComplexArray(data, type_pair)
# get the data as amplitude from the amplitude array, skip last entry
amplitudes = complex_array.amp[:-1]
# calculate the statistics
median = numpy.median(amplitudes)
stddev = numpy.std(amplitudes)
# Swap outliers with median version of the data
corrected = numpy.where(
numpy.abs(amplitudes - median) > sigma * stddev,
median,
amplitudes
)
# assign the corect data back to the complex_array
complex_array.amp = numpy.concatenate((corrected, complex_array.amp[-1:]))
# collect all corrected data
corrected_polarization_data[pol] = complex_array
return corrected_polarization_data
def _convert_data_to_ComplexArray(self, data, type_pair):
if sorted(type_pair) == sorted(RealImagArray.keys):
# The type_pair is in alphabetical order: Imag on index 0
complex_array = RealImagArray(data[1]["values"], data[0]["values"])
elif sorted(type_pair) == sorted(AmplPhaseArray.keys):
complex_array = AmplPhaseArray(data[0]["values"], data[1]["values"])
else:
self.logger.error("Incorrect data type pair provided: {0}".format(
type_pair))
raise PipelineRecipeFailed(
"Invalid data type retrieved from parmdb: {0}".format(type_pair))
return complex_array
def _write_corrected_data(self, parmdb, station, polarization_data,
corected_data):
for pol, data in polarization_data.iteritems():
if not pol in corected_data:
error_message = "Requested polarisation type is unknown:" \
"{0} \n valid polarisations: {1}".format(pol, corected_data.keys())
self.logger.error(error_message)
raise PipelineRecipeFailed(error_message)
corrected_data = corected_data[pol]
#get the "complex" converted data from the complex array
for component, value in corrected_data.writeable.iteritems():
#Collect all the data needed to write an array
name = "Gain:{0}:{1}:{2}".format(pol, component, station)
freqscale = data[0]['freqs'][0]
freqstep = data[0]['freqwidths'][0]
timescale = data[0]['times'][0]
timestep = data[0]['timewidths'][0]
#call the write function on the parmdb
parmdb.setValues(name, value, freqscale, freqstep, timescale,
timestep)
if __name__ == "__main__":
......
......@@ -4,11 +4,15 @@ import unittest
import tempfile
import sys
import shutil
import numpy
from argparse import ArgumentTypeError
from lofarpipe.support.utilities import create_directory #@UnresolvedImport
from lofarpipe.support.lofarexceptions import PipelineRecipeFailed
from lofarpipe.recipes.nodes.parmexportcal import ParmExportCal
from lofarpipe.recipes.helpers.ComplexArray import ComplexArray, RealImagArray, AmplPhaseArray
from lofarpipe.recipes.helpers.WritableParmDB import WritableParmDB
#import from fixtures:
from logger import logger
......@@ -16,10 +20,9 @@ class ParmExportCalWrapper(ParmExportCal):
"""
The test wrapper allows overwriting of function with muck functionality
"""
def __init__(self, name):
def __init__(self):
"""
"""
super(ParmExportCalWrapper, self).__init__(name)
self.logger = logger()
class ParmExportCalTest(unittest.TestCase):
......@@ -29,16 +32,156 @@ class ParmExportCalTest(unittest.TestCase):
def setUp(self):
self.tempDir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tempDir)
def test_convert_data_to_ComplexArray_real_imag(self):
data = [[0], [1]]
type_pair = ["Image", "Real"] # Order is alphabetical
data = [{"values": [1]}, {"values": [1]}]
type_pair = ["Imag", "Real"] # Order is alphabetical
parmExportCal = ParmExportCalWrapper()
complex_array = parmExportCal._convert_data_to_ComplexArray(data, type_pair)
list_of_names = list_stations("test")
goal_set = ["name1", "name2", "name3", "name4", "station1"]
self.assertTrue(list_of_names == goal_set, "{0} != {1}".format(
list_of_names, goal_set))
goal_array = RealImagArray([1], [1])
self.assertTrue(complex_array.real == goal_array.real)
self.assertTrue(complex_array.imag == goal_array.imag)
def test_convert_data_to_ComplexArray_amp_phase(self):
data = [{"values": [1]}, {"values": [1]}]
type_pair = ["Ampl", "Phase"] # Order is alphabetical
parmExportCal = ParmExportCalWrapper()
complex_array = parmExportCal._convert_data_to_ComplexArray(data, type_pair)
goal_array = AmplPhaseArray([1], [1])
self.assertTrue(complex_array.amp == goal_array.amp)
self.assertTrue(complex_array.phase == goal_array.phase)
def test_convert_data_to_ComplexArray_incorrect_pair(self):
data = [{"values": [1]}, {"values": [1]}]
type_pair = ["spam", "spam"] # Order is alphabetical
parmExportCal = ParmExportCalWrapper()
self.assertRaises(PipelineRecipeFailed,
parmExportCal._convert_data_to_ComplexArray,
data, type_pair)
def test_write_corrected_data(self):
# define input data
name = "test"
station = "station"
parmExportCal = ParmExportCalWrapper()
input_polarization_data = {"pol1":[{'freqs':[11],
'freqwidths':[12],
'times':[13],
'timewidths':[14]}]}
input_corected_data = {"pol1":RealImagArray([[1], [1]], [[2], [2]]),
"pol22":RealImagArray([[3], [3]], [[4], [4]])}
# This object will be taken from the fixture: it is a recorder muck
parmdb = WritableParmDB("parmdb")
# call function
parmExportCal = ParmExportCalWrapper()
parmExportCal._write_corrected_data(parmdb, station,
input_polarization_data, input_corected_data)
# test output: (the calls to parmdb)
# there is one polarization, containing a single complex array
# when writing this should result in, 1 times 2 function calls
# first delete the REAL entry
expected = ['deleteValues', ['Gain:pol1:Real:station']]
self.assertTrue(parmdb.called_functions_and_parameters[0] == expected,
"result({0}) != expected({1})".format(
parmdb.called_functions_and_parameters[0], expected))
# then the new values should be added, with the correct values
expected = ['addValues', ['Gain:pol1:Real:station',
numpy.array([[1.], [1.]],),
11, 11 + 12, 13, 13 + 2 * 14, False]] #stat + steps*size
# Now scan the argument array: for numpy use special compare function
for left, right in zip(parmdb.called_functions_and_parameters[1][1],
expected[1]):
error_message = "\nresult({0}) != \nexpected({1}) \n"\
"-> {2} != {3}".format(
parmdb.called_functions_and_parameters[1], expected,
left, right)
try:
if not left == right:
self.assertTrue(False, error_message)
except ValueError:
if not numpy.array_equal(left, right):
self.assertTrue(False, error_message)
# now delete the imag entry: Rememder these are on the 2nd and 3th array
# position
expected = ['deleteValues', ['Gain:pol1:Imag:station']]
self.assertTrue(parmdb.called_functions_and_parameters[2] == expected,
"result({0}) != expected({1})".format(
parmdb.called_functions_and_parameters[2], expected))
# then the new values should be added, with the correct values
expected = ['addValues', ['Gain:pol1:Imag:station',
numpy.array([[2.], [2.]],),
11, 11 + 12, 13, 13 + 2 * 14, False]] #stat + steps*size
# Now scan the argument array: for numpy use special compare function
for left, right in zip(parmdb.called_functions_and_parameters[3][1],
expected[1]):
error_message = "\nresult({0}) != \nexpected({1}) \n"\
"-> {2} != {3}".format(
parmdb.called_functions_and_parameters[3], expected,
left, right)
try:
if not left == right:
self.assertTrue(False, error_message)
except ValueError:
if not numpy.array_equal(left, right):
self.assertTrue(False, error_message)
def test_write_corrected_data_does_not_contain_pol(self):
name = "test"
station = "station"
parmExportCal = ParmExportCalWrapper()
input_polarization_data = {"unknownPolarisation":[{'freqs':[11],
'freqwidths':[12],
'times':[13],
'timewidths':[14]}]}
input_corected_data = {"pol1":RealImagArray([[1], [1]], [[2], [2]]),
"pol2":RealImagArray([[3], [3]], [[4], [4]])}
# This object will be taken from the fixture: it is a recorder muck
parmdb = WritableParmDB("parmdb")
# call function
parmExportCal = ParmExportCalWrapper()
self.assertRaises(PipelineRecipeFailed,
parmExportCal._write_corrected_data,
parmdb, station,
input_polarization_data, input_corected_data)
def test_swap_outliers_with_median(self):
data = {"pol1":[{"values": [1., 1., 1., 1., 100., 100.]},
{"values": [1., 1., 1., 1., 100., 100.]}]
}
type_pair = ["Imag", "Real"] # Order is alphabetical
# omit the last entry do swap the 5th entry with the median (1)
goal_filtered_array = numpy.array([1., 1., 1., 1., 1., 100.])
parmExportCal = ParmExportCalWrapper()
corrected_polarisation = \
parmExportCal._swap_outliers_with_median(data, type_pair, 2.0)
#incredibly rough and incorrect float comparison of the values in the
for left, right in zip(corrected_polarisation['pol1'].real, goal_filtered_array):
message = "Comparison of float values in the array did not" \
"result in about the same value: {0}"
if not int(left) == int(right):
self.assertTrue(False, message.format(
"int value not the same: "
"{0} != {1}".format(int(left), int(right))))
precision = 1000
if not int(left * precision) == int(right * precision):
self.assertTrue(False, message.format(
"value not the same within current precision: "
"{0} != {1}".format(int(left * precision), int(right * precision))))
import re
import copy
class parmdb(object):
"""
Much parmdb interface:
Allows basic checking of called function and parameters
"""
def __init__ (self, dbname, create = True, names = None):
self._basename = dbname
if not names == None:
self.names = names
else:
self.names = ["1:1:Real:name1",
"1:1:Real:name2",
"1:1:Real:name3",
"1:1:Real:name4",
"Gain:1:2:Real:station1"]
self.called_functions_and_parameters = []
def getNames(self, parmnamepattern = ''):
if parmnamepattern == '':
return self.names
#convert the pattern to regexp
listed_pattern = parmnamepattern.split("*")
parmnamepattern = ".*{0}.*".format(".*".join(listed_pattern))
#create regexp matcher!
prog = re.compile(parmnamepattern)
# only return regexp matches the pattern in the string
return [s for s in self.names if prog.match(s)]
def deleteValues(self, *args):
self.called_functions_and_parameters.append(['deleteValues',
[arg for arg in args]])
def addValues(self, *args):
self.called_functions_and_parameters.append(
['addValues', [arg for arg in args]])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment