Skip to content
Snippets Groups Projects
copy_service.py 12.5 KiB
Newer Older
#!/usr/bin/env python3

# Copyright (C) 2012-2015  ASTRON (Netherlands Institute for Radio Astronomy)
# P.O. Box 2, 7990 AA Dwingeloo, The Netherlands
#
# This file is part of the LOFAR software suite.
# The LOFAR software suite is free software: you can redistribute it and/or
# modify it under the terms of the GNU General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# The LOFAR software suite is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with the LOFAR software suite. If not, see <http://www.gnu.org/licenses/>.


import os
from optparse import OptionParser, OptionGroup
import logging
logger = logging.getLogger(__name__)

from lofar.sas.tmss.client.tmssbuslistener import *
from lofar.sas.tmss.client.tmss_http_rest_client import TMSSsession
from subprocess import call
from lofar.common.cep4_utils import *
from lofar.common.subprocess_utils import check_output_returning_strings

class TMSSCopyServiceEventMessageHandler(TMSSEventMessageHandler):
    '''
    '''
    def __init__(self, tmss_client_credentials_id: str="TMSSClient"):
        super().__init__()
        self.tmss_client = TMSSsession.create_from_dbcreds_for_ldap(tmss_client_credentials_id)
        self._last_df_check_timestamp = datetime.min

    def start_handling(self):
        self.tmss_client.open()
        super().start_handling()
        super().stop_handling()
        # use 1-sec event loop to poll queued subtasks (rate limited at 60sec)
        if datetime.utcnow() - self._last_df_check_timestamp > timedelta(seconds=60):
            self._last_df_check_timestamp = datetime.utcnow()
            self.check_and_run_queued_copy_subtask_if_enough_disk_space()

    def onSubTaskStatusChanged(self, id: int, status:str):
        if status in ('scheduled', 'queued'):
            subtask = self.tmss_client.get_subtask(id)
            if subtask['subtask_type'] == 'copy':
                if status == 'scheduled':
                    self.queue_copy_subtask(subtask)
                elif status == 'queued':
                    self.run_copy_subtask_if_enough_disk_space(subtask)
Jorrit Schaap's avatar
Jorrit Schaap committed
            else:
                logger.info("skipping subtask id=%s status=%s type=%s", subtask['id'], subtask['state_value'], subtask['subtask_type'])
    def queue_copy_subtask(self, subtask):
        if subtask['subtask_type'] != 'copy':
            return
        if subtask['state_value'] != 'scheduled':
            return
        self.tmss_client.set_subtask_status(subtask['id'], 'queueing')
        self.tmss_client.set_subtask_status(subtask['id'], 'queued')

    def check_and_run_queued_copy_subtask_if_enough_disk_space(self):
        subtasks = self.tmss_client.get_subtasks(subtask_type='copy', state='queued')
        for subtask in subtasks:
            self.run_copy_subtask_if_enough_disk_space(subtask)

    def run_copy_subtask_if_enough_disk_space(self, subtask):
        if subtask['subtask_type'] != 'copy':
            return
        if subtask['state_value'] != 'queued':
            return

        try:
            # determine destination host and root_dir
            destination = subtask['specifications_doc']['destination']
            dst_host = destination[:destination.find(':')] if ':' in destination else ''
            dst_root_dir = '/'+destination.replace(dst_host + ':', '').split('/')[1]
            # remove unneeded localhost to prevent unneeded ssh calls
            dst_host = dst_host.replace('localhost:','').replace('127.0.0.1:','')

            df_cmd = ['df', dst_root_dir]
            if dst_host:
                df_cmd = wrap_command_in_ssh_call(df_cmd, dst_host)

            logger.info("checking free disk space for copy-subtask id=%s, executing: %s", subtask['id'], ' '.join(df_cmd))

            # run df cmd, and parse output for total free disk space
            df_result = check_output_returning_strings(df_cmd)
            df_result_line = df_result.splitlines()[-1]
            df_result_line_parts = df_result_line.split()
            df_bytes = int(df_result_line_parts[3])
            input_dp_sizes = self.tmss_client.get_url_as_json_object(subtask['url'].rstrip('/')+'/input_dataproducts?fields=size')
            total_size = sum(x['size'] for x in input_dp_sizes)

            if df_bytes > total_size:
                logger.info("enough free disk space available for copy-subtask id=%s destination=%s df=%d needed=%d", subtask['id'], destination, df_bytes, total_size)

                # # clear previously set "not enough free disk space available" error_reason if set
                # if subtask.get('error_reason'):
                #     self.tmss_client.do_request_and_get_result_as_string('PATCH', subtask['url'], {'error_reason': None})

                # run it
                self.run_copy_subtask(subtask)
            else:
                msg = "not enough free disk space available to start copy-subtask id=%s df=%d needed=%d" % (subtask['id'], df_bytes, total_size)
                logger.warning(msg)
                # self.tmss_client.do_request_and_get_result_as_string('PATCH', subtask['url'], {'error_reason': msg})
        except Exception as e:
            logger.exception(str(e))

            # try to run it anyway, maybe it fails on not enough disk space.
            # if it fails while running, then it results in an error status, and the user can take appropriate action.

    def run_copy_subtask(self, subtask):
        if subtask['subtask_type'] != 'copy':
            return
        if subtask['state_value'] != 'queued':
            return

        try:
            self.tmss_client.set_subtask_status(subtask['id'], 'starting')
            self.tmss_client.set_subtask_status(subtask['id'], 'started')

            # cache to reduced rest-calls
            # maps producer_id to tuple of producing subtask id and cluster name
            _cache = {}
            # cache to reduce (ssh) mkdir calls. Only create parent dirs once.
            _created_dir_cache = set()

            # ToDo: maybe parallelize this? Are multiple parallel rsync's faster?
            for input_dataproduct in self.tmss_client.get_subtask_input_dataproducts(subtask['id']):
                # fetch producing subtask id and cluster name for cache if needed
                if input_dataproduct['producer'] not in _cache:
                    producer = self.tmss_client.get_url_as_json_object(input_dataproduct['producer'])
                    filesystem = self.tmss_client.get_url_as_json_object(producer['filesystem'])
                    cluster = self.tmss_client.get_url_as_json_object(filesystem['cluster'])
                    _cache[input_dataproduct['producer']] = {'producing_subtask_id': producer['subtask_id'], 'cluster_name': cluster['name'] }

                if subtask['specifications_doc'].get('managed_output', False):
                    output_dataproduct = self.tmss_client.get_subtask_transformed_output_dataproduct(subtask['id'], input_dataproduct['id'])
                    output_dp_path = output_dataproduct['filepath']
                else:
                    output_dp_path = os.path.join(subtask['specifications_doc']['destination'].rstrip('/'),
                                                  ('L' + str(_cache[input_dataproduct['producer']]['producing_subtask_id'])) if subtask['specifications_doc'].get('group_by_id', True) else '',
                                                  input_dataproduct['filename'])

                # split in host & parent_dir_path
                dst_host = output_dp_path[:output_dp_path.find(':')] if ':' in output_dp_path else ''
                dst_parent_dir_path = output_dp_path.replace(dst_host + ':','').replace(input_dataproduct['filename'],'')
                # replace unneeded localhost to prevent under-the-hood ssh wrapping in rsync, and to prevent unneed ssh calls
                dst_host = dst_host.replace('localhost:','').replace('127.0.0.1:', '')
                dst_parent_dir_path = dst_parent_dir_path.replace('localhost:','').replace('127.0.0.1:','')

                if dst_parent_dir_path not in _created_dir_cache:
                    # create dst_parent_dir_path directories if needed, prepend them to the cmd
                    mkdir_cmd = ['mkdir', '-p', dst_parent_dir_path]
                    if dst_host:
                        mkdir_cmd = wrap_command_in_ssh_call(mkdir_cmd, dst_host)

                    logger.info("creating parent destination dir if needed for copy-subtask id=%s, executing: %s",subtask['id'], ' '.join(mkdir_cmd))

                    if call(mkdir_cmd) == 0:
                        _created_dir_cache.add(dst_parent_dir_path)
                    else:
                        msg = "could not create parent destination dir '%s' for copy-subtask id=%s" % (dst_parent_dir_path,  subtask['id'])
                        logger.error(msg)
                        self.tmss_client.set_subtask_status(subtask['id'], 'error', error_reason=msg)
                        return

                # prepare the actual copy command
                cmd = ['rsync', '-a', input_dataproduct['filepath'].rstrip('/'), dst_parent_dir_path]

                # wrap in cep4 ssh call if cep4
                cluster_name = _cache[input_dataproduct['producer']]['cluster_name']
                if cluster_name.lower() == 'cep4':
                    cmd = wrap_command_in_cep4_available_node_with_lowest_load_ssh_call(cmd, via_head=True)

                logger.info("copying dataproduct id=%s for copy-subtask id=%s, executing: %s", input_dataproduct['id'], subtask['id'],  ' '.join(cmd))
                    logger.info("copied dataproduct id=%s for copy-subtask id=%s to '%s'", input_dataproduct['id'], subtask['id'], dst_parent_dir_path)
                    msg = "could not copy dataproduct id=%s for copy-subtask id=%s to '%s'" % (input_dataproduct['id'], subtask['id'], dst_parent_dir_path)
                    logger.error(msg)
                    self.tmss_client.set_subtask_status(subtask['id'], 'error', error_reason=msg)
                    return

            self.tmss_client.set_subtask_status(subtask['id'], 'finishing')
            self.tmss_client.set_subtask_status(subtask['id'], 'finished')
        except Exception as e:
            logger.error(e)
            self.tmss_client.set_subtask_status(subtask['id'], 'error', error_reason=str(e))


def create_copy_service(exchange: str=DEFAULT_BUSNAME, broker: str=DEFAULT_BROKER, tmss_client_credentials_id: str="TMSSClient"):
    return TMSSBusListener(handler_type=TMSSCopyServiceEventMessageHandler,
                           handler_kwargs={'tmss_client_credentials_id': tmss_client_credentials_id},
                           exchange=exchange, broker=broker)


def main():
    # make sure we run in UTC timezone
    os.environ['TZ'] = 'UTC'
    logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', level=logging.INFO)

    # Check the invocation arguments
    parser = OptionParser('%prog [options]',
                          description='run the tmss_copy_service which runs the copy-pipeline for scheduled copy-subtasks')

    group = OptionGroup(parser, 'Messaging options')
    group.add_option('-b', '--broker', dest='broker', type='string', default=DEFAULT_BROKER, help='Address of the message broker, default: %default')
    group.add_option('-e', "--exchange", dest="exchange", type="string", default=DEFAULT_BUSNAME, help="exchange where the TMSS event messages are published. [default: %default]")
    parser.add_option_group(group)

    group = OptionGroup(parser, 'Django options')
    parser.add_option_group(group)
    group.add_option('-R', '--tmss_client_credentials_id', dest='tmss_client_credentials_id', type='string', default='TMSSClient', help='TMSS django REST API credentials name, default: %default')

    (options, args) = parser.parse_args()

    # check TMSS is up and running via the client
    TMSSsession.check_connection_and_exit_on_error(options.tmss_client_credentials_id)

    from lofar.common.util import waitForInterrupt

    with create_copy_service(options.exchange, options.broker, options.tmss_client_credentials_id):
        waitForInterrupt()

if __name__ == '__main__':
    main()