Skip to content
Snippets Groups Projects
Commit a5158a3a authored by Jan David Mol's avatar Jan David Mol
Browse files

Merge branch 'multiprocessing-casacore' into 'master'

Use multiprocessing to parallellise casacore

See merge request !923
parents f644fc29 1abb6788
No related branches found
No related tags found
1 merge request!923Use multiprocessing to parallellise casacore
...@@ -161,6 +161,7 @@ Next change the version in the following places: ...@@ -161,6 +161,7 @@ Next change the version in the following places:
# Release Notes # Release Notes
* 0.37.0 Run casacore in separate processes, increasing beam-tracking performance
* 0.36.2 Fix polling 2D attributes * 0.36.2 Fix polling 2D attributes
Harden periodic tasks against exceptions Harden periodic tasks against exceptions
* 0.36.1 Fix tile beamforming * 0.36.1 Fix tile beamforming
......
0.36.2 0.37.0
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import datetime import datetime
import logging
from functools import lru_cache from functools import lru_cache
import threading from concurrent.futures import ProcessPoolExecutor
from typing import TypedDict from typing import TypedDict
import casacore.measures import casacore.measures
...@@ -26,12 +27,11 @@ is part of the 'casacore-tools' package. ...@@ -26,12 +27,11 @@ is part of the 'casacore-tools' package.
The measures The measures
""" """
logger = logging.getLogger()
# Where to store the measures table sets # Where to store the measures table sets
IERS_ROOTDIR = "/opt/IERS" IERS_ROOTDIR = "/opt/IERS"
# Compute lock to prevent thrashing when multithreading
compute_lock = threading.Lock()
def get_IERS_timestamp() -> datetime.datetime: def get_IERS_timestamp() -> datetime.datetime:
"""Return the date of the currently installed IERS tables.""" """Return the date of the currently installed IERS tables."""
...@@ -94,95 +94,168 @@ def pointing_to_str(pointing: tuple[str, str, str]) -> str: ...@@ -94,95 +94,168 @@ def pointing_to_str(pointing: tuple[str, str, str]) -> str:
) )
@lru_cache
def is_valid_pointing(pointing: tuple[str, str, str]) -> bool:
"""Check validity of the direction measure"""
try:
if pointing[0] == "None":
# uninitialised direction
return True
measure = casacore.measures.measures()
# if this raises, the direction is invalid
_ = measure.direction(*pointing)
return True
except (RuntimeError, TypeError, KeyError, IndexError) as e:
return False
def subtract(a, b) -> numpy.ndarray: def subtract(a, b) -> numpy.ndarray:
return numpy.array([x - y for x, y in zip(a, b)]) return numpy.array([x - y for x, y in zip(a, b)])
class Delays: class Delays:
def __init__(self, itrf: tuple[float, float, float]): # process-local casacore.measures.measures instance
"""Create a measure object, configured for the specified terrestrial location.""" measures = None
measure = casacore.measures.measures() @staticmethod
frame_location = measure.position("ITRF", *[f"{x}m" for x in itrf]) def _init_process():
"""Initialise the process that will be used to query measures."""
if not measure.do_frame(frame_location): Delays.measure = casacore.measures.measures()
raise ValueError(f"measure.do_frame failed for ITRF location {itrf}")
self.reference_itrf = itrf def __init__(self, reference_itrf: tuple[float, float, float] | None = None):
self.measure = measure self.pool = None
self.measure_time = None
def set_measure_time(self, utc_time: datetime.datetime): logger.debug("Starting Delays process pool")
"""Configure the measure object for the specified time."""
# start a dedicated process
self.pool = ProcessPoolExecutor(
max_workers=1,
initializer=self._init_process,
)
try:
# process a command to make sure the pool process is initialised
_ = self.set_measure_time(datetime.datetime.now())
logger.info("Started Delays process pool")
# set default reference position
if reference_itrf is not None:
_ = self.set_reference_itrf(reference_itrf)
except Exception as ex:
# destruct properly
self.stop()
raise
@property
def running(self) -> bool:
return self.pool is not None
def __del__(self):
if self.running:
logger.warning(
"stop() was never called. __del__ now will but this cannot be counted on."
)
self.stop()
def stop(self):
if not self.running:
return
logger.debug("Stopping Delays process pool")
self.pool.shutdown()
self.pool = None
logger.info("Stopped Delays process pool")
@staticmethod
def _set_measure_time(utc_time: datetime.datetime):
utc_time_str = utc_time.isoformat(" ") utc_time_str = utc_time.isoformat(" ")
frame_time = self.measure.epoch("UTC", utc_time_str) frame_time = Delays.measure.epoch("UTC", utc_time_str)
if not self.measure.do_frame(frame_time): if not Delays.measure.do_frame(frame_time):
raise ValueError(f"measure.do_frame failed for UTC time {utc_time_str}") raise ValueError(f"measure.do_frame failed for UTC time {utc_time_str}")
def set_measure_time(self, utc_time: datetime.datetime):
"""Configure the measure object for the specified time."""
assert self.running, "Pool not running."
future = self.pool.submit(self._set_measure_time, utc_time)
_ = future.result()
@staticmethod
def _set_reference_itrf(reference_itrf: tuple[float, float, float]):
frame_location = Delays.measure.position(
"ITRF", *[f"{x}m" for x in reference_itrf]
)
if not Delays.measure.do_frame(frame_location):
raise ValueError(
f"measure.do_frame failed for ITRF location {reference_itrf}"
)
def set_reference_itrf(self, reference_itrf: tuple[float, float, float]):
"""Configure the measure object for the specified reference location."""
assert self.running, "Pool not running."
future = self.pool.submit(self._set_reference_itrf, reference_itrf)
_ = future.result()
def get_direction_vector(self, pointing: list[str]) -> numpy.ndarray: def get_direction_vector(self, pointing: list[str]) -> numpy.ndarray:
"""Compute direction vector for a given pointing, relative to the measure.""" """Compute direction vector for a given pointing, relative to the measure."""
return self.get_direction_vector_bulk([pointing]).flatten() return self.get_direction_vector_bulk([pointing]).flatten()
def get_direction_vector_bulk(self, pointings: list[list[str]]) -> numpy.ndarray: @staticmethod
"""Compute direction vectors for the given pointings, relative to the measure.""" def _get_direction_vector_bulk(pointings: list[list[str]]):
"""Process that provides casacore functions. Multiprocessing is needed
to avoid casacore from blocking itself in a multithreaded environment."""
angles0 = numpy.empty(len(pointings)) angles = numpy.zeros((2, len(pointings)), dtype=numpy.float64)
angles1 = numpy.empty(len(pointings))
for idx, pointing in enumerate(pointings): for idx, pointing in enumerate(pointings):
direction: CasacoreMDirection | None = self._pointing_to_direction(pointing) if pointing[0] == "None":
# uninitialised direction
if direction is None: angles[:, idx] = (0, 0)
# uninitialised pointings
angles0[idx] = 0
angles1[idx] = 0
else: else:
angles = self.measure.measure(direction, "ITRF") direction = Delays.measure.direction(*pointing)
angles0[idx] = angles["m0"]["value"] m_angles = Delays.measure.measure(direction, "ITRF")
angles1[idx] = angles["m1"]["value"] angles[:, idx] = (m_angles["m0"]["value"], m_angles["m1"]["value"])
# Convert polar to carthesian coordinates # Convert polar to carthesian coordinates
# see also https://github.com/casacore/casacore/blob/e793b3d5339d828a60339d16476bf688a19df3ec/casa/Quanta/MVDirection.cc#L67 # see also https://github.com/casacore/casacore/blob/e793b3d5339d828a60339d16476bf688a19df3ec/casa/Quanta/MVDirection.cc#L67
direction_vectors = numpy.array( direction_vectors = numpy.array(
[ [
numpy.cos(angles0) * numpy.cos(angles1), numpy.cos(angles[0]) * numpy.cos(angles[1]),
numpy.sin(angles0) * numpy.cos(angles1), numpy.sin(angles[0]) * numpy.cos(angles[1]),
numpy.sin(angles1), numpy.sin(angles[1]),
] ]
) )
# Return array [directions][angles]
return direction_vectors.T return direction_vectors.T
def _pointing_to_direction( def get_direction_vector_bulk(self, pointings: list[list[str]]) -> numpy.ndarray:
self, pointing: tuple[str, str, str] """Compute direction vectors for the given pointings, relative to the measure."""
) -> CasacoreMDirection | None:
try:
if pointing[0] == "None":
# uninitialised direction
return None
return self.measure.direction(*pointing)
except (RuntimeError, TypeError, KeyError, IndexError) as e:
raise ValueError(f"Invalid pointing: {pointing}") from e
def is_valid_pointing(self, pointing: tuple[str, str, str]) -> bool: assert self.running, "Pool not running."
"""Check validity of the direction measure"""
try:
_ = self._pointing_to_direction(pointing)
except ValueError as e:
return False
return True future = self.pool.submit(self._get_direction_vector_bulk, pointings)
return future.result()
def delays( def delays(
self, self,
pointing: list[str], pointing: list[str],
antenna_absolute_itrf: list[tuple[float, float, float]], antenna_relative_itrf: list[tuple[float, float, float]],
) -> numpy.ndarray: ) -> numpy.ndarray:
"""Get the delays for a direction and *absolute* antenna positions. """Get the delays for a direction and relative antenna positions.
These are the delays that have to be applied to the signal chain in order to line up the signal. These are the delays that have to be applied to the signal chain in order to line up the signal.
Positions closer to the source will result in a positive delay. Positions closer to the source will result in a positive delay.
...@@ -191,7 +264,7 @@ class Delays: ...@@ -191,7 +264,7 @@ class Delays:
return self.delays_bulk( return self.delays_bulk(
numpy.array([pointing]), numpy.array([pointing]),
numpy.array(antenna_absolute_itrf) - self.reference_itrf, numpy.array(antenna_relative_itrf),
).flatten() ).flatten()
def delays_bulk( def delays_bulk(
...@@ -204,7 +277,6 @@ class Delays: ...@@ -204,7 +277,6 @@ class Delays:
Returns delays[antenna][direction].""" Returns delays[antenna][direction]."""
with compute_lock:
# obtain the direction vector for each pointing # obtain the direction vector for each pointing
direction_vectors = self.get_direction_vector_bulk(pointings) direction_vectors = self.get_direction_vector_bulk(pointings)
......
...@@ -26,7 +26,8 @@ class TileBeamManager(AbstractBeamManager): ...@@ -26,7 +26,8 @@ class TileBeamManager(AbstractBeamManager):
super().__init__() super().__init__()
self.HBAT_antenna_positions = [] self.HBAT_antenna_positions = []
self.HBAT_delay_calculators = [] self.HBAT_reference_positions = []
self.delay_calculator = None
self.nr_tiles = 0 self.nr_tiles = 0
self.device = device self.device = device
...@@ -43,15 +44,18 @@ class TileBeamManager(AbstractBeamManager): ...@@ -43,15 +44,18 @@ class TileBeamManager(AbstractBeamManager):
delays = numpy.zeros((self.nr_tiles, N_elements), dtype=numpy.float64) delays = numpy.zeros((self.nr_tiles, N_elements), dtype=numpy.float64)
d = self.delay_calculator
for tile in range(self.nr_tiles): for tile in range(self.nr_tiles):
# initialise delay calculator # initialise delay calculator for this tile
d = self.HBAT_delay_calculators[tile] d.set_reference_itrf(self.HBAT_reference_positions[tile])
d.set_measure_time(timestamp) d.set_measure_time(timestamp)
# calculate the delays based on the set reference position, the set time and # calculate the delays based on the set reference position, the set time and
# now the set direction and antenna positions # now the set direction and antenna positions
delays[tile] = d.delays( delays[tile] = d.delays(
pointing_direction[tile], self.HBAT_antenna_positions[tile] pointing_direction[tile],
self.HBAT_antenna_positions[tile] - self.HBAT_reference_positions[tile],
) )
return delays return delays
......
...@@ -26,7 +26,11 @@ from tango import ( ...@@ -26,7 +26,11 @@ from tango import (
# PyTango imports # PyTango imports
from tango.server import attribute, command, device_property from tango.server import attribute, command, device_property
from tangostationcontrol.beam.delays import Delays, pointing_to_str, get_IERS_timestamp from tangostationcontrol.beam.delays import (
pointing_to_str,
get_IERS_timestamp,
is_valid_pointing,
)
from tangostationcontrol.beam.managers import AbstractBeamManager from tangostationcontrol.beam.managers import AbstractBeamManager
from tangostationcontrol.common.constants import MAX_POINTINGS, N_point_prop from tangostationcontrol.common.constants import MAX_POINTINGS, N_point_prop
from tangostationcontrol.common.device_decorators import ( from tangostationcontrol.common.device_decorators import (
...@@ -65,7 +69,6 @@ class BeamDevice(AsyncDevice): ...@@ -65,7 +69,6 @@ class BeamDevice(AsyncDevice):
def __init__(self, cl, name): def __init__(self, cl, name):
self._beam_manager = None self._beam_manager = None
self._num_pointings = None self._num_pointings = None
self.generic_delay_calculator = None
# Super must be called after variable assignment due to executing init_device! # Super must be called after variable assignment due to executing init_device!
super().__init__(cl, name) super().__init__(cl, name)
...@@ -212,7 +215,7 @@ class BeamDevice(AsyncDevice): ...@@ -212,7 +215,7 @@ class BeamDevice(AsyncDevice):
) )
for pointing in value: for pointing in value:
if not self.generic_delay_calculator.is_valid_pointing(pointing): if not is_valid_pointing(tuple(pointing)):
raise ValueError(f"Invalid direction: {pointing}") raise ValueError(f"Invalid direction: {pointing}")
# store the new values # store the new values
...@@ -249,9 +252,6 @@ class BeamDevice(AsyncDevice): ...@@ -249,9 +252,6 @@ class BeamDevice(AsyncDevice):
# thread to perform beam tracking # thread to perform beam tracking
self.Beam_tracker = None self.Beam_tracker = None
# generic delay calculator to ask about validity of settings
self.generic_delay_calculator = Delays([0, 0, 0])
# Derived classes will override this with a non-parameterised # Derived classes will override this with a non-parameterised
# version that lofar_device will call. # version that lofar_device will call.
@log_exceptions() @log_exceptions()
......
...@@ -163,8 +163,6 @@ class DigitalBeam(BeamDevice): ...@@ -163,8 +163,6 @@ class DigitalBeam(BeamDevice):
def __init__(self, cl, name): def __init__(self, cl, name):
self.parent = None self.parent = None
self.beamlet = None self.beamlet = None
self.delay_calculator = None
self.relative_antenna_positions = None
# Super must be called after variable assignment due to executing init_device! # Super must be called after variable assignment due to executing init_device!
super().__init__(cl, name) super().__init__(cl, name)
...@@ -195,3 +193,10 @@ class DigitalBeam(BeamDevice): ...@@ -195,3 +193,10 @@ class DigitalBeam(BeamDevice):
# relative positions of each antenna # relative positions of each antenna
self._beam_manager.relative_antenna_positions = antenna_itrf - reference_itrf self._beam_manager.relative_antenna_positions = antenna_itrf - reference_itrf
def configure_for_off(self):
# Turn off BeamTracker first
super().configure_for_off()
if self._beam_manager and self._beam_manager.delay_calculator:
self._beam_manager.delay_calculator.stop()
...@@ -61,13 +61,21 @@ class TileBeam(BeamDevice): ...@@ -61,13 +61,21 @@ class TileBeam(BeamDevice):
"HBAT_antenna_itrf_offsets_R" "HBAT_antenna_itrf_offsets_R"
).reshape(self._beam_manager.nr_tiles, N_elements, N_xyz) ).reshape(self._beam_manager.nr_tiles, N_elements, N_xyz)
# a delay calculator for each tile # a delay calculator
self._beam_manager.HBAT_delay_calculators = [ self._beam_manager.delay_calculator = Delays()
Delays(reference_itrf) for reference_itrf in antenna_reference_itrf
] # absolute positions of each tile
self._beam_manager.HBAT_reference_positions = antenna_reference_itrf
# absolute positions of each antenna element # absolute positions of each antenna element
self._beam_manager.HBAT_antenna_positions = [ self._beam_manager.HBAT_antenna_positions = [
antenna_reference_itrf[tile] + hbat_antenna_itrf_offsets[tile] antenna_reference_itrf[tile] + hbat_antenna_itrf_offsets[tile]
for tile in range(self._beam_manager.nr_tiles) for tile in range(self._beam_manager.nr_tiles)
] ]
def configure_for_off(self):
# Turn off BeamTracker first
super().configure_for_off()
if self._beam_manager and self._beam_manager.delay_calculator:
self._beam_manager.delay_calculator.stop()
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import datetime import datetime
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch, call
from test import base from test import base
import numpy.testing import numpy.testing
...@@ -33,14 +33,18 @@ class TestTileBeamManager(base.TestCase): ...@@ -33,14 +33,18 @@ class TestTileBeamManager(base.TestCase):
sut = TileBeamManager(device_mock) sut = TileBeamManager(device_mock)
sut.nr_tiles = DEFAULT_N_HBA_TILES sut.nr_tiles = DEFAULT_N_HBA_TILES
sut.HBAT_delay_calculators = [] sut.delay_calculator = MagicMock()
sut.delay_calculator.delays.side_effect = [
for i in range(DEFAULT_N_HBA_TILES): numpy.full(N_elements, i, dtype=numpy.float64)
mock = MagicMock() for i in range(DEFAULT_N_HBA_TILES)
mock.delays.return_value = numpy.full(N_elements, i, dtype=numpy.float64) ]
sut.HBAT_delay_calculators.append(mock)
sut.HBAT_antenna_positions = range(DEFAULT_N_HBA_TILES, 0, -1) sut.HBAT_reference_positions = numpy.array(
[[-x] for x in range(DEFAULT_N_HBA_TILES, 0, -1)]
)
sut.HBAT_antenna_positions = numpy.array(
[[x] for x in range(DEFAULT_N_HBA_TILES, 0, -1)]
)
pointings = [i**2 for i in range(DEFAULT_N_HBA_TILES)] pointings = [i**2 for i in range(DEFAULT_N_HBA_TILES)]
...@@ -55,9 +59,12 @@ class TestTileBeamManager(base.TestCase): ...@@ -55,9 +59,12 @@ class TestTileBeamManager(base.TestCase):
] ]
), ),
) )
for i, mock in enumerate(sut.HBAT_delay_calculators):
mock.set_measure_time.assert_called_with(_dt) for i in range(DEFAULT_N_HBA_TILES):
mock.delays.assert_called_with(pointings[i], DEFAULT_N_HBA_TILES - i) sut.delay_calculator.set_measure_time.assert_has_calls([call(_dt)])
sut.delay_calculator.delays.assert_has_calls(
[call(pointings[i], (DEFAULT_N_HBA_TILES - i) * 2)]
)
@patch.object(_tilebeam, "create_device_proxy") @patch.object(_tilebeam, "create_device_proxy")
@patch.object(device_decorators, "get_current_device") @patch.object(device_decorators, "get_current_device")
......
...@@ -12,8 +12,7 @@ import numpy ...@@ -12,8 +12,7 @@ import numpy
import numpy.testing import numpy.testing
import threading import threading
from tangostationcontrol.beam import delays from tangostationcontrol.beam.delays import Delays, is_valid_pointing
from tangostationcontrol.beam.delays import Delays
from tangostationcontrol.common.constants import MAX_ANTENNA, N_beamlets_ctrl from tangostationcontrol.common.constants import MAX_ANTENNA, N_beamlets_ctrl
from test import base from test import base
...@@ -30,7 +29,7 @@ class TestDelays(base.TestCase): ...@@ -30,7 +29,7 @@ class TestDelays(base.TestCase):
] # CS002LBA, in ITRF2005 epoch 2012.5 ] # CS002LBA, in ITRF2005 epoch 2012.5
d = Delays(reference_itrf) d = Delays(reference_itrf)
self.assertIsNotNone(d) d.stop()
def test_init_fails(self): def test_init_fails(self):
"""Test do_measure returning false is correctly caught""" """Test do_measure returning false is correctly caught"""
...@@ -41,29 +40,28 @@ class TestDelays(base.TestCase): ...@@ -41,29 +40,28 @@ class TestDelays(base.TestCase):
self.assertRaises(ValueError, Delays, [0, 0, 0]) self.assertRaises(ValueError, Delays, [0, 0, 0])
def test_is_valid_pointing(self): def test_is_valid_pointing(self):
d = Delays([0, 0, 0])
# should accept base use cases # should accept base use cases
self.assertTrue(d.is_valid_pointing(("J2000", "0rad", "0rad"))) self.assertTrue(is_valid_pointing(("J2000", "0rad", "0rad")))
self.assertTrue(d.is_valid_pointing(("J2000", "4.712389rad", "1.570796rad"))) self.assertTrue(is_valid_pointing(("J2000", "4.712389rad", "1.570796rad")))
self.assertTrue(d.is_valid_pointing(("AZELGEO", "0rad", "0rad"))) self.assertTrue(is_valid_pointing(("AZELGEO", "0rad", "0rad")))
self.assertTrue(d.is_valid_pointing(("AZELGEO", "4.712389rad", "1.570796rad"))) self.assertTrue(is_valid_pointing(("AZELGEO", "4.712389rad", "1.570796rad")))
self.assertTrue(d.is_valid_pointing(("SUN", "0rad", "0rad"))) self.assertTrue(is_valid_pointing(("SUN", "0rad", "0rad")))
self.assertTrue(d.is_valid_pointing(("None", "", ""))) self.assertTrue(is_valid_pointing(("None", "", "")))
# should not throw, and return False, on bad uses # should not throw, and return False, on bad uses
self.assertFalse(d.is_valid_pointing([])) self.assertFalse(is_valid_pointing(()))
self.assertFalse(d.is_valid_pointing(("", "", ""))) self.assertFalse(is_valid_pointing(("", "", "")))
self.assertFalse(d.is_valid_pointing(("J2000", "0rad", "0rad", "0rad", "0rad"))) self.assertFalse(is_valid_pointing(("J2000", "0rad", "0rad", "0rad", "0rad")))
self.assertFalse(d.is_valid_pointing((1, 2, 3))) self.assertFalse(is_valid_pointing((1, 2, 3)))
self.assertFalse(d.is_valid_pointing("foo")) self.assertFalse(is_valid_pointing("foo"))
self.assertFalse(d.is_valid_pointing(None)) self.assertFalse(is_valid_pointing(None))
def test_sun(self): def test_sun(self):
# # create a frame tied to the reference position # # create a frame tied to the reference position
reference_itrf = [3826577.066, 461022.948, 5064892.786] reference_itrf = [3826577.066, 461022.948, 5064892.786]
d = Delays(reference_itrf) d = Delays(reference_itrf)
try:
for i in range(24): for i in range(24):
# set the time to the day of the winter solstice 2021 (21 december 16:58) as this is the time with the least change in sunlight # set the time to the day of the winter solstice 2021 (21 december 16:58) as this is the time with the least change in sunlight
timestamp = datetime.datetime(2021, 12, 21, i, 58, 0) timestamp = datetime.datetime(2021, 12, 21, i, 58, 0)
...@@ -88,6 +86,8 @@ class TestDelays(base.TestCase): ...@@ -88,6 +86,8 @@ class TestDelays(base.TestCase):
z_direction = direction[2] z_direction = direction[2]
self.assertAlmostEqual(z_at_solstice, z_direction, 4) self.assertAlmostEqual(z_at_solstice, z_direction, 4)
finally:
d.stop()
def test_identical_location(self): def test_identical_location(self):
# # create a frame tied to the reference position # # create a frame tied to the reference position
...@@ -98,10 +98,9 @@ class TestDelays(base.TestCase): ...@@ -98,10 +98,9 @@ class TestDelays(base.TestCase):
] # CS002LBA, in ITRF2005 epoch 2012.5 ] # CS002LBA, in ITRF2005 epoch 2012.5
d = Delays(reference_itrf) d = Delays(reference_itrf)
try:
# set the antenna position identical to the reference position # set the antenna position identical to the reference position
antenna_itrf = [ relative_antenna_itrf = [[0.0, 0.0, 0.0]]
[reference_itrf[0], reference_itrf[1], reference_itrf[2]]
] # CS001LBA, in ITRF2005 epoch 2012.5
# # set the timestamp to solve for # # set the timestamp to solve for
timestamp = datetime.datetime(2000, 1, 1, 0, 0, 0) timestamp = datetime.datetime(2000, 1, 1, 0, 0, 0)
...@@ -113,9 +112,11 @@ class TestDelays(base.TestCase): ...@@ -113,9 +112,11 @@ class TestDelays(base.TestCase):
direction = "J2000", "0rad", "0rad" direction = "J2000", "0rad", "0rad"
# calculate the delays based on the set reference position, the set time and now the set direction and antenna positions. # calculate the delays based on the set reference position, the set time and now the set direction and antenna positions.
delays = d.delays(direction, antenna_itrf) delays = d.delays(direction, relative_antenna_itrf)
self.assertListEqual(delays.tolist(), [0.0], msg=f"delays = {delays}") self.assertListEqual(delays.tolist(), [0.0], msg=f"delays = {delays}")
finally:
d.stop()
def test_regression(self): def test_regression(self):
reference_itrf = [ reference_itrf = [
...@@ -125,10 +126,13 @@ class TestDelays(base.TestCase): ...@@ -125,10 +126,13 @@ class TestDelays(base.TestCase):
] # CS002LBA, in ITRF2005 epoch 2012.5 ] # CS002LBA, in ITRF2005 epoch 2012.5
d = Delays(reference_itrf) d = Delays(reference_itrf)
try:
# set the antenna position identical to the reference position # set the antenna position identical to the reference position
antenna_itrf = [ antenna_itrf = numpy.array(
[3826923.503, 460915.488, 5064643.517] [[3826923.503, 460915.488, 5064643.517]]
] # CS001LBA, in ITRF2005 epoch 2012.5 ) # CS001LBA, in ITRF2005 epoch 2012.5
relative_antenna_itrf = antenna_itrf - reference_itrf
# # set the timestamp to solve for # # set the timestamp to solve for
timestamp = datetime.datetime(2000, 1, 1, 0, 0, 0) timestamp = datetime.datetime(2000, 1, 1, 0, 0, 0)
...@@ -138,10 +142,12 @@ class TestDelays(base.TestCase): ...@@ -138,10 +142,12 @@ class TestDelays(base.TestCase):
direction = "J2000", "0rad", "1.570796rad" direction = "J2000", "0rad", "1.570796rad"
# calculate the delays based on the set reference position, the set time and now the set direction and antenna positions. # calculate the delays based on the set reference position, the set time and now the set direction and antenna positions.
delays = d.delays(direction, antenna_itrf) delays = d.delays(direction, relative_antenna_itrf)
# check for regression # check for regression
self.assertAlmostEqual(-8.31467564781444e-07, delays[0], delta=1.0e-10) self.assertAlmostEqual(-8.31467564781444e-07, delays[0], delta=1.0e-10)
finally:
d.stop()
def test_light_second_delay(self): def test_light_second_delay(self):
""" """
...@@ -151,11 +157,12 @@ class TestDelays(base.TestCase): ...@@ -151,11 +157,12 @@ class TestDelays(base.TestCase):
reference_itrf = [0, 0, 0] reference_itrf = [0, 0, 0]
d = Delays(reference_itrf) d = Delays(reference_itrf)
try:
# set the antenna position 0.1 lightsecond in the Z direction of the ITRF, # set the antenna position 0.1 lightsecond in the Z direction of the ITRF,
# which is aligned with the North Pole, see # which is aligned with the North Pole, see
# https://en.wikipedia.org/wiki/Earth-centered,_Earth-fixed_coordinate_system#Structure # https://en.wikipedia.org/wiki/Earth-centered,_Earth-fixed_coordinate_system#Structure
speed_of_light = 299792458.0 speed_of_light = 299792458.0
antenna_itrf = [[0, 0, 0.1 * speed_of_light]] relative_antenna_itrf = [[0, 0, 0.1 * speed_of_light]]
# We need to point along the same direction in order to have the delay reflect the distance. # We need to point along the same direction in order to have the delay reflect the distance.
# #
...@@ -168,18 +175,27 @@ class TestDelays(base.TestCase): ...@@ -168,18 +175,27 @@ class TestDelays(base.TestCase):
direction = "J2000", "0rad", "1.570796rad" direction = "J2000", "0rad", "1.570796rad"
# calculate the delays based on the set reference position, the set time and now the set direction and antenna positions. # calculate the delays based on the set reference position, the set time and now the set direction and antenna positions.
delays = d.delays(direction, antenna_itrf) delays = d.delays(direction, relative_antenna_itrf)
self.assertAlmostEqual(0.1, delays[0], 6, f"delays[0] = {delays[0]}") self.assertAlmostEqual(0.1, delays[0], 6, f"delays[0] = {delays[0]}")
finally:
d.stop()
class TestDelaysBulk(base.TestCase): class TestDelaysBulk(base.TestCase):
def setUp(self): @staticmethod
self.d = Delays([0, 0, 0]) def makeDelays():
d = Delays([0, 0, 0])
timestamp = datetime.datetime( timestamp = datetime.datetime(
2022, 3, 1, 0, 0, 0 2022, 3, 1, 0, 0, 0
) # timestamp does not actually matter, but casacore doesn't know that. ) # timestamp does not actually matter, but casacore doesn't know that.
self.d.set_measure_time(timestamp) d.set_measure_time(timestamp)
return d
def setUp(self):
self.d = self.makeDelays()
self.addCleanup(self.d.stop)
# generate different positions and directions # generate different positions and directions
self.positions = numpy.array([[i, 2, 3] for i in range(MAX_ANTENNA)]) self.positions = numpy.array([[i, 2, 3] for i in range(MAX_ANTENNA)])
...@@ -223,12 +239,19 @@ class TestDelaysBulk(base.TestCase): ...@@ -223,12 +239,19 @@ class TestDelaysBulk(base.TestCase):
duration_results_ms: list[float] = [] duration_results_ms: list[float] = []
def run_delays_bulk(): def run_delays_bulk():
# make sure each thread runs its own instance to allow parallellism
# between instances
d = self.makeDelays()
try:
before = time.monotonic_ns() before = time.monotonic_ns()
for _ in range(count): for _ in range(count):
_ = self.d.delays_bulk(self.directions, self.positions) _ = d.delays_bulk(self.directions, self.positions)
after = time.monotonic_ns() after = time.monotonic_ns()
duration_results_ms.append((after - before) / count / 1e6) duration_results_ms.append((after - before) / count / 1e6)
finally:
d.stop()
# measure single-threaded performance first # measure single-threaded performance first
for _ in range(nr_threads): for _ in range(nr_threads):
...@@ -250,25 +273,9 @@ class TestDelaysBulk(base.TestCase): ...@@ -250,25 +273,9 @@ class TestDelaysBulk(base.TestCase):
) )
# as we have a real time system and sufficient cores, we care # as we have a real time system and sufficient cores, we care
# most about the worst case for a thread. We assume the worst # most about the worst case for a thread. We test for full
# case thread is slowed down by all the work of the other threads. # parallellism, plus some tolerance.
self.assertLess( self.assertLess(
max(duration_results_ms) / nr_threads, max(duration_results_ms),
single_thread_execution_time * 1.50, # 50% tolerance single_thread_execution_time * 1.50, # 50% tolerance
) )
# report the performance if we remove the compute lock around casacore, to detect
# on which systems it matters.
with mock.patch.object(delays, "compute_lock") as m_lock:
# compare against multi-threaded performance
duration_results_ms = []
threads = [
threading.Thread(target=run_delays_bulk) for _ in range(nr_threads)
]
[t.start() for t in threads]
[t.join() for t in threads]
logging.error(
f"delays bulk multi-threaded averages WITHOUT LOCK are {[int(d) for d in duration_results_ms]} ms per call to convert 488 directions for 96 antennas {count} times using {nr_threads} threads."
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment