Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
message_handler.py 12.10 KiB
#!/usr/bin/env python3
#
# Copyright (C) 2023
# ASTRON (Netherlands Institute for Radio Astronomy)
# P.O.Box 2, 7990 AA Dwingeloo, The Netherlands
#
# This file is part of the LOFAR software suite.
# The LOFAR software suite is free software: you can redistribute it and/or
# modify it under the terms of the GNU General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# The LOFAR software suite is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with the LOFAR software suite. If not, see <http://www.gnu.org/licenses/>.
#
# $Id$

from lofar.sas.tmss.client.tmss_http_rest_client import TMSSsession
from lofar.sas.tmss.client.tmssbuslistener import TMSSEventMessageHandler
from lofar.common.datetimeutils import round_to_second_precision
from lofar.sas.tmss.services.lobster.observation_pool import ObservationPool
from lofar.sas.tmss.services.lobster.specification_translation import extract_stations

from .config import (
    COBALT_HEADNODE,
    COBALT_PARSET_DIR_TEMPLATE,
    COBALT_PARSET_OVERRIDES,
    COBALT_STARTBGL_SCRIPT_TEMPLATE,
    COBALT_PARSET_FILENAME_PATTERN,
)

from copy import copy
from datetime import datetime, timedelta
from dateutil import parser
from os import system
from tempfile import TemporaryDirectory

import logging


logger = logging.getLogger(__name__)

LOOK_AHEAD_WINDOW = timedelta(minutes=3)
POLL_INTERVAL     = timedelta(seconds=15)


def _system(cmdline: str):
    """Wraps os.system to raise a RuntimeError on failure."""

    exit_status = system(cmdline)
    if exit_status != 0:
        raise RuntimeError(f"Execution failed, got exit status {exit_status} when executing: {cmdline}")


class L2TMSSObservationControlMessageHandler(TMSSEventMessageHandler):

    def __init__(
        self, observation_pool: ObservationPool, tmss_client_credentials_id: str = None
    ):
        super().__init__()
        self.tmss_client = TMSSsession.create_from_dbcreds_for_ldap(tmss_client_credentials_id)
        self._last_poll_timestamp = datetime.min
        self.observation_pool = observation_pool

    def start_handling(self):
        '''upon startup, connect to TMSS and handle all (near)future scheduled observations'''
        self.tmss_client.open()
        self.check_upcoming_scheduled_observations()

    def stop_handling(self):
        '''upon shutdown, close TMSS connection'''
        self.tmss_client.close()

    def onSubTaskStatusChanged(self, id: int, status: str):
        '''Handle TMSS subtask status changes'''

        logger.debug("subtask id=%s status changed to %s", id, status)

        if status in ("scheduled", "queued"):
            subtask = self.tmss_client.get_subtask(id)

            if status == "scheduled":
                self.enqueue_scheduled_observation_subtask(subtask)
            elif status == "queued":
                self.start_queued_observation_subtask(subtask)

    def before_receive_message(self):
        """Use this method which is called in the event handler message loop to poll at regular intervals.
        In principle this should have no effect, as we are event driven and will automatically enqueue observations when
        the 'scheduled' event message is handled.
        On the other hand, it doesn't hurt either to poll. Low cost. High reliability.
        """
        if datetime.utcnow() - self._last_poll_timestamp >= POLL_INTERVAL:
            self._last_poll_timestamp = datetime.utcnow()
            self.check_upcoming_scheduled_observations()

    def check_upcoming_scheduled_observations(self):
        '''Fetch all scheduled observations in the upcoming lookahead window, and try to enqueue them.'''

        logger.info("checking for upcoming scheduled observations in TMSS...")
        now = round_to_second_precision(datetime.utcnow())
        scheduled_observation_subtasks = self.tmss_client.get_subtasks(
            state="scheduled",
            subtask_type='observation',
            is_using_lofar2_stations=True,
            scheduled_start_time_greater_then=now,
            scheduled_start_time_less_then=now + LOOK_AHEAD_WINDOW
        )
        for subtask in scheduled_observation_subtasks:
            self.enqueue_scheduled_observation_subtask(subtask)

    def _enqueue_observation_on_stations(self, subtask: dict):
        subtask_id = subtask['id']
        l2stationspecs = self.tmss_client.get_subtask_l2stationspecs(subtask_id,
                                                                     retry_count=5)

        observation_id = l2stationspecs['stations'][0]['antenna_fields'][0]['observation_id']

        # To support, different obs_id per subtask, split specifications per station +
        # antenna field into sets for a given obs id prior to calling
        # create_multistationobservation
        for station in l2stationspecs['stations']:
            for antenna_field in station['antenna_fields']:
                if antenna_field['observation_id'] != observation_id:
                    raise RuntimeError("LOBSTER does not support specifications with multiple different observations "
                                       "ids!")

        stations = extract_stations(l2stationspecs)
        logger.info("subtask id=%s has observation=%s for stations=%s with specs=%s", subtask_id, observation_id,
                    stations, l2stationspecs)
        station_obs = self.observation_pool.create_multistationobservation(subtask_id=subtask_id,
                                                                           observation_id=observation_id,
                                                                           stations=stations,
                                                                           specifications=l2stationspecs)

        if not station_obs.all_connected:
            logger.warning("failed to connect to some of stations=%s during preparation!", stations)
            for station, connected in station_obs.is_connected.items():
                if not connected:
                    logger.warning("failed to connect to: ", station)

    def _enqueue_observation_on_COBALT(self, subtask: dict):
        subtask_id = subtask['id']
        cobalt_release = subtask['specifications_doc'].get('COBALT', {}).get('release', 'current')
        COBALT_PARSET_DIR = COBALT_PARSET_DIR_TEMPLATE % {'cobalt_release': cobalt_release}
        COBALT_STARTBGL_SCRIPT = COBALT_STARTBGL_SCRIPT_TEMPLATE % {'cobalt_release': cobalt_release}

        # COBALT expects the parset to be on the disk of its head node, and
        # then for a script to kickstart the observation process. COBALT
        # will then spawn its own processes to prepare and start the observation.

        with TemporaryDirectory(prefix=f"tmp-cobalt-{subtask_id}-") as tmpdir:
            # write parset to disk so we can scp it
            parset_filename = COBALT_PARSET_FILENAME_PATTERN.format(subtask_id=subtask_id)

            with open(f"{tmpdir}/{parset_filename}", "w") as parset_file:
                parset = self.tmss_client.get_subtask_parset(subtask_id, retry_count=5)
                parset += COBALT_PARSET_OVERRIDES
                parset_file.write(parset)

            # copy it to COBALT
            # TODO(Corne): https://support.astron.nl/jira/browse/TMSS-2860
            _system(f"scp -v -o 'StrictHostKeyChecking=no' {tmpdir}/{parset_filename} {COBALT_HEADNODE}:{COBALT_PARSET_DIR}/{parset_filename}")

            # kickstart the observation on COBALT to start and stop at the
            # times as provided in the parset.
            # first 3 parameters are historical and ignored
            # NB: This command returns "immediately", that is, COBALT will start the actual observation in the
            #     background.
            # TODO(Corne): https://support.astron.nl/jira/browse/TMSS-2860
            _system(f"ssh -v -o 'StrictHostKeyChecking=no' {COBALT_HEADNODE} '{COBALT_STARTBGL_SCRIPT} 1 2 3 {COBALT_PARSET_DIR}/{parset_filename} "
                    f"{subtask_id}'")

    def enqueue_scheduled_observation_subtask(self, subtask: dict):
        subtask_id = subtask['id']

        if subtask['subtask_type'] != 'observation' or subtask['state_value'] != 'scheduled':
            logger.debug("skipping %s %s subtask id=%s", subtask['state_value'], subtask['subtask_type'], subtask_id)
            return

        if not subtask['is_using_lofar2_stations']:
            logger.info("skipping %s %s subtask id=%s because it is not using lofar2 stations", subtask['state_value'],
                        subtask['subtask_type'], subtask_id)
            return

        scheduled_start_time = parser.parse(subtask['scheduled_start_time'], ignoretz=True)
        time_to_start = scheduled_start_time - datetime.utcnow()

        if time_to_start > LOOK_AHEAD_WINDOW:
            logger.info("skipping scheduled %s subtask id=%s because the scheduled_start_time='%s' is too far ahead %s",
                         subtask['subtask_type'], subtask_id, scheduled_start_time, time_to_start)
            return

        try:
            logger.info("queueing scheduled %s subtask id=%s which starts in %d[sec] at scheduled_start_time='%s'",
                        subtask['subtask_type'], subtask_id, time_to_start.total_seconds(), scheduled_start_time)

            self.tmss_client.set_subtask_status(subtask_id, "queueing", retry_count=5)

            self._enqueue_observation_on_stations(subtask)
            self._enqueue_observation_on_COBALT(subtask)

            self.tmss_client.set_subtask_status(subtask_id, "queued", retry_count=5)
        except Exception as e:
            logger.exception(e)
            self.tmss_client.set_subtask_status(subtask_id, "error", error_reason=str(e), retry_count=5)

    def start_queued_observation_subtask(self, subtask: dict):
        '''start the queued observation subtask by sending the station specs to the station(s)'''
        subtask_id = subtask['id']

        if subtask['subtask_type'] != 'observation' or subtask['state_value'] != 'queued':
            logger.debug("skipping %s %s subtask id=%s", subtask['state_value'], subtask['subtask_type'], subtask_id)
            return

        if not subtask['is_using_lofar2_stations']:
            logger.debug("skipping %s %s subtask id=%s because it is not using lofar2 stations", subtask['state_value'],
                         subtask['subtask_type'], subtask_id)
            return

        try:
            self.tmss_client.set_subtask_status(subtask_id, "starting", retry_count=5)

            # For each observation get the MultiStationObservation and start,
            # ignore errors until done for each observation
            host_errors = {}
            observe_errors = set()
            for observation_id, multi_obs in copy(self.observation_pool.get_observations(subtask_id)).items():
                results = multi_obs.start()

                # Gather potential errors
                for host, result in results.items():
                    if isinstance(result, Exception):
                        host_errors[host] = result
                        observe_errors.add(observation_id)

                # TODO(Corne): Schedule task to clean up observation when done

            # The start time is part of the specification, and the station will honour it.
            # But we won't be informed of that so we just go to STARTED now
            self.tmss_client.set_subtask_status(subtask_id, "started", retry_count=5)

            # If any exception raise
            if observe_errors:
                hosts = host_errors.keys()
                exceptions = host_errors.values()
                raise RuntimeError(
                    "Failed to start observations=%s on hosts=%s with"
                    "exceptions=%s", observe_errors, hosts, exceptions
                    )

            # NB: COBALT starts itself just before the start time
        except Exception as e:
            logger.exception(e)
            self.tmss_client.set_subtask_status(subtask_id, "error", error_reason=str(e), retry_count=5)