Skip to content
Snippets Groups Projects
Commit 1f04ff29 authored by Maik Nijhuis's avatar Maik Nijhuis
Browse files

Merge branch 'ast-490-python-formatting' into 'master'

AST-490 Format python files using black

See merge request !700
parents 2327d6ca 036459ce
No related branches found
No related tags found
1 merge request!700AST-490 Format python files using black
Pipeline #34922 passed
Showing
with 437 additions and 304 deletions
......@@ -4,8 +4,12 @@
with section("parse"):
# Specify structure for custom cmake functions
additional_commands = { 'foo': { 'flags': ['BAR', 'BAZ'],
'kwargs': {'DEPENDS': '*', 'HEADERS': '*', 'SOURCES': '*'}}}
additional_commands = {
"foo": {
"flags": ["BAR", "BAZ"],
"kwargs": {"DEPENDS": "*", "HEADERS": "*", "SOURCES": "*"},
}
}
# Override configurations per-command where available
override_spec = {}
......@@ -41,7 +45,7 @@ with section("format"):
# 'use-space', fractional indentation is left as spaces (utf-8 0x20). If set
# to `round-up` fractional indentation is replaced with a single tab character
# (utf-8 0x09) effectively shifting the column to the next tabstop
fractional_tab_policy = 'use-space'
fractional_tab_policy = "use-space"
# If an argument group contains more than this many sub-groups (parg or kwarg
# groups) then force it to a vertical layout.
......@@ -69,7 +73,7 @@ with section("format"):
# to this reference: `prefix`: the start of the statement, `prefix-indent`:
# the start of the statement, plus one indentation level, `child`: align to
# the column of the arguments
dangle_align = 'prefix'
dangle_align = "prefix"
# If the statement spelling length (including space and parenthesis) is
# smaller than this amount, then force reject nested layouts.
......@@ -85,13 +89,13 @@ with section("format"):
max_lines_hwrap = 2
# What style line endings to use in the output.
line_ending = 'unix'
line_ending = "unix"
# Format command names consistently as 'lower' or 'upper' case
command_case = 'canonical'
command_case = "canonical"
# Format keywords consistently as 'lower' or 'upper' case
keyword_case = 'unchanged'
keyword_case = "unchanged"
# A list of command names which should always be wrapped
always_wrap = []
......@@ -120,10 +124,10 @@ with section("format"):
with section("markup"):
# What character to use for bulleted lists
bullet_char = '*'
bullet_char = "*"
# What character to use as punctuation after numerals in an enumerated list
enum_char = '.'
enum_char = "."
# If comment markup is enabled, don't reflow the first comment block in each
# listfile. Use this to preserve formatting of your copyright/license
......@@ -136,15 +140,15 @@ with section("markup"):
# Regular expression to match preformat fences in comments default=
# ``r'^\s*([`~]{3}[`~]*)(.*)$'``
fence_pattern = '^\\s*([`~]{3}[`~]*)(.*)$'
fence_pattern = "^\\s*([`~]{3}[`~]*)(.*)$"
# Regular expression to match rulers in comments default=
# ``r'^\s*[^\w\s]{3}.*[^\w\s]{3}$'``
ruler_pattern = '^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$'
ruler_pattern = "^\\s*[^\\w\\s]{3}.*[^\\w\\s]{3}$"
# If a comment line matches starts with this pattern then it is explicitly a
# trailing comment for the preceeding argument. Default is '#<'
explicit_trailing_pattern = '#<'
explicit_trailing_pattern = "#<"
# If a comment line starts with at least this many consecutive hash
# characters, then don't lstrip() them off. This allows for lazy hash rulers
......@@ -167,38 +171,38 @@ with section("lint"):
disabled_codes = []
# regular expression pattern describing valid function names
function_pattern = '[0-9a-z_]+'
function_pattern = "[0-9a-z_]+"
# regular expression pattern describing valid macro names
macro_pattern = '[0-9A-Z_]+'
macro_pattern = "[0-9A-Z_]+"
# regular expression pattern describing valid names for variables with global
# (cache) scope
global_var_pattern = '[A-Z][0-9A-Z_]+'
global_var_pattern = "[A-Z][0-9A-Z_]+"
# regular expression pattern describing valid names for variables with global
# scope (but internal semantic)
internal_var_pattern = '_[A-Z][0-9A-Z_]+'
internal_var_pattern = "_[A-Z][0-9A-Z_]+"
# regular expression pattern describing valid names for variables with local
# scope
local_var_pattern = '[a-z][a-z0-9_]+'
local_var_pattern = "[a-z][a-z0-9_]+"
# regular expression pattern describing valid names for privatedirectory
# variables
private_var_pattern = '_[0-9a-z_]+'
private_var_pattern = "_[0-9a-z_]+"
# regular expression pattern describing valid names for public directory
# variables
public_var_pattern = '[A-Z][0-9A-Z_]+'
public_var_pattern = "[A-Z][0-9A-Z_]+"
# regular expression pattern describing valid names for function/macro
# arguments and loop variables.
argument_var_pattern = '[a-z][a-z0-9_]+'
argument_var_pattern = "[a-z][a-z0-9_]+"
# regular expression pattern describing valid names for keywords used in
# functions or macros
keyword_pattern = '[A-Z][0-9A-Z_]+'
keyword_pattern = "[A-Z][0-9A-Z_]+"
# In the heuristic for C0201, how many conditionals to match within a loop in
# before considering the loop a parser.
......@@ -224,11 +228,11 @@ with section("encode"):
emit_byteorder_mark = False
# Specify the encoding of the input file. Defaults to utf-8
input_encoding = 'utf-8'
input_encoding = "utf-8"
# Specify the encoding of the output file. Defaults to utf-8. Note that cmake
# only claims to support utf-8 so be careful when using anything else
output_encoding = 'utf-8'
output_encoding = "utf-8"
# -------------------------------------
# Miscellaneous configurations options.
......
......@@ -12,6 +12,7 @@ import pyrap.tables as pt
import numpy
import pylab
def plotflags(tabnames):
"""Plot NDPPP Count results
......@@ -27,15 +28,17 @@ def plotflags (tabnames):
"""
t = pt.table(tabnames)
if 'Frequency' in t.colnames():
t1 = t.sort ('Frequency')
pylab.plot (t1.getcol('Frequency'), t1.getcol('Percentage'))
elif 'Station' in t.colnames():
if "Frequency" in t.colnames():
t1 = t.sort("Frequency")
pylab.plot(t1.getcol("Frequency"), t1.getcol("Percentage"))
elif "Station" in t.colnames():
percs = []
names = []
for t1 in t.iter ('Station'):
percs.append (t1.getcol('Percentage').mean())
names.append (t1.getcell('Name', 0))
pylab.plot (numpy.array(percs), '+')
for t1 in t.iter("Station"):
percs.append(t1.getcol("Percentage").mean())
names.append(t1.getcell("Name", 0))
pylab.plot(numpy.array(percs), "+")
else:
raise RuntimeError('Table appears not to be a NDPPP Count result; it does not contain a Frequency or Station column')
raise RuntimeError(
"Table appears not to be a NDPPP Count result; it does not contain a Frequency or Station column"
)
......@@ -11,79 +11,74 @@ import xml.dom.minidom
if sys.version_info[0] < 3:
b2s = lambda s: s
else:
b2s = lambda s: s.decode('utf-8')
b2s = lambda s: s.decode("utf-8")
def _has_subelement(element, subelement_name):
return len(element.getElementsByTagName(subelement_name)) > 0
def _has_error(test_case):
return _has_subelement(test_case, 'error')
return _has_subelement(test_case, "error")
def _has_failure(test_case):
return _has_subelement(test_case, 'failure')
return _has_subelement(test_case, "failure")
def get_git_sha():
if 'CI_COMMIT_SHA' in os.environ:
return os.environ['CI_COMMIT_SHA']
return b2s(subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip())
if "CI_COMMIT_SHA" in os.environ:
return os.environ["CI_COMMIT_SHA"]
return b2s(subprocess.check_output(["git", "rev-parse", "HEAD"]).strip())
def get_coverage_metrics(cov_xml_file):
cov_dom = xml.dom.minidom.parse(cov_xml_file)
coverage = cov_dom.getElementsByTagName('coverage')[0]
coverage_line_rate = float(coverage.attributes['line-rate'].value)
return {
'percentage': coverage_line_rate * 100
}
coverage = cov_dom.getElementsByTagName("coverage")[0]
coverage_line_rate = float(coverage.attributes["line-rate"].value)
return {"percentage": coverage_line_rate * 100}
def get_tests_metrics(utests_xml_file):
utests_xml = xml.dom.minidom.parse(utests_xml_file)
test_cases = utests_xml.getElementsByTagName('testcase')
test_cases = utests_xml.getElementsByTagName("testcase")
errors = len(list(filter(_has_error, test_cases)))
failures = len(list(filter(_has_failure, test_cases)))
return {
'errors': errors,
'failed': failures,
'total': len(test_cases)
}
return {"errors": errors, "failed": failures, "total": len(test_cases)}
def get_lint_metrics(lint_xml_file):
return {
'errors': 0,
'failures': 0,
'tests': 0
}
return {"errors": 0, "failures": 0, "tests": 0}
def get_build_status(ci_metrics):
now = time.time()
test_metrics = ci_metrics['tests']
if test_metrics['errors'] > 0 or test_metrics['failed'] > 0:
last_build_status = 'failed'
test_metrics = ci_metrics["tests"]
if test_metrics["errors"] > 0 or test_metrics["failed"] > 0:
last_build_status = "failed"
else:
last_build_status = 'passed'
last_build_status = "passed"
return {
'last': {
'status': last_build_status,
'timestamp': now
},
'green': {
'timestamp': now
}
"last": {"status": last_build_status, "timestamp": now},
"green": {"timestamp": now},
}
def produce_ci_metrics(build_dir):
cov_xml_file = os.path.join(build_dir, 'code-coverage.xml')
utests_xml_file = os.path.join(build_dir, 'unit-tests.xml')
lint_xml_file = os.path.join(build_dir, 'linting.xml')
cov_xml_file = os.path.join(build_dir, "code-coverage.xml")
utests_xml_file = os.path.join(build_dir, "unit-tests.xml")
lint_xml_file = os.path.join(build_dir, "linting.xml")
ci_metrics = {
'commit_sha': get_git_sha(),
'coverage': get_coverage_metrics(cov_xml_file),
'tests': get_tests_metrics(utests_xml_file),
'lint': get_lint_metrics(lint_xml_file)
"commit_sha": get_git_sha(),
"coverage": get_coverage_metrics(cov_xml_file),
"tests": get_tests_metrics(utests_xml_file),
"lint": get_lint_metrics(lint_xml_file),
}
ci_metrics['build-status'] = get_build_status(ci_metrics)
ci_metrics["build-status"] = get_build_status(ci_metrics)
print(json.dumps(ci_metrics, indent=2))
if __name__ == '__main__':
if __name__ == "__main__":
produce_ci_metrics(sys.argv[1])
......@@ -87,8 +87,8 @@ def count_junit_metrics(filename):
log = logging.getLogger(LOGGER_NAME)
try:
root_elem = etree.parse(filename).getroot()
if root_elem.tag not in ['testsuites', 'testsuite']:
raise ValueError('Invalid JUnit XML file.')
if root_elem.tag not in ["testsuites", "testsuite"]:
raise ValueError("Invalid JUnit XML file.")
stats = parse_junit_tree(root_elem)
result = dict(errors=0, failures=0, tests=0, skipped=0)
for key in result:
......@@ -99,12 +99,16 @@ def count_junit_metrics(filename):
stats["testcase"][key],
)
else:
result[key] = max(stats["testsuite"][key],
stats["testcase"][key])
result[key] = max(
stats["testsuite"][key], stats["testcase"][key]
)
result["total"] = result["tests"]
del result["tests"]
except Exception as expt:
log.exception("Exception caught parsing '%s', returning 0 since the CI does not allow any linting errors/warnings", filename)
log.exception(
"Exception caught parsing '%s', returning 0 since the CI does not allow any linting errors/warnings",
filename,
)
result = dict(errors=0, failures=0, total=0, skipped=0)
return result
......@@ -174,8 +178,9 @@ def main():
# latest_pipeline_id = str(pipeline["id"])
latest_build_date = pipeline["created_at"]
latest_build_timestamp = datetime.timestamp(
datetime.strptime(latest_build_date,
"%Y-%m-%dT%H:%M:%S.%fZ")
datetime.strptime(
latest_build_date, "%Y-%m-%dT%H:%M:%S.%fZ"
)
)
break
except Exception as err:
......
......@@ -37,7 +37,11 @@ def main(): # pylint: disable=too-many-branches,too-many-statements
# Create badge
badge = anybadge.Badge(
label=label, value=value, default_color=color, value_prefix=" ", value_suffix=" "
label=label,
value=value,
default_color=color,
value_prefix=" ",
value_suffix=" ",
)
# Write badge
......@@ -65,7 +69,11 @@ def main(): # pylint: disable=too-many-branches,too-many-statements
# Create badge
badge = anybadge.Badge(
label=label, value=value, default_color=color, value_prefix=" ", value_suffix=" "
label=label,
value=value,
default_color=color,
value_prefix=" ",
value_suffix=" ",
)
# Write badge
......@@ -91,7 +99,11 @@ def main(): # pylint: disable=too-many-branches,too-many-statements
# Create badge
badge = anybadge.Badge(
label=label, value=value, default_color=color, value_prefix=" ", value_suffix=" "
label=label,
value=value,
default_color=color,
value_prefix=" ",
value_suffix=" ",
)
# Write badge
......@@ -117,7 +129,11 @@ def main(): # pylint: disable=too-many-branches,too-many-statements
# Create badge
badge = anybadge.Badge(
label=label, value=value, default_color=color, value_prefix=" ", value_suffix=" "
label=label,
value=value,
default_color=color,
value_prefix=" ",
value_suffix=" ",
)
# Write badge
......@@ -141,7 +157,11 @@ def main(): # pylint: disable=too-many-branches,too-many-statements
# Create badge
badge = anybadge.Badge(
label=label, value=value, default_color=color, value_prefix=" ", value_suffix=" "
label=label,
value=value,
default_color=color,
value_prefix=" ",
value_suffix=" ",
)
# Write badge
......@@ -214,7 +234,11 @@ def main(): # pylint: disable=too-many-branches,too-many-statements
# Create badge
badge = anybadge.Badge(
label=label, value=value, default_color=color, value_prefix=" ", value_suffix=" "
label=label,
value=value,
default_color=color,
value_prefix=" ",
value_suffix=" ",
)
# Write badge
......@@ -240,7 +264,11 @@ def main(): # pylint: disable=too-many-branches,too-many-statements
# Create badge
badge = anybadge.Badge(
label=label, value=value, default_color=color, value_prefix=" ", value_suffix=" "
label=label,
value=value,
default_color=color,
value_prefix=" ",
value_suffix=" ",
)
# Write badge
......@@ -264,7 +292,11 @@ def main(): # pylint: disable=too-many-branches,too-many-statements
# Create badge
badge = anybadge.Badge(
label=label, value=value, default_color=color, value_prefix=" ", value_suffix=" "
label=label,
value=value,
default_color=color,
value_prefix=" ",
value_suffix=" ",
)
# Write badge
......
......@@ -53,9 +53,15 @@ def create_skymodel():
f.write(
"FORMAT = Name, Type, Ra, Dec, I, MajorAxis, MinorAxis, PositionAngle, ReferenceFrequency='134e6', SpectralIndex='[0.0]'\r\n"
)
f.write("center, POINT, 16:38:28.205000, +63.44.34.314000, 1, , , , , \r\n")
f.write("ra_off, POINT, 16:58:28.205000, +63.44.34.314000, 1, , , , , \r\n")
f.write("radec_off, POINT, 16:38:28.205000, +65.44.34.314000, 1, , , , , \r\n")
f.write(
"center, POINT, 16:38:28.205000, +63.44.34.314000, 1, , , , , \r\n"
)
f.write(
"ra_off, POINT, 16:58:28.205000, +63.44.34.314000, 1, , , , , \r\n"
)
f.write(
"radec_off, POINT, 16:38:28.205000, +65.44.34.314000, 1, , , , , \r\n"
)
@pytest.fixture()
......@@ -156,7 +162,11 @@ def test_only_predict(create_skymodel):
)
check_call(
[tcf.DP3EXE, "msout=PREDICT_DIR_1.MS", "predict.sources=[center, dec_off]"]
[
tcf.DP3EXE,
"msout=PREDICT_DIR_1.MS",
"predict.sources=[center, dec_off]",
]
+ common_args
+ predict_args
)
......@@ -230,8 +240,14 @@ def test_uvwflagger(create_skymodel, create_corrupted_data_from_regular):
# When uvw flagging is disabled, the NaNs in the solution file are only 9
expected_flagged_solutions = 54
assert np.count_nonzero(np.isnan(amplitude_solutions)) == expected_flagged_solutions
assert np.count_nonzero(np.isnan(phase_solutions)) == expected_flagged_solutions
assert (
np.count_nonzero(np.isnan(amplitude_solutions))
== expected_flagged_solutions
)
assert (
np.count_nonzero(np.isnan(phase_solutions))
== expected_flagged_solutions
)
# Only test a limited set of caltype + nchannels combinations, since testing
......@@ -256,7 +272,9 @@ def test_uvwflagger(create_skymodel, create_corrupted_data_from_regular):
# "rotation+diagonal", # part of fulljones -> not implemented
],
)
def test_caltype(create_skymodel, create_corrupted_data_from_regular, caltype_nchan):
def test_caltype(
create_skymodel, create_corrupted_data_from_regular, caltype_nchan
):
"""Test calibration for different calibration types"""
caltype = caltype_nchan[:-1]
nchan = int(caltype_nchan[-1])
......@@ -282,13 +300,20 @@ def test_caltype(create_skymodel, create_corrupted_data_from_regular, caltype_nc
h5 = h5py.File("solutions.h5", "r")
if caltype in ["scalar", "diagonal", "scalaramplitude", "diagonalamplitude"]:
if caltype in [
"scalar",
"diagonal",
"scalaramplitude",
"diagonalamplitude",
]:
amplitude_solutions = h5["sol000/amplitude000/val"]
if caltype.startswith("scalar"):
assert amplitude_solutions.attrs["AXES"] == b"time,freq,ant,dir"
else:
assert amplitude_solutions.attrs["AXES"] == b"time,freq,ant,dir,pol"
assert (
amplitude_solutions.attrs["AXES"] == b"time,freq,ant,dir,pol"
)
if nchan == 0:
assert amplitude_solutions.shape[1] == 1
......@@ -339,7 +364,12 @@ def test_subtract(create_skymodel, create_corrupted_data):
residual = float(
check_output(
[tcf.TAQLEXE, "-nopr", "-noph", "select gmax(abs(DATA)) from out.MS"]
[
tcf.TAQLEXE,
"-nopr",
"-noph",
"select gmax(abs(DATA)) from out.MS",
]
)
)
......
......@@ -114,12 +114,17 @@ def create_corrupted_visibilities():
@pytest.mark.parametrize(
"caltype", ["complexgain", "scalarcomplexgain", "amplitudeonly", "scalaramplitude"]
"caltype",
["complexgain", "scalarcomplexgain", "amplitudeonly", "scalaramplitude"],
)
@pytest.mark.parametrize("solint", [0, 1, 2, 4])
@pytest.mark.parametrize("nchan", [1, 2, 5])
def test(
create_corrupted_visibilities, copy_data_to_model_data, caltype, solint, nchan
create_corrupted_visibilities,
copy_data_to_model_data,
caltype,
solint,
nchan,
):
# Subtract corrupted visibilities using multiple predict steps
check_call(
......@@ -251,12 +256,13 @@ def test_calibration_with_dd_intervals(
>= reference_solutions["sol000/amplitude000/val"].shape[0]
):
corresponding_index = (
reference_solutions["sol000/amplitude000/val"].shape[0] - 1
reference_solutions["sol000/amplitude000/val"].shape[0]
- 1
)
values_reference = reference_solutions["sol000/amplitude000/val"][
corresponding_index, :, :, direction_index, 0
]
values_reference = reference_solutions[
"sol000/amplitude000/val"
][corresponding_index, :, :, direction_index, 0]
assert (abs(values_ddecal - values_reference) < 1.2).all()
......@@ -331,12 +337,13 @@ def test_bug_ast_924(
>= reference_solutions["sol000/amplitude000/val"].shape[0]
):
corresponding_index = (
reference_solutions["sol000/amplitude000/val"].shape[0] - 1
reference_solutions["sol000/amplitude000/val"].shape[0]
- 1
)
values_reference = reference_solutions["sol000/amplitude000/val"][
corresponding_index, :, :, direction_index, 0
]
values_reference = reference_solutions[
"sol000/amplitude000/val"
][corresponding_index, :, :, direction_index, 0]
assert (abs(values_ddecal - values_reference) < 1.2).all()
......@@ -347,7 +354,13 @@ def test_bug_ast_924(
)
@pytest.mark.parametrize(
"caltype",
["amplitudeonly", "scalaramplitude", "scalar", "diagonal", "diagonalamplitude"],
[
"amplitudeonly",
"scalaramplitude",
"scalar",
"diagonal",
"diagonalamplitude",
],
)
def test_subtract_with_dd_intervals(
create_corrupted_visibilities,
......@@ -739,7 +752,9 @@ def test_bda_constaints():
phase_bda = f_bda["sol000/phase000/val"]
phase_no_bda = f_no_bda["sol000/phase000/val"]
np.testing.assert_allclose(ampl_bda, ampl_no_bda, rtol=0.05, atol=0, equal_nan=True)
np.testing.assert_allclose(
ampl_bda, ampl_no_bda, rtol=0.05, atol=0, equal_nan=True
)
np.testing.assert_allclose(
phase_bda, phase_no_bda, rtol=0.3, atol=0, equal_nan=True
)
......@@ -101,7 +101,9 @@ def test_input_with_single_sources(source, offset):
except FileNotFoundError:
pytest.skip("WSClean not available")
check_call(["wsclean", "-use-idg", "-predict", "-name", f"{source}", f"{MSIN}"])
check_call(
["wsclean", "-use-idg", "-predict", "-name", f"{source}", f"{MSIN}"]
)
check_output(
[
tcf.TAQLEXE,
......@@ -112,7 +114,10 @@ def test_input_with_single_sources(source, offset):
)
# Predict source: $source offset: $offset using IDG
if "polygon" in open(f"{tcf.DDECAL_RESOURCEDIR}/{source}-{offset}.reg").read():
if (
"polygon"
in open(f"{tcf.DDECAL_RESOURCEDIR}/{source}-{offset}.reg").read()
):
check_call(
[
tcf.DP3EXE,
......
......@@ -9,32 +9,31 @@ import os
os.system("rm dummy-image.fits dummy-dirty.fits")
# -channel-range 0 1 ensures the reference frequency is from the first channel.
os.system("wsclean -size 512 512 -scale 0.01 -channel-range 0 1 -name dummy tDDECal.MS")
os.system(
"wsclean -size 512 512 -scale 0.01 -channel-range 0 1 -name dummy tDDECal.MS"
)
sources = {
"radec": (400, 64),
"ra": (400, 256),
"dec": (256, 64),
"center": ( 256, 256 )
}
brightness = {
"radec": 10,
"ra": 20,
"dec": 20,
"center": 10
"center": (256, 256),
}
brightness = {"radec": 10, "ra": 20, "dec": 20, "center": 10}
term_brightness = {0: 10, 1: 20000, 2: 30000}
fits_files = []
hdu = fits.open("dummy-image.fits")[0]
def write_fits(name):
filename = name + "-model.fits"
os.system("rm -rf " + filename)
hdu.writeto(filename)
fits_files.append(filename)
# Generate foursources.fits, which has all four sources.
hdu.data *= 0
......
......@@ -80,6 +80,7 @@ RUN wget -nv -O /WSRT_Measures.ztar ftp://ftp.astron.nl/outgoing/Measures/WSRT_M
&& rm /WSRT_Measures.ztar
# Install pip dependencies
RUN pip3 install \
black \
cmake-format \
h5py \
sphinx \
......
Subproject commit b422a1bcc4a51d12ac4f8a3f08c5885c776f4a6d
Subproject commit c126d20b28c8557375e12bc3039bb6a64c05f8b0
......@@ -24,6 +24,7 @@ Script can be invoked in two ways:
MSIN = "tDemix.in_MS"
CWD = os.getcwd()
@pytest.fixture(autouse=True)
def source_env():
os.chdir(CWD)
......@@ -40,18 +41,35 @@ def source_env():
os.chdir(CWD)
shutil.rmtree(tmpdir)
def test_skymodel_sourcedb_roundtrip():
"""Check that skymodel in default format is reproduced after makesourcedb and showsourcedb"""
# sky.txt is not in the default format, create a skymodel in default format by going through sourcedb
check_call([tcf.MAKESOURCEDBEXE, "in=tDemix_tmp/sky.txt", "out=sourcedb"])
# The first line of showsourcedb is the application's announcement
skymodel_defaultformat_input = check_output([tcf.SHOWSOURCEDBEXE, "in=sourcedb", "mode=skymodel"]).decode('utf-8').split('\n',1)[-1]
skymodel_defaultformat_input = (
check_output([tcf.SHOWSOURCEDBEXE, "in=sourcedb", "mode=skymodel"])
.decode("utf-8")
.split("\n", 1)[-1]
)
with open("tDemix_tmp/sky_defaultformat.txt", "w") as f:
f.write(skymodel_defaultformat_input)
# Now do the roundtrip test: make sourcedb and print the result in the default format
check_call([tcf.MAKESOURCEDBEXE, "in=tDemix_tmp/sky_defaultformat.txt", "out=sourcedb_defaultformat"])
skymodel_defaultformat_output = check_output([tcf.SHOWSOURCEDBEXE, "in=sourcedb_defaultformat", "mode=skymodel"]).decode('utf-8').split('\n',1)[-1]
check_call(
[
tcf.MAKESOURCEDBEXE,
"in=tDemix_tmp/sky_defaultformat.txt",
"out=sourcedb_defaultformat",
]
)
skymodel_defaultformat_output = (
check_output(
[tcf.SHOWSOURCEDBEXE, "in=sourcedb_defaultformat", "mode=skymodel"]
)
.decode("utf-8")
.split("\n", 1)[-1]
)
assert skymodel_defaultformat_input == skymodel_defaultformat_output
......@@ -12,12 +12,6 @@ SOURCE_DIR=$(dirname "$0")/..
#relative to SOURCE_DIR.
EXCLUDE_DIRS=(external build CMake)
#The patterns of the C++ source files, which clang-format should format.
CXX_SOURCES=(*.cc *.h)
#The patterns of the CMake source files, which cmake-format should format.
CMAKE_SOURCES=(CMakeLists.txt *.cmake)
#End script configuration.
#The common formatting script has further documentation.
......
......@@ -109,7 +109,5 @@ def test_with_updateweights():
f"applybeam.updateweights=true",
]
)
taql_command = (
f"select from {MSIN} where all(near(WEIGHT_SPECTRUM, NEW_WEIGHT_SPECTRUM))"
)
taql_command = f"select from {MSIN} where all(near(WEIGHT_SPECTRUM, NEW_WEIGHT_SPECTRUM))"
assert_taql(taql_command)
......@@ -54,9 +54,15 @@ def create_skymodel():
f.write(
"FORMAT = Name, Type, Ra, Dec, I, MajorAxis, MinorAxis, PositionAngle, ReferenceFrequency='134e6', SpectralIndex='[0.0]'\r\n"
)
f.write("center, POINT, 16:38:28.205000, + 63.44.34.314000, 10, , , , , \r\n")
f.write("ra_off, POINT, 16:38:28.205000, + 64.44.34.314000, 10, , , , , \r\n")
f.write("radec_off, POINT, 16:38:28.205000, +65.44.34.314000, 10, , , , , \r\n")
f.write(
"center, POINT, 16:38:28.205000, + 63.44.34.314000, 10, , , , , \r\n"
)
f.write(
"ra_off, POINT, 16:38:28.205000, + 64.44.34.314000, 10, , , , , \r\n"
)
f.write(
"radec_off, POINT, 16:38:28.205000, +65.44.34.314000, 10, , , , , \r\n"
)
check_call([tcf.MAKESOURCEDBEXE, "in=test.skymodel", "out=test.sourcedb"])
......@@ -143,6 +149,6 @@ def test_regular_buffer_writing():
"checkparset=true",
"msin=regular_buffer.MS",
"msout=out.MS",
"steps=[]"
"steps=[]",
]
)
......@@ -56,7 +56,12 @@ def create_skymodel():
)
check_call(
[tcf.MAKESOURCEDBEXE, "in=test.skymodel", "out=test.sourcedb", "append=false"]
[
tcf.MAKESOURCEDBEXE,
"in=test.skymodel",
"out=test.sourcedb",
"append=false",
]
)
......@@ -66,10 +71,17 @@ def create_skymodel_in_phase_center():
f.write(
"FORMAT = Name, Type, Ra, Dec, I, MajorAxis, MinorAxis, PositionAngle, ReferenceFrequency='134e6', SpectralIndex='[0.0]'\r\n"
)
f.write(f"center, POINT, 01:37:41.299000, +33.09.35.132000, 10, , , , , \r\n")
f.write(
f"center, POINT, 01:37:41.299000, +33.09.35.132000, 10, , , , , \r\n"
)
check_call(
[tcf.MAKESOURCEDBEXE, "in=test.skymodel", "out=test.sourcedb", "append=false"]
[
tcf.MAKESOURCEDBEXE,
"in=test.skymodel",
"out=test.sourcedb",
"append=false",
]
)
......
......@@ -42,6 +42,7 @@ common_args = [
skymodel_arg = "demix.skymodel='tDemix_tmp/{}'"
@pytest.fixture(autouse=True)
def source_env():
os.chdir(CWD)
......@@ -50,7 +51,13 @@ def source_env():
os.chdir(tmpdir)
untar_ms(f"{tcf.RESOURCEDIR}/{MSIN}.tgz")
check_call([tcf.MAKESOURCEDBEXE, "in=tDemix_tmp/sky.txt", "out=tDemix_tmp/sourcedb"])
check_call(
[
tcf.MAKESOURCEDBEXE,
"in=tDemix_tmp/sky.txt",
"out=tDemix_tmp/sourcedb",
]
)
# Tests are executed here
yield
......@@ -60,7 +67,7 @@ def source_env():
shutil.rmtree(tmpdir)
@pytest.mark.parametrize("skymodel", ['sky.txt', 'sourcedb'])
@pytest.mark.parametrize("skymodel", ["sky.txt", "sourcedb"])
def test_without_target(skymodel):
check_call(
[
......@@ -73,13 +80,12 @@ def test_without_target(skymodel):
+ common_args
)
# Compare some columns of the output MS with the reference output.
taql_command = f"select from tDemix_out.MS t1, tDemix_tmp/tDemix_ref1.MS t2 where not all(near(t1.DATA,t2.DATA,1e-3) || (isnan(t1.DATA) && isnan(t2.DATA))) || not all(t1.FLAG = t2.FLAG) || not all(near(t1.WEIGHT_SPECTRUM, t2.WEIGHT_SPECTRUM)) || not all(t1.LOFAR_FULL_RES_FLAG = t2.LOFAR_FULL_RES_FLAG) || t1.ANTENNA1 != t2.ANTENNA1 || t1.ANTENNA2 != t2.ANTENNA2 || t1.TIME !~= t2.TIME"
assert_taql(taql_command)
@pytest.mark.parametrize("skymodel", ['sky.txt', 'sourcedb'])
@pytest.mark.parametrize("skymodel", ["sky.txt", "sourcedb"])
def test_with_target_projected_away(skymodel):
check_call(
[
......@@ -92,13 +98,12 @@ def test_with_target_projected_away(skymodel):
+ common_args
)
# Compare some columns of the output MS with the reference output.
taql_command = f"select from tDemix_out.MS t1, tDemix_tmp/tDemix_ref2.MS t2 where not all(near(t1.DATA,t2.DATA,1e-3) || (isnan(t1.DATA) && isnan(t2.DATA))) || not all(t1.FLAG = t2.FLAG) || not all(near(t1.WEIGHT_SPECTRUM, t2.WEIGHT_SPECTRUM)) || not all(t1.LOFAR_FULL_RES_FLAG = t2.LOFAR_FULL_RES_FLAG) || t1.ANTENNA1 != t2.ANTENNA1 || t1.ANTENNA2 != t2.ANTENNA2 || t1.TIME !~= t2.TIME"
assert_taql(taql_command)
@pytest.mark.parametrize("skymodel", ['sky.txt', 'sourcedb'])
@pytest.mark.parametrize("skymodel", ["sky.txt", "sourcedb"])
def test_with_target(skymodel):
check_call(
[
......@@ -131,7 +136,6 @@ def test_time_freq_resolution():
+ common_args
)
# Compare some columns of the output MS with the reference output.
taql_command = f"select from tDemix_out.MS t1, tDemix_tmp/tDemix_ref1.MS t2 where not all(near(t1.DATA,t2.DATA,1e-3) || (isnan(t1.DATA) && isnan(t2.DATA))) || not all(t1.FLAG = t2.FLAG) || not all(near(t1.WEIGHT_SPECTRUM, t2.WEIGHT_SPECTRUM)) || not all(t1.LOFAR_FULL_RES_FLAG = t2.LOFAR_FULL_RES_FLAG) || t1.ANTENNA1 != t2.ANTENNA1 || t1.ANTENNA2 != t2.ANTENNA2 || t1.TIME !~= t2.TIME"
assert_taql(taql_command)
......@@ -141,7 +141,8 @@ def test_write_thread_enabled():
assert re.search(b"use thread: *true", result)
assert re.search(
b"(1[0-9]| [ 0-9])[0-9]\\.[0-9]% \\([ 0-9]{5} [m ]s\\) Creating task\n", result
b"(1[0-9]| [ 0-9])[0-9]\\.[0-9]% \\([ 0-9]{5} [m ]s\\) Creating task\n",
result,
)
assert re.search(
b"(1[0-9]| [ 0-9])[0-9]\\.[0-9]% \\([ 0-9]{5} [m ]s\\) Writing \\(threaded\\)\n",
......@@ -169,10 +170,12 @@ def test_write_thread_disabled():
assert re.search(b"use thread: *false", result)
assert (
re.search(
b"(1[0-9]| [ 0-9])[0-9]\\.[0-9]% \\([ 0-9]{5} [m ]s\\) Creating task\n", result
b"(1[0-9]| [ 0-9])[0-9]\\.[0-9]% \\([ 0-9]{5} [m ]s\\) Creating task\n",
result,
)
== None
)
assert re.search(
b"(1[0-9]| [ 0-9])[0-9]\\.[0-9]% \\([ 0-9]{5} [m ]s\\) Writing\n", result
b"(1[0-9]| [ 0-9])[0-9]\\.[0-9]% \\([ 0-9]{5} [m ]s\\) Writing\n",
result,
)
......@@ -19,6 +19,7 @@ from utils import untar_ms, get_taql_result, check_output
MSIN = "tNDPPP-generic.MS"
CWD = os.getcwd()
@pytest.fixture(autouse=True)
def source_env():
os.chdir(CWD)
......@@ -56,12 +57,18 @@ def test_chunking():
# Each should have two timesteps:
taql_command = f"select unique TIME from chunktest-000.ms"
result = get_taql_result(taql_command)
assert result == 'Unit: s\n29-Mar-2013/13:59:53.007\n29-Mar-2013/14:00:03.021'
assert (
result == "Unit: s\n29-Mar-2013/13:59:53.007\n29-Mar-2013/14:00:03.021"
)
taql_command = f"select unique TIME from chunktest-001.ms"
result = get_taql_result(taql_command)
assert result == 'Unit: s\n29-Mar-2013/14:00:13.035\n29-Mar-2013/14:00:23.049'
assert (
result == "Unit: s\n29-Mar-2013/14:00:13.035\n29-Mar-2013/14:00:23.049"
)
taql_command = f"select unique TIME from chunktest-002.ms"
result = get_taql_result(taql_command)
assert result == 'Unit: s\n29-Mar-2013/14:00:33.063\n29-Mar-2013/14:00:43.076'
assert (
result == "Unit: s\n29-Mar-2013/14:00:33.063\n29-Mar-2013/14:00:43.076"
)
......@@ -23,7 +23,9 @@ Script can be invoked in two ways:
"""
MSIN = "tNDPPP-generic.MS"
PARMDB_TGZ = "tApplyCal2.parmdb.tgz" # Note: This archive contains tApplyCal.parmdb.
PARMDB_TGZ = (
"tApplyCal2.parmdb.tgz" # Note: This archive contains tApplyCal.parmdb.
)
PARMDB = "tApplyCal.parmdb"
CWD = os.getcwd()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment