Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
PipelineControl.py 21.13 KiB
#!/usr/bin/env python
#
# Copyright (C) 2016
# ASTRON (Netherlands Institute for Radio Astronomy)
# P.O.Box 2, 7990 AA Dwingeloo, The Netherlands
#
# This file is part of the LOFAR software suite.
# The LOFAR software suite is free software: you can redistribute it and/or
# modify it under the terms of the GNU General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# The LOFAR software suite is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with the LOFAR software suite. If not, see <http://www.gnu.org/licenses/>.
#
# $Id$
"""
Daemon that starts/stops pipelines based on their status in OTDB.

The execution chains are as follows:

-----------------------------
  Starting a pipeline
-----------------------------

[SCHEDULED]          -> PipelineControl schedules

                           runPipeline.sh <obsid>

                        and

                           setOTDBTreeStatus -o <obsid> -s aborted

                        using two SLURM jobs, guaranteeing that pipelineAborted.sh is
                        called in the following circumstances:

                          - runPipeline.sh wrapper cannot finish (bash bugs, etc)
                          - runPipeline.sh job is cancelled in the queue

                        State is set to [QUEUED].

(runPipeline.sh)     -> Calls
                          - state <- [ACTIVE]
                          - getParset
                          - (run pipeline)
                          - success:
                              - state <- [COMPLETING]
                              - (wrap up)
                              - state <- [FINISHED]
                          - failure:
                              - state <- [ABORTED]

(setOTDBTreeStatus)  -> Calls
                          - state <- [ABORTED]

-----------------------------
  Stopping a pipeline
-----------------------------

[ABORTED]            -> Cancels SLURM job associated with pipeline, causing
                        a cascade of job terminations of successor pipelines.
"""

from lofar.messaging import FromBus, ToBus, RPC, EventMessage
from lofar.parameterset import PyParameterValue
from lofar.sas.otdb.OTDBBusListener import OTDBBusListener
from lofar.sas.otdb.config import DEFAULT_OTDB_NOTIFICATION_BUSNAME, DEFAULT_OTDB_SERVICE_BUSNAME
from lofar.sas.otdb.otdbrpc import OTDBRPC
from lofar.common.util import waitForInterrupt
from lofar.messaging.RPC import RPCTimeoutException, RPCException
from lofar.sas.resourceassignment.resourceassignmentservice.rpc import RARPC
from lofar.sas.resourceassignment.resourceassignmentservice.config import DEFAULT_BUSNAME as DEFAULT_RAS_SERVICE_BUSNAME

import subprocess
import datetime
import pipes
import os
import re
from socket import getfqdn

import logging
logger = logging.getLogger(__name__)

def runCommand(cmdline, input=None):
  logger.info("runCommand starting: %s", cmdline)

  # Start command
  proc = subprocess.Popen(
    cmdline,
    stdin=subprocess.PIPE if input else file("/dev/null"),
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
    shell=True,
    universal_newlines=True
    )

  # Feed input and wait for termination
  logger.debug("runCommand input: %s", input)
  stdout, _ = proc.communicate(input)
  logger.debug("runCommand output: %s", stdout)

  # Check exit status, bail on error
  if proc.returncode != 0:
    logger.warn("runCommand(%s) had exit status %s with output: %s", cmdline, proc.returncode, stdout)
    raise subprocess.CalledProcessError(proc.returncode, cmdline)

  # Return output
  return stdout.strip()

""" Prefix that is common to all parset keys, depending on the exact source. """
PARSET_PREFIX="ObsSW."

class Parset(dict):
  def predecessors(self):
    """ Extract the list of predecessor obs IDs from the given parset. """

    key = PARSET_PREFIX + "Observation.Scheduler.predecessors"
    strlist = PyParameterValue(str(self[key]), True).getStringVector()

    # Key contains "Lxxxxx" values, we want to have "xxxxx" only
    result = [int(filter(str.isdigit,x)) for x in strlist]

    return result

  def isObservation(self):
    return self[PARSET_PREFIX + "Observation.processType"] == "Observation"

  def isPipeline(self):
    return not self.isObservation()

  def processingCluster(self):
    return self[PARSET_PREFIX + "Observation.Cluster.ProcessingCluster.clusterName"] or "CEP2"

  def processingPartition(self):
    result = self[PARSET_PREFIX + "Observation.Cluster.ProcessingCluster.clusterPartition"] or "cpu"
    if '/' in result:
        logger.error('clusterPartition contains invalid value: %s. Defaulting clusterPartition to \'cpu\'', result)
        return 'cpu'
    return result

  def processingNumberOfCoresPerTask(self):
    result = int(self[PARSET_PREFIX + "Observation.Cluster.ProcessingCluster.numberOfCoresPerTask"]) or "21"
    if result != 2:
        logger.warn('Invalid Observation.Cluster.ProcessingCluster.numberOfCoresPerTask: %s, defaulting to %s', result, 21)
    return 2

  def processingNumberOfTasks(self):
    """ Parse the number of nodes to allocate from "Observation.Cluster.ProcessingCluster.numberOfTasks",
        which can have either the format "{number}" or "{min}-{max}". """

    defaultValue = 244
    parsetValue = self[PARSET_PREFIX + "Observation.Cluster.ProcessingCluster.numberOfTasks"].strip()

    result = int(parsetValue) * 10 # ScS expects to schedule 10 jobs/node, and still specifies #nodes

    # apply bound
    if result <= 0 or result > 50*24:
      result = defaultValue

    if result != parsetValue:
      logger.warn('Invalid Observation.Cluster.ProcessingCluster.numberOfTasks: %s, defaulting to %s', parsetValue, result)

    return result

  @staticmethod
  def dockerRepository():
    return "nexus.cep4.control.lofar:18080"

  @staticmethod
  def defaultDockerImage():
    return "lofar-pipeline:latest"

  @staticmethod
  def defaultDockerTag():
    return "latest"

  def dockerImage(self):
    # Return the version set in the parset, and fall back to our own version.
    image = self[PARSET_PREFIX + "Observation.ObservationControl.PythonControl.softwareVersion"]

    if not image:
       return self.defaultDockerImage()

    if ":" in image:
       return image

    # Insert our tag by default
    return "%s:%s" % (image, self.defaultDockerTag())

  def otdbId(self):
    return int(self[PARSET_PREFIX + "Observation.otdbID"])

  def description(self):
    return "%s - %s" % (self.get(PARSET_PREFIX + "Observation.Campaign.name", 'unknown'),self.get(PARSET_PREFIX + "Observation.Scheduler.taskName", 'unknown'))
    

class Slurm(object):
  def __init__(self, headnode="head01.cep4.control.lofar"):
    self.headnode = headnode

    # TODO: Derive SLURM partition name
    self.partition = "cpu"

  def _runCommand(self, cmdline, input=None):
    cmdline = "ssh %s %s" % (self.headnode, cmdline)
    return runCommand(cmdline, input)

  def submit(self, jobName, cmdline, sbatch_params=None):
    if sbatch_params is None:
      sbatch_params = []

    script = """#!/bin/bash -v
{cmdline}
""".format(cmdline = cmdline)

    stdout = self._runCommand("sbatch --partition=%s --job-name=%s %s" % (self.partition, jobName, " ".join(sbatch_params)), script)

    # Returns "Submitted batch job 3" -- extract ID
    match = re.search("Submitted batch job (\d+)", stdout)
    if not match:
      return None

    return match.group(1)

  def cancel(self, jobName):
    self._runCommand("scancel --jobname %s" % (jobName,))

  def isQueuedOrRunning(self, jobName):
    stdout = self._runCommand("sacct --starttime=2016-01-01 --noheader --parsable2 --format=jobid --name=%s --state=PENDING,CONFIGURING,RUNNING,RESIZING,COMPLETING,SUSPENDED" % (jobName,))

    return stdout != ""

class PipelineDependencies(object):
  class TaskNotFoundException(Exception):
    """ Raised when a task cannot be found in the RADB. """
    pass

  def __init__(self, ra_service_busname=DEFAULT_RAS_SERVICE_BUSNAME, otdb_service_busname=DEFAULT_OTDB_SERVICE_BUSNAME):
    self.rarpc = RARPC(busname=ra_service_busname)
    logger.info('PipelineDependencies otdb_service_busname=%s', otdb_service_busname)
    self.otdbrpc = OTDBRPC(busname=otdb_service_busname)

  def open(self):
    self.rarpc.open()
    self.otdbrpc.open()

  def close(self):
    self.rarpc.close()
    self.otdbrpc.close()

  def __enter__(self):
    self.open()
    return self

  def __exit__(self, type, value, tb):
    self.close()

  def getState(self, otdb_id):
    """
      Return the status of a single `otdb_id'.
    """
    return self.otdbrpc.taskGetStatus(otdb_id=otdb_id)

  def getPredecessorStates(self, otdb_id):
    """
      Return a dict of {"sasid":"status"} pairs of all the predecessors of `otdb_id'.
    """
    radb_task = self.rarpc.getTask(otdb_id=otdb_id)

    if radb_task is None:
      raise PipelineDependencies.TaskNotFoundException("otdb_id %s not found in RADB" % (otdb_id,))

    predecessor_radb_ids = radb_task['predecessor_ids']
    predecessor_tasks = self.rarpc.getTasks(task_ids=predecessor_radb_ids)

    #get states from otdb in order to prevent race conditions between states in radb/otdb
    predecessor_otdb_ids = [t["otdb_id"] for t in predecessor_tasks]
    predecessor_states = { otdb_id:self.getState(otdb_id) for otdb_id in predecessor_otdb_ids }

    logger.debug("getPredecessorStates(%s) = %s", otdb_id, predecessor_states)

    return predecessor_states

  def getSuccessorIds(self, otdb_id):
    """
      Return a list of all the successors of `otdb_id'.
    """
    radb_task = self.rarpc.getTask(otdb_id=otdb_id)

    if radb_task is None:
      raise PipelineDependencies.TaskNotFoundException("otdb_id %s not found in RADB" % (otdb_id,))

    successor_radb_ids = radb_task['successor_ids']
    successor_tasks = self.rarpc.getTasks(task_ids=successor_radb_ids) if successor_radb_ids else []
    successor_otdb_ids = [t["otdb_id"] for t in successor_tasks]

    logger.debug("getSuccessorIds(%s) = %s", otdb_id, successor_otdb_ids)

    return successor_otdb_ids

  def canStart(self, otdbId):
    """
      Return whether `otdbId' can start, according to the status of the predecessors
      and its own status.
    """

    try:
      myState = self.getState(otdbId)
      predecessorStates = self.getPredecessorStates(otdbId)
    except PipelineDependencies.TaskNotFoundException, e:
      logger.error("canStart(%s): Error obtaining task states, not starting pipeline: %s", otdbId, e)
      return False

    startable = (myState == "scheduled" and all([x == "finished" for x in predecessorStates.values()]))
    logger.info("canStart(%s)? state = %s, predecessors = %s, canStart = %s", otdbId, myState, predecessorStates, startable)
    return startable

class PipelineControl(OTDBBusListener):
  def __init__(self, otdb_notification_busname=DEFAULT_OTDB_NOTIFICATION_BUSNAME, otdb_service_busname=DEFAULT_OTDB_SERVICE_BUSNAME, ra_service_busname=DEFAULT_RAS_SERVICE_BUSNAME, **kwargs):
    super(PipelineControl, self).__init__(busname=otdb_notification_busname, **kwargs)

    logger.info('PipelineControl otdb_service_busname=%s', otdb_service_busname)
    self.otdb_service_busname = otdb_service_busname
    self.otdbrpc = OTDBRPC(busname=otdb_service_busname)
    self.dependencies = PipelineDependencies(ra_service_busname=ra_service_busname, otdb_service_busname=otdb_service_busname)
    self.slurm = Slurm()

  def _setStatus(self, otdb_id, status):
    self.otdbrpc.taskSetStatus(otdb_id=otdb_id, new_status=status)

  def _getParset(self, otdbId):
    try:
      return Parset(self.otdbrpc.taskGetSpecification(otdb_id=otdbId)["specification"])
    except RPCException, e:
      # Parset not in OTDB, probably got deleted
      logger.error("Cannot retrieve parset of task %s: %s", otdbId, e)
      return None

  def start_listening(self, **kwargs):
    self.otdbrpc.open()
    self.dependencies.open()

    super(PipelineControl, self).start_listening(**kwargs)

    self._checkScheduledPipelines()

  def stop_listening(self, **kwargs):
    super(PipelineControl, self).stop_listening(**kwargs)

    self.dependencies.close()
    self.otdbrpc.close()

  def _checkScheduledPipelines(self):
    try:
      scheduled_pipelines = self.dependencies.rarpc.getTasks(task_status='scheduled', task_type='pipeline')
      logger.info("Checking %s scheduled pipelines if they can start.", len(scheduled_pipelines))

      for pipeline in scheduled_pipelines:
        logger.info("Checking if scheduled pipeline otdbId=%s can start.", pipeline['otdb_id'])
        try:
            otdbId = pipeline['otdb_id']
            parset = self._getParset(otdbId)
            if not parset or not self._shouldHandle(parset):
              continue

            # Maybe the pipeline can start already
            if self.dependencies.canStart(otdbId):
              self._startPipeline(otdbId, parset)
            else:
              logger.info("Job %s was set to scheduled, but cannot start yet.", otdbId)
        except Exception as e:
          logger.error(e)
    except Exception as e:
      logger.error(e)

  @staticmethod
  def _shouldHandle(parset):
    try:
      if not parset.isPipeline():
        logger.info("Not processing tree: is not a pipeline")
        return False

      if parset.processingCluster() == "CEP2":
        logger.info("Not processing tree: is a CEP2 pipeline")
        return False
    except KeyError as e:
      # Parset not complete
      logger.error("Parset incomplete, ignoring: %s", e)
      return False

    return True

  @staticmethod
  def _jobName(otdbId):
    return str(otdbId)

  def _startPipeline(self, otdbId, parset):
    """
      Schedule "docker-runPipeline.sh", which will fetch the parset and run the pipeline within
      a SLURM job.
    """

    # Avoid race conditions by checking whether we haven't already sent the job
    # to SLURM. Our QUEUED status update may still be being processed.
    if self.slurm.isQueuedOrRunning(otdbId):
      logger.info("Pipeline %s is already queued or running in SLURM.", otdbId)
      return

    logger.info("***** START Otdb ID %s *****", otdbId)

    # Determine SLURM parameters
    sbatch_params = [
                     # Only run job if all nodes are ready
                     "--wait-all-nodes=1",

                     # Enforce the dependencies, instead of creating lingering jobs
                     "--kill-on-invalid-dep=yes",

                     # Maximum run time for job (7 days)
                     "--time=7-0",

                     # Annotate the job
                     "--comment=%s" % pipes.quote(pipes.quote(parset.description())),

                     # Lower priority to drop below inspection plots
                     "--nice=1000",

                     "--partition=%s" % parset.processingPartition(),
                     "--ntasks=%s" % parset.processingNumberOfTasks(),
                     "--cpus-per-task=%s" % parset.processingNumberOfCoresPerTask(),

                     # Define better places to write the output
                     os.path.expandvars("--output=/data/log/pipeline-%s-%%j.log" % (otdbId,)),
                     ]

    def setStatus_cmdline(status):
      return (
      "ssh {myhostname} '"
        "source {lofarroot}/lofarinit.sh && "
        "setOTDBTreeStatus -o {obsid} -s {status} -B {status_bus}"
        "'"
      .format(
        myhostname = getfqdn(),
        lofarroot = os.environ.get("LOFARROOT", ""),
        obsid = otdbId,
        status = status,
        status_bus = self.otdb_service_busname,
      ))

    try:
        logger.info("Handing over pipeline %s to SLURM", otdbId)

        # Schedule runPipeline.sh
        slurm_job_id = self.slurm.submit(self._jobName(otdbId),
        """
# Run a command, but propagate SIGINT and SIGTERM
function runcmd {{
trap 'kill -s SIGTERM $PID' SIGTERM
trap 'kill -s SIGINT  $PID' SIGINT

"$@" &
PID=$!
wait $PID # returns the exit status of "wait" if interrupted
wait $PID # returns the exit status of $PID
CMDRESULT=$?

trap - SIGTERM SIGINT

return $CMDRESULT
}}

# print some info
echo Running on $SLURM_NODELIST

# notify OTDB that we're running
runcmd {setStatus_active}

# notify ganglia
wget -O - -q "http://ganglia.control.lofar/ganglia/api/events.php?action=add&start_time=now&summary=Pipeline {obsid} ACTIVE&host_regex="

# run the pipeline
runcmd docker-run-slurm.sh --rm --net=host \
    -e LOFARENV={lofarenv} \
    -v $HOME/.ssh:$HOME/.ssh:ro \
    -e SLURM_JOB_ID=$SLURM_JOB_ID \
    -v /data:/data \
    {image} \
runPipeline.sh -o {obsid} -c /opt/lofar/share/pipeline/pipeline.cfg.{cluster} -P {parset_dir}
RESULT=$?

# notify that we're tearing down
runcmd {setStatus_completing}

if [ $RESULT -eq 0 ]; then
    # wait for MoM to pick up feedback before we set finished status
    runcmd sleep 60

    # if we reached this point, the pipeline ran succesfully
    runcmd {setStatus_finished}

    # notify ganglia
    wget -O - -q "http://ganglia.control.lofar/ganglia/api/events.php?action=add&start_time=now&summary=Pipeline {obsid} FINISHED&host_regex="
fi

# report status back to SLURM
echo "Pipeline exited with status $RESULT"
exit $RESULT
    """.format(
            lofarenv = os.environ.get("LOFARENV", ""),
            obsid = otdbId,
            parset_dir = "/data/parsets",
            repository = parset.dockerRepository(),
            image = parset.dockerImage(),
            cluster = parset.processingCluster(),

            setStatus_active = setStatus_cmdline("active"),
            setStatus_completing = setStatus_cmdline("completing"),
            setStatus_finished = setStatus_cmdline("finished"),
        ),

        sbatch_params=sbatch_params
        )
        logger.info("Scheduled SLURM job %s for otdb_id=%s", slurm_job_id, otdbId)

        # Schedule pipelineAborted.sh
        logger.info("Scheduling SLURM job for pipelineAborted.sh")
        slurm_cancel_job_id = self.slurm.submit("%s-abort-trigger" % self._jobName(otdbId),
    """
# notify OTDB
{setStatus_aborted}

# notify ganglia
wget -O - -q "http://ganglia.control.lofar/ganglia/api/events.php?action=add&start_time=now&summary=Pipeline {obsid} ABORTED&host_regex="
    """
        .format(
            setStatus_aborted = setStatus_cmdline("aborted"),
            obsid = otdbId,
        ),

        sbatch_params=[
            "--partition=%s" % parset.processingPartition(),
            "--cpus-per-task=1",
            "--ntasks=1",
            "--dependency=afternotok:%s" % slurm_job_id,
            "--kill-on-invalid-dep=yes",
            "--requeue",
            "--output=/data/log/abort-trigger-%s.log" % (otdbId,),
        ]
        )
        logger.info("Scheduled SLURM job %s for abort trigger for otdb_id=%s", slurm_cancel_job_id, otdbId)

        logger.info("Handed over pipeline %s to SLURM, setting status to QUEUED", otdbId)
        self._setStatus(otdbId, "queued")
    except Exception as e:
        logger.error(str(e))
        self._setStatus(otdbId, "aborted")

  def _stopPipeline(self, otdbId):
    # Cancel corresponding SLURM job, but first the abort-trigger
    # to avoid setting ABORTED as a side effect.
    # to be cancelled as well.

    if not self.slurm.isQueuedOrRunning(otdbId):
      logger.info("_stopPipeline: Job %s not running", otdbId)
      return

    def cancel(jobName):
        logger.info("Cancelling job %s", jobName)
        self.slurm.cancel(jobName)

    jobName = self._jobName(otdbId)
    cancel("%s-abort-trigger" % jobName)
    cancel(jobName)

  def _startSuccessors(self, otdbId):
    try:
      successor_ids = self.dependencies.getSuccessorIds(otdbId)
    except PipelineDependencies.TaskNotFoundException, e:
      logger.error("_startSuccessors(%s): Error obtaining task successors, not starting them: %s", otdbId, e)
      return

    for s in successor_ids:
      parset = self._getParset(s)
      if not parset or not self._shouldHandle(parset):
        continue

      if self.dependencies.canStart(s):
        self._startPipeline(s, parset)
      else:
        logger.info("Job %s still cannot start yet.", otdbId)

  def onObservationScheduled(self, otdbId, modificationTime):
    parset = self._getParset(otdbId)
    if not parset or not self._shouldHandle(parset):
      return

    # Maybe the pipeline can start already
    if self.dependencies.canStart(otdbId):
      self._startPipeline(otdbId, parset)
    else:
      logger.info("Job %s was set to scheduled, but cannot start yet.", otdbId)

  def onObservationFinished(self, otdbId, modificationTime):
    """ Check if any successors can now start. """

    logger.info("Considering to start successors of %s", otdbId)
    self._startSuccessors(otdbId)

  def onObservationAborted(self, otdbId, modificationTime):
    parset = self._getParset(otdbId)
    if parset and not self._shouldHandle(parset): # stop jobs even if there's no parset
      return

    logger.info("***** STOP Otdb ID %s *****", otdbId)
    self._stopPipeline(otdbId)

  """
    More statusses we want to abort on.
  """
  onObservationDescribed    = onObservationAborted
  onObservationApproved     = onObservationAborted
  onObservationPrescheduled = onObservationAborted
  onObservationConflict     = onObservationAborted
  onObservationHold         = onObservationAborted