Skip to content
Snippets Groups Projects
PipelineControl.py 17.07 KiB
#!/usr/bin/env python
#
# Copyright (C) 2016
# ASTRON (Netherlands Institute for Radio Astronomy)
# P.O.Box 2, 7990 AA Dwingeloo, The Netherlands
#
# This file is part of the LOFAR software suite.
# The LOFAR software suite is free software: you can redistribute it and/or
# modify it under the terms of the GNU General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# The LOFAR software suite is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with the LOFAR software suite. If not, see <http://www.gnu.org/licenses/>.
#
# $Id$
"""
Daemon that starts/stops pipelines based on their status in OTDB.

The execution chains are as follows:

-----------------------------
  Starting a pipeline
-----------------------------

[SCHEDULED]          -> PipelineControl schedules

                           runPipeline.sh <obsid> || setOTDBTreeStatus -o <obsid> -s aborted

                        using two SLURM jobs, guaranteeing that pipelineAborted.sh is
                        called in the following circumstances:

                          - runPipeline.sh exits with failure
                          - runPipeline.sh is killed by SLURM
                          - runPipeline.sh job is cancelled in the queue

                        State is set to [QUEUED].

(runPipeline.sh)     -> Calls
                          - state <- [ACTIVE]
                          - getParset
                          - (run pipeline)
                          - state <- [COMPLETING]
                          - (wrap up)
                          - state <- [FINISHED]

(setOTDBTreeStatus)  -> Calls
                          - state <- [ABORTED]

-----------------------------
  Stopping a pipeline
-----------------------------

[ABORTED]            -> Cancels SLURM job associated with pipeline, causing
                        a cascade of job terminations of successor pipelines.
"""

from lofar.messaging import FromBus, ToBus, RPC, EventMessage
from lofar.parameterset import PyParameterValue
from lofar.sas.otdb.OTDBBusListener import OTDBBusListener
from lofar.sas.otdb.config import DEFAULT_OTDB_NOTIFICATION_BUSNAME, DEFAULT_OTDB_SERVICE_BUSNAME
from lofar.sas.otdb.otdbrpc import OTDBRPC
from lofar.common.util import waitForInterrupt
from lofar.messaging.RPC import RPCTimeoutException
from lofar.sas.resourceassignment.resourceassignmentservice.rpc import RARPC
from lofar.sas.resourceassignment.resourceassignmentservice.config import DEFAULT_BUSNAME as DEFAULT_RAS_SERVICE_BUSNAME

import subprocess
import datetime
import os
import re
from socket import getfqdn

import logging
logger = logging.getLogger(__name__)

def runCommand(cmdline, input=None):
  logger.info("runCommand starting: %s", cmdline)

  # Start command
  proc = subprocess.Popen(
    cmdline,
    stdin=subprocess.PIPE if input else file("/dev/null"),
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
    shell=True,
    universal_newlines=True
    )

  # Feed input and wait for termination
  logger.debug("runCommand input: %s", input)
  stdout, _ = proc.communicate(input)
  logger.debug("runCommand output: %s", stdout)

  # Check exit status, bail on error
  if proc.returncode != 0:
    logger.warn("runCommand(%s) had exit status %s with output: %s", cmdline, proc.returncode, stdout)
    raise subprocess.CalledProcessError(proc.returncode, cmdline)

  # Return output
  return stdout.strip()

""" Prefix that is common to all parset keys, depending on the exact source. """
PARSET_PREFIX="ObsSW."

class Parset(dict):
  def predecessors(self):
    """ Extract the list of predecessor obs IDs from the given parset. """

    key = PARSET_PREFIX + "Observation.Scheduler.predecessors"
    strlist = PyParameterValue(str(self[key]), True).getStringVector()

    # Key contains "Lxxxxx" values, we want to have "xxxxx" only
    result = [int(filter(str.isdigit,x)) for x in strlist]

    return result

  def isObservation(self):
    return self[PARSET_PREFIX + "Observation.processType"] == "Observation"

  def isPipeline(self):
    return not self.isObservation()

  def processingCluster(self):
    return self[PARSET_PREFIX + "Observation.Cluster.ProcessingCluster.clusterName"] or "CEP2"

  def processingPartition(self):
    return self[PARSET_PREFIX + "Observation.Cluster.ProcessingCluster.clusterPartition"] or "cpu"

  def processingNumberOfCoresPerTask(self):
    return int(self[PARSET_PREFIX + "Observation.Cluster.ProcessingCluster.numberOfCoresPerTask"]) or "20"

  def processingNumberOfTasks(self):
    return int(self[PARSET_PREFIX + "Observation.Cluster.ProcessingCluster.numberOfTasks"]) or "24"

  @staticmethod
  def dockerRepository():
    return "nexus.cep4.control.lofar:18080"

  @staticmethod
  def defaultDockerImage():
    return runCommand("docker-template", "lofar-pipeline:${LOFAR_TAG}")

  def dockerImage(self):
    # Return the version set in the parset, and fall back to our own version.
    return (self[PARSET_PREFIX + "Observation.ObservationControl.PythonControl.softwareVersion"] or
            self.defaultDockerImage())

  def otdbId(self):
    return int(self[PARSET_PREFIX + "Observation.otdbID"])

class Slurm(object):
  def __init__(self, headnode="head01.cep4.control.lofar"):
    self.headnode = headnode

    # TODO: Derive SLURM partition name
    self.partition = "cpu"

  def _runCommand(self, cmdline, input=None):
    cmdline = "ssh %s %s" % (self.headnode, cmdline)
    return runCommand(cmdline, input)

  def submit(self, jobName, cmdline, sbatch_params=None):
    if sbatch_params is None:
      sbatch_params = []

    script = """#!/bin/bash -v
{cmdline}
""".format(cmdline = cmdline)

    stdout = self._runCommand("sbatch --partition=%s --job-name=%s %s" % (self.partition, jobName, " ".join(sbatch_params)), script)

    # Returns "Submitted batch job 3" -- extract ID
    match = re.search("Submitted batch job (\d+)", stdout)
    if not match:
      return None

    return match.group(1)

  def cancel(self, jobName):
    self._runCommand("scancel --jobname %s" % (jobName,))

  def isQueuedOrRunning(self, jobName):
    stdout = self._runCommand("sacct --starttime=2016-01-01 --noheader --parsable2 --format=jobid --name=%s --state=PENDING,CONFIGURING,RUNNING,RESIZING,COMPLETING,SUSPENDED" % (jobName,))

    return stdout != ""

class PipelineDependencies(object):
  class TaskNotFoundException(Exception):
    """ Raised when a task cannot be found in the RADB. """
    pass

  def __init__(self, ra_service_busname=DEFAULT_RAS_SERVICE_BUSNAME):
    self.rarpc = RARPC(busname=ra_service_busname)

  def open(self):
    self.rarpc.open()

  def close(self):
    self.rarpc.close()

  def __enter__(self):
    self.open()
    return self

  def __exit__(self, type, value, tb):
    self.close()

  def getState(self, otdb_id):
    """
      Return the status of a single `otdb_id'.
    """

    radb_task = self.rarpc.getTask(otdb_id=otdb_id)
    return radb_task["status"]

  def getPredecessorStates(self, otdb_id):
    """
      Return a dict of {"sasid":"status"} pairs of all the predecessors of `otdb_id'.
    """
    radb_task = self.rarpc.getTask(otdb_id=otdb_id)

    if radb_task is None:
      raise TaskNotFoundException("otdb_id %s not found in RADB" % (otdb_id,))

    predecessor_radb_ids = radb_task['predecessor_ids']
    predecessor_tasks = self.rarpc.getTasks(task_ids=predecessor_radb_ids)
    predecessor_states = {t["otdb_id"]: t["status"] for t in predecessor_tasks}

    logger.debug("getPredecessorStates(%s) = %s", otdb_id, predecessor_states)

    return predecessor_states

  def getSuccessorIds(self, otdb_id):
    """
      Return a list of all the successors of `otdb_id'.
    """
    radb_task = self.rarpc.getTask(otdb_id=otdb_id)

    if radb_task is None:
      raise TaskNotFoundException("otdb_id %s not found in RADB" % (otdb_id,))

    successor_radb_ids = radb_task['successor_ids']
    successor_tasks = self.rarpc.getTasks(task_ids=successor_ids) if successor_radb_ids else []
    successor_otdb_ids = [t["otdb_id"] for t in successor_tasks]

    logger.debug("getSuccessorIds(%s) = %s", otdb_id, successor_otdb_ids)

    return successor_otdb_ids

  def canStart(self, otdbId):
    """
      Return whether `otdbId' can start, according to the status of the predecessors
      and its own status.
    """

    try:
      myState = self.getState(otdbId)
      predecessorStates = self.getPredecessorStates(otdbId)
    except TaskNotFoundException, e:
      logger.error("canStart(%s): Error obtaining task states, not starting pipeline: %s", otdbId, e)
      return False

    logger.debug("canStart(%s)? state = %s, predecessors = %s", otdbId, myState, predecessorStates)

    return (
      myState == "scheduled" and
      all([x == "finished" for x in predecessorStates.values()])
    )

class PipelineControl(OTDBBusListener):
  def __init__(self, otdb_notification_busname=DEFAULT_OTDB_NOTIFICATION_BUSNAME, otdb_service_busname=DEFAULT_OTDB_SERVICE_BUSNAME, ra_service_busname=DEFAULT_RAS_SERVICE_BUSNAME, **kwargs):
    super(PipelineControl, self).__init__(busname=otdb_notification_busname, **kwargs)

    self.otdb_service_busname = otdb_service_busname
    self.otdbrpc = OTDBRPC(busname=otdb_service_busname)
    self.dependencies = PipelineDependencies(ra_service_busname=ra_service_busname)
    self.slurm = Slurm()

  def _setStatus(self, otdb_id, status):
    self.otdbrpc.taskSetStatus(otdb_id=otdb_id, new_status=status)

  def _getParset(self, otdbId):
    return Parset(self.otdbrpc.taskGetSpecification(otdb_id=otdbId)["specification"])

  def start_listening(self, **kwargs):
    self.otdbrpc.open()
    self.dependencies.open()

    super(PipelineControl, self).start_listening(**kwargs)

  def stop_listening(self, **kwargs):
    super(PipelineControl, self).stop_listening(**kwargs)

    self.dependencies.close()
    self.otdbrpc.close()

  @staticmethod
  def _shouldHandle(parset):
    if not parset.isPipeline():
      logger.info("Not processing tree: is not a pipeline")
      return False

    if parset.processingCluster() == "CEP2":
      logger.info("Not processing tree: is a CEP2 pipeline")
      return False

    return True

  @staticmethod
  def _jobName(otdbId):
    return str(otdbId)

  def _startPipeline(self, otdbId, parset):
    """
      Schedule "docker-runPipeline.sh", which will fetch the parset and run the pipeline within
      a SLURM job.
    """

    # Avoid race conditions by checking whether we haven't already sent the job
    # to SLURM. Our QUEUED status update may still be being processed.
    if self.slurm.isQueuedOrRunning(otdbId):
      logger.info("Pipeline %s is already queued or running in SLURM.", otdbId)
      return

    # Determine SLURM parameters
    sbatch_params = [
                     # Only run job if all nodes are ready
                     "--wait-all-nodes=1",

                     # Enforce the dependencies, instead of creating lingering jobs
                     "--kill-on-invalid-dep=yes",

                     # Restart job if a node fails
                     "--requeue",

                     # Maximum run time for job (31 days)
                     "--time=31-0",

                     # Lower priority to drop below inspection plots
                     "--nice=1000",
                   
                     "--partition=%s" % parset.processingPartition(),
                     "--nodes=%s" % parset.processingNumberOfTasks(),
                     "--cpus-per-task=%s" % parset.processingNumberOfCoresPerTask(),
                    
                     # Define better places to write the output
                     os.path.expandvars("--output=/data/log/pipeline-%s-%%j.log" % (otdbId,)),
                     ]

    def setStatus_cmdline(status):
      return (
      "ssh {myhostname} '"
        "source {lofarroot}/lofarinit.sh && "
        "setOTDBTreeStatus -o {obsid} -s {status} -B {status_bus}"
        "'"
      .format(
        myhostname = getfqdn(),
        lofarroot = os.environ.get("LOFARROOT", ""),
        obsid = otdbId,
        status = status,
        status_bus = self.otdb_service_busname,
      ))

    # Schedule runPipeline.sh
    logger.info("Scheduling SLURM job for runPipeline.sh")
    slurm_job_id = self.slurm.submit(self._jobName(otdbId),
      # notify that we're running
      "{setStatus_active}\n"
      # pull docker image from repository on all nodes
      "srun --nodelist=$SLURM_NODELIST --cpus-per-task=1 --job-name=docker-pull"
        " --kill-on-bad-exit=0 --wait=0"
        " docker pull {repository}/{image}\n"
      # put a local tag on the pulled image
      "srun --nodelist=$SLURM_NODELIST --cpus-per-task=1 --job-name=docker-tag"
        " --kill-on-bad-exit=0 --wait=0"
        " docker tag -f {repository}/{image} {image}\n"
      # call runPipeline.sh in the image on this node
      "docker run --rm"
        " --net=host"
        " -e LOFARENV={lofarenv}"
        " -u $UID"
        " -e USER=$USER"
        " -e HOME=$HOME"
        " -v $HOME/.ssh:$HOME/.ssh:ro"
        " -e SLURM_JOB_ID=$SLURM_JOB_ID"
        " -v /data:/data"
        " {image}"
        " runPipeline.sh -o {obsid} -c /opt/lofar/share/pipeline/pipeline.cfg.{cluster} -P /data/parsets || exit $?\n"

        # notify that we're tearing down
        "{setStatus_completing}\n"
        # wait for MoM to pick up feedback before we set finished status
        "sleep 60\n"
        # if we reached this point, the pipeline ran succesfully
        "{setStatus_finished}\n"
      .format(
        lofarenv = os.environ.get("LOFARENV", ""),
        obsid = otdbId,
        repository = parset.dockerRepository(),
        image = parset.dockerImage(),
        cluster = parset.processingCluster(),

        setStatus_active = setStatus_cmdline("active"),
        setStatus_completing = setStatus_cmdline("completing"),
        setStatus_finished = setStatus_cmdline("finished"),
      ),

      sbatch_params=sbatch_params
    )
    logger.info("Scheduled SLURM job %s", slurm_job_id)

    # Schedule pipelineAborted.sh
    logger.info("Scheduling SLURM job for pipelineAborted.sh")
    slurm_cancel_job_id = self.slurm.submit("%s-abort-trigger" % self._jobName(otdbId),
      "{setStatus_aborted}\n"
      .format(
        setStatus_aborted = setStatus_cmdline("aborted"),
      ),

      sbatch_params=[
        "--partition=%s" % parset.processingPartition(),
        "--cpus-per-task=1",
        "--ntasks=1",
        "--dependency=afternotok:%s" % slurm_job_id,
        "--kill-on-invalid-dep=yes",
        "--requeue",
        "--output=/data/log/abort-trigger-%s.log" % (otdbId,),
      ]
    )
    logger.info("Scheduled SLURM job %s", slurm_cancel_job_id)

    logger.info("Setting status to QUEUED")
    self._setStatus(otdbId, "queued")

  def _stopPipeline(self, otdbId):
    # Cancel corresponding SLURM job, but first the abort-trigger
    # to avoid setting ABORTED as a side effect.
    # to be cancelled as well.

    if not self.slurm.isQueuedOrRunning(otdbId):
      logger.info("_stopPipeline: Job %s not running")
      return

    def cancel(jobName):
        logger.info("Cancelling job %s", jobName)
        self.slurm.cancel(jobName)

    jobName = self._jobName(otdbId)
    cancel("%s-abort-trigger" % jobName)
    cancel(jobName)

  def _startSuccessors(self, otdbId):
    try:
      successor_ids = self.dependencies.getSuccessorIds(otdbId)
    except TaskNotFoundException, e:
      logger.error("_startSuccessors(%s): Error obtaining task successors, not starting them: %s", otdbId, e)
      return

    for s in successor_ids:
      parset = self._getParset(s)
      if not self._shouldHandle(parset):
        continue

      if self.dependencies.canStart(s):
        logger.info("***** START Otdb ID %s *****", otdbId)
        self._startPipeline(s, parset)
      else:
        logger.info("Job %s still cannot start yet.", otdbId)

  def onObservationScheduled(self, otdbId, modificationTime):
    parset = self._getParset(otdbId)
    if not self._shouldHandle(parset):
      return

    # Maybe the pipeline can start already
    if self.dependencies.canStart(otdbId):
      logger.info("***** START Otdb ID %s *****", otdbId)
      self._startPipeline(otdbId, parset)
    else:
      logger.info("Job %s was set to scheduled, but cannot start yet.", otdbId)

  def onObservationFinished(self, otdbId, modificationTime):
    """ Check if any successors can now start. """

    logger.info("Considering to start successors of %s", otdbId)

    self._startSuccessors(otdbId)

  def onObservationAborted(self, otdbId, modificationTime):
    parset = self._getParset(otdbId)
    if not self._shouldHandle(parset):
      return

    logger.info("***** STOP Otdb ID %s *****", otdbId)
    self._stopPipeline(otdbId)

  """
    More statusses we want to abort on.
  """
  onObservationDescribed    = onObservationAborted
  onObservationPrepared     = onObservationAborted
  onObservationApproved     = onObservationAborted
  onObservationPrescheduled = onObservationAborted
  onObservationConflict     = onObservationAborted
  onObservationHold         = onObservationAborted