From 182106d411572f4b3a6ea9279b8f02e0d79bea52 Mon Sep 17 00:00:00 2001 From: mancini <mancini@astron.nl> Date: Wed, 18 Sep 2019 12:15:33 +0200 Subject: [PATCH] implemented scheduling workflow --- lib/slurm_executor/slurm.py | 60 +++++++++++++++++++++++++++---------- 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/lib/slurm_executor/slurm.py b/lib/slurm_executor/slurm.py index 8bf1b29..8f981a9 100644 --- a/lib/slurm_executor/slurm.py +++ b/lib/slurm_executor/slurm.py @@ -1,35 +1,63 @@ from airflow.plugins_manager import AirflowPlugin from airflow.executors.base_executor import BaseExecutor from airflow.utils.state import State -from slurm_cli.slurm_control import get_jobs_status -import subprocess +from slurm_cli.slurm_control import get_jobs_status, run_job +import logging +import uuid + +logger = logging.getLogger(__name__) + + +def reindex_job_status_by_job_name(job_list): + return {job_status.job_name: job_status for job_status in job_list.values()} + # Will show up under airflow.executors.slurm.SlurmExecutor class SlurmExecutor(BaseExecutor): def __init__(self): super().__init__() - self.commands_to_run = [] + self.commands_to_check = {} def execute_async(self, key, command, queue=None, executor_config=None): print("execute async called") - self.commands_to_run.append((key, command,)) - - def trigger_tasks(self, open_slots): - print('trigger tasks called', open_slots) - super().trigger_tasks(open_slots) + unique_id = str(key[0]) + str(uuid.uuid1()) + queue = queue if queue != 'default' else None + logging.debug('submitting job %s on queue %s', key, queue) + run_job(cmd=command, queue=queue, task_name=unique_id) + self.commands_to_check[unique_id] = key - def sync(self): - for key, command in self.commands_to_run: - self.log.info("Executing command with key %s: %s", key, command) + def check_state(self): + ids = list(self.commands_to_check.keys()) + statuses = reindex_job_status_by_job_name(get_jobs_status(job_name=ids)) + logger.debug('statuses found are %s', statuses) + logger.debug('commands to check are %s', self.commands_to_check) - try: - subprocess.check_call(command, close_fds=True) + completed_jobs = [] + for unique_id, key in self.commands_to_check.items(): + status = statuses[unique_id] + if status.status_code == 'CD': self.change_state(key, State.SUCCESS) - except subprocess.CalledProcessError as e: + completed_jobs.append(unique_id) + elif status.status_code == 'F': self.change_state(key, State.FAILED) - self.log.error("Failed to execute task %s.", str(e)) + completed_jobs.append(unique_id) + elif status.status_code in ('CG', 'R'): + self.change_state(key, State.RUNNING) + elif status.status_code == 'PD': + self.change_state(key, State.SCHEDULED) + + for unique_id in completed_jobs: + if unique_id in self.commands_to_check: + self.commands_to_check.pop(unique_id) + else: + logger.error('id %s missing in %s', unique_id, self.commands_to_check) + + def trigger_tasks(self, open_slots): + self.check_state() + super().trigger_tasks(open_slots) - self.commands_to_run = [] + def sync(self): + pass def end(self): self.heartbeat() -- GitLab